多クラス分類(Multi-Class Classification)

ロジスティック回帰では、データを2つのクラスに分類する方法を学びました。しかし、実社会ではサンプルが3つ以上のクラスに分けられる問題も多くあります。

ここからのレクチャーでは、こうした問題に対応出来る、多クラス分類の方法を学びます。

1.) Iris(アヤメ)データの紹介
2.) ロジスティック回帰を使った多クラス分類の紹介
3.) データの準備
4.) データの可視化
5.) scikit-learnを使った多クラス分類
6.) K近傍法(K Nearest Neighbors)の紹介
7.) scikit-learnを使ったK近傍法
8.) まとめ

Step 1: Iris(アヤメ)のデータ

機械学習のサンプルデータとして非常によく使われるデータセットがあります。 それが、Iris(アヤメ)のデータ です。

このデータセットは、イギリスの統計学者ロナルド・フィッシャーによって、1936年に紹介されました。

3種類のアヤメについて、それぞれ50サンプルのデータがあります。それぞれ、Iris setosa、Iris virginica、Iris versicolorという名前がついています。全部で150のデータになっています。4つの特徴量が計測されていて、これが説明変数になります。4つのデータは、花びら(petals)と萼片(sepals)の長さと幅です。

花びら(petals)と萼片(sepals)

In [6]:
# Iris Setosa
from IPython.display import Image
url = 'http://upload.wikimedia.org/wikipedia/commons/5/56/Kosaciec_szczecinkowaty_Iris_setosa.jpg'
Image(url,width=300, height=300)
Out[6]:
In [7]:
# Iris Versicolor
from IPython.display import Image
url = 'http://upload.wikimedia.org/wikipedia/commons/4/41/Iris_versicolor_3.jpg'
Image(url,width=300, height=300)
Out[7]:
In [8]:
# Iris Virginica
from IPython.display import Image
url = 'http://upload.wikimedia.org/wikipedia/commons/9/9f/Iris_virginica.jpg'
Image(url,width=300, height=300)
Out[8]:

データの概要をまとめておきましょう。

3つのクラスがあります。

Iris-setosa (n=50)
Iris-versicolor (n=50)
Iris-virginica (n=50)

説明変数は4つです。

萼片(sepal)の長さ(cm)
萼片(sepal)の幅(cm)
花びら(petal)の長さ(cm)
花びら(petal)の幅(cm)

Step 2: 多クラス分類の紹介

最も基本的な多クラス分類の考え方は、「1対その他(one vs all, one vs rest)」というものです。 複数のクラスを、「注目するクラス」と「その他のすべて」に分けて、この2クラスについて、ロジスティック回帰の手法を使います。

どのクラスに分類されるかは、回帰の結果もっとも大きな値が割り振られたクラスなります。

後半では、K近傍法という別の方法を紹介します。

In [3]:
# 英語になりますが、Andrew Ng先生の動画は、イメージを掴むのによいかもしれません。
from IPython.display import YouTubeVideo
YouTubeVideo("Zj403m-fjqg")
Out[3]:

Step 3: データの準備

In [6]:
import numpy as np
import pandas as pd
from pandas import Series,DataFrame

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')

%matplotlib inline

サンプルデータを読み込みます。scikit-learnに付属しています。

In [7]:
from sklearn import linear_model
from sklearn.datasets import load_iris

# データの読み込み
iris = load_iris()

# 説明変数をXに
X = iris.data

#目的変数をYに
Y = iris.target

# 説明文です。
print(iris.DESCR)
Iris Plants Database

Notes
-----
Data Set Characteristics:
    :Number of Instances: 150 (50 in each of three classes)
    :Number of Attributes: 4 numeric, predictive attributes and the class
    :Attribute Information:
        - sepal length in cm
        - sepal width in cm
        - petal length in cm
        - petal width in cm
        - class:
                - Iris-Setosa
                - Iris-Versicolour
                - Iris-Virginica
    :Summary Statistics:
    ============== ==== ==== ======= ===== ====================
                    Min  Max   Mean    SD   Class Correlation
    ============== ==== ==== ======= ===== ====================
    sepal length:   4.3  7.9   5.84   0.83    0.7826
    sepal width:    2.0  4.4   3.05   0.43   -0.4194
    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
    petal width:    0.1  2.5   1.20  0.76     0.9565  (high!)
    ============== ==== ==== ======= ===== ====================
    :Missing Attribute Values: None
    :Class Distribution: 33.3% for each of 3 classes.
    :Creator: R.A. Fisher
    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
    :Date: July, 1988

This is a copy of UCI ML iris datasets.
http://archive.ics.uci.edu/ml/datasets/Iris

The famous Iris database, first used by Sir R.A Fisher

This is perhaps the best known database to be found in the
pattern recognition literature.  Fisher's paper is a classic in the field and
is referenced frequently to this day.  (See Duda & Hart, for example.)  The
data set contains 3 classes of 50 instances each, where each class refers to a
type of iris plant.  One class is linearly separable from the other 2; the
latter are NOT linearly separable from each other.

References
----------
   - Fisher,R.A. "The use of multiple measurements in taxonomic problems"
     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
     Mathematical Statistics" (John Wiley, NY, 1950).
   - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.
     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.
   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
     Structure and Classification Rule for Recognition in Partially Exposed
     Environments".  IEEE Transactions on Pattern Analysis and Machine
     Intelligence, Vol. PAMI-2, No. 1, 67-71.
   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions
     on Information Theory, May 1972, 431-433.
   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II
     conceptual clustering system finds 3 classes in the data.
   - Many, many more ...

In [8]:
X
Out[8]:
array([[ 5.1,  3.5,  1.4,  0.2],
       [ 4.9,  3. ,  1.4,  0.2],
       [ 4.7,  3.2,  1.3,  0.2],
       [ 4.6,  3.1,  1.5,  0.2],
       [ 5. ,  3.6,  1.4,  0.2],
       [ 5.4,  3.9,  1.7,  0.4],
       [ 4.6,  3.4,  1.4,  0.3],
       [ 5. ,  3.4,  1.5,  0.2],
       [ 4.4,  2.9,  1.4,  0.2],
       [ 4.9,  3.1,  1.5,  0.1],
       [ 5.4,  3.7,  1.5,  0.2],
       [ 4.8,  3.4,  1.6,  0.2],
       [ 4.8,  3. ,  1.4,  0.1],
       [ 4.3,  3. ,  1.1,  0.1],
       [ 5.8,  4. ,  1.2,  0.2],
       [ 5.7,  4.4,  1.5,  0.4],
       [ 5.4,  3.9,  1.3,  0.4],
       [ 5.1,  3.5,  1.4,  0.3],
       [ 5.7,  3.8,  1.7,  0.3],
       [ 5.1,  3.8,  1.5,  0.3],
       [ 5.4,  3.4,  1.7,  0.2],
       [ 5.1,  3.7,  1.5,  0.4],
       [ 4.6,  3.6,  1. ,  0.2],
       [ 5.1,  3.3,  1.7,  0.5],
       [ 4.8,  3.4,  1.9,  0.2],
       [ 5. ,  3. ,  1.6,  0.2],
       [ 5. ,  3.4,  1.6,  0.4],
       [ 5.2,  3.5,  1.5,  0.2],
       [ 5.2,  3.4,  1.4,  0.2],
       [ 4.7,  3.2,  1.6,  0.2],
       [ 4.8,  3.1,  1.6,  0.2],
       [ 5.4,  3.4,  1.5,  0.4],
       [ 5.2,  4.1,  1.5,  0.1],
       [ 5.5,  4.2,  1.4,  0.2],
       [ 4.9,  3.1,  1.5,  0.1],
       [ 5. ,  3.2,  1.2,  0.2],
       [ 5.5,  3.5,  1.3,  0.2],
       [ 4.9,  3.1,  1.5,  0.1],
       [ 4.4,  3. ,  1.3,  0.2],
       [ 5.1,  3.4,  1.5,  0.2],
       [ 5. ,  3.5,  1.3,  0.3],
       [ 4.5,  2.3,  1.3,  0.3],
       [ 4.4,  3.2,  1.3,  0.2],
       [ 5. ,  3.5,  1.6,  0.6],
       [ 5.1,  3.8,  1.9,  0.4],
       [ 4.8,  3. ,  1.4,  0.3],
       [ 5.1,  3.8,  1.6,  0.2],
       [ 4.6,  3.2,  1.4,  0.2],
       [ 5.3,  3.7,  1.5,  0.2],
       [ 5. ,  3.3,  1.4,  0.2],
       [ 7. ,  3.2,  4.7,  1.4],
       [ 6.4,  3.2,  4.5,  1.5],
       [ 6.9,  3.1,  4.9,  1.5],
       [ 5.5,  2.3,  4. ,  1.3],
       [ 6.5,  2.8,  4.6,  1.5],
       [ 5.7,  2.8,  4.5,  1.3],
       [ 6.3,  3.3,  4.7,  1.6],
       [ 4.9,  2.4,  3.3,  1. ],
       [ 6.6,  2.9,  4.6,  1.3],
       [ 5.2,  2.7,  3.9,  1.4],
       [ 5. ,  2. ,  3.5,  1. ],
       [ 5.9,  3. ,  4.2,  1.5],
       [ 6. ,  2.2,  4. ,  1. ],
       [ 6.1,  2.9,  4.7,  1.4],
       [ 5.6,  2.9,  3.6,  1.3],
       [ 6.7,  3.1,  4.4,  1.4],
       [ 5.6,  3. ,  4.5,  1.5],
       [ 5.8,  2.7,  4.1,  1. ],
       [ 6.2,  2.2,  4.5,  1.5],
       [ 5.6,  2.5,  3.9,  1.1],
       [ 5.9,  3.2,  4.8,  1.8],
       [ 6.1,  2.8,  4. ,  1.3],
       [ 6.3,  2.5,  4.9,  1.5],
       [ 6.1,  2.8,  4.7,  1.2],
       [ 6.4,  2.9,  4.3,  1.3],
       [ 6.6,  3. ,  4.4,  1.4],
       [ 6.8,  2.8,  4.8,  1.4],
       [ 6.7,  3. ,  5. ,  1.7],
       [ 6. ,  2.9,  4.5,  1.5],
       [ 5.7,  2.6,  3.5,  1. ],
       [ 5.5,  2.4,  3.8,  1.1],
       [ 5.5,  2.4,  3.7,  1. ],
       [ 5.8,  2.7,  3.9,  1.2],
       [ 6. ,  2.7,  5.1,  1.6],
       [ 5.4,  3. ,  4.5,  1.5],
       [ 6. ,  3.4,  4.5,  1.6],
       [ 6.7,  3.1,  4.7,  1.5],
       [ 6.3,  2.3,  4.4,  1.3],
       [ 5.6,  3. ,  4.1,  1.3],
       [ 5.5,  2.5,  4. ,  1.3],
       [ 5.5,  2.6,  4.4,  1.2],
       [ 6.1,  3. ,  4.6,  1.4],
       [ 5.8,  2.6,  4. ,  1.2],
       [ 5. ,  2.3,  3.3,  1. ],
       [ 5.6,  2.7,  4.2,  1.3],
       [ 5.7,  3. ,  4.2,  1.2],
       [ 5.7,  2.9,  4.2,  1.3],
       [ 6.2,  2.9,  4.3,  1.3],
       [ 5.1,  2.5,  3. ,  1.1],
       [ 5.7,  2.8,  4.1,  1.3],
       [ 6.3,  3.3,  6. ,  2.5],
       [ 5.8,  2.7,  5.1,  1.9],
       [ 7.1,  3. ,  5.9,  2.1],
       [ 6.3,  2.9,  5.6,  1.8],
       [ 6.5,  3. ,  5.8,  2.2],
       [ 7.6,  3. ,  6.6,  2.1],
       [ 4.9,  2.5,  4.5,  1.7],
       [ 7.3,  2.9,  6.3,  1.8],
       [ 6.7,  2.5,  5.8,  1.8],
       [ 7.2,  3.6,  6.1,  2.5],
       [ 6.5,  3.2,  5.1,  2. ],
       [ 6.4,  2.7,  5.3,  1.9],
       [ 6.8,  3. ,  5.5,  2.1],
       [ 5.7,  2.5,  5. ,  2. ],
       [ 5.8,  2.8,  5.1,  2.4],
       [ 6.4,  3.2,  5.3,  2.3],
       [ 6.5,  3. ,  5.5,  1.8],
       [ 7.7,  3.8,  6.7,  2.2],
       [ 7.7,  2.6,  6.9,  2.3],
       [ 6. ,  2.2,  5. ,  1.5],
       [ 6.9,  3.2,  5.7,  2.3],
       [ 5.6,  2.8,  4.9,  2. ],
       [ 7.7,  2.8,  6.7,  2. ],
       [ 6.3,  2.7,  4.9,  1.8],
       [ 6.7,  3.3,  5.7,  2.1],
       [ 7.2,  3.2,  6. ,  1.8],
       [ 6.2,  2.8,  4.8,  1.8],
       [ 6.1,  3. ,  4.9,  1.8],
       [ 6.4,  2.8,  5.6,  2.1],
       [ 7.2,  3. ,  5.8,  1.6],
       [ 7.4,  2.8,  6.1,  1.9],
       [ 7.9,  3.8,  6.4,  2. ],
       [ 6.4,  2.8,  5.6,  2.2],
       [ 6.3,  2.8,  5.1,  1.5],
       [ 6.1,  2.6,  5.6,  1.4],
       [ 7.7,  3. ,  6.1,  2.3],
       [ 6.3,  3.4,  5.6,  2.4],
       [ 6.4,  3.1,  5.5,  1.8],
       [ 6. ,  3. ,  4.8,  1.8],
       [ 6.9,  3.1,  5.4,  2.1],
       [ 6.7,  3.1,  5.6,  2.4],
       [ 6.9,  3.1,  5.1,  2.3],
       [ 5.8,  2.7,  5.1,  1.9],
       [ 6.8,  3.2,  5.9,  2.3],
       [ 6.7,  3.3,  5.7,  2.5],
       [ 6.7,  3. ,  5.2,  2.3],
       [ 6.3,  2.5,  5. ,  1.9],
       [ 6.5,  3. ,  5.2,  2. ],
       [ 6.2,  3.4,  5.4,  2.3],
       [ 5.9,  3. ,  5.1,  1.8]])
In [9]:
Y
Out[9]:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

pandas.DataFrameにしておきましょう。

In [10]:
iris_data = DataFrame(X,columns=['Sepal Length','Sepal Width','Petal Length','Petal Width'])

iris_target = DataFrame(Y,columns=['Species'])

クラスが0,1,2の数字なので、文字列の名前を付けておきましょう。

In [11]:
def flower(num):
    ''' 数字を受け取って、対応する名前を返します。'''
    if num == 0:
        return 'Setosa'
    elif num == 1:
        return 'Veriscolour'
    else:
        return 'Virginica'

iris_target['Species'] = iris_target['Species'].apply(flower)
In [12]:
iris_target.head()
Out[12]:
Species
0 Setosa
1 Setosa
2 Setosa
3 Setosa
4 Setosa
In [13]:
# まとめておきましょう。
iris = pd.concat([iris_data,iris_target],axis=1)

iris.head()
Out[13]:
Sepal Length Sepal Width Petal Length Petal Width Species
0 5.1 3.5 1.4 0.2 Setosa
1 4.9 3.0 1.4 0.2 Setosa
2 4.7 3.2 1.3 0.2 Setosa
3 4.6 3.1 1.5 0.2 Setosa
4 5.0 3.6 1.4 0.2 Setosa

Step 4: データの可視化

pairplotを使えば、簡単に全体像を把握できます。

In [14]:
sns.pairplot(iris,hue='Species',size=2)
Out[14]:
<seaborn.axisgrid.PairGrid at 0x10bf1b160>

全体像がよくわかります。

特徴量でアヤメの種類を予測できそうです。特に、Setosaは最も特徴的な花のようです。

次に、花びらの長さに注目して、ヒストグラムを描いてみましょう。

In [15]:
plt.figure(figsize=(12,4))
sns.countplot('Petal Length',data=iris,hue='Species')
Out[15]:
<matplotlib.axes._subplots.AxesSubplot at 0x10d83bb38>

その他の特徴量についても、可視化してみてください。

1対その他の方法論で、ロジスティック回帰を使った多クラス分類の挑戦してみましょう。


Step 5: scikit-learnを使った多クラス分類

すでに説明変数Xと、目的変数Yが用意されているので、これを使って解析を進めて行きます。

データを学習用とテストように分けておきましょう。全体の40%がテストデータになるようにします。

In [16]:
from sklearn.linear_model import LogisticRegression
from sklearn.cross_validation import train_test_split

logreg = LogisticRegression()

# データを分割します。テストが全体の40%になるようにします。
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.4,random_state=3)

# データを使って学習します。
logreg.fit(X_train, Y_train)
Out[16]:
LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr',
          penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
          verbose=0)

テストデータを使って、モデルの精度を確認してみましょう。

In [17]:
# 精度を計算するのに便利なツールです。
from sklearn import metrics

# テストデータを予測します。
Y_pred = logreg.predict(X_test)

# 精度を計算してみましょう。
print(metrics.accuracy_score(Y_test,Y_pred))
0.933333333333

93%と高い精度が得られました。random_stateを指定すれば、再現性がある結果を得ることができます。

次に、K近傍法に進んで行きましょう。

Step 6: K近傍法

K近傍法は英語で、k-nearest neighborなので、kNNと略されることもありますが、極めてシンプルな方法論です。

学習のプロセスは、単純に学習データを保持するだけです。新しいサンプルが、どちらのクラスに属するかを予測するときにだけ、すこし計算をします。

与えられたサンプルのk個の隣接する学習データのクラスを使って、このサンプルのクラスを予測します。 イメージをうまく説明した図がこちら。

In [18]:
Image('http://bdewilde.github.io/assets/images/2012-10-26-knn-concept.png',width=400, height=300)
Out[18]:

★が新しいサンプルです。これを中心に、既存のサンプルのクラスを見ていきます。K=3ではAが1つ、Bが2つなので、分類されるクラスは、Bです。K=6とすると、A4つ、B2つなので、Aと判別されます。

Kの選び方によっては、同数になってしまうことがあるので注意が必要です。(アルゴリズムの中で、これを解決する方法論が実装されていることが多いです。)

Step 7: scikit-learnを使ったkNN

Irisデータを使って、実際のPythonコードを見ていきましょう。

In [19]:
# K近傍法
from sklearn.neighbors import KNeighborsClassifier

# k=6からはじめてみます。
# インスタンスを作ります。
knn = KNeighborsClassifier(n_neighbors = 6)

# 学習します。
knn.fit(X_train,Y_train)

# テストデータを予測します。
Y_pred = knn.predict(X_test)

# 精度を調べてみましょう。
print(metrics.accuracy_score(Y_test,Y_pred))
0.95

95%の精度が得られました。k=1にするとどうなるでしょうか?もっとも近いサンプルのクラスを予測値とする方法です。

In [20]:
# 今度は、1です。
knn = KNeighborsClassifier(n_neighbors = 1)

knn.fit(X_train,Y_train)

Y_pred = knn.predict(X_test)

print(metrics.accuracy_score(Y_test,Y_pred))
0.966666666667

kを変化させるとどうなるでしょうか?

In [21]:
# kを変化させてみましょう。
k_range = range(1, 90)

accuracy = []

# 先ほどの計算を繰り返して見ましょう。
for k in k_range:
    knn = KNeighborsClassifier(n_neighbors=k)
    knn.fit(X_train, Y_train)
    Y_pred = knn.predict(X_test)
    accuracy.append(metrics.accuracy_score(Y_test, Y_pred))

結果をプロットします。

In [23]:
plt.plot(k_range, accuracy)
plt.xlabel('K for kNN')
plt.ylabel('Testing Accuracy')
Out[23]:
<matplotlib.text.Text at 0x10e5a8f98>

学習用のデータとテスト用のデータを分けるやり方を変えると、これらの結果がどうなるか、検討してみるのも面白いかもしれません。

Step 8: まとめ

ロジスティック回帰とk近傍法を使った多クラス分類について学びました。

英語になりますが、参考資料をいくつかあげておきます。

1.) Wikipedia on Multiclass Classification

2.) MIT Lecture Slides on MultiClass Classification

3.) Sci Kit Learn Documentation

4.) DataRobot on Classification Techniques