CNNで畳み込み/プーリング後のテンソルのサイズ(Shape)を確認する
CNN(畳み込みニューラルネットワーク)で畳み込み、プーリング後のテンソルのサイズは非常にわかりにくいです。
そこで、Tensorオブジェクトのget_shape()メソッドを使用すると、テンソルのサイズを簡単に確認する事が可能です。
※全コードは後述する参考文献を参照してください。
[5x5]ストライド1でパディングあり
x = tf.placeholder(tf.float32, [None, 784]) img = tf.reshape(x,[-1,28,28,1]) f1 = tf.Variable(tf.truncated_normal([5,5,1,32], stddev=0.1)) conv1 = tf.nn.conv2d(img, f1, strides=[1,1,1,1], padding='SAME') print(conv1.get_shape())
(?, 28, 28, 32)
[5x5]ストライド1でパディングなし
x = tf.placeholder(tf.float32, [None, 784]) img = tf.reshape(x,[-1,28,28,1]) f1 = tf.Variable(tf.truncated_normal([5,5,1,32], stddev=0.1)) conv1 = tf.nn.conv2d(img, f1, strides=[1,1,1,1], padding='VALID') print(conv1.get_shape())
(?, 24, 24, 32)
[3x3]ストライド1でパディングなし
x = tf.placeholder(tf.float32, [None, 784]) img = tf.reshape(x,[-1,28,28,1]) f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1)) conv1 = tf.nn.conv2d(img, f1, strides=[1,1,1,1], padding='VALID') print(conv1.get_shape())
(?, 26, 26, 32)
[3x3]ストライド2でパディングなし
x = tf.placeholder(tf.float32, [None, 784]) img = tf.reshape(x,[-1,28,28,1]) f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1)) conv1 = tf.nn.conv2d(img, f1, strides=[1,2,2,1], padding='VALID') print(conv1.get_shape())
(?, 13, 13, 32)
[3x3]ストライド2でパディングあり
x = tf.placeholder(tf.float32, [None, 784]) img = tf.reshape(x,[-1,28,28,1]) f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1)) conv1 = tf.nn.conv2d(img, f1, strides=[1,2,2,1], padding='SAME') print(conv1.get_shape())
(?, 14, 14, 32)
[3x3]ストライド3でパディングなし
x = tf.placeholder(tf.float32, [None, 784]) img = tf.reshape(x,[-1,28,28,1]) f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1)) conv1 = tf.nn.conv2d(img, f1, strides=[1,3,3,1], padding='VALID') print(conv1 .get_shape())
(?, 9, 9, 32)
[3x3]ストライド3でパディングあり
x = tf.placeholder(tf.float32, [None, 784]) img = tf.reshape(x,[-1,28,28,1]) f1 = tf.Variable(tf.truncated_normal([3,3,1,32], stddev=0.1)) conv1 = tf.nn.conv2d(img, f1, strides=[1,3,3,1], padding='SAME') print(conv1 .get_shape())
(?, 10, 10, 32)
最後に
これらの例は畳み込みでしたが、プーリングでも同様に確認可能です。
参考文献
TensorFlowではじめるDeepLearning実践入門のサンプルコード
スポンサーリンク
関連記事
前の記事: | TensorFlow.jsのHello World [WebでAIモデルを実行する] |
次の記事: | 画像内の物体を検出するObject Detection APIの使用方法 [TensorFlow] |
公開日:2018年08月02日 最終更新日:2018年08月24日
記事NO:02710