計算グラフと関数の合成

機械学習 機械学習

計算グラフというものがあります。機械学習の勉強をしていると、丸と線だけで出来た図に出くわします。あれを見てどんなモデルなのか理解できる必要があります。関数の合成というのを意識しておくと、誤差逆伝搬法や、誤差関数を取り換えるといった操作もピンとくると思います。

関数の合成

関数の合成という観点で、多クラス分類を振り返ります。関数の合成という観点を持つと、機械学習に関する議論の見通しがよくなります。クラス数をK, データ数をN, 説明変数の数をMとしておきます。クラスに関わる添え字はk 、データの番号はi で表していきます。多クラス分類の誤差関数はクラスkに関するパラメーターを[mathjax] \( \vec{w_k}\)として、以下の式で表すことが出来ました。
[mathjax]$$ \begin{eqnarray}
a_{ik} &=&\vec{w_k} \cdot \vec{x_i} \\
\pi _{ik} &=& \frac{\exp (a_{ik}) }{\sum_{j}^{K}\exp (a_{ij} )} \\
L &=& -\sum_{i} \sum_{k} t_{ik} \pi_{ik}
\end{eqnarray}$$
ベクトル[mathjax]\( \vec{w} , \vec{x} \), 行列[mathjax]\( t= \{t_{ij} \}, A=\{ a_{ij} \} \)に対して 以下のように関数(写像)を定義しましょう。
[mathjax]$$ \begin{eqnarray}
f( \vec{w} , \vec{x} )&=& \vec{w} \cdot \vec{x} \\
g(A)&=& \pi = \{ \frac{\exp (a_{ij}) }{\sum_{k}^{K}\exp (a_{ik} )} \} \\
L(t, A) &=& -\sum_{i} \sum_{k} t_{ik} \ln (a_{ik})
\end{eqnarray}$$
クラス分類の尤度は、[mathjax]\( W=\{ f(\vec{w_j} , \vec{x_i}) \} \)として、以下のように関数の合成で書くことが出来ます。
[mathjax]$$ \begin{eqnarray}
L= L(t, g( W))
\end{eqnarray}$$

関数の合成と連鎖律

連鎖律を使う練習として多クラス分類の尤度の勾配を計算してみましょう。
[mathjax]$$ \begin{eqnarray}
\nabla _{\vec{w_j}} L &=& -\sum_i \sum_k \frac{1}{\pi_{ik}} \nabla_{\vec{w_j}}(a_{ij})\frac{\partial \pi_{ik}}{\partial a_{ij}} \\
&=& -\sum_i \sum_k \frac{1}{\pi_{ik}} \vec{x_i} \pi_{ik} \frac{\delta _{kj}\sum _l \exp (a_{il}) – \exp (a_{ij}) }{\sum _l \exp (a_{il})} \\
&=& -\sum_i \sum_k \vec{x_i} (\delta _{kj} -\pi_{ij} ) \\
&=& \sum_i (\pi_{ij}-t_{ij})\vec{x_i}
\end{eqnarray}$$
最後の式変形では, ラベルはkについて足すと1になることを使いました。

関数と計算グラフの対応

演算の操作と図を対応させることで、どんなモデルを使ったか可視化することが出来ます。次のような、円と矢印からなるグラフを描く事が多々あります。

関数の合成を図に表した

〇をノードと呼び、矢印はエッジと呼びます。ノード上の、誤差関数以外の関数は活性化関数と呼びます。機械学習の文脈では、内積を取ったり、行列をかけたりする関数は、活性化関数として扱わず、エッジ上に[mathjax]\( w_{jk} \)とか [mathjax]\( W \) とだけ書いておく事が多いです。
このようなグラフを計算グラフと呼んだりします。グラフ化することの利点は、使っている関数が何か明確になる点です。また、モデルが複雑になったときに、見たい場所だけ見る事が出来ます。
誤差関数の値や予測値は、最初から最後まで矢印を辿らないと得ることは出来ませんが、パラメーターを更新するための勾配はどうでしょうか。勾配の式と、グラフを見るとそうでもないことが分かりますが、別の記事で解説したいと思います。

まとめ

・パラメーターを求める手順を、関数の合成という観点から整理した。
・連鎖律を使って、勾配を計算した。
・グラフを使うことで、モデルを表せる。

タイトルとURLをコピーしました