多層ニューラルネットでBatch Normalizationの検証 - Qiitaでクォートされていた、
バッチ正規化使ってないなら人生損してるで
If you aren’t using batch normalization you should
というのを見て初めてニューラルネットワークでのバッチ正規化というものを知った。 なんか使うだけでいいことずくめらしいので調べてみた。
イントロ
論文はBatch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shiftにある。 中身はよく理解してないけど、学習時のミニバッチごとに、各レイヤーの各要素ごとに正規化して学習させることで内部共変量シフト(Internal Covariate Shift)を減らすことができて、それによって学習率を高く設定できて速く学習が進み、またウェイトの初期値に敏感にならなくて済む。 またRegularizerとしても機能するためドロップアウトを使わなくてもよい、ということらしい。 論文では、その時点での最高の画像分類のモデルに対して14分の1のステップ数で正解率に達したとのこと。
実装
ということで論文の中身を読んでみるが意味もよく理解できないし、自力ではまず実装に落とし込めない。 そこでキーワードでググってみて、Batch Normalizationによる収束性能向上 - Qiitaやpython - How could I use Batch Normalization in TensorFlow? - Stack Overflowの解答などを見て、実際に動かしたりしてなんとなく動作が掴めた。 実際のところ便利関数を用意して各層に挟んでやればそのまま適用できて、Tensorflowなどのフレームワークを使えば自動微分で逆誤差伝播も勝手に計算してくれるので、詳細に仕組みを理解しなくても使えてしまうのだった。
学習時のミニバッチごとの平均と分散を計算するにはtf.nn.momentsを使う。
評価時には訓練データ全体の平均と分散を使…いたいところだけど計算するのが大変なので、tf.train.ExponentialMovingAverage(指数移動平均)を使う方法が一般的のようだ。 これだと学習を進めていくうちに自動的に値が得られ、また個々の値を保持しておく必要がないので都合がいいのだろう。
学習結果の保存・復帰
学習とテストデータでの評価はできたけど、状態を保存するところで躓いた。
学習時にはそれまでに与えている訓練データの平均と分散を使えるが、それらのVariable
をどうやって保存したらいいのかよくわからなかった。
クロージャを配列として返しておいて学習が終わったら取り出せるようにしてtf.identityで名前をつけて別のグラフを構築して…とか力づくでやろうとしたらえらく複雑になってしまった。
でうろついてたところ、Implementing Batch Normalization in Tensorflow - R2RTのやり方がスマートだった(コメント)。
訓練データ全体の平均と分散を保持する変数のpop_mean
とpop_var
をtrainable=False
として生成することでチェックポイントに保存されるようになるらしい。
そして学習時にはそれらの変数に対してtf.Variable.assignすることで値がセットされ、tf.Saver
で保存・復帰ができる。
ソース
Deep MNIST for Expertsにバッチ正規化を適用してみた。 以下ブロックごとに解説:
インポート、設定
import tensorflow as tf |
flags
でデフォルトのパラメータを設定しつつ、コマンドラインから変更できるようにする
バッチ正規化ルーチン
# this is a simpler version of Tensorflow's 'official' version. See: |
- 学習時:
phase_train
にVariable
を渡してもらい、tf.nn.batch_normalizationを呼び出してバッチ正規化を行うtf.nn.batch_normalization
を呼び出さずに自前で計算することも可能:scale * (inputs - mean) / tf.sqrt(variance + epsilon) + beta
- tf.condで分岐させる:
- 学習時
true
の場合には、ミニバッチの平均と分散 - 学習中にテストデータで正解率を調べる場合には
false
にして、それまでの学習データの指数移動平均
- 学習時
- 識別時:
phase_train
にNone
を渡してもらい、計算済みの訓練データの平均を使う
グラフ構築
def build_graph(is_training): |
- Deep MNIST for Expertsのモデルにバッチ正規化を適用
- バイアス項は不要なので削除し、活性化関数に渡す前に
batch_norm_wrapper
を呼び出す - 出力層はバッチ正規化はしない
- ドロップアウトはなくてもいい場合がある、ということなので適用しないでみる
駆動部分
def train(mnist): |
train
で学習させて、チェックポイントファイルに保存test
でチェックポイントファイルから読み込み、テストデータでの正解率を計算is_training
を切り替えて、学習時のグラフとは別のグラフを作っているが、学習時と同じ正解率になれば望み通り保存・復帰ができている
感想
- グラフとか取ってないのでフィーリングの比較だけど、論文の通り学習率を高く設定できて、学習がなかなか進まないということも少なくて、学習率や初期値の調整にわずらわされなくなるのでとてもよい
- TFLearnでも使えるようなので、そちらで動かせるようにしたい
- 「BatchNormalizationの仕組みとその直感的な理解 - Qiita」という記事がよさそうなんだけど見れなくなっていて残念…