ニューラルネットワークの話

機械学習 機械学習

本当にお話程度の事しか書かないとは思いませんでした。
おすすめの本のリンクを貼っておきます。

マサムネも読んでいる定番です。数学の人には細部が書いてなかったり当たり前の事を長々と書いていたりで物足りないかもしれません。しかし、python に慣れ、ニューラルネットワークを実際に作る体験ができるので、おすすめです。

ニューラルネットワークとは

ニューラルネットワークとは、機械学習のモデルの1つです。 以下の計算グラフで定義されるモデルの事を指します。

ニューラルネットワークを表す計算グラフ

左側から順に入力層、隠れ層、出力層と呼ばれます。 画像には書かれていませんが、線形回帰の切片を作る項が隠れ層と出力層に入ったりします。全てのノードが繋がっていますが、これは各ノードで行列を書けることを表しています。全てのノードが繋がった部分を全結合と言ったりします。 例えば、層Aと層Bは全結合にする。のように使います。機械学習のライブラリだとdenseと表されます。
隠れ層、出力層の活性化関数を[mathjax]\( h, \sigma \)、入力層→隠れ層の行列をA、隠れ層→出力層の行列をB として、出力を表してみましょう。入力データxに対して 出力yは 次のようになります。
[mathjax]$$ \begin{eqnarray}
E = \big( \sigma(Bh(A\vec{x_0} )),\cdots , \sigma (Bh(A \vec{x_N}) ) \big)
\end{eqnarray}$$
隠れ層の活性化関数は任意ですが、reluが使われることが多いです。
[mathjax]$$ \begin{eqnarray}
\bf{relu} (x) = max(x,0)
\end{eqnarray}$$
出力層の活性化関数は、回帰なら恒等写像、分類問題ならソフトマックス関数とする事が多いです。 誤差関数は問題に合わせて選ばれます。 ちゃんとした証明は読んだことが無いのですが、ニューラルネットワークモデルで、滑らかな関数は近似できるらしいです。ちゃんとした証明は理解出来たら記事にするかもしれません。
実際に、ニューラルネットワークを使ってみましょう。

ニューラルネットワークによる分類問題

ニューラルネットワークを使って、タイタニック号の生存者を予測してみたいと思います。使うデータは、kaggle で公開されているtitanicです。そのまま使うのは無理で、ちょっと加工が必要ですが詳細は別の記事で書くかもしれません。

titanic号の乗客データSurvived=1が生存

良い感じに加工したデータを隠れ層のノード数が64のニューラルネットワークモデルに学習させてみます。

正解率のプロット
オレンジ:トレーニング用データ,
青:テスト用データ

100回パラメーターを更新して、正解率の推移をプロットしています。これを見ても上手くモデルが作れているのか分かりません。(良くはない感じが伝わって欲しいです。)クロスエントロピー誤差が更新の度に小さくなっているか確認してみましょう。

クロスエントロピーのプロット

誤差関数を見ると、トレーニング用のデータは順調に誤差が小さくなっていますが、テスト用のデータでは、40回目くらいから頭打ちになっています。正答率をみても、40くらいから殆ど正答率が変わっていないことが分かります。
正解率を指標にするより、誤差関数を指標にしたほうがモデルが上手く機能しているかはっきり分かる例になってしまいました。

過学習と汎化性能

ニューラルネットワークを始めとする機械学習のモデルは、非常に表現力が高く、トレーニング用のデータに対しては、ほぼ100%の精度を誇ります。一方で、トレーニング用のデータしか表現出来ないモデルになりがちです。このような状態を過学習に陥っていると言います。過学習に陥らないようにするにはどうしたら良いのでしょうか。(モデルを作る側として)一番簡単なのは、使えるデータを増やすことです。トレーニングデータにない事象は起こらないという状態なら、過学習しても問題ないのでとても簡単です。
もう一つ過学習を防ぐ方法があります。それは、モデルが学習しにくいように誤差項(ペナルティ)を導入することです。データを増やすことは出来ないことが多いので、大体はこの誤差項で解決します。この誤差項については、有名な物があるので、別の記事で説明します。
有名なリッジ回帰についての解説はこちら。

リッジ回帰分析
リッジ回帰の解説をします。重回帰分析にペナルティを課すモデルです。これによって、パラメーター全体が大きくなることが抑えられ、データ自体のバラツキを無視してくれるようになります。さよなら過学習。

まとめ

・ニューラルネットワークモデルとは、層が三つで、全結合しているモデル。
・適当にモデルを作ると過学習が起こる。
・過学習を防ぐには、誤差項でモデルを制御する。

タイトルとURLをコピーしました