渋谷駅前で働くデータサイエンティストのブログ

元祖「六本木で働くデータサイエンティスト」です / 道玄坂→銀座→東京→六本木→渋谷駅前

Stanで統計モデリングを学ぶ(5): とりあえず階層ベイズモデルを試してみる(応用編:トレンドのあるモデル) *追記2件あり

このシリーズ記事、全然真面目に事前分布の勉強をしていない人間がStanで無理やりフルベイズをやろうという無謀な代物でございますが、何だかんだで段々佳境に入ってまいりました。


ということで、今回は階層ベイズモデルをこんな感じでやってみましたという例を挙げてみようかと思います。ちなみに内容的には@さんのこちらの記事(「RStanで『予測にいかす統計モデリングの基本』の売上データの分析をトレースしてみた」)をグレードダウンさせた感じのものだったりします(笑)。そして先日招待講演させていただいた時の最後の方で取り上げた例でもあります。


そんなわけで、どのようにしてやっていったかを含めてサクサク見ていきましょう。階層ベイズについて忘れちゃったという人は、前回の記事あたりを読んで復習してもらえれば。


データをインポートする


いつも通り、サンプルデータをGitHubに上げてあるので持ってきてRにインポートしてください。dとかいう名前にでもしておきましょう。

x1 x2 x3 y
2988021 3029541 3429375 2387
4331957 2996819 4128007 2625
2492737 3027725 4200477 2371
1683820 2957989 6376299 2351
... ... ... ...

f:id:TJO:20140528151532p:plain


想定としては、とあるECサイトのコンバージョン(CV)件数に与えるオンライン広告の投下額の影響を推定したい、というもの。x1, x2, x3が別々の種類のオンライン広告の日次の投下額、yがその日のCV件数です。まぁこれ自体はどうってことないデータです。


単純に正規線形モデルを当てはめてみると?


上記のデータは物凄く見た感じ単純なので、一見「重回帰(正規線形モデル)でちゃちゃっとやればええんちゃう?」という気がします。実際、やってみるとこうなります。

> d.lm<-lm(y~.,d)
> summary(d.lm)

Call:
lm(formula = y ~ ., data = d)

Residuals:
    Min      1Q  Median      3Q     Max 
-1043.3  -739.2  -236.4   685.4  2032.7 

Coefficients:
              Estimate Std. Error t value Pr(>|t|)    
(Intercept) -2.638e+02  2.761e+03  -0.096    0.924    
x1           1.504e-04  3.186e-05   4.721    8e-06 ***
x2           1.041e-03  9.005e-04   1.156    0.251    
x3           2.651e-05  5.731e-05   0.463    0.645    
---
Signif. codes:  0***0.001**0.01*0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 877 on 96 degrees of freedom
Multiple R-squared:  0.2084,	Adjusted R-squared:  0.1836 
F-statistic: 8.423 on 3 and 96 DF,  p-value: 5.025e-05


簡単にできますよね。そこで、実際に推定したモデルをpredict関数を使って元データに対して当てはめてみましょう。

> matplot(cbind(d$y,predict(d.lm,d[,-4])),type='l',lty=1,lwd=c(3,2),col=c('red','blue'),ylab="CV")
> legend(3,5500,c("Data","Predicted"),col=c('red','blue'),lty=1,lwd=c(3,1),cex=1.2)

f:id:TJO:20140530171915p:plain


うがー、全然うまく当てはまってませんね。。。それもそのはず、このデータにはどう見ても右肩上がりのトレンドがあります。ただの正規線形モデルでは、このトレンドを表現することができないというわけです。


Stanでトレンドを分離したモデルを推定してみる


ならば、トレンドを線形モデルとは別に表現してやればいいじゃないか!ということでこれをStanで推定してみましょう。考え方としては、以下のようにCV数yを分解するようなモデルを置けば良いだけです。


CV_t = Q_t + trend_t

trend_t - trend_{t-1} = trend_{t-1} - trend_{t-2} + \epsilon_t

 Q_t = \alpha x_{1t} + \beta x_{2t} + \gamma x_{3t} + \epsilon_t


大本の式は1本目で、単にCV数yを線形モデルで表現可能な成分 Q_tとトレンド成分 trend_tとに分解するだけです。で、 trend_tは2次のラグまで含む差分形として、1期前の時点での増分に正規分布からなるばらつきを加えたものが、現時点での増分になるものと仮定します。そして残った Q_tは普通に正規線形モデルとして推定する、ただそれだけです。


これを表現したStanコードが以下の通りです。GitHubにも上げてあるので、普通にDLしてもらってOKです。

data {
	int<lower=0> N;
	real<lower=0> x1[N];
	real<lower=0> x2[N];
	real<lower=0> x3[N];
	real<lower=0> y[N];
}

parameters {
	real trend[N];
	real s_trend;
	real s_q;
	real<lower=0> a;
	real<lower=0> b;
	real<lower=0> c;
	real d;
}

model {
	real q[N];
	trend~normal(30,10);
	for (i in 3:N)
		trend[i]~normal(2*trend[i-1]-trend[i-2],s_trend);
	for (i in 1:N)
		q[i]<-y[i]-trend[i];
	for (i in 1:N)
		q[i]~normal(a*x1[i]+b*x2[i]+c*x3[i]+d,s_q);
}

トレンドなんですが、馬鹿正直に無情報事前分布からやるとデタラメな結果になる可能性もあるので、一応「x軸方向に100増えるごとにy軸方向に3000増える」という見かけ上の特徴に基づいて、平均30の正規分布を事前分布として与えてやります。


あとは以下のようにR側でkickしてやるだけです。収束具合は適宜{coda}パッケージを利用して確認するとして、一気にパラメータを求めるところまでやってしまいましょう。

> dat<-list(N=100,x1=d$x1,x2=d$x2,x3=d$x3,y=d$y)
> fit<-stan(file='hb_trend.stan',data=dat,iter=1000,chains=4)
# 略
> fit.coda<-mcmc.list(lapply(1:ncol(fit),function(x) mcmc(as.array(fit)[,x,])))
> plot(fit.coda)
# 収束具合を確認しておく
> fit.smp<-extract(fit)
> dens_a<-density(fit.smp$a)
> dens_b<-density(fit.smp$b)
> dens_c<-density(fit.smp$c)
> dens_d<-density(fit.smp$d)
> a_est<-dens_a$x[dens_a$y==max(dens_a$y)]
> b_est<-dens_b$x[dens_b$y==max(dens_b$y)]
> c_est<-dens_c$x[dens_c$y==max(dens_c$y)]
> d_est<-dens_d$x[dens_d$y==max(dens_d$y)]
> trend_est<-rep(0,100)
> for (i in 1:100) {
+ 	tmp<-density(fit.smp$trend[,i])
+ 	trend_est[i]<-tmp$x[tmp$y==max(tmp$y)]
+ }
> pred<-a_est*d$x1+b_est*d$x2+c_est*d$x3+d_est+cumsum(trend_est)
> matplot(cbind(d$y,pred),type='l',lty=1,lwd=c(3,2),col=c('red','blue'))
> legend(3,6500,c("Data","Predicted"),col=c('red','blue'),lty=1,lwd=c(3,1),cex=1.2)

f:id:TJO:20140530171938p:plain


お?トレンドは表現できてる気がするんですが、肝心の切片項が合ってない気がします。これは実は何度か出くわしている現象で、多分切片項dについて広めに事前分布取っておかないとダメなんじゃないかという気も。。。仕方ないので、これはちょっとズルをしてd_estを無理やり合わせます。

> fr<-function(x) {
+ 	sum((a_est*d$x1+b_est*d$x2+c_est*d$x3+x+cumsum(trend_est)-d$y)^2)
+ }
# 残差平方和を切片項d_estの関数とする

> res<-optim(d_est,fr,method="Brent",lower=-5000,upper=5000)
> res$par
[1] -1517.909
# optim関数で残差平方和最小となるd_estの値を求める

> d_est<-res$par
> pred_optim<-a_est*d$x1+b_est*d$x2+c_est*d$x3+d_est+cumsum(trend_est)
> matplot(cbind(d$y,pred_optim),type='l',lty=1,lwd=c(3,2),col=c('red','blue'),ylab="CV")
> legend(3,6500,c("Data","Predicted"),col=c('red','blue'),lty=1,lwd=c(3,1),cex=1.2)

> cor(d$y,pred_optim)
[1] 0.9620495
# 元データと予測値の相関は0.96

f:id:TJO:20140530172023p:plain


これでほぼぴったり合いました(ズルをしてますが)。ところで、推定したパラメータはこんな感じでした。

> a_est
[1] 0.0001440766
> b_est
[1] 0.0009194789
> c_est
[1] 5.112106e-05
> d_est
[1] -1517.909


実はこれにはもちろん答えがあって、元のデータd$yを生成した時のパラメータはというと。。。

> a
[1] 0.00015
> b
[1] 0.00025
> c
[1] 5e-05
> d
[1] 1000
> y<-a*x1+b*x2+c*x3+d+cumsum(c(rep(10,30),rep(20,30),rep(50,40)))


全く違うorz ちなみにトレンドもtrend_estだけ見るとほぼ一様に見えますが、実は本当は一様ではなくて3段階に分かれて後になるほど急峻になるような値を持っています。ということで、aとcの推定は良かったんですがbとdとトレンドの推定はダメでしたー、というお粗末な結果に。まぁこんなもんです。


というわけで、まとめ


意外と階層ベイズって綺麗に収束しないんだよコラ難しいんですよねー。この辺はもうちょっと色々良い方法を勉強してみようと思います。ただ、実務でも切片項の推定が合わないケースをよく見かけるので、切片項に関してだけはoptimとか使って合わせるのはアリなのかなぁとも。詳しい人教えて下さい。。。


追記1


@氏がこんなナイスな指摘をしてくれました。



これに従って上記Stanコードを直すとこうなります。

model {
	real q[N];
	real cum_trend[N];
	trend~normal(30,10);
	for (i in 3:N)
		trend[i]~normal(2*trend[i-1]-trend[i-2],s_trend);

	cum_trend[1]<-trend[1];
	for (i in 2:N)
		cum_trend[i]<-cum_trend[i-1]+trend[i];

	for (i in 1:N)
		q[i]<-y[i]-cum_trend[i];
	for (i in 1:N)
		q[i]~normal(a*x1[i]+b*x2[i]+c*x3[i]+d,s_q);
}

f:id:TJO:20140624215515p:plain

f:id:TJO:20140624215524p:plain


意地悪でトレンド項を3段階に分けたところまで綺麗~~~に推定できています。ちなみに偏回帰係数を見ると、

> a_est_mod
[1] 0.0001499713
> b_est_mod
[1] 0.0002487178
> c_est_mod
[1] 5.012751e-05
> d_est_mod
[1] 1000.706


と生成した時の元のパラメータとほぼほぼ一致。そして元データと推定値との相関係数も0.9873とほぼ完璧。@氏、有難う!!!!!


追記2


さらにさらに@先生からこんなツッコミもいただくことに。



ということで、{dlm}使ってカルマンフィルタとかそろそろやりますかね。。。確かに割と大したことのない問題にStanぶち込むのはオーバーキル気味な気もするし。これぞ炎上ラーニングの醍醐味ということで(笑)、そのうちネタ仕込んでみます。