我遇到的问题可以反射(reflect)如下:
tf.reset_default_graph()
x = tf.placeholder(dtype=tf.int32, shape=())
init = tf.zeros(shape=tf.squeeze(x), dtype=tf.float32)
v = tf.get_variable('foo', initializer=init, validate_shape=False)
v_sig = tf.saved_model.signature_def_utils.build_signature_def(
inputs={"x_input": tf.saved_model.utils.build_tensor_info(x)},
outputs={
'v_output': tf.saved_model.utils.build_tensor_info(v)
},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
)
with tf.Session() as sess:
builder = tf.saved_model.builder.SavedModelBuilder(export_dir="~/test/")
sess.run(tf.global_variables_initializer()) # here leads to problem
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'v_sig': v_sig
},
main_op=tf.tables_initializer(),
strip_default_attrs=True
)
builder.save()
我有一个变量foo
,其形状是动态计算的(取决于占位符x
的输入)。当我尝试将其另存为图表时,遇到错误:
You must feed a value for placeholder tensor 'Placeholder' with dtype int32
如果我不运行global_variables_initializer
,则会出现错误变量不存在
。
那么如何解决这种情况呢?我已经被这个问题困扰了很长一段时间,感谢您的回答。
最佳答案
您可以将图形保存为元图形对象,而无需初始化变量,如下所示:
import tensorflow as tf
import json
x = tf.placeholder(dtype=tf.int32, shape=(), name='x')
init = tf.zeros(shape=tf.squeeze(x), dtype=tf.float32, name='init')
v = tf.get_variable('foo', initializer=init, validate_shape=False)
tensor_names = {
'x': x.name,
'v': v.name
}
with open('tensor_names.json', 'w') as fo:
json.dump(tensor_names, fo)
fname = 'graph.meta'
proto = tf.train.export_meta_graph(filename=fname,
graph=tf.get_default_graph())
然后恢复此图表:
import tensorflow as tf
import json
with open('tensor_names.json', 'r') as fo:
tensor_names = json.load(fo)
graph = tf.Graph()
with graph.as_default():
tf.train.import_meta_graph(fname)
x = graph.get_tensor_by_name(tensor_names['x'])
v = graph.get_tensor_by_name(tensor_names['v'])
# works as expected:
with tf.Session(graph=graph) as sess:
sess.run(tf.global_variables_initializer(), {x:5})
print(v.eval()) # [0. 0. 0. 0. 0.]
关于python - 如何在不初始化变量的情况下保存 tensorflow 图?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56729579/