我正在看Google's example关于如何在 Android 上部署和使用预训练的 Tensorflow 图(模型)。此示例使用位于以下位置的 .pb
文件:
https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
这是自动下载的文件的链接。
该示例展示了如何将.pb
文件加载到Tensorflow session 并使用它来执行分类,但它似乎没有提及如何生成这样的.pb
文件,在图被训练之后(例如,在 Python 中)。
有关于如何做到这一点的示例吗?
最佳答案
编辑: freeze_graph.py
脚本是 TensorFlow 存储库的一部分,现在用作从现有 TensorFlow GraphDef 和保存的检查点生成表示“卡住”训练模型的 Protocol Buffer 的工具。它使用与下面描述的相同的步骤,但使用起来更容易。
目前该过程没有很好的记录(并且有待完善),但大致步骤如下:
- 将您的模型构建并训练为名为
g_1
的tf.Graph
。 - 获取每个变量的最终值并将其存储为 numpy 数组(使用
Session.run()
)。 - 在名为
g_2
的新tf.Graph
中,创建tf.constant()
每个变量的张量,使用步骤 2 中获取的相应 numpy 数组的值。 使用
tf.import_graph_def()
将节点从g_1
复制到g_2
,并使用input_map
参数将g_1
中的每个变量替换为相应的变量在步骤 3 中创建的tf.constant()
张量。您可能还想使用input_map
指定新的输入张量(例如,将 input pipeline 替换为tf.placeholder()
)。使用return_elements
参数指定预测输出张量的名称。调用
g_2.as_graph_def()
以获取图形的 Protocol Buffer 表示形式。
(注意:生成的图将在图中有额外的节点用于训练。虽然它不是公共(public) API 的一部分,但您可能希望使用内部 graph_util.extract_sub_graph()
函数来剥离这些节点图中的节点。)
关于protocol-buffers - 有没有关于如何生成包含经过训练的 TensorFlow 图的 protobuf 文件的示例,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34343259/