python - 恢复 tensorflow 模型并使用输入运行它

标签 python neural-network tensor

我有一个模型,它接受 int 输入 x 并创建大小为 x 的向量的均值和方差。 我能够保存这个模型,但想要恢复,通过传递 x 值来运行它。我也能够恢复,但不知道如何在行后执行它

saver.restore(sess, './mean_var.ckpt')

对于不同的x。我可以为此使用 feed_dict 吗?请帮我解决这个问题。

import tensorflow as tf
def mean_var(x):
    vec = tf.random_normal([x])
    mean, variance = tf.nn.moments(vec, [0], keep_dims=True)
    return  mean, variance 
with tf.Graph().as_default():
    x = tf.placeholder(tf.int32)
    output = mean_var(x)
    init = tf.initialize_all_variables()
    _ = tf.Variable(initial_value='fake_variable')
    saver = tf.train.Saver()


    with tf.Session() as sess:
        sess.run(init)
        sess.run(_.initializer)
        #val = sess.run(output, feed_dict={x: 4})
        #print(val[0], val[1])
        save_path = saver.save(sess, "./mean_var.ckpt")

tf.reset_default_graph()

with tf.Graph().as_default():
    init = tf.initialize_all_variables()
    _ = tf.Variable(initial_value='fake_variable')
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(init)
        sess.run(_.initializer)
        saver.restore(sess, './mean_var.ckpt')

最佳答案

用它来恢复和预测:

with tf.Graph().as_default():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./mean_var.ckpt.meta')
        saver.restore(sess, tf.train.latest_checkpoint('./'))
        graph = tf.get_default_graph()
        x = graph.get_tensor_by_name("x:0")   
        output = mean_var(x)
        y_pred = sess.run(output, feed_dict={x:4})
        print(y_pred)

而且,还有一件事为占位符 x 命名,如下所示:

x = tf.placeholder(tf.int32, name="x")

完整代码:

import tensorflow as tf
def mean_var(x):
    vec = tf.random_normal([x])
    mean, variance = tf.nn.moments(vec, [0], keep_dims=True)
    return  mean, variance 

with tf.Graph().as_default():
    x = tf.placeholder(tf.int32, name="x")
    output = mean_var(x)
    init = tf.initialize_all_variables()
    _ = tf.Variable(initial_value='fake_variable')
    saver = tf.train.Saver()


    with tf.Session() as sess:
        sess.run(init)
        sess.run(_.initializer)
        val = sess.run(output, feed_dict={x: 4})
        print(val[0], val[1])
        save_path = saver.save(sess, "./mean_var/mean_var.ckpt")

tf.reset_default_graph()

with tf.Graph().as_default():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./mean_var/mean_var.ckpt.meta')
        saver.restore(sess, tf.train.latest_checkpoint('./mean_var/'))
        #saver.restore(sess, './mean_var/mean_var.ckpt')
        graph = tf.get_default_graph()
        x = graph.get_tensor_by_name("x:0")   
        output = mean_var(x)
        y_pred = sess.run(output, feed_dict={x:4})
        print(y_pred)

关于python - 恢复 tensorflow 模型并使用输入运行它,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56881438/

相关文章:

python - 如何在 Selenium 中找到这些复选框元素

python - 无法从 Python 连接到在 Azure Linux VM 上运行的 Redis 实例

tensorflow - 如何在 tensorflow playground 的第 4 个数据集上实现 0(训练和测试)损失?

python - 沿最后一个维度不同大小的Tensorflow张量运算

c++11 - 连接(堆叠)Eigen::Tensors 以创建另一个张量

python - 属性错误 : module 'tensorflow' has no attribute 'feature_column'

python - 测试不纯函数

c - 范恩错误 20 : The number of output neurons in the ann (4196752) and data (1) don't match Epochs

python - Tensorflow:在 LSTM 中显示或保存遗忘门值

python - 从 tensorflow 中的张量检索值的最快方法是什么?