渋谷駅前で働くデータサイエンティストのブログ

元祖「六本木で働くデータサイエンティスト」です / 道玄坂→銀座→東京→六本木→渋谷駅前

一般化加法モデル(GAM)のknotsはどう決めるべきか

この記事は、以前MMM (Media/Marketing Mix Modeling)について概説した記事の続きです。

今年ローンチされたMMMフレームワークのMeridianでは、従来の様々なMMMフレームワークとは異なり、トレンド・季節調整をモデリングする際に一般化加法モデル(Generalized Additive Models: GAM)を用いています。ただ、そのハイパーパラメータである"knots"の決め方が、当初ドキュメントを一瞥しただけではよく分からなかったので、一通り調べてみた結果を備忘録代わりにここに書き記しておこうと思います。

一般化加法モデルは「関数の線形和で作る線形モデルの拡張版」


言わずもがなですが、線形モデル族は以下のような式で表されます。

 y = \beta_0 + \beta_1 x_1 + \beta_2 x_2 + \cdots + \beta_p x_p + \epsilon

これに対して、非線形なデータセットに対してより柔軟性の高い当てはめを目指して、説明変数を非線形なスプライン関数で変換したものを配置するという考え方があります。

 y = \beta_0 + \beta_1 f_1(x_1) + \beta_2 f_2(x_2) + \cdots + \beta_p f_p(x_p) + \epsilon

こうすることで、「変数ごとの影響度が分かりやすい」という線形モデルに由来する説明性の高さと、非線形なスプライン関数がもたらす当てはまりの良さとが、一挙両得できるわけです。このモデルはスプライン関数で変換したものを足し合わせていくことで成るため、加法モデルと呼ばれます。スプライン関数には、多項式やガウシアンなど様々なものが一般には使われます。

そもそもGAMはカステラ本の原著ESLやISLを執筆したHastie & Tibshiraniが考案したものであり、それらのテキストでは詳細な解説がなされています*1

ちなみにザッとググったところ、ISLの例題をRとPythonで試してみたという分かりやすい記事がありましたので、詳細とその雰囲気についてはこちらをお読みいただいた方が早道かと思います。

ハイパーパラメータknotsは「一貫性のある区間ごとに区切る結び目」


https://r.qcbs.ca/workshop08/book-en/how-gams-work.html

これまたザッとググったところ、GAMをRで実践する系のワークショップのサイトが出てきました。



(Cited from https://r.qcbs.ca/workshop08/book-en/how-gams-work.html)

5.2節の3次スプライン関数を用いた例が分かりやすいですが、GAMではデータセットを(互いに一貫性のあるパターンを持つ)複数の領域に分割した上でそれぞれの領域に合わせた異なるスプライン関数を当てはめるというアプローチを取っており、その領域の分割の仕方を決めるのがknotsということなんですね*2。イメージとしては「極値」「変曲点」にknotsを打つ、という感じになるかと思います。


時系列データであればknotsは具体的なベクトルとしてx軸の座標として定めることもできますし、その場合は原始的ですが「目検」で決めることになります。特に、regime switchを含めたベースラインシフトが顕著な場合は尚更です。ただ、一般には「個数」だけ指定すれば上手く当てはまることが多いようです。Meridianもこの立場でknotsを引数に含めているようです。


直感的に考えればお分かりかと思いますが、knotsが多ければ多いほど多数のスプライン関数を用意して無理やり当てはめるということになり、過学習につながりかねません。そこで、最適なknots数を決めるという考え方が出てきます。実際のGAMではL1正則化と同じアイデアで、ハイパーパラメータを持つ正則化項を加えてからknotsを多めに設定した上で最適化計画を解くことで、スムージング性能を保ちながら事後的にknotsを抑制するというアプローチもある*3ようですが、今回は割愛します。


knotsは交差検証で最適値を決める


これまたググるとCross Validatedの質疑が見つかるのでわざわざここで麗々しく書くほどの話ではないのですが、最適なknotsは交差検証で決めるのが一般的なようです。本来のGAMではGCV (generalized cross validation)という手法で自動的に決めることができるようですが、Meridianにはそんなものはないので手動でコードを組んでやることになります。


一つ注意点があって、knotsを例えば時系列データセット長と同じ値(つまり最大値)に設定すると、それだけ多数のスプライン関数の演算をさせることになり、計算がOOMで落ちる危険性があります。このため、交差検証をする際は必ずknots = 1から1つずつ増やしていくアプローチの方が現実的と思われます。


ただし、時系列データセットに対してGAMのような手法は「未来」方向への予測は不安定になるであろうことは容易に想像できる*4わけで、交差検証のやり方は要検討かと。Meridianのドキュメントでは時系列モデルであるにもかかわらずrandom splitを推奨していますが、これがその理由なのかもしれません。もっとも、個人的には「時系列モデルの性能は未来予測性能によってのみ評価されるべき」だと考えていますので、やはりpast-and-future splitでやるべきだという意見です。


なお、今回は「MMMでトレンド&季節調整の当てはめにGAMを用いる」ケースのみを想定したため、目的変数も説明変数も1つずつのパターンをメインに論じましたが、通常のGAMで多変量すなわち説明変数が複数の場合*5はそれぞれにknotsを決めることになって煩雑なため、ある程度決め打ちで例えばknots = 4のように固定するのが現実的な対応策とされるようです。

Conflict of interests

筆者はMeridianの配布元企業に勤務しています。

*1:ちなみに今回もカステラ本を引っ張り出して読もうと思ったんですが、本箱の奥深くに埋もれてしまっていたので泣く泣く諦めました

*2:ESL/ISLをNotebookLMに読み込ませて解説させたらこう言ってました笑

*3:ただし正則化項のハイパーパラメータは交差検証で決める必要があるため、この後の下りとプロセスは同じ

*4:そもそも既定のスプライン関数を恣意的に割り当てる時点で予測可能な範囲が限られる

*5:Meridianならgeo-levelで複数地域を設定する場合