Tensorflow freeze_graph 脚本在使用 Keras 定义的模型上失败

标签 tensorflow

我正在尝试将使用 Keras 构建和训练的模型导出到 protobuffer,我可以将其加载到 C++ 脚本中(如本例所示)。我生成了一个包含模型定义的 .pb 文件和一个包含检查点数据的 .ckpt 文件。但是,当我尝试使用 freeze_graph 脚本将它们合并到单个文件中时,出现错误:

ValueError: Fetch argument 'save/restore_all' of 'save/restore_all' cannot be interpreted as a Tensor. ("The name 'save/restore_all' refers to an Operation not in the graph.")

我像这样保存模型:

with tf.Session() as sess:
    model = nndetector.architecture.models.vgg19((3, 50, 50))
    model.load_weights('/srv/nn/weights/scratch-vgg19.h5')
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    graph_def = sess.graph.as_graph_def()
    tf.train.write_graph(graph_def=graph_def, logdir='.',   name='model.pb', as_text=False)
    saver = tf.train.Saver()
    saver.save(sess, 'model.ckpt')

nnDetector.architecture.models.vgg19((3, 50, 50)) 只是 Keras 中定义的类似 vgg19 的模型。

我像这样调用 freeze_graph 脚本:

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=[path-to-model.pb] --input_checkpoint=[path-to-model.ckpt] --output_graph=[output-path] --output_node_names=sigmoid --input_binary=True

如果我运行 freeze_graph_test 脚本,一切正常。

有人知道我做错了什么吗?

谢谢。

致以诚挚的问候

菲利普

编辑

我尝试打印 tf.train.Saver().as_saver_def().restore_op_name ,它返回 save/restore_all

此外,我尝试了一个简单的纯 tensorflow 示例,但仍然遇到相同的错误:

a = tf.Variable(tf.constant(1), name='a')
b = tf.Variable(tf.constant(2), name='b')
add = tf.add(a, b, 'sum')

with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
tf.train.write_graph(graph_def=sess.graph.as_graph_def(), logdir='.',     name='simple_as_binary.pb', as_text=False)
tf.train.Saver().save(sess, 'simple.ckpt')

而且我实际上也无法在 python 中恢复图形。如果我与保存图形分开执行以下代码,则使用以下代码会抛出ValueError:没有要保存的变量(也就是说,如果我在同一脚本中保存和恢复模型,则一切正常)。

with gfile.FastGFile('simple_as_binary.pb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

with tf.Session() as sess:
    tf.import_graph_def(graph_def)
    saver = tf.train.Saver()
    saver.restore(sess, 'simple.ckpt')

我不确定这两个问题是否相关,或者我是否只是没有在 python 中正确恢复模型。

最佳答案

问题在于原始程序中这两行的顺序:

tf.train.write_graph(graph_def=sess.graph.as_graph_def(), logdir='.',     name='simple_as_binary.pb', as_text=False)
tf.train.Saver().save(sess, 'simple.ckpt')

调用tf.train.Saver()向图中添加一组节点,其中包括一个名为“save/restore_all”的节点。但是,该程序在写出图形之后调用它,因此传递给 freeze_graph.py 的文件不包含进行重写所必需的那些节点。

颠倒这两行应该使脚本按预期工作:

tf.train.Saver().save(sess, 'simple.ckpt')
tf.train.write_graph(graph_def=sess.graph.as_graph_def(), logdir='.',     name='simple_as_binary.pb', as_text=False)

关于Tensorflow freeze_graph 脚本在使用 Keras 定义的模型上失败,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37704362/

相关文章:

python - TensorFlow 在实现逻辑回归时返回 nan

tensorflow - TensorFlow Timeline 中 GPU_0_bfc 分配器和 GPU_host_bfc 分配器的区别

python - 在 TensorFlow 中对整数张量进行卷积

tensorflow - 在 tensorflow 中卡住部分神经网络

cuda - 带有 tensorflow 、Windows 的多进程多 GPU

python - 从 C++ 自动生成 python 模块的 Tensorflow 源代码

python - Keras model.predict() 花费了不合理的时间

r - 如何在 LSTM 中有效地使用批量归一化?

python - 在 Python 中变量后面使用括号有什么作用?

python - 设置 TensorFlow GPU 支持的问题