python - 在预定义的图形对象上使用 tf.train.Saver 对象

标签 python tensorflow graph

TL;DR:为什么我们不能使用tf.saver.Save(graph=graph_obj)之类的东西来定义一个saver对象?

<小时/>

标题基本上说明了一切......据我所知,为了将保护程序对象链接到您的图表,您需要像这样定义它

with tf.Graph().as_default() as g_def:
    x_input_fun = tf.placeholder(dtype=tf.float32, name='input')
    y_output_fun = tf.placeholder(dtype=tf.float32, name='output')
    w_weights_fun = tf.get_variable('weight_set', dtype=tf.float32, shape=(5,5))
    output = tf.matmul(x_input_fun, w_weights_fun, name='pred')
    loss = tf.subtract(output, y_output_fun, name='loss')
    self.opti = tf.train.AdamOptimizer(loss, name='opti')
    g_def.add_to_collection(tf.GraphKeys.TRAIN_OP, self.opti)

    # Now the saver is linked to this graph when we do saver.save(...)
    saver = tf.train.Saver()

如果你想将它链接到默认图表,你只需要说tf.train.Saver()(当然如果你有可训练/可保存的变量)。

但是为什么我们不能这样做:tf.train.Saver(graph=g_def)

这对我来说感觉更自然。当我们从检查点恢复模型时,类似的情况(对我来说)成立......即使我们执行以下代码

with tf.Session(graph=tf.Graph()) as sess:
    saver = tf.train.import_meta_graph('some_meta_file.meta')
    saver.restore(sess, './some_meta_file')

然后tf.default_graph()仍然从导入的元文件中获得了节点。我可以想到它如何工作的原因...但现在为什么呢?

编辑:

我在检查导入图的节点时犯的一个错误如下。我运行了这段代码

with tf.Session(graph=tf.Graph()) as sess:
    saver = tf.train.import_meta_graph('some_meta_file.meta')
    saver.restore(sess, './some_meta_file')
    print(sess.graph == tf.get_default_graph())

因为我想确保默认图表不包含我刚刚导入到 session 图表中的节点。然而,这个 tf.get_default_graph() 当然是..相对的..因此在这个 session 中默认图实际上是导入的图。

因此,这也使得保护程序对象的工作更加合乎逻辑。由于该对象将始终保存/获取 tf.get_default_graph() 的内容。

最佳答案

为了保存或恢复任何内容,tf.train.Saver 需要一个 session ,并且 session 绑定(bind)到特定的图形实例(如您的示例中所示)。这意味着如果没有 session ,保护程序实际上毫无意义。我想这是不在保护程序中进行显式图形绑定(bind)的主要动机。

我认为您可能感兴趣的是 tf.train.Saver 中的 defer_build 属性:

defer_build: If True, defer adding the save and restore ops to the build() call. In that case build() should be called before finalizing the graph or using the saver.

通过这种方式,您可以创建一个不绑定(bind)到任何图形的 tf.train.Saver ,并稍后针对特定的 tf.Graph 调用 build() 实例。

关于python - 在预定义的图形对象上使用 tf.train.Saver 对象,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48169666/

相关文章:

python - Adagrad 如何在 Keras 中工作? Keras Optimizer 中的 self.weights 是什么意思?

python - 仅在 Facebook Prophet 模型 Python 上绘制预测值

javascript - D3 : Bar chart coloring

python - __reduce__ 函数在 pickle 模块的情况下如何工作?

python - 在 Python 脚本中提升权限

python - Matplotlib 在所有 Jupyter 笔记本中显示空白数字(无错误)

python - "AttributeError: no attribute ' 下载“使用PyTube

python - Amazon SageMaker 中的 Tensorflow 服务

python - 每个示例具有不同权重的 Keras 自定义损失函数

python - 将 pandas 数据框转换为定向 networkx 多图