マサムネの部屋

変分推論入門

ベイズ統計学に、変分推論という技があります。この技はEMアルゴリズムを知っていると分かりやすくなります。原理を解説し、ポアソン混合モデルに対して変分推論を適用します。1

スポンサーリンク

変分推論

未知の確率分布\(p\)があるとしましょう。この確率分布を近似したい時、 良く知った確率分布\(q(Z) \)で表すことが出来れば、少しは情報が得られそうです。確率分布同士の距離はKLダイバージェンスで測れる 2 ので、\( KL(p \| q ) \)を最小化するようなパラメーター\(Z =\{ z_1 , \cdots , z_m \} \)を推定できれば、\( p\)についての情報が得られる事になります。
一般的な場合では話が前に進まないので、\(q(Z) =\prod q(z_i ) \)と分解できるような場合を考えます。このような仮定を置くことを平均場近似とか、変分推論と呼んでいます。
平均場近似の記事によると、以下の式のように確率分布を更新するとKLダイバージェンスを最小化できます。
$$\begin{eqnarray}
E_{i\neq j}[\log p(X, Z )] &=& \int \log p(X,Z) \prod_{i\neq j} q_i dz_i \\
q^{\ast}_j (z_j) &=& \frac{\exp( E_{i\neq j} [\log p(X, Z)] )}{ \int \exp( E_{i\neq j} [\log p (X, Z)] ) dz_j }
\end{eqnarray}$$
変分推論をポアソン混合モデルに 適用しましょう。

ポアソン混合モデル

ポアソン混合モデルについてまとめます。
データ全体がクラスターに分かれていて、それぞれのクラスターがポアソン分布に従っているのがポアソン混合モデルです。
ポアソン分布は自然数\(x\) と正の実数\( \lambda \) を用いて、\( {\rm Poi} (x|\lambda ) = \frac{\lambda ^{x}}{x!}e^{-\lambda} \)と表されます。
各データ\(x_n \)が所属するクラスター \(k \) を決めるのはカテゴリー分布\(q(s_n | \pi ) = \prod \pi_k ^{s_{n,k}} \)です。
ただし、\( \sum_k \pi _k =1 \)で、 \(s_{n,k} \)で\(s_n \)の第k成分を表します。 また、\( s_n \)はある成分\( m\) だけが1で、他の成分は0のベクトルです。\(m \)が\(x_n \)の所属するクラスターを表しています。
ポアソン混合モデル\( p(X , S, \lambda ,\pi ) \) は以下のように書き表す事が出来ます。
$$\begin{eqnarray}
p(X, S , \lambda , \pi ) &=& p(X|S, \lambda )p(S|\pi)p(\pi ) p(\lambda )
\end{eqnarray}$$
それぞれの確率分布は、計算しやすいように設定します。3
$$\begin{eqnarray}
p(X| S , \lambda ) &=& \prod_{n} p(x_n | s_n , \lambda ) \\
p(x_n | s_n , \lambda ) &=& \prod_{k} {\rm Poi} (x_n |\lambda _k ) ^{s_{n,k} } \\
p(S |\pi ) &=& \prod p(s_n|\pi) \\
p(s_n|\pi) &=& {\rm Cat}(s_n , \pi ) = \prod \pi_k ^{s_{n,k}} \\
p(\lambda _k ) &=&{ \rm Gam } (\lambda _k | a,b ) \\
&=& C_{G}(a, b ) \lambda ^{a-1} \exp(-b\lambda _k )\\
p(\pi )= &=& { \rm Dir} (\pi |\alpha ) \\
&=&C_{D}(\alpha ) \prod \pi _{k} ^{\alpha _k -1}
\end{eqnarray}$$
\(C_G (a,b) \)や \(C_D (\alpha ) \)は積分して1になるようにするための定数 4 です。
\( a, b, \alpha \)の更新式を得る事で、混合ポアソン分布のパラメーター\(S, \lambda , \pi \)をサンプリングすることが出来ます。

変分推論によるポアソン混合モデルのパラメーター推定

ポアソン混合分布において、隠れ変数は\(S = \{s_1 , \cdots , s_N \} \)です。
パラメーターは、\( \pi =\{\pi _1, \cdots , \pi _K \} , \lambda = \{ \lambda _1 , \cdots , \lambda _K \} \)です。
隠れ変数\(S \)と パラメーター\( (\lambda , \pi )\)は、\(X \)について条件付き独立と仮定しましょう。
$$\begin{eqnarray}
p(S, \lambda , \pi |X ) &=& p(S|X )p(\lambda , \pi |X )
\end{eqnarray}$$
この仮定によって、データからパラメーター\(S, \lambda , \pi \)を推定するには、\(p(S), p(\lambda , \pi) \)を求めれば良い事になります。
\(S \)と\( ( \lambda , \pi ) \)について平均場近似による変分推論を行います。
$$\begin{eqnarray}
\log q^{\ast } (S) & \propto & E_{\lambda ,\pi }[\log p(X, S, \lambda , \pi )] \\
& \propto & E_{\lambda }[\log p(X|S,\lambda ) ] + E_{\pi }[\log p(S |\pi )] \\
\log q^{\ast}(\lambda, \pi ) & \propto  & E_{S }[\log p(X, S, \lambda , \pi )] \\
& \propto & E_{S }[\log p(X|S,\lambda ) ] + \log p(\lambda ) + E_{S }[\log p(S |\pi )]
\end{eqnarray}$$
\( \log ^{\ast} q(\lambda , \pi ) \)の形から、\( \lambda \)と\(\pi \)は独立である事が分かります。
$$\begin{eqnarray}
\log q^{\ast}(\lambda, \pi ) = \log q^{\ast}(\lambda) + \log q^{\ast} ( \pi )
\end{eqnarray}$$
後は、それぞれの期待値を計算すれば、更新された確率分布の指数部分が分かり、良く知った形になれば嬉しい!となるわけです。一つずつ計算しましょう。

\( q^{\ast} (S) \)について

\( q^{\ast} (S) \) を知るには、 ガンマ分布での\( \lambda , \log \lambda \)の期待値と、 \( \log \pi _k \)のディリクレ分布での期待値を計算しなくてはなりません。\( n \) については積を取るだけなので、\( E_{\lambda } [ (p(x_n |s_n , \lambda )] , E_{\pi } [ p(s_{n,k} |\pi) ] \)を計算しましょう。
$$\begin{eqnarray}
E_{\lambda } [\log p(x_n |s_{n} , \lambda ) ] &=& \sum _{k} s_{n,k}
E_{\lambda _k }[ \log {\rm Poi} (x_n | \lambda _k ) ] \\
&=&
\sum _{k} s_{n,k} \left( x_n E_{\lambda _k }[ \log \lambda _{k} ]- E_{\lambda _k }[ \lambda _k ] \right) – \log x_n ! \\
E_{\pi } [ \log p(s_{n} |\pi) ] &=& \sum_{k} s_{n,k} E_{\pi _k} [\log \pi _k]
\end{eqnarray}$$
この計算を見ると、\( q^{\ast} (S) \)は, カテゴリー分布に従うことが分かります。5
$$\begin{eqnarray}
q^{\ast}(s_n ) &=& {\rm Cat }(s_n|\eta _n )\\
\eta _{n,k} &\propto & x_n E_{\lambda _k }[ \log \lambda _{k} ]- E_{\lambda _k }[ \lambda _k ]
+ E_{\pi _k } [\log \pi _k]
\end{eqnarray}$$
\( \eta _{n,k} \)はベクトル \( \eta _n \)の第k成分を表しています。 他の確率分布の形も決定し、更新式を導きます。

\( q^{\ast} (\lambda ) \)について

\( q^{\ast} (\lambda ) \) を決定するには、以下の式を計算すれば良いです。
$$\begin{eqnarray}
E_{S }[\log p(X|S,\lambda ) ] + \log p(\lambda )
\end{eqnarray}$$
\( p(X|S, \lambda ) = \prod_{n} \prod_{k} {\rm Poi }(x_n|\lambda _{k} )^{s_{n ,k}} \)に注意すれば簡単に計算出来ます。
$$\begin{eqnarray}
q^{\ast} (\lambda ) \propto \sum _{k} \left( ( \sum_{n} s_{n,k} x_n + a-1 )\log \lambda _{k} – (\sum_{n} E_{S}[s_{n,k} ] +b ) \lambda _{k} \right)
\end{eqnarray}$$
この計算から、 各\( q^{\ast} (\lambda _{k} ) \) は再びガンマ分布に従うこと6 が分かります。
$$\begin{eqnarray}
q^{\ast} (\lambda _k ) &=& { \rm Gam } (\lambda _k|\hat{a}_k , \hat{b}_k )\\
\hat{a}_k &=& \sum _{n} E_S [s_{n,k} ]x_n +a \\
\hat{b}_k &=& \sum_{n} E_S [s_{n,k} ] + b
\end{eqnarray}$$

\( q^{\ast} (\pi ) \)について

\( q^{\ast} (\lambda ) \) を決定するには、以下の式を計算すれば良いです。
$$\begin{eqnarray}
E_{S }[\log p(S |\pi )] +\log (\pi | \alpha )
\end{eqnarray}$$
$$\begin{eqnarray}
q^{\ast} (\pi ) \propto \sum _{k} ( \sum_{n} E_{S} [s_{n,k} ] +\alpha _{k} -1 )\log \pi _{k}
\end{eqnarray}$$
この式から、\( q^{\ast} (\pi ) \)はディリクレ分布に従うことが分かります。
$$\begin{eqnarray}
q^{\ast} (\pi ) &=& {\rm Dir } (\pi |\hat{\alpha } )\\
\hat{\alpha _k } &=& \sum_{n} E_{S} [s_{n,k} ] +\alpha _{k}
\end{eqnarray}$$

更新式

以上の計算によって、パラメーター\(S, \pi , \lambda \)の
パラメーターを更新しましょう。
\( \psi (z ) = \frac{1}{\Gamma (z) } \int_{0}^{\infty} t^{z-1} e^{-t} \log t dt \)をディガンマ関数として、パラメーターの更新に必要な式は以下のように計算出来ます。 7
$$\begin{eqnarray}
E_{\lambda _k }[ \lambda _{k} ] &=& \hat{a} /\hat{b}\\
E_{\lambda _k }[ \log \lambda _{k} ] &=& \psi (\hat{a} ) – \log \hat{b} \\
E_{\pi _k }[ \log \pi _{k} ] &=& \psi (\hat{\alpha _k }) -\psi (\sum \hat{ \alpha _k })\\
E_S[ s_{n,k} ] &=& \eta _{n,k}
\end{eqnarray}$$
これで、確率分布を更新することが出来ます。8

変分EMアルゴリズム

変分推論では、KL ダイバージェンスが計算をやめる指標になります。EMアルゴリズムは、対数尤度\(\log p(X) \)が更新されているかが指標になっていました。
対数尤度は、KLダイバージェンスと汎関数の和\( \mathcal{L}(q) \) に分けることが出来ました。 9
$$\begin{eqnarray}
\log p(X| \theta ) &=& \mathcal{L}(q, \theta) +KL(q \| p) \\
\mathcal{L}(q, \theta) &=& \sum_{Z} q(Z) \log \left\{ \frac{p(X,Z| \theta )}{q(Z)} \right\}
\end{eqnarray} $$
上の式の\( \mathcal{L } (q , \theta ) \)をELBO ( evidence lower bound ) と呼びます。
KLダイバージェンスが常に0以上なので、 \( \log p(X| \theta ) \geq \mathcal{L}(q, \theta) \)となるから、lower bound です。この値を計算を止める指標にすることが出来ます。つまり、パラメーターの更新の度にELBOを計算し、変化率が小さくなったら更新をやめるのです。

まとめ

  1. 正規分布だと計算が大変なのです。別の記事で計算だけするかもしれません。
  2. 測れるけど、実際にどれくらい似ているかは分かりません。
  3. この確率分布の設定をした段階で、色々な仮定をしていることに注意した方が良いです。例えば、データ\(x_ n\)はi.i.d だったり、\( \lambda _k \)たちは独立だったりしてます。グラフィカルモデルを描くとスッキリするのですが、別の記事で説明したら追記します。
  4. 楽をするために\(a, b \)を\( k\)に寄らない定数として置いていますが、\( k \)毎に\( a_k , b_k \)と置いても良いです。
  5. カテゴリー分布の\( \log \)を取ると、\( \log p(s|\pi ) = \sum_{k} s_{k} \log \pi _k \)です。
  6. ガンマ分布の\( \log \)を取ると、\( (a-1)\log \lambda -b \lambda +{\rm const} \)です
  7. ガンマ関数の計算は、過去の記事で計算しています。
  8. 別の記事で、適当なデータで試してみます。
  9. 計算はEMアルゴリズムの記事にあります。 KLダイバージェンスが0以上で、0になるのは同じ確率分布の距離を測った時だけ、という事が大事でした。