tensorflow - 使用 output_layer=pool_3 运行 label_image 时出错

标签 tensorflow

图像识别教程中的 tensorflow 练习建议运行 C++ 示例 --output_layer=pool_3。我尝试运行此程序并收到错误:

$ bazel-bin/tensorflow/examples/label_image/label_image --output_layer=pool_3

I tensorflow/core/common_runtime/local_device.cc:40] Local device intra op parallelism threads: 4
I tensorflow/core/common_runtime/direct_session.cc:58] Direct session inter op parallelism threads: 4
W tensorflow/core/common_runtime/executor.cc:1076] 0x558ae6a5d210 Compute status: Invalid argument: input must be 2-dimensional
     [[Node: top_k = TopK[T=DT_FLOAT, k=5, _device="/job:localhost/replica:0/task:0/cpu:0"](Const/_0)]]
E tensorflow/examples/label_image/main.cc:311] Running print failed: Invalid argument: input must be 2-dimensional
     [[Node: top_k = TopK[T=DT_FLOAT, k=5, _device="/job:localhost/replica:0/task:0/cpu:0"](Const/_0)]]

我错过了什么?

最佳答案

这里的问题是 image recognition tutorial 中的 TensorFlow 代码在 --output_layer=pool_3 选项起作用之前需要进行额外的修改:

One can specify this by setting --output_layer=pool_3 in the C++ API example and then changing the output tensor handling.

要更改输出张量处理,您需要修改代码 below this line in label_image/main.ccPrintTopLabels()函数调用GetTopLabels() ,它采用单个二维(批处理 x 类)张量 - 假设为 tf.nn.softmax() 的输出包含一批图像中标签的概率分布,并使用 tf.nn.top_k() 构建一个小型 TensorFlow 图操作。 pool_3 层输出一个四维(批处理 x 高度 x 宽度 x 深度)张量,这需要额外的处理。

附加处理已留给读者作为练习。但是,您可以尝试以下一些操作:

  • 将输出 reshape 为二维(批处理 x 特征)矩阵,并训练全连接层(或更多)以识别您自己的训练数据中的特征。

  • 通过沿深度维度对池化层的输出进行切片,并使用tf.image.encode_png()将切片编码为图像来可视化池化层的输出。

注意由于文档更好,我提供了 Python 文档的链接,而不是相应的 C++ API。您可能会发现修改 Python code for Inception inference 更加容易!相反。

关于tensorflow - 使用 output_layer=pool_3 运行 label_image 时出错,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34466356/

相关文章:

python - tensorflow 中的成对求和

python - 如何在完整验证示例上评估 Tensorflow 模型

python-2.7 - 为什么 `tf.constant_initializer`不取常数张量?

tensorflow - 如何在tensorflow中设置特定的gpu?

tensorflow - 使用 tf.app.run() 从类中调用 main 函数

python - Tensorflow:从图像中预测一个点,用点标签训练模型

python - 在Keras中,如何通过模型批量发送每一项?

python - TensorFlow:不可重复的结果

python - 执行 import tensorflow as tf 时出现导入错误