読者です 読者をやめる 読者になる 読者になる

六本木で働くデータサイエンティストのブログ

元祖「銀座で働くデータサイエンティスト」です / 道玄坂→銀座→東京→六本木

パッケージユーザーのための機械学習(1):決定木

R Python 機械学習 サンプルデータで試す機械学習シリーズ

(※はてなフォトライフの不具合で正しくない順番で画像が表示されている可能性があります)


だいぶ前に「糞コードで頑張る機械学習シリーズ」と言うのを始めようとしたんですが、パーセプトロンPythonで実装した次にMatlabで書いたSMO-SVMコードをPythonに移植しようと思っているうちに時間が過ぎ。。。


あまつさえ転職したら、今の現場にはライブラリ皆無でほぼ全ての機械学習のコードをPython / Java / C++のどれでも書ける化け物^H^H「教授」がいてそんなこと僕がやる必要性は完全になくなってしまったのでした(笑)。


ということで、カテゴリ名はそのまま*1ながら方向性を変えて、僕のようなパッケージやライブラリに依存するユーザーが機械学習を実践する際に原理上のどのような点に気を付けて実装・実践すべきかを、僕自身の備忘録のためにだらだらと書いていくシリーズにしてみようと思いますー。厳密な説明は全て参考文献に丸投げするというやる気のなさ全開で行くので、悪しからず。


基本Rでやります。ただし、RだとやりにくいとかどうしてもPythonでやりたいというケースがもしあればPythonでもやるつもりです。なので両方のタグを打っておきますが、もしかしたらRだけで終わるかもなのでそこのところご了承あれ。


そうそう、Pythonコード頑張って書くシリーズは地味~~~~~に継続中です*2。進捗があったらまた書きます。ま、これまた備忘録なので適当です(笑)。


参考文献は基本以下の4点の組み合わせです。どんなテーマでも、どれかを読めば厳密な説明に行き当たるものと勝手に期待してます(笑)。一応、脇に全部置きながら書いていくつもりです。


はじめてのパターン認識

はじめてのパターン認識

通称「はじパタ」「ピンクの薄い本」。勉強会が大人気になるほどの定番テキストです*3。2010年代の現在使われている機械学習アルゴリズムの大半をカバーしています。


イラストで学ぶ 機械学習 最小二乗法による識別モデル学習を中心に (KS情報科学専門書)

イラストで学ぶ 機械学習 最小二乗法による識別モデル学習を中心に (KS情報科学専門書)

通称「イラスト機械学習」「青本」「杉山本」。説明のシンプルさでは一番なんですが、コード例がMatlabだけなのでライセンス持ってる僕は嬉しい普通の人には辛いかも*4


パターン認識と機械学習 上

パターン認識と機械学習 上

パターン認識と機械学習 下 (ベイズ理論による統計的予測)

パターン認識と機械学習 下 (ベイズ理論による統計的予測)


通称「PRML」「黄色い本」。言わずと知れた機械学習テキストの最高峰。豪華翻訳陣によって編まれていて、かなりマニアックなところまで網羅しているので辞書代わりに常備しておきたい2冊。ただし何もかもがベイズ理論ベースなので、ベイズ理論に対する本質的理解が要求される点に要注意。


ということで、ゆるーく始めてみようと思います。1回目は何でやねん?と思われそうですが、「決定木」からいきます。


まずRでどんなものかを見てみる


何でこんなデータセットでやるんじゃお前はーと怒られそうですが、適当に自前で作ってみました。プロットしてみれば分かりますが、これはいわゆるXOR(線形分離不可能)パターンです*5


今回使うRパッケージは定番の{mvpart}です。なおCRANには他にも{tree}や{C50}などの決定木向けパッケージがあり、それぞれに特徴があるので試してみるのも良いでしょう。


ちなみにデータ自体はGitHubのこちらに置いてあります。d1という名前でインポートしてください。

# シンプルなケース
> require("mvpart")
Loading required package: mvpart

> d1.rp<-rpart(label~.,d1)
# rpart()で決定木。目的変数と説明変数の組み合わせはformula式で与えられる
> plot(d1.rp,uniform=T,margin=0.2)
> text(d1.rp,uniform=T,use.n=T,all=F)
# 最新バージョンではtext()にもuniform引数が必要なので要注意


f:id:TJO:20131120155606p:plain


こんな感じで、サックリプロットできると思います。この「分岐条件が明示されて枝分かれしていった先に分類結果が表示される」スタイルが分かりやすいということで、決定木は色々な現場で愛用されているようです。


ちなみに前の現場では、事業サイドの人向けにRベースでデータだけ突っ込むと勝手に決定木が表示されるようなシステムを作ったり、何もわからなくてもSPSSでデータを突っ込めば決定木が描けるような簡易マニュアルが作成されたりしていたものでした。


そもそも決定木って何?


誤解を恐れずに超絶簡略化して説明すると「教師データをもとにある数理的基準にのっとってIF-ELIF-ELIF...-ELSE THEN文の形で表せるように分岐を作って分類するモデル」です。IF分岐だらけになるため、他の機械学習と違ってカクカクと不連続なパターン分けが階層的に作られていくので、これを可視化すると「樹木(ツリー)」状になります。「樹木モデル」と呼ばれる所以ですね。


決定木のアルゴリズムについては、例えば「はじパタ」pp.176-183あたりを読んでもらうのが手っ取り早いですが、直観的な説明は例えば『Rによるデータサイエンス - データ解析の基礎から最新手法まで』pp.231-233を読んだ方が分かりやすいと思います。


ポイントになるのは「不純度」(purity)という概念です。「個々の分岐に来るごとにできるだけその1回で『集団としてのバラバラさ=不純度』が大幅に少なくなるようにする」ことで、分岐が進めば進むほどサンプルが純度が高く綺麗に分類されていくように、木の枝を茂らせていく(条件を変えながら分岐を増やしていく)というのがそのアルゴリズムの基本です。


つまり、ある特徴量の軸に沿ってソートした上*6「不純度の減少幅が最も大きくなる」ような説明変数の条件を選び出して、その条件に従ってサンプルを2つ(もしくはそれ以上)に分けていく。そして分かれた先で、それをまた繰り返す。これを繰り返しながら、全ての分岐ノードが打ち切り基準に達するまで繰り返す。。。これが決定木の具体的なアルゴリズムです。ちなみに不純度の指標として、以下のものが知られています。rpart(){mvpart}はデフォルトではジニ係数を選ぶようになっています。

分類クラス数をKとし、分岐ノードtでi番目の分類クラスのデータに選ばれる確率をP(C_i|t)として、


ジニ係数(Gini index)
GI(t)={\displaystyle \sum_{j \neq i}} P(C_i|t)P(C_j|t)={\displaystyle \sum^K_{i=1}} P(C_i|t)(1-P(C_i|t))
=1-{\displaystyle \sum^K_{i=1}} P^2(C_i|t)


交差エントロピー
E(t)=-{\displaystyle \sum^K_{i=1}} P(C_i|t)logP(C_i|t)


以上のポイントを絵にしてみると、こんな感じになります。これはまさに決定木ですよね?


f:id:TJO:20131121163722p:plain


樹木モデルは必ず終点を決める必要があり、例えば「分岐後に残ったサンプル数が1になったとき」とか色々あります。Rで使われている関数でもこの辺は明示的に引数で指定することができます。


ちなみに、分類に使えば「決定木」ですが回帰に使えば「回帰木」と呼ばれます。演算ルールは殆ど変わりません。


決定境界を描いてみる


ところで僕個人としては、機械学習のアルゴリズムや原理的な側面を最も手っ取り早く理解するには、2次元データに対して決定境界(Decision boundary)もしくは分離超平面(Hyperplane)を描くのが良いと思ってます。


色々なアルゴリズムで試してみると分かりますが、何だかんだで分離超平面こそがそれぞれのアルゴリズムの特徴をよく表している気がするんですよね。


で、決定木の決定境界を描くだけなら{mvpart}ではなく{tree}に入っているpartition.tree()を使った方が手っ取り早いです。ただ、それでは汎用性に乏しいのでちょっと色々手を入れて上記の例でも使ったrpart(){mvpart]でやってみることにします。


決定木は、その条件分岐が不等式やカテゴリ変数選択の形で表されることからも分かるように、決定境界は原則として変数の軸に対して平行(他の軸に対して垂直)になるように描画されます。

# シンプルなケース

> d1.rp<-rpart(label~.,d1)

> plot(d1[1:50,-3],col="blue",pch=19,cex=3,xlim=c(-3,3),ylim=c(-3,3))
> points(d1[51:100,-3],col="red",pch=19,cex=3)
# 一旦プロットしてみた

> px<-seq(-3,3,0.03)
> py<-seq(-3,3,0.03)
> pgrid<-expand.grid(px,py)
> names(pgrid)<-c("x","y")
# コンタープロットを使うので、メッシュグリッドを切る

> out1<-predict(d1.rp,pgrid,type="vector") # 予測する=コンターのデータを作る

> par(new=T)
> contour(px,py,array(out1,dim=c(length(px),length(py))),xlim=c(-3,3),ylim=c(-3,3),col="purple",lwd=3,drawlabels=F)
# コンタープロットで決定境界を描く

f:id:TJO:20131121121239p:plain

f:id:TJO:20131121121249p:plain


シンプルなXORパターンを、くの字型の決定境界が綺麗に切り分けているのがよく分かると思います。ところで、決定木はその性質上「めちゃくちゃ入り組んだ切り分け方」もできることが知られているので*7、試しに入り組んだXORパターンでも試してみましょう。


データ自体はGitHubのこちらに置いてあります。d2という名前でインポートしてください。

# 複雑なケース

> d2.rp<-rpart(label~.,d2) # まず決定木モデルを算出

> plot(d2[1:50,-3],col="blue",pch=19,cex=3,xlim=c(-3,3),ylim=c(-3,3))
> points(d2[51:100,-3],col="red",pch=19,cex=3)
# 一旦プロットしておく

> out2<-predict(d2.rp,pgrid,type="vector")
# 予測する=コンターのデータを作る

> par(new=T)
> contour(px,py,array(out2,dim=c(length(px),length(py))),xlim=c(-3,3),ylim=c(-3,3),col="purple",lwd=3,drawlabels=F)
# コンタープロットで決定境界を描く

f:id:TJO:20131121121257p:plain

f:id:TJO:20131121121304p:plain


すげー、めっちゃくちゃです(笑)。でも一応何とか分離できているのが分かるんじゃないかと思います。


注意点


実データでやってみれば分かると思うんですが、時々全く枝が茂らないケースがあるんですね。これには色々な理由が考えられます。


例えば、直感的には決定境界が描かれる様子から想像できる通り「単に軸に平行な決定境界では分けられないデータだから」ということもあったりします。


またXOR(複雑なケース)パターンの分類例からピンと来るかと思いますが、「どのように分岐条件を決めても不純度が全然下がらない」場合も然りです。


これらのケースでは決定木を使ってもうまくいかないので、異なる機械学習分類器を使う必要があります。と言うか、特に機械学習分類器のバラエティ豊かな現在では、あくまでもデータが分類されていく様子を直接確認する目的に限って使われるのが普通なんじゃないかと思ってます。


樹木の剪定(汎化)


ところで、樹木モデルを際限なく茂らせていくとオーバーフィッティングが気になってきますよね。{mvpart}ではcross validationでこれを防ぐためのやり方が備わっていて、簡単に「剪定」*8することができます。


一旦簡単な例で説明したいので、お決まりのirisデータに登場してもらいます。

> data(iris)
> iris.rp<-rpart(Species~.,iris)
> plot(iris.rp,uniform=T,margin=0.2)
> text(iris.rp,uniform=T,use.n=T,all=F)


f:id:TJO:20131120155655p:plain


そんなに悪そうなツリー構造ではないんですが、試しにcross validationの結果をチェックしてみましょう。{mvpart}ではplotcp()でこれを簡単にチェックすることができます。

> plotcp(iris.rp)


f:id:TJO:20131120155703p:plain


オレンジで表示されているところが「このcp (complexity parameter)値で剪定するといいよ!」のサインです(xerrorの最小値を中心としたその標準偏差1倍の範囲内で最大のxerror値を与える木のサイズを選ぶそうです)。なので、こうして木を推定し直します。

iris.rp2<-rpart(Species~.,iris,cp=0.094)
plot(iris.rp2,uniform=T,margin=0.2)
text(iris.rp2,uniform=T,use.n=T,all=F)


f:id:TJO:20131120155711p:plain


こんな感じで、木の複雑さがちょっと減りました。こんな感じで木の剪定(≒汎化)を行うことができます。なお、同じことを上記のXOR(複雑)パターンに当てはめてやろうとすると、


f:id:TJO:20131121171826p:plain


となって剪定に適したサイズが決まらないのです(それだけデータがぐちゃぐちゃということ笑)。試しにcp=0.059で剪定してやると、


f:id:TJO:20131121171847p:plain


それっぽくなったけど、オーバーフィッティングしなくなった代わりに分類エラーがどう見ても増えてる罠。ま、こういうのはSVMとかがやるべきもの*9だということで、お後がよろしいようで。


次回はロジスティック回帰か、ニューラルネットワークか、SVMのどれかにしようと思ってます。ナイーブベイズとかやるのかなぁ? ランダムフォレストは必ずやるつもりですが。。。


最後に


いつも通り炎上ラーニング(笑)を目論んでおりますので、おかしなところ・間違っているところがあればバシバシ容赦なくツッコミ入れてください! お待ちしておりますー。


おまけ:データセットの中身


こんな感じで生成してます。何の工夫もないデータですが、ご参考までに。

# シンプルなXORパターン
> p11<-cbind(rnorm(n=25,mean=1,sd=0.5),rnorm(n=25,mean=1,sd=0.5))
> p12<-cbind(rnorm(n=25,mean=-1,sd=0.5),rnorm(n=25,mean=1,sd=0.5))
> p13<-cbind(rnorm(n=25,mean=-1,sd=0.5),rnorm(n=25,mean=-1,sd=0.5))
> p14<-cbind(rnorm(n=25,mean=1,sd=0.5),rnorm(n=25,mean=-1,sd=0.5))
> t<-as.factor(c(rep(0,50),rep(1,50)))
> d1<-as.data.frame(cbind(rbind(p11,p13,p12,p14),t))
> names(d1)<-c("x","y","label")

# 複雑なXORパターン
> p21<-cbind(rnorm(n=25,mean=1,sd=1),rnorm(n=25,mean=1,sd=1))
> p22<-cbind(rnorm(n=25,mean=-1,sd=1),rnorm(n=25,mean=1,sd=1))
> p23<-cbind(rnorm(n=25,mean=-1,sd=1),rnorm(n=25,mean=-1,sd=1))
> p24<-cbind(rnorm(n=25,mean=1,sd=1),rnorm(n=25,mean=-1,sd=1))
> t<-as.factor(c(rep(0,50),rep(1,50)))
> d2<-as.data.frame(cbind(rbind(p21,p23,p22,p24),t))
> names(d2)<-c("x","y","label")

*1:おいおいまだこのカテゴリ名使うのかよ

*2:SMO-SVMコードの移植が終わったら進みます。。。優先順位最低なので当分進まない予定ですが

*3:毎回誘われてるのに行ってませんごめんなさい

*4:誰かPythonコード例ボランティアで作ってあげると喜ばれるかも

*5:伝統的にはニューラルネットワークパーセプトロンに対する優位性を示すために用いられてきたデータ例

*6:杉山本p.95参照

*7:際限なく条件分岐を増やせば良い

*8:まさに読んで字のごとく、枝を切り落とすということ

*9:汎化性能に優れた低バリアンスモデルなので