学習済みグラフをプロトコルバッファ形式で保存する

2016-09-14
blog

PC上でTensorflowを使って学習させて、そのデータを使ってスマホなどで識別だけさせるというときに、データはPythonで通常出力するチェックポイントファイルじゃなくてプロトコルバッファ形式にする必要がある。

TesorFlow: Pythonで学習したデータをAndroidで実行 - Qiitaを参考にして、Variableevalで値を取り出して、同じ構成なんだけどtf.Variableの代わりにtf.constantを使うようにしたグラフを作成して…ということをやっていたんだけど、面倒だし複雑になってしまう。

TensorFlowで学習済みグラフを保存する方法 | Workpilesに便利な方法が書いてあって、tf. import_graph_defを使う方法と、まさに目的通りのVariableconstantに置き換える便利関数convert_variables_to_constants を使う方法が書いてあった。

どちらも問題なく動いたけど、convert_variables_to_constantsのほうが簡単なのでそちらを使うことにした。

import tensorflow as tf
from tensorflow.python.framework import graph_util

# グラフを構築する関数
# 学習時とも共通で使える
def build_graph():
...
y = tf.nn.softmax(..., name='output') # 出力層の名前

with tf.Graph().as_default() as graph:
build_graph()

with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, 'checkpoint.ckpt') # 学習済みのグラフを読み込み

graph_def = graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), ['output']) # 出力層の名前を指定
# プロトコルバッファ出力
tf.train.write_graph(graph_def, '.',
'graph.pb', as_text=False)

リンク