さて、折角Deep Learningなんて使うんだったらもうちょっと面白いデータでやってみようよ!ということで、多次元データの代表たるMNIST手書き文字データ*1を使って試してみようかと思います。
で、MNISTデータなんですが真面目に取ってこようとするとえらく大変なので、前処理済みのものをKaggleからもらってきます。
もちろんKaggle側の前処理の過程で失われている情報もあるかもしれませんが、今回はそんなところまで目くじら立てても仕方ないので放っておきます。なので必然的に1次元方向に特徴量を並べただけの素性でやることになるので、例えばConvNetsみたいなことは今回は想定しません。悪しからずご了承を。
MNIST手書き文字データについて
PRMLのAppendix Aでも紹介されている、非常に有名な機械学習分類器向けのサンプルデータセットです。
Description - Digit Recognizer | Kaggle
MNIST databaseとは米標準技術局(The National Institute of Standards and Technology: NIST)が提供する手書き数字に関する混合データベースで、米国勢調査局の従業員と米国内の高校生から収集されたものだそうです*2。
元データはあくまでも白地に黒のインクで数字を手で書いてもらっただけの二値データですが、これを画像のアスペクト比を保ちながら20×20ピクセルの枠にはめ込み、さらにこれを解像度を変更しつつアンチエイリアシング処理をかけたことによって28×28ピクセルの256階調グレースケール画像に変換したものです。よって、これを前処理することで0-255の整数値を取る784次元の特徴ベクトルと、0-9のどの文字であったかという分類ラベルから成るデータテーブルが出来上がることになります。
なお、オリジナルのMNISTデータは6万サンプルの学習データと1万サンプルのテストデータから成っており、学習データを手書きした個人とテストデータを手書きした個人とが重ならないようにきちんと振り分けられているとのこと。このようにきちんと設計されたサンプルデータセットゆえ、あちこちで機械学習分類器の性能ベンチマークとして愛用されているのは皆さんもご存知の通りでしょう*3。
ちなみに、R上で個々のMNISTサンプルを描画したいという奇特な人のためにStackOverFlowに描画コードの例が出ていますので、興味のある人はどうぞw
噂ではテストデータの文字全てを描画した上で目視で判別して手入力でラベルを打っていっている変態もいるとか何とか。。。
DL、ロジスティック回帰、SVM、ランダムフォレストの3種類で比べてみる
ということで、早速やってみようと思います。まず、Kaggleのコンペサイトからtrain.csvを落としてきます。その上で、自前でtrain.csvを32000サンプルの学習データと10000サンプルのテストデータとにランダムに割り振ります。残念ながらこの操作によって元々のMNISTデータが持っている「手書きした個々人レベルでCVが成立する」という特徴は消えてしまいますが、まぁ置いといてということで。。。
> dat<-read.csv("train.csv", header=TRUE) > labels<-dat[,1] > test_idx<-c() > for (i in 1:10) { + tmp1<-which(labels==(i-1)) + tmp2<-sample(tmp1,1000,replace=F) + test_idx<-c(test_idx,tmp2) + } > test<-dat[test_idx,] > train<-dat[-test_idx,] > write.table(train,file="prac_train.csv",quote=F,col.names=T,row.names=F,sep=",") > write.table(test,file="prac_test.csv",quote=F,col.names=T,row.names=F,sep=",")
これでデータの準備はできました。ということで、以下のように実行してみましょう。まずは普通にベンチマークとしてRandomForestで。
> library(randomForest) > prac_train$label<-as.factor(prac_train$label) > prac.rf<-randomForest(label~.,prac_train) > prd.rf<-predict(prac.rf,newdata=prac_test[,-1],type="response") > sum(diag(table(test_labels,prd.rf))) [1] 9658
あくまでもベンチマークですが、いきなり96.58%。ちなみにこれでKaggle本番にsubmitすると、大体270位ぐらいにランクインします*4。後は多項ロジットとSVMなんですが、これが結構厄介でして。まず多項ロジットは{VGAM}のvglm関数で行います。のですが。。。
> library(VGAM) > res.mlg<-vglm(label~.,multinomial,train) Error: cannot allocate vector of size 15.2 Gb In addition: Warning messages: 1: In dim(ans) <- c(R, nrow(xmat)) : Reached total allocation of 32714Mb: see help(memory.size) 2: In dim(ans) <- c(R, nrow(xmat)) : Reached total allocation of 32714Mb: see help(memory.size) 3: In dim(ans) <- c(R, nrow(xmat)) : Reached total allocation of 32714Mb: see help(memory.size) 4: In dim(ans) <- c(R, nrow(xmat)) : Reached total allocation of 32714Mb: see help(memory.size)
コケたorz 仕方ないので普通にSVMでやってみます。これは{e1071}のsvm関数でやっていきます。
> library(e1071) > res.svm<-svm(label~.,train,scale=F) > pred.svm<-predict(res.svm,newdata=test[,-1]) > sum(diag(table(test[,1],pred.svm))) [1] 1000
何と、繰り返してみたものの何故かどれか1つのクラスに固定されてしまいうまくいかず。。。ということで多項ロジットとSVMはJapan.Rまでの宿題とさせてくださいorz
さて、いよいよh2o.deeplearningの出番。こんな感じでkickできます。
> library(h2o) > localH2O <- h2o.init(ip = "localhost", port = 54321, startH2O = TRUE, nthreads=-1) > prac_train <- read.csv("prac_train.csv") > prac_test <- read.csv("prac_test.csv") > trData<-h2o.importFile(localH2O,path = "prac_train.csv") > tsData<-h2o.importFile(localH2O,path = "prac_test.csv") > res.dl <- h2o.deeplearning(x = 2:785, y = 1, data = trData, activation = "Tanh",hidden=rep(160,5),epochs = 20) > pred.dl<-h2o.predict(object=res.dl,newdata=tsData[,-1]) > pred.dl.df<-as.data.frame(pred.dl) > sum(diag(table(test_labels,pred.dl.df[,1]))) [1] 9711
とりあえずRandomForestでのベンチマークを超えることができました。実はHinton先生の2012年の論文(PDF注意)でDeep RBMでMNISTデータを扱った際の結果が紹介されていて、その中に最適パラメータ設定がいくつか例示されています。で、その通りにやるとこんな感じに。
> res.dl <- h2o.deeplearning(x = 2:785, y = 1, data = trData, activation = "Tanh",hidden=c(500,500,1000), + epochs = 20,rate=0.01,rate_annealing = 0.001) > pred.dl<-h2o.predict(object=res.dl,newdata=tsData[,-1]) > pred.dl.df<-as.data.frame(pred.dl) > sum(diag(table(test_labels,pred.dl.df[,1]))) [1] 9726
正答率97.26%までやってきました。でもこれだとConvNetsには全然敵わないんですよね。。。と思いながらMNISTの元ネタを公開しているLeCun先生のところを見たらもっと色々書いてあるし。。。
その上でHinton先生の2012年の別の論文(PDF注意)見るともっと細かく書いてあるのに気が付いて、さらにさらにこれを援用したNIPS2013の"Understanding dropout"論文(PDF注意)での設定にも気が付いたんですが、これが意外と精度が上がらず。。。特に何故かdropoutをつけると精度が逆に下がってしまったり。ちなみに内容は以下の通り。
> res.dl<-h2o.deeplearning(x=2:785, y=1, data=trData, activation = "TanhWithDropout", hidden = c(784, 500, 500, 2000), + hidden_dropout_ratios = rep(0.5, 5), input_dropout_ratio = 0.2, momentum_start = 0.5, momentum_ramp = 500, + momentum_stable = 0.99, epochs = 20, rate = 0.01, rate_annealing = 0.001)
これでは92%ぐらいまでしか行かなくてダメだったので、LeCun先生のところで引用されている「お前らautoencoderは甘え、dropoutも甘え、RBMも甘え、ただひたすら多層多ユニットNNを組め」という2010年の論文(PDF注意)に従って以下のようにしてみました。
> res.dl<-h2o.deeplearning(x=2:785, y=1, data=trData, activation="Tanh", hidden=c(2500,2000,1500,1000,500,10), epochs=20, + rate=0.001, rate_annealing=0.02, initial_weight_distribution="Normal", initial_weight_scale=0.01)
Kaggleコンペ本番データ(42000行)でやったところ、手元のデスクトップで普通に4時間半かかりました。。。しかも精度が96.543%まで下がりやがったというorz これではどうにもならんですね。。。ということで、H2Oの開発者であるArno Candelが今年7月のトークで紹介したパラメータ設定に合わせてみます*5。
> res.dl <- h2o.deeplearning(x = 2:785, y = 1, data = trData, activation = "RectifierWithDropout", + hidden=c(1024,1024,2048),epochs = 100, adaptive_rate = FALSE, rate=0.01, rate_annealing = 1.0e-6, + rate_decay = 1.0, momentum_start = 0.5,momentum_ramp = 32000*6, momentum_stable = 0.99, input_dropout_ratio = 0.2, + l1 = 1.0e-5,l2 = 0.0,max_w2 = 15.0, initial_weight_distribution = "Normal",initial_weight_scale = 0.01, + nesterov_accelerated_gradient = T, loss = "CrossEntropy", fast_mode = T, diagnostics = T, ignore_const_cols = T, + force_load_balance = T) > pred.dl<-h2o.predict(object=res.dl,newdata=tsData[,-1]) > pred.dl.df<-as.data.frame(pred.dl) > sum(diag(table(test_labels,pred.dl.df[,1]))) [1] 9811
ようやく98.11%。ちなみにこれでさらにepochs = 200, momentum_ramp = 42000*12としたところ、手元のデータでは98.16%。これをKaggleコンペ本番に回したところ、98.3%(56位:11月12日時点)に上がりました。計算時間はfast_mode = Tとしたのもあるせいか70分程度。
とまぁ、こんな感じです。Arno Candelの資料を見ると「epochs増やせばどんどん上がるよ」という由で、後は計算時間との勝負。。。と思ったら! epochs = 1000にしてやってみたところ、
マジかよorz 下がった。。。やっぱりこの手の画像ネタはConvNets使うしかないのかなぁ。。。ということで、non-ConvNetsの限界を感じたこの2週間でした。