python - tensorflow只保存初始化值

标签 python machine-learning tensorflow

我正在尝试保存一些变量,看看以后是否可以恢复它。 这是我的保存代码:

   import tensorflow as tf;
   my_a = tf.Variable(2,name = "my_a");
   my_b = tf.Variable(3,name = "my_b");
   my_c = tf.Variable(4,name = "my_c");
   my_c = tf.add(my_a,my_b);

   with tf.Session() as sess:
       init = tf.initialize_all_variables();
       sess.run(init);
       print("my_c =  ",sess.run(my_c));
       saver = tf.train.Saver();
       saver.save(sess,"test.ckpt");

打印出:

    my_c =   5

当我恢复它时:

   import tensorflow as tf;
   c = tf.Variable(3100,dtype = tf.int32);
   with tf.Session() as sess:
       sess.run(tf.initialize_all_variables());
       saver = tf.train.Saver({"my_c":c});
       saver.restore(sess, "test.ckpt");
       cc= sess.run(c);
       print(cc);

这给了我:

    4

my_c 的恢复值应为 5,因为它是 my_a 和 my_b 的和。然而它给了我 4,这是 my_c 的初始化值。谁能解释为什么会发生这种情况,以及如何保存对变量的更改?

最佳答案

在原始代码中,您并未真正将名为 my_c 的变量(请注意,TensorFlow name)分配给 my_a + my_b

通过编写 my_c = tf.add(my_a,my_b),Python 变量 my_c 现在与具有以下内容的 tf.Variable 不同: name='my_c'

当您执行sess.run()时,您只是执行操作,而不是更新该变量。

如果您希望此代码正确运行,请使用此代码 - (请参阅更改的注释)

import tensorflow as tf
my_a = tf.Variable(2,name = "my_a")
my_b = tf.Variable(3,name = "my_b")
my_c = tf.Variable(4,name="my_c")
# Use the assign() function to set the new value
add = my_c.assign(tf.add(my_a,my_b))

with tf.Session() as sess:
    init = tf.initialize_all_variables()
    sess.run(init)
    # Execute the add operator
    sess.run(add)
    print("my_c =  ",sess.run(my_c))
    saver = tf.train.Saver()
    saver.save(sess,"test.ckpt")

关于python - tensorflow只保存初始化值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41251385/

相关文章:

python - 如何通过将另外两个 tf.feature_column 相乘来创建 tf.feature_column?

python - TensorFlow 仅适用于 GPU 0

php - Python 中的 API 调用身份验证(工作 PHP 示例)

python - Django动态页面功能和url

python - 发生异常 : AttributeError 'int' object has no attribute 'shape' when calling dqn. fit()

python - RandomizedSearchCv 导致属性错误

python-3.x - tensorflow ValueError : features should be a dictionary of `Tensor` s. 给定类型:<class 'tensorflow.python.framework.ops.Tensor' >

python - json 值是一个 html 字符串 - 如何在 python 中解析它?

python - 我怎样才能让我的 python 文件显示其 mercurial 标签或修订作为模块版本?

machine-learning - 如何保存没有变量的 tensorflow 模型?