我有一个模型,它接受 int 输入 x 并创建大小为 x 的向量的均值和方差。 我能够保存这个模型,但想要恢复,通过传递 x 值来运行它。我也能够恢复,但不知道如何在行后执行它
saver.restore(sess, './mean_var.ckpt')
对于不同的x。我可以为此使用 feed_dict 吗?请帮我解决这个问题。
import tensorflow as tf
def mean_var(x):
vec = tf.random_normal([x])
mean, variance = tf.nn.moments(vec, [0], keep_dims=True)
return mean, variance
with tf.Graph().as_default():
x = tf.placeholder(tf.int32)
output = mean_var(x)
init = tf.initialize_all_variables()
_ = tf.Variable(initial_value='fake_variable')
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
sess.run(_.initializer)
#val = sess.run(output, feed_dict={x: 4})
#print(val[0], val[1])
save_path = saver.save(sess, "./mean_var.ckpt")
tf.reset_default_graph()
with tf.Graph().as_default():
init = tf.initialize_all_variables()
_ = tf.Variable(initial_value='fake_variable')
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
sess.run(_.initializer)
saver.restore(sess, './mean_var.ckpt')
最佳答案
用它来恢复和预测:
with tf.Graph().as_default():
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./mean_var.ckpt.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
graph = tf.get_default_graph()
x = graph.get_tensor_by_name("x:0")
output = mean_var(x)
y_pred = sess.run(output, feed_dict={x:4})
print(y_pred)
而且,还有一件事为占位符 x
命名,如下所示:
x = tf.placeholder(tf.int32, name="x")
完整代码:
import tensorflow as tf
def mean_var(x):
vec = tf.random_normal([x])
mean, variance = tf.nn.moments(vec, [0], keep_dims=True)
return mean, variance
with tf.Graph().as_default():
x = tf.placeholder(tf.int32, name="x")
output = mean_var(x)
init = tf.initialize_all_variables()
_ = tf.Variable(initial_value='fake_variable')
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
sess.run(_.initializer)
val = sess.run(output, feed_dict={x: 4})
print(val[0], val[1])
save_path = saver.save(sess, "./mean_var/mean_var.ckpt")
tf.reset_default_graph()
with tf.Graph().as_default():
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./mean_var/mean_var.ckpt.meta')
saver.restore(sess, tf.train.latest_checkpoint('./mean_var/'))
#saver.restore(sess, './mean_var/mean_var.ckpt')
graph = tf.get_default_graph()
x = graph.get_tensor_by_name("x:0")
output = mean_var(x)
y_pred = sess.run(output, feed_dict={x:4})
print(y_pred)
关于python - 恢复 tensorflow 模型并使用输入运行它,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56881438/