TensorFlowを使ってひらがな+漢字の文字認識を行うプログラムを試してみた。


cnn-kanji


TensorFlowでのMNIST学習結果を、実際に手書きして試す - すぎゃーんメモがTensorFlowのサンプルで構築するニューラルネットワークに対してインタラクティブに書いた文字を認識させる、ということをしているので、これをベースにさせてもらった。

学習

使用したデータは手書教育漢字データベースETL8で、ひらがな75文字+漢字881の合計956文字を160セット分。白黒2値イメージは64x63の解像度。

ニューラルネットワークの構成はDeep MNISTのまま、5x5の畳み込み+最大プーリングを2段、それぞれ32と64フィーチャーを抽出、全結合層で1024次元にしてドロップアウト、最終的にソフトマックスで956次元を出力。

元画像データをMNISTにあわせて28x28にリサイズ。訓練用150、テスト用10に分割し20,000エポック学習させた。テストデータでの認識率は96.819%になった。

  • TensorFlowでの学習結果を保存したデータ.ckptのサイズ:約16Mバイト (16,978,193)
    • 係数:
      • 畳み込み層1:(5x5x1x32) + 32 = 832
      • 畳み込み層2:(5x5x32x64) + 64 = 51,264
      • 全結合層:(7x7x64)x1024 + 1024 = 3,212,288
      • 最終層:1024x956 + 956 = 979,900
      • 計:4,244,284
      • バイト数: 係数4,244,284個 x float4バイト = 16,977,136
    • だいたい一致

CSVの読み込み

訓練データとして使う画像データをCSVとして保存しているのだけど、データが956x150≒14.3万行になり、Python標準のcsv.readerを使うと69秒もかかってしまう。あらかじめシリアライズしておいて読みこむだけにすれば速くなるんじゃないかと、pickleを使ってみたが、逆に111秒と遅くなってしまった。

pandasというライブラリのread_csvだと12秒とまあマシだったので、これを使うことにした:

import pandas as pd

...
df = pd.read_csv(fileName, header=None)
data = df.as_matrix()  # numpy.arrayとして取得

多クラス判定後、確率の高い順に数件取り出す

MNISTでは認識させる文字は数字10個なので認識させた結果のすべての文字の確率を返しているが、956文字となるとすべてを返すのは無駄なので、高い候補だけを返すようにする。

TensorFlowで多クラス判定のニューラルネットワークの最終出力から確率の高い順にn件取り出すには、numpyのargpartitionを使うと小さい値n件のインデクスを取り出せるので、それと値を組み合わせてソートしてやる:

import numpy as np

def extract_ranking(probabilities, labels, n):
  """
  Returns highest n (probability, label) pairs.
  @param probabilities  np.array(1, N)
  @param labels  label list
  @return [(probability 1, label 1), ...,  (probability n, label n)]
  """
  indices = np.argpartition(-probabilities, n - 1)[:n]
  values = probabilities[indices]
  ranking = np.array([values, indices]).transpose().tolist()
  list.sort(ranking, reverse=True)
  return [(prob, labels[int(label)]) for (prob, label) in ranking]

使う側:

kLabels = ['あ', 'い', 'う', ...]
...
output2 = convolutional(input)
top10_2 = extract_ranking(output2, kLabels, 10)

動かした結果

で実際に手書き入力してみると、そこそこうまく判定してくれる。ナイーブなテストでここまでしっかり認識してくれるとは思っていなかった。学習時のハイパーパラメータもいじってないしネットワーク構成もそのままであまり深くないし、GPUなしの低火力ノートPCで学習、データもたったの150セットで水増しもなし。これもうちょっと物量増やしてちゃんと調整したら相当使い物になるんじゃないか?という手応えだった。

ただ問題もあって、もともとデータにない文字は絶対認識できないし、逆にめちゃくちゃに書いて全く文字に見えなくてもなんかしらと判定してしまう。また人間ではしないような認識間違いをしてしまう。これはCNNが画像を畳み込んで判定しているだけで、ストロークの形とか交差とか分岐とか人間が認識しているような特徴をベクトルとして入力してないからだと思う。またどちらかというと画数の少ない単純な文字のほうが間違えやすい。

  • 学習機能のないTensorFlowフロントエンド版が欲しい。そのJavaScript版があれば入力画像をサーバに投げずに判定できる。