ちょっと前に話題になってたんですが、何でもCRANに確率的勾配降下法(Stochastic Gradient Descent)を実装した{sgd}というパッケージが公開されているそうで。JSS掲載予定のVignetteもあるみたいです。
ということで、ミーハーかもですが試しにちょっと触ってみようかと思います。ちなみに今回検証に用いたマシンはWindows 7 (64bit)、8コア (3.60GHz)、メモリ32GBというスペックです。
確率的勾配降下法について
最近はMLPシリーズとか良いテキストが沢山出ているので、そちらでも読んでみて下さい*1。
- 作者: 海野裕也,岡野原大輔,得居誠也,徳永拓之
- 出版社/メーカー: 講談社
- 発売日: 2015/04/08
- メディア: 単行本(ソフトカバー)
- この商品を含むブログ (3件) を見る
pp.29-30に出ていますが、要は毎回全サンプルをメモリに入れて勾配を計算する最急降下法とは異なり、1つサンプルを読み込むごとに1回パラメータを更新して勾配を逐次計算する手法です。一応@shima__shima先生の解説が朱鷺の森にあります(確率的勾配降下法 - 機械学習の「朱鷺の杜Wiki」)。
極めて大ざっぱに言い換えると、これはオンライン学習のためのパラメータ最適化手法だということです。全サンプルをメモリに入れずに毎回1サンプルだけ読み込めば良いようにすることで、必要とするメモリ容量が非常に小さくて済むことになります。その代わり局所最適にハマりやすくなる危険性もあるわけで、その辺は今でも改良する研究が進められているようです。
{sgd}のインストールについて
で、本題の{sgd}なんですが、R 3.0.x系では動かないようですorz 結局これを機にR 3.2.x系にアップデートしました。
そしてWindowsマシンでインストールしようとすると、バイナリ版のv0.1かソース版のv1.0のどちらかを選べと言ってくるんですが、これが正しいバージョンのRtools入れてるはずなのに*2ソース版のビルドが何をやっても通らないんですよねorz
ということで以下の検証ではCRANからinstall.packagesでインストールしたバイナリ版v0.1の挙動しか見ていません。v0.1だと"glm"の選択肢が少ないとか色々物足りないところだらけなんですが、ごめんなさいということで。
小規模データセットの場合
ということでお手軽に試してみましょう。v0.1だと基本的にはlmかglmの二択なので、分かりやすくglmそしてロジスティック回帰(二項ロジット)のみに的を絞ってみます。まずは小規模データセットの場合ということでこちらを。
このconflictデータセットを使うの久しぶりですね。ということで、d1という名前で適当に読み込んでおきましょう。以下ただベタっと回すだけ。
# d1はconflictデータセットを読み込んでおいたもの > dim(d1) [1] 3000 8 > summary(d1$cv) No Yes 1500 1500 # glmの場合 > t<-proc.time() > d1.glm<-glm(cv~.,data=d1,family=binomial) > proc.time()-t ユーザ システム 経過 0.02 0.00 0.02 # sgdの場合 > library(sgd) > d1$cv<-as.integer(d1$cv)-1 # {sgd}はfactorを目的変数に選べないのでnumericに直す > t<-proc.time() > d1.sgd.glm<-sgd(cv~.,data=d1,model="glm",model.control=list(family=binomial())) > proc.time()-t ユーザ システム 経過 0.02 0.00 0.02 > summary(d1.glm) Call: glm(formula = cv ~ ., family = binomial, data = d1) Deviance Residuals: Min 1Q Median 3Q Max -3.6404 -0.2242 -0.0358 0.2162 3.1418 Coefficients: Estimate Std. Error z value Pr(>|z|) (Intercept) -1.37793 0.25979 -5.304 1.13e-07 *** a1 1.05846 0.17344 6.103 1.04e-09 *** a2 -0.54914 0.16752 -3.278 0.00105 ** a3 0.12035 0.16803 0.716 0.47386 a4 -3.00110 0.21653 -13.860 < 2e-16 *** a5 1.53098 0.17349 8.824 < 2e-16 *** a6 5.33547 0.19191 27.802 < 2e-16 *** a7 0.07811 0.16725 0.467 0.64048 --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 (Dispersion parameter for binomial family taken to be 1) Null deviance: 4158.9 on 2999 degrees of freedom Residual deviance: 1044.4 on 2992 degrees of freedom AIC: 1060.4 Number of Fisher Scoring iterations: 7 > d1.sgd.glm$coefficients (Intercept) a1 a2 a3 a4 a5 a6 a7 0.0153553149 0.0005093711 0.0001825324 0.0001825045 -0.0024151939 -0.0018321832 -0.0140327344 -0.4150608728 > plot(d1.sgd.glm$coefficients-d1.glm$coefficients,cex=3) > segments(0,0,9,0,col='red',lty=2,lwd=2)
小規模データセットなので計算時間に差はないですね。一方、意外と偏回帰係数の推定が変なところに行っている印象です。試しに偏回帰係数の差分をプロットしてみると、上下にバタついています。
大規模データセットの場合
次に、大規模データセットで検証してみます。こちらはさすがに皆さんにデータセットをシェアするには重過ぎるので、あくまでも回した結果のコンソール出力のみで失礼。
# d2は以前遊びで作ったことのある割と重めの二値分類不均衡データセット > dim(d2) [1] 35643 1028 > summary(d2$label) 0 1 34478 1165 # クラス重み付けベクトルを作る > idx0<-which(d2$label=="0") > idx1<-which(d2$label=="1") > weight_vec<-rep(0,length(d2)) > weight_vec[idx0]<-length(d2)/34478 > weight_vec[idx1]<-length(d2)/1165 # glmの場合 > t<-proc.time() > d2.glm<-glm(label~.,data=d2,family=binomial,weights=weight_vec) > proc.time()-t ユーザ システム 経過 584.68 2.40 587.53 # sgdの場合 > d2$label<-as.integer(d2$label)-1 > t<-proc.time() > d2.sgd.glm<-sgd(label~.,data=d2,model="glm",model.control=list(family=binomial()),sgd.control=list(weights=weight_vec)) > proc.time()-t ユーザ システム 経過 136.41 97.47 233.97 > plot(d2.sgd.glm$coefficients-d2.glm$coefficients) > segments(0,0,1050,0,col='red',lty=2,lwd=2)
計算速度自体は確かに速いです。{bigmemory}を使っているというのも大きいのでしょうが、glmに比べて3分の1ぐらいの計算時間で回り切ってます。
問題は、前のconflictデータセット同様にパラメータ推定精度。単に偏回帰係数だけを比べてみたんですが、ほぼ同じ推定値(0付近に集中)しているものと、何故か20ぐらいずれてるものとに分かれてますね。。。あと{sgd}の方がかなり小さく推定されてしまっているものも(100以上低いとか)あるようです。
ということで、もしかしたら何か局所最適踏んでるのかなぁ感も覚えているところです。もうちょっと良いデータセットがあったら踏み込んでチェックしてみようかなと。
なお今回はバイナリ版のあるv0.1でしか試していませんが、ソース版のv1.0では例えばモーメンタム法やNesterovの加速勾配法やAdaGradなどをsgd.control引数で指定することができます。またv1.0がきちんと入れられたら追記の形でレポート入れるかもしれません。。。
そうそう、Vignetteを見るとMNISTのベンチマークがあるので、開発陣は当然多項ロジットもやってるんだろうと思うんですが、まだ実装はされてないんですかね? それとも他に多クラス分類やる方法があるんでしょうか?