【Python/scikit-learn】決定木を可視化する:Graphvizとpydotplusの使い方
Pythonで機械学習モデルの一つである決定木を可視化する方法を解説します。決定木は分類モデルの構造を理解しやすくしますが、実際にツリー構造として可視化することでさらに詳細な理解ができます。今回はGraphvizとpydotplusを使って決定木のモデルをPNG画像として出力します。
必要なライブラリのインストール
Pythonのコードを実行する前に、いくつかのライブラリをインストールする必要があります。
Graphvizのインストール
pydotplusが内部でGraphvizを使用するためまずGraphvizをインストールします。Linuxでのインストールコマンドは以下の通りです。
$ sudo apt-get install graphviz
Macではbrew、WindowsではGraphvizの公式サイトからインストーラをダウンロードしてインストールすることができます。
Python用ライブラリのインストール
Python用に必要なpydotplusと決定木の生成に使うscikit-learnもインストールします。
$ pip install pydotplus scikit-learn
サンプルコード
インストールが完了したら、次にコードを書いて実行します。今回は、Irisデータセットを用いて、scikit-learnのDecisionTreeClassifierで決定木を構築し、pydotplusを用いてモデルを画像として出力します。
from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier from pydotplus import graph_from_dot_data from sklearn.tree import export_graphviz iris = load_iris() X, y = iris.data, iris.target X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.3, random_state=123) tree = DecisionTreeClassifier(max_depth=3, random_state=123) tree.fit(X_train, y_train) dot_data = export_graphviz(tree, filled=True, rounded=True, class_names=['Setosa', 'Versicolor', 'Virginica'], feature_names=['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'], out_file=None) # 決定木のプロットを出力 graph = graph_from_dot_data(dot_data) graph.write_png('tree.png')
コード解説
- データセットのロードと分割:load_irisでIrisデータセットをロードし、train_test_splitを用いてデータを訓練データとテストデータに分割します。
- モデルの学習:DecisionTreeClassifierを使い、決定木モデルを生成します。max_depth=3で深さ3に制限し、過学習を防ぎます。
- 可視化用データの生成:export_graphviz関数を用いて、決定木モデルの可視化用のデータを生成します。filled=Trueはノードの色分けを行い、class_namesとfeature_namesで分類クラスと特徴量の名前を指定します。
- 画像ファイルの出力:graph_from_dot_data関数でdot_dataを読み込み、PNGファイルとして出力します。生成されるtree.pngを開くことで、決定木の構造を確認できます。
出力結果
上記プログラムを実行すると以下の画像が保存されます。
決定木の構造
この決定木は、Irisデータセットを基に、Petal Length(花弁の長さ)、Petal Width(花弁の幅)といった特徴量を使って、Setosa、Versicolor、Virginicaの3つのクラスに分類するモデルです。木の各ノードには、以下の情報が表示されています:
条件(例: Petal Length <= 2.45)
特定の特徴量に基づいてデータを分岐します。True(条件を満たす)とFalse(条件を満たさない)の2つの枝が続きます。
Gini指数(gini)
不純度の指標で、値が0に近いほど純度が高く、ノード内のサンプルがほぼ単一クラスに分類されています。
サンプル数(samples)
ノードに属するサンプルの数を示します。
クラスの分布(value)
各クラスのサンプル数を示します。例: [32, 0, 0]は32サンプルがSetosaであることを意味します。
予測クラス(class)
ノードで予測されるクラスを示しています。サンプルの大部分を占めるクラスが表示されます。
出力結果の解釈
ルートノード
一番上のノード(Petal Length <= 2.45)では、105個のサンプルを基に条件が設定され、左の枝はTrue、右の枝はFalseのサンプルに分岐します。このノードでは、gini=0.663で、3つのクラスが混在していますが、Versicolorが多い状態です。
左の枝(True)
Petal Length <= 2.45の条件を満たすサンプル(左)は、Setosaクラスのみから成り、giniが0のため純度が高い状態です。ここでSetosaと分類されます。
右の枝(False)
Petal Length > 2.45の条件を満たさないサンプル(右)は、さらにPetal Width <= 1.75の条件で分岐します。このノードもVersicolorが優勢ですが、まだ他のクラスも含まれています(gini=0.495)。
葉ノード
最終的に、下部のノードでサンプルがクラスごとにほぼ分類され、giniが0または小さい値(高純度)になっています。例えば、右側のclass = Virginicaのノードでは、29サンプル全てがVirginicaで、giniが0のため純粋な状態です。
まとめ
Graphvizとpydotplusを用いることで、Pythonでも簡単に決定木モデルを視覚化できることが分かりました。モデルの解釈や、特徴量の重要度の理解に役立つので、ぜひ活用してみてください!
スポンサーリンク