TensorFlow MobileのHello World [スマホでAIモデルを実行する]
TensorFlowで作成した「学習済みモデル」をAndroidで実行する方法です。
前提条件
TensorFlow.jsのHello World [WebでAIモデルを実行する] |
と同様な事を行います。環境設定などがお済でない方は先にご覧ください。
1. モデル(チェックポイント、ログ、PBファイル)の作成
次のコードをJupyter Notebookで実行します。
import tensorflow as tf with tf.name_scope('X'): x = tf.placeholder(tf.int32) with tf.name_scope('Y'): y = tf.Variable(3) with tf.name_scope('Z'): z = tf.add(x, y) saver = tf.train.Saver() init =tf.global_variables_initializer() with tf.Session() as sess: tf.summary.FileWriter("logs", sess.graph) tf.train.write_graph(sess.graph_def, './', 'graph.pbtxt') sess.run(init) result = sess.run(z, feed_dict={x:5}) saver.save(sess, 'ckpt/my_model') result
TensorBoardで確認するとグラフは次のようになります。
2. summarize_graphで入出力ノードを検査する
--in_graphのファイルパスは適宜、変更してください。
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/foo/graph.pbtxt
入力ノードは「X/Placeholder」。出力ノードは「Z/Add」となっています。
ただし、summarize_graphはあくまでも検査(予測)なのでTensorBoardでも再確認して下さい。また、graph.pbtxtファイルの中身はテキスト形式なのでそちらでも確認可能です。
3. freeze_graphでPBファイルとチェックポイントファイルを固めて「Frozen Model」にする
frozen_graph.pb(Frozen Model形式)を作成します。
bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/foo/graph.pbtxt --input_checkpoint=/foo/ckpt/my_model --output_graph=/foo/frozen_graph.pb --output_node_names=Z/Add
TensorFlow.jsでは「Frozen Model」を「Web-friendly format」に変換しましたが、スマホではFrozen Modelをそのまま使用します。
4. Android Studioでプロジェクトを作成する
Android Studioで新規プロジェクトを作成します。
4-1. build.gradleの設定
build.gradle(モジュール:app)の下部に次のコードを追記します。
allprojects { repositories { jcenter() } } dependencies { api 'org.tensorflow:tensorflow-android:+' }
これだけで、モバイルでTensorFlowが使用できるようになります。
4-2. frozen_graph.pbを取り込む
OS側の操作で\app\src\mainにassetsフォルダを作成して、その中にfrozen_graph.pbを移動します。
Android Stduioでは次のように表示されます。
4-3. ソースコード
import android.app.AlertDialog; import android.support.v7.app.AppCompatActivity; import android.os.Bundle; import android.view.View; import org.tensorflow.contrib.android.TensorFlowInferenceInterface; public class MainActivity extends AppCompatActivity { @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); findViewById(R.id.button).setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(getAssets(), "frozen_graph.pb"); int[] inputs = {5}; int[] outputs = new int[1]; String[] outputNames = {"Z/Add"}; // 入力データを設定する inferenceInterface.feed("X/Placeholder", inputs); // モデルの推論(実行) inferenceInterface.run(outputNames); // モデルから結果を取得する inferenceInterface.fetch(outputNames[0], outputs); AlertDialog.Builder alertDialogBuilder = new AlertDialog.Builder(MainActivity.this); alertDialogBuilder.setTitle("結果"); alertDialogBuilder.setMessage(String.valueOf(outputs[0])); alertDialogBuilder.setPositiveButton("OK", null); alertDialogBuilder.show(); } }); } }
[解説]
ボタンを押すとメッセージボックスに8が表示されます。
これは、元のモデルでは
となっているからです。xはプレースフォルダ(placeholder)、yは変数(Variable)で3で定義されています。
xのプレースフォルダはAndroid側で渡す入力値です。
なので、21行目の「int[] inputs = {5};」を「int[] inputs = {7};」にすると10が表示されます。
[TensorFlowInferenceInterface]
現在の所、TensorFlowInferenceInterfaceクラスについては詳細な情報がありません。Googleさんで検索しても約 2,360 件しかヒットしません。
なので、feed/run/fetchの各メソッドの宣言をまとめてみました。
// feed系 public void feed(String inputName, boolean[] src, long... dims) public void feed(String inputName, float[] src, long... dims) public void feed(String inputName, int[] src, long... dims) public void feed(String inputName, long[] src, long... dims) public void feed(String inputName, double[] src, long... dims) public void feed(String inputName, byte[] src, long... dims) public void feedString(String inputName, byte[] src) public void feedString(String inputName, byte[][] src) public void feed(String inputName, FloatBuffer src, long... dims) public void feed(String inputName, IntBuffer src, long... dims) public void feed(String inputName, LongBuffer src, long... dims) public void feed(String inputName, DoubleBuffer src, long... dims) public void feed(String inputName, ByteBuffer src, long... dims) // run系 public void run(String[] outputNames) public void run(String[] outputNames, boolean enableStats) public void run(String[] outputNames, boolean enableStats, String[] targetNodeNames) // fetch系 public void fetch(String outputName, float[] dst) public void fetch(String outputName, int[] dst) public void fetch(String outputName, long[] dst) public void fetch(String outputName, double[] dst) public void fetch(String outputName, byte[] dst) public void fetch(String outputName, FloatBuffer dst) public void fetch(String outputName, IntBuffer dst) public void fetch(String outputName, LongBuffer dst) public void fetch(String outputName, DoubleBuffer dst) public void fetch(String outputName, ByteBuffer dst)
引数さえわかれば、なんとかなりますね :-)