我创建了一个 tensorflow Graph
.我可以加载它,例如
with tf.gfile.FastGFile(modelFullPath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
它将 protobuffer 文件中定义的图作为当前的默认图。如果我现在创建一个 session ,该图将用作当前图。
正在尝试保存序列化 graph_def
反对一个变量并启动一个Session
作为
with tf.Session(graph=graph_def) as sess:
以预期错误结束
TypeError: graph must be a tf.Graph, but got <class 'tensorflow.core.framework.graph_pb2.GraphDef'>
我有一个用例,我必须在多个图表之间切换。使用所提供的方法,我可以清除默认图并加载一个新图,但缺点是必须重复调用导入函数。
问题是,来 self 的 graph.pb
,如何是Graph
对象 my_graph
获得,所以可以使用
with tf.Session(graph=my_graph) as sess:
并在不从 graph.pb
重新加载图表的情况下创建 session 文件?
最佳答案
您可以创建自己的图表并将其设置为导入操作的默认值:
import tensorflow as tf
graph1 = tf.Graph()
graph2 = tf.Graph()
with graph1.as_default():
tf.import_graph_def(graph_def1) # graph_def1 loaded somewhere
with graph2.as_default():
tf.import_graph_def(graph_def2) # graph_def2 loaded somewhere
session1 = tf.Session(graph=graph1)
session2 = tf.Session(graph=graph2)
关于python - 如何从保存的graph.pb中获取Session的Graph对象,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43649558/