PythonでSVMをやってみる

Python機械学習」って、流行語過ぎるので、乗っかる。irisの分類をSVMでやるっていう、Rの例題でよく見るヤツ。

RでSVM

## irisデータの読み込み
data(iris)
## 半分をトレーニングセットにするために、ランダムに選択する
train_ids <- sample(nrow(iris), nrow(iris)*0.5)
## トレーニングセットの作成
iris.train <- iris[train_ids,]
## 残り半分をテストセットに
iris.test  <- iris[-train_ids,]
### SVM実行
library(kernlab)
iris.svm <- ksvm(Species~., data=iris.train)
svm.predict <- predict(iris.svm, iris.test)
### 結果表示
table(svm.predict, iris.test$Species)

ところで、irisのsetosa,versicolor,virginicaって、こんな感じらしい。正直、花の形で区別がつけられる気がしない。っていうか、そもそもirisがアヤメなのかショウブなのかカキツバタなのか、区別がついてない。

PythonSVM

Python SVMとかで検索するとscikit-learnを使えと皆さんおっしゃるので、そうする。scikit-learnにdatasetsとして、irisも含まれているようだ。

from sklearn import svm, datasets
iris = datasets.load_iris()

Rではsample()を使って、トレーニングセットとテストセットを分割したけれど、scikit-learnには、ソレ専用のメソッド(train_test_split)が用意されてた。Rと同じように、トレーニングセットとテストセットを半分ずつで分割するように、test_size=0.5をオプションとして指定する。

from sklearn.cross_validation import train_test_split
iris_data_train, iris_data_test, iris_target_train, iris_target_test = train_test_split(iris.data, iris.target, test_size=0.5)

識別器を作って、トレーニングを実施する。とりあえず、kernelその他オプションは全部デフォルトのお任せ仕様だと、こんな感じ。

iris_predict = svm.SVC().fit(iris_data_train, iris_target_train).predict(iris_data_test)

このiris_predict(推定結果)と、iris_target_test(正解)を比較すればいい。ちょうど、Rでtable(svm.predict, iris.test$Species) のように表示するのは、scikit-learnではconfusion_matrixというメソッドで定義されている。

from sklearn.metrics import confusion_matrix, accuracy_score
cm = confusion_matrix(iris_target_test, iris_predict)

confusion_matrix()は、
sklearn.metrics.confusion_matrix — scikit-learn 0.19.2 documentation
によると、第一引数が真値で、第二引数が識別器による判別値になっているようだ。

というワケで、全体では、こんな感じ。

from sklearn import svm, datasets
from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix
import numpy as np

iris = datasets.load_iris()
iris_data_train, iris_data_test, iris_target_train, iris_target_test = \ train_test_split(iris.data, iris.target, test_size=0.5)
iris_predict = svm.SVC().fit(iris_data_train, iris_target_train).predict(iris_data_test)
cm = confusion_matrix(iris_target_test, iris_predict)

print(cm)
confusion_matrixのheatmap表示

Confusion matrix — scikit-learn 0.19.2 documentationによると、heatmap表示もできる。まずheatmapを他のデータと比較可能にするために、confusion_matrixを正規化する。

こんな式は、思いつかない。

cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

そして、confusion_matrixを表示するための関数を定義する。

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
    ''' confusion_matrixをheatmap表示する関数
    Keyword arguments:
        cm -- confusion_matrix
        title -- 図の表題
        cmap -- 使用するカラーマップ
        
    '''
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(iris.target_names))
    plt.xticks(tick_marks, iris.target_names, rotation=45)
    plt.yticks(tick_marks, iris.target_names)
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

コレって、iris.target_namesとか埋め込まれているので、要注意だ。他のデータセットだったら、ちゃんとソレっぽく修正しないとダメだ。

というワケで、以下のようにして完成。

from sklearn.cross_validation import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline

iris = datasets.load_iris()
iris_data_train, iris_data_test, iris_target_train, iris_target_test = train_test_split(iris.data, iris.target, test_size=0.5)

iris_predict = svm.SVC().fit(iris_data_train, iris_target_train).predict(iris_data_test)

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.title(title)
        plt.colorbar()
        tick_marks = np.arange(len(iris.target_names))
        plt.xticks(tick_marks, iris.target_names, rotation=45)
        plt.yticks(tick_marks, iris.target_names)
        plt.tight_layout()
        plt.ylabel('True label')
        plt.xlabel('Predicted label')

cm = confusion_matrix(iris_target_test, iris_predict)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print('Normalized confusion matrix')
print(cm_normalized)
plt.figure()
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')
plt.show()

さて、irisっていうのは、アヤメなのかショウブなのかカキツバタなのか、その辺は謎のままですね。