渋谷駅前で働くデータサイエンティストのブログ

元祖「六本木で働くデータサイエンティスト」です / 道玄坂→銀座→東京→六本木→渋谷駅前

Stanで統計モデリングを学ぶ(1): まずはStanの使い方のおさらいから

(※Stan v2.4.0以降でインストール方法に若干変更があります!詳しくはこの記事の中ほどをご覧ください)


さて、年初の抱負でも語ったように今年はStanを頑張って会得していこうと思います。理由は簡単で、ありったけの要素を詰め込んでMCMCサンプラーでガンガン推定していくような階層ベイズモデリングに自分の興味としても惹かれる上に、実務でも必要になりそうな見通し*1だからです。


既に以前の記事でも簡単に触れてますが、StanはC++ベースのコンパイラで高速化させたMCMCサンプラーです。文法も簡単でなおかつ高速なので、BUGSでは時間がかかり過ぎて辛かった計算でも比較的サクサク回せます。


このシリーズを通して参考にするのは、@さんのブログです。


本当にもう、かゆいところに手が届くくらい完璧な解説の数々でいつも大変お世話になっております。と言うか、ぶっちゃけこちらのブログをただトレースしていくだけになるんじゃないかという気すらするんですが(笑)、その分こちらでは僕のグダグダな試行錯誤の様子を記録していけば良いかなぁ、という。。。


それから参考図書を挙げておきます。シリーズの後の方でどんどん追加していくかもしれませんが、ひとまず今持っている分だけ。


入口としてはやっぱり久保先生の緑本を挙げないわけにはいかないでしょう。分かりやすさ・とっつきやすさでは他の追随を許しません*2。ただ、BUGSのみでの実装なのでStanコードにする際はその辺を考慮する必要があります。


予測にいかす統計モデリングの基本―ベイズ統計入門から応用まで (KS理工学専門書)

予測にいかす統計モデリングの基本―ベイズ統計入門から応用まで (KS理工学専門書)

ビッグデータ時代のマーケティング―ベイジアンモデリングの活用 (KS社会科学専門書)

ビッグデータ時代のマーケティング―ベイジアンモデリングの活用 (KS社会科学専門書)

統数研所長の樋口先生が著者として関わっていらっしゃるシリーズ。実装面の話はほとんど出てきませんが、その分理論面での説明がコンパクトで分かりやすい上に、「直観的にどう理解すべきか」というポイント*3についても(学術書としては珍しく)触れられていて大変助かります。


The BUGS Book: A Practical Introduction to Bayesian Analysis (Chapman & Hall/CRC Texts in Statistical Science)

The BUGS Book: A Practical Introduction to Bayesian Analysis (Chapman & Hall/CRC Texts in Statistical Science)

BUGS/Stan使いなら必携の一冊。洋書なのでそれなりに読むのは手間ですが、BUGSコーディングをする人向けに実践的な内容のみから構成された良書なので絶対に読む価値があります。事前分布の定め方や、階層ベイズモデルにする際の設定の仕方なども細かく載っています*4。これもBUGSでしかコード例が出ていませんが、Stanであっても参考になるはずです。


ちなみにMCMCの原理的な側面に踏み込むと、頭の回転が遅い上に数学が大の苦手な僕にとっては単に魔境に迷い込むだけになってしまうのでMCMCの原理についての本はほとんど持ってないんですが、唯一持っているのがこちら。

Rによるモンテカルロ法入門

Rによるモンテカルロ法入門

厳密な話*5はある程度割愛した上で、MCMCというアプローチの裏側で何が行われているかを比較的分かりやすく説明してくれている本です。ベースとなっている普通のモンテカルロ法の説明から入っているので、MCMC以外でモンテカルロ法を使うケースでも参考になると思います。


とりあえずこれまでのおさらい


まずはどうやってStanをインストールして動かすかというお話から。このシリーズ記事では原則としてR上からStanを使っていくので、必然的に{rstan}ベースで話を進めていきます。PyStanとかでやりたい方はすみませんが他を当たってください、ということで。


そのStanのインストールなんですが、どうやらTwitter界隈を見ていると「Stanインストールあるある」みたいなハマりパターンがどうもあるっぽいので*6、その辺も踏まえて書いておきます。なお、参考までに僕の環境を書いておくと以下の通りです。

  • Windows 7 64bit
  • R 3.0.2
  • Stan 2.0.1(下記初回インストール時:現在は2.1.0)


実は別にMac Book Airにも入れているのでそこで詰まった時に色々ごにょごにょトラブルシューティングもしてるんですが、何だかんだで僕はマカーではなく間違っていた場合の責任も取りづらいのでやめときます*7


全体の流れ


詳細はRStan Getting Startedを見てもらいたいんですが、要はこれだけです。

  1. {Rcpp}, {inline}をインストールする
  2. Rtoolsをインストールする
  3. {rstan}をリポジトリURLを直に指定してインストールする


これだけ見ると単純そうですが(ってか実際シンプル)、全ステップにコケる要因が散りばめられているので注意が必要です。初めての人は深刻にハマることがあるかもしれないので、ステップを進むごとにその都度色々確認した方が良いかと思います。

{Rcpp}, {inline}をインストールする

install.packages('inline')
install.packages('Rcpp')


C++絡みのCRANパッケージを入れておきます。何も考えずに普通に上の通りにやれば入るはずです。この時点では特に何も依存関係などはないので、トラブルは生じないと思います。


Rtoolsをインストールする


問題はここから。これは以前の記事でも書いてますが、しつこいながらも今一度公式ヘルプを確認しておきましょう。


Rtoolsを入れる理由は簡単で、Rからアクセス可能なC++コンパイラを入れておきたいという、ただそれだけです。なのでMacLinuxでは別のやり方もありますが、WindowsではRtoolsを入れるのが第一選択ということになってます。ところが、上記の公式ヘルプ通りにやってもうまくいかないケースが結構あります。一番多いのが、RtoolsをインストールしたにもかかわらずC++コンパイルが通らずエラーになるというもの。これは基本的には

  1. RStanはgcc-4.6.3で動くことを前提としている
  2. なので、以下のような状況だとエラーになる
    • 4.6.3以下のバージョンのgccが既にシステム内にあり
    • これがWindowsのPATH変数に入ったままになっている


という構図になっています。で、よくあるのがデータ分析をやっている人ほどPython(x,y)もインストールしていて*8、それに同梱されているMinGWが4.5.2などの古いバージョンのgccを同梱していて、こいつがいつの間にかシステム内に入っていてパスも通っているというパターン。なので、予めWindowsのPATH変数を確認した上で、

C:\MinGW32-xy\bin;

があれば(もしくは古いバージョンのC:\...\gcc...\4.5.2などへのパスがあれば)手動で削除し、代わりに公式ヘルプでも指定されているように

C:\Rtools\bin;c:\Rtools\gcc-4.6.3\bin;C:\Windows\system32;

をPATHに追加する。これでうまくいくはずです。ということで

  • システム内の既存のgccに注意
  • 特にPython(x,y)などでMinGWが既に入っている場合は特に要注意


これらにだけ気を付ければ問題なく入るはずです。一応コンパイラのバージョンを確認し、さらにコンパイラが走るかどうか公式ヘルプのhello world実行例でチェックしておきましょう*9

> system('g++ -v')
Using built-in specs.
COLLECT_GCC=c:\Rtools\GCC-46~1.3\bin\G__~1.EXE
COLLECT_LTO_WRAPPER=c:/rtools/gcc-46~1.3/bin/../libexec/gcc/i686-w64-mingw32/4.6.3/lto-wrapper.exe
Target: i686-w64-mingw32
Configured with: /data/gannet/ripley/Sources/mingw-test3/src/gcc/configure --host=i686-w64-mingw32 --build=x86_64-linux-gnu --target=i686-w64-mingw32 --with-sysroot=/data/gannet/ripley/Sources/mingw-test3/mingw32mingw32/mingw32 --prefix=/data/gannet/ripley/Sources/mingw-test3/mingw32mingw32/mingw32 --with-gmp=/data/gannet/ripley/Sources/mingw-test3/mingw32mingw32/prereq_install --with-mpfr=/data/gannet/ripley/Sources/mingw-test3/mingw32mingw32/prereq_install --with-mpc=/data/gannet/ripley/Sources/mingw-test3/mingw32mingw32/prereq_install --disable-shared --enable-static --enable-targets=all --enable-languages=c,c++,fortran --enable-libgomp --enable-sjlj-exceptions --enable-fully-dynamic-string --disable-nls --disable-werror --enable-checking=release --disable-win32-registry --disable-rpath --disable-werror CFLAGS='-O2 -mtune=core2 -fomit-frame-pointer' LDFLAGS=
Thread model: win32
gcc version 4.6.3 20111208 (prerelease) (GCC)

> library(inline) 
> library(Rcpp)

 次のパッケージを付け加えます: ‘Rcpp’ 

 以下のオブジェクトはマスクされています (from ‘package:inline’) : 

     registerPlugin 

> src <- ' 
+ std::vector<std::string> s; 
+ s.push_back("hello");
+ s.push_back("world");
+ return Rcpp::wrap(s);
+ '
> hellofun <- cxxfunction(body = src, includes = '', plugin = 'Rcpp', verbose = FALSE)
cygwin warning:
  MS-DOS style path detected: C:/PROGRA~1/R/R-30~1.2/etc/x64/Makeconf
  Preferred POSIX equivalent is: /cygdrive/c/PROGRA~1/R/R-30~1.2/etc/x64/Makeconf
  CYGWIN environment variable option "nodosfilewarning" turns off this warning.
  Consult the user's guide for more details about POSIX paths:
    http://cygwin.com/cygwin-ug-net/using.html#using-pathnames
> cat(hellofun(), '\n') 
hello world 


このようにhello worldが出ればOKです。


Stanを{rstan}ごとインストールする


Rの場合、{rstan}をインストールすればStan本体も同時にインストールされるので、難しいことは考えなくても公式ヘルプにあるようにやっていけば勝手に入ってくれます。

options(repos = c(getOption("repos"), rstan = "http://wiki.rstan-repo.googlecode.com/git/"))
install.packages('rstan', type = 'source')


とやれば、後はソースファイルがDLされた後延々と時間をかけてパッケージ自体のC++ソースコードコンパイルが行われるので、待つだけです。終了すれば、普通にインストール完了のメッセージが出ます。ちなみにSSH接続が確立できないとか問題があってコマンドラインからDLできない場合は、上記URLに直接アクセスしてrstan_2.1.0.tar.gzをローカルに落としてきて、そこからインストールすればいけます。


全て無事完了すれば、

> require("rstan")
Loading required package: rstan
Loading required package: Rcpp
Loading required package: inline

Attaching package: ‘inline’

The following object is masked from ‘package:Rcpp’:

    registerPlugin

rstan (Version 2.0.1, packaged: 2013-10-25 13:14:25 UTC, GitRev: 1a89615fac00)


という感じで普通にR上で展開できるはずです。


Stan v2.4.0以降の状況


2014年7月末の状況ですが、RStan Getting Started · stan-dev/rstan Wiki · GitHubを見ると以下の通りにやれば一発だよ!と書いてあります。なので、この通りにすれば基本的には必ずStan + {rstan}が入るはずです。

> source('http://mc-stan.org/rstan/install.R', echo = TRUE, max.deparse.length = 2000)
> install_rstan()


なのですが、何故か僕の環境でこれやると毎回{Rcpp}のインストールでRセッションの再起動を要求されて、そこでループが続いちゃってコケるんですよね。。。で、僕の場合は{Rcpp}を一旦RStudioのパッケージコンソールからアンインストールして、改めて{Rcpp}のみをインストールして、それから{rstan}のインストールコマンドを直打ち(というかコピペ)で入れたらうまくいきました*10


適当なサンプルデータで動かしてみる


Stanだからと言って複雑な階層ベイズモデルをやらなきゃいけないというわけでも何でもなくて、ぶっちゃけただの正規線形モデルでも一般化線形モデルでも何でも推定してくれます。ということで以前の記事同様、いつも使っているconflict.datを使ってみようと思います。dという名前で読み込んで、CVカラムは[0,1]のnumericに直しておきます。

> d <- read.table("~/Dev/R/conflict_sample.txt", header=T, quote="\"")
> d$cv<-as.numeric(d$cv)-1


次に、エディタなどで以下のようなStanコードを書いてd.stanというファイル名で保存しておきます。

data {
	int<lower=0> N;
	int<lower=0> M;
	matrix[N, M] X;
	int<lower=0, upper=1> y[N];
}
parameters {
	real beta0;
	vector[M] beta;
}
model {
	for (i in 1:N)
	// X[i] は row_vector, beta は vector だが, dot_product が吸収してくれる
		y[i] ~ bernoulli(inv_logit(beta0+dot_product(X[i],beta)));

	// もちろん単回帰ないし次数が低い簡単なケースではベクトル表現を使わずに
	// y[i] ~ bernoulli(inv_logit(beta0+beta1*x[i]))
	// のようにベタ書きしても良い
}


そして、Stanに{rstan}経由で読み込ませるデータをlistとして作っておきます。

> d.dat<-list(N=dim(d)[1],M=dim(d)[2]-1,X=d[,-8],y=d$cv)


あとはstan(){rstan}関数でStanをkickさせるだけです。

> d.fit<-stan(file='d.stan',data=d.dat,iter=1000,chains=4)

TRANSLATING MODEL 'd' FROM Stan CODE TO C++ CODE NOW.
COMPILING THE C++ CODE FOR MODEL 'd' NOW.
cygwin warning:
  MS-DOS style path detected: C:/PROGRA~1/R/R-30~1.2/etc/i386/Makeconf
  Preferred POSIX equivalent is: /cygdrive/c/PROGRA~1/R/R-30~1.2/etc/i386/Makeconf
  CYGWIN environment variable option "nodosfilewarning" turns off this warning.
  Consult the user's guide for more details about POSIX paths:
    http://cygwin.com/cygwin-ug-net/using.html#using-pathnames
C:/Program Files/R/R-3.0.2/library/rstan/include//stansrc/stan/agrad/rev/var_stack.hpp:49:17: warning: 'void stan::agrad::free_memory()' defined but not used [-Wunused-function]
SAMPLING FOR MODEL 'd' NOW (CHAIN 1).
Iteration: 1000 / 1000 [100%]  (Sampling)
Elapsed Time: 23.447 seconds (Warm-up)
              42.818 seconds (Sampling)
              66.265 seconds (Total)

SAMPLING FOR MODEL 'd' NOW (CHAIN 2).
Iteration: 1000 / 1000 [100%]  (Sampling)
Elapsed Time: 17.502 seconds (Warm-up)
              23.147 seconds (Sampling)
              40.649 seconds (Total)

SAMPLING FOR MODEL 'd' NOW (CHAIN 3).
Iteration: 1000 / 1000 [100%]  (Sampling)
Elapsed Time: 48.873 seconds (Warm-up)
              24.728 seconds (Sampling)
              73.601 seconds (Total)

SAMPLING FOR MODEL 'd' NOW (CHAIN 4).
Iteration: 1000 / 1000 [100%]  (Sampling)
Elapsed Time: 35.494 seconds (Warm-up)
              21.94 seconds (Sampling)
              57.434 seconds (Total)


これで計算が終了しました。普通に結果は以下のようにすれば表示できます。

> d.fit
Inference for Stan model: d.
4 chains, each with iter=1000; warmup=500; thin=1; 
post-warmup draws per chain=500, total post-warmup draws=2000.

          mean se_mean  sd   2.5%    25%    50%    75%  97.5% n_eff Rhat
beta0     -1.4     0.0 0.2   -1.9   -1.5   -1.4   -1.2   -0.9  1193    1
beta[1]    1.1     0.0 0.2    0.7    1.0    1.1    1.2    1.4  1329    1
beta[2]   -0.6     0.0 0.2   -0.9   -0.7   -0.5   -0.4   -0.2  1623    1
beta[3]    0.1     0.0 0.2   -0.2    0.0    0.1    0.2    0.4  1708    1
beta[4]   -3.0     0.0 0.2   -3.4   -3.2   -3.0   -2.9   -2.6  1393    1
beta[5]    1.5     0.0 0.2    1.2    1.4    1.5    1.6    1.9  1438    1
beta[6]    5.4     0.0 0.2    5.0    5.3    5.4    5.5    5.7  1520    1
beta[7]    0.1     0.0 0.2   -0.2    0.0    0.1    0.2    0.4  1580    1
lp__    -526.0     0.1 2.0 -530.7 -527.1 -525.7 -524.5 -523.2   893    1

Samples were drawn using NUTS(diag_e) at Mon Jan 27 19:17:54 2014.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).


パラメータの分布形状はともかく、平均・標準偏差・主要なパーセンタイル点は表示してくれます。なので、パッと見だけでもどのパラメータが有意に得られたかぐらいは分かるという按配です。


おまけ:可視化してみた


もっとも、MCMCサンプラーを使うからにはやっぱりパラメータ分布を可視化してみたいですよね?ってかそれが醍醐味だし。。。


以前の記事では{coda}パッケージを使って可視化してたんですが、それでは芸がないということでまたしても@さんのブログ記事に倣ってやってみました。


要は{ggplot2}パッケージを使うのが肝です。これで、コンパクトながら同時にパラメータ分布を分かりやすく可視化したプロットを描くことができます。上記計算の後、続けて以下のように処理していきます。

> d.ext<-extract(d.fit,permuted=T)
> N.mcmc<-length(d.ext$beta0)


一旦extract(){rstan}関数でMCMCサンプルを抽出してきます。サンプル長も適当にlength()とかで取ってきておきましょう。後は@さんの記事に従ってやっていくだけです。

> require(ggplot2)
> require(reshape2)
> require(plyr)

> b1<-d.ext$beta[1:2000]
> b2<-d.ext$beta[2001:4000]
> b3<-d.ext$beta[4001:6000]
> bs<-data.frame(b1=b1,b2=b2,b3=b3)
> bs.melt<-melt(bs,id=c(),variable.name="param")
> bs.qua.melt<-ddply(bs.melt,.(param),summarize,
+                    median=median(value),
+                    ymax=quantile(value,prob=0.975),
+                    ymin=quantile(value,prob=0.025))
> colnames(bs.qua.melt)[2]<-"value"

> bs.melt<-data.frame(bs.melt,ymax=rep(0,N.mcmc),ymin=rep(0,N.mcmc))
> p<-ggplot(bs.melt,aes(x=param,y=value,group=param,ymax=ymax,ymin=ymin,color=param))
> p<-p+geom_violin(trim=F,fill="#5B423D",linetype="blank",alpha=I(1/3))
> p<-p+geom_pointrange(data=bs.qua.melt,size=0.75)
> p<-p+labs(x="",y="")+theme(axis.text.x=element_text(size=14),axis.text.y=element_text(size=14))
> ggsave(file="d.png",plot=p,dpi=300,width=4,height=3)

f:id:TJO:20140128130412p:plain

まず分かりやすいように最初の3つの偏回帰係数のバイオリンプロットを出してみました。b1とb2が0のラインから正負それぞれに有意に離れた偏回帰係数を示しているのが分かると思います。なお、7つの偏回帰係数のバイオリンプロット全部を出すと以下のようになります。

> b1<-d.ext$beta[1:2000]
> b2<-d.ext$beta[2001:4000]
> b3<-d.ext$beta[4001:6000]
> b4<-d.ext$beta[6001:8000]
> b5<-d.ext$beta[8001:10000]
> b6<-d.ext$beta[10001:12000]
> b7<-d.ext$beta[12001:14000]
> bs2<-data.frame(b1=b1,b2=b2,b3=b3,b4=b4,b5=b5,b6=b6,b7=b7)
> bs2.melt<-melt(bs2,id=c(),variable.name="param")
> bs2.qua.melt<-ddply(bs2.melt,.(param),summarize,
+ median=median(value),
+ ymax=quantile(value,prob=0.975),
+ ymin=quantile(value,prob=0.025))
> colnames(bs2.qua.melt)[2]<-"value"

> bs2.melt<-data.frame(bs2.melt,ymax=rep(0,N.mcmc),ymin=rep(0,N.mcmc))
> p<-ggplot(bs2.melt,aes(x=param,y=value,group=param,ymax=ymax,ymin=ymin,color=param))
> p<-p+geom_violin(trim=F,fill="#5B423D",linetype="blank",alpha=I(1/3))
> p<-p+geom_pointrange(data=bs2.qua.melt,size=0.4)
> p<-p+labs(x="",y="")+theme(axis.text.x=element_text(size=11),axis.text.y=element_text(size=11))
> ggsave(file="d7.png",plot=p,dpi=300,width=4,height=3)

f:id:TJO:20140128131312p:plain

7つ全てまとめるとやっぱりちょっと個々のプロットが小さくなって見づらいかもですね。。。そしてb4とb6の収束が悪い気が(笑)。そんなこともこのプロットから分かります。


最後に


今後は@さんお薦めのMS提供のinfer.netのデータや、もしくは自分の興味のあるタイプの階層構造を持つようなデータを対象に、Stanで推定した時にどのような振る舞いを示すのかを見ながら学んでいこうと思ってます。ま、かなりいい加減なシリーズ記事になる予定なので、あまり期待しないでもらえると有難いです。。。


おまけ

masa_grant55

こういうポストの準備、投稿に何日くらいかけるのだろう。内容に加えて、エントリ化するスキルも並大抵じゃないな

はてなブックマーク - masa_grant55 のブックマーク - 2014年1月28日


今回の記事は過去記事からの引用だらけなので、ぶっちゃけ1~2日で書きました(笑)。

*1:詳細はもちろん今はまだ秘密

*2:一般化線形モデルへの導入としても良書ですね

*3:「数式の感覚的理解」について記述している書籍は僕にとってはこれが初めてです

*4:具体的なBUGSコードとしてその辺をどう表すかまで含めて書かれている

*5:特に収束証明のあたり

*6:最低でも3人ハマったのを見てます

*7:おまけに大体英語圏まで含めてググれば出てくるようなありきたりの話ばかりなので

*8:RとPythonの両刀使いが多いのが原因

*9:挙動確認のために手元のPCでやったので、実際には32bit版の表示になってますごめんなさい

*10:Windows 7 + 32/64bit環境