【Python/scikit-learn】サポートベクタマシン(SVM)を使ったシンプルな分類モデルの作成と可視化
サポートベクタマシン(SVM)は、分類や回帰、異常検知などに使用される強力な機械学習アルゴリズムです。SVMは、データを異なるクラスに分類するために、データ間の「マージン」を最大化する境界線(ハイパープレーン)を見つけます。このアルゴリズムは、非線形データに対してもカーネルトリックと呼ばれる手法を使ってデータを高次元にマッピングすることで効果的に動作します。
Scikit-learnではSVMを簡単に実装できるライブラリを提供しています。以下はScikit-learnでSVMを使用して2クラス分類を行う例です。Scikit-learnに内蔵されている “`make_blobs“` 関数を使ってデータを生成し分類します。
ソースコード
# 必要なライブラリのインポート
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
# 2クラスのデータセットを生成
X, y = datasets.make_blobs(n_samples=100, centers=2, random_state=6)
# 訓練データとテストデータに分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# SVMモデルの作成(線形カーネル)
svm_model = SVC(kernel='linear')
# モデルの訓練
svm_model.fit(X_train, y_train)
# テストデータを使って予測
y_pred = svm_model.predict(X_test)
# モデルの精度を計算
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy * 100:.2f}%")
# 境界線を描画
def plot_svm_decision_boundary(model, X, y):
# 軸の範囲を取得
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),
np.arange(y_min, y_max, 0.01))
# 各点のクラスを予測
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
# 境界線の描画
plt.contourf(xx, yy, Z, alpha=0.8)
plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', marker='o')
plt.show()
# 訓練データとSVMによる境界線を描画
plot_svm_decision_boundary(svm_model, X_test, y_test)
ソースコードの解説
- make_blobs 関数で2クラスのデータを生成しています。このデータを、2次元平面上で散布された2つのクラスに分けます。
- train_test_splitを使用して、データセットを訓練用とテスト用に分割します。
- SVCクラスを使用してSVMモデルを作成します。この例では kernel=’linear’ として線形SVMを使用していますが、 kernel=’rbf’(RBFカーネル)などの非線形カーネルも使用可能です。
- モデルを訓練し、テストデータで予測を行います。
- 予測結果の精度を accuracy_scoreで計算します。
- 最後に、SVMが生成したクラス境界をプロットして、データとともに視覚化しています。
make_blobs 関数については以下の記事に基本的な使い方を紹介しています。
出力結果の解説
出力結果はこちら
背景の色分け
- 背景の紫色の部分は、一方のクラスに分類された領域を示しています。
- 背景の黄色の部分は、もう一方のクラスに分類された領域を表しています。
- この2色の境界線は、SVMがデータを分けるために見つけたハイパープレーン(最適な分類境界線)です。この線は、データ間のマージンを最大化するように調整されています。
データ点
- 紫色の丸は一方のクラスに属するデータポイントを示しています。
- 黄色の丸は、もう一方のクラスに属するデータポイントを表しています。
- データポイントの散らばり方から、2つのクラスがある程度はっきりと分かれていることが確認できますが、完全に線形分離できるわけではないことが見て取れます。
分類境界線
- クラス間の境界線(ハイパープレーン)は、データの分布をもとにSVMが計算したもので、この線を基準にどちらのクラスに属するかが決まります。この場合、線はやや傾いており、クラスの間に大きなマージンを確保しています。
SVMは、線形および非線形の分類に強力なツールです。Scikit-learnを使用すると、非常に少ないコードでSVMを実装してデータを分類することが可能です。カーネルの選択により、さまざまな形状のデータに対して柔軟に適応できるため、実際の問題に対しても非常に効果的です。
スポンサーリンク
