python - 从文件打开 tensorflow 图

标签 python tensorflow graph

我正在尝试使用 tensorflow 进行研究,但我不明白如何打开和使用早期保存在文件中的 tf.Graph 类型的图表。像这样:

import tensorflow as tf

my_graph = tf.Graph()

with g.as_default():
    x = tf.Variable(0)
    b = tf.constant(-5)
    k = tf.constant(2)

    y = k*x + b

tf.train.write_graph(my_graph, '.', 'graph.pbtxt')

f = open('graph.pbtxt', "r")

# Do something with "f" to get my saved graph and use it below in
# tf.Session(graph=...) instead of dots

with tf.Session(graph=...) as sess:
    tf.initialize_all_variables().run()

    y1 = sess.run(y, feed_dict={x: 5})
    y2 = sess.run(y, feed_dict={x: 10})
    print(y1, y2)

最佳答案

您必须加载文件内容,将其解析为 GraphDef,然后导入。 它将被导入到当前图表中。你可能想用 graph.as_default(): 上下文管理器来包装它。

import tensorflow as tf
from tensorflow.core.framework import graph_pb2 as gpb
from google.protobuf import text_format as pbtf

gdef = gpb.GraphDef()

with open('my-graph.pbtxt', 'r') as fh:
    graph_str = fh.read()

pbtf.Parse(graph_str, gdef)

tf.import_graph_def(gdef)

关于python - 从文件打开 tensorflow 图,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40305996/

相关文章:

templates - 使用 boost::graph 实现异构节点类型和边类型

python - 当尝试替换 .csv 文件中的值时,会出现错误

java - 解析未知的 XML 派生文件

tensorflow - 如何在 tensorflow 对象检测中仅检测特定类别的对象

python - 使用队列 Tensorflow 训练模型

database - 从 ArangoDB 检索嵌套 json 或树

python - GAE 中的 worker 角色和 Web 角色对应

python - 如何对 python 中的不同实现使用相同的单元测试?

python - Tensorflow 模型使用 Flask 抛出的不是该图的元素

algorithm - 打印动态规划解决方案中遍历的路径