tensorflow - tf.assign 和赋值运算符 (=) 之间的区别

标签 tensorflow

我试图理解 tf.assign 和赋值运算符 (=) 之间的区别。我有三套代码

首先,使用简单的 tf.assign

import tensorflow as tf

with tf.Graph().as_default():
  a = tf.Variable(1, name="a")
  assign_op = tf.assign(a, tf.add(a,1))
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(assign_op)
    print a.eval()
    print a.eval()

预期输出为

2
2
2

二、使用赋值运算符

import tensorflow as tf

with tf.Graph().as_default():
  a = tf.Variable(1, name="a")
  a = a + 1
  with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   print sess.run(a)
   print a.eval()
   print a.eval()

结果仍然是 2,2,2。

第三,我使用两者 tf.assign 和赋值运算符

import tensorflow as tf

with tf.Graph().as_default():
  a = tf.Variable(1, name="a")
  a = tf.assign(a, tf.add(a,1))
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(a)
    print a.eval()
    print a.eval()

现在,输出变为 2, 3, 4。

我的问题是

  1. 在使用 (=) 的第二个片段中,当我有 sess.run(a) 时,我似乎正在运行一个分配操作。那么“a = a+1”是否在内部创建了一个像 allocate_op = tf.assign(a, a+1) 这样的赋值操作? session 运行的操作真的只是 allocate_op 吗?但是当我运行 a.eval() 时,它不会继续递增 a,因此看起来 eval 正在评估“静态”变量。

  2. 我不知道如何解释第三个片段。为什么两个 eval 会增加 a,但第二个片段中的两个 eval 却不会?

谢谢。

最佳答案

这里的主要困惑是,执行 a = a + 1 会将 Python 变量 a 重新分配给加法运算的结果张量 a + 1。另一方面,tf.assign 是用于设置 TensorFlow 变量值的操作。

a = tf.Variable(1, name="a")
a = a + 1

这相当于:

a = tf.add(tf.Variable(1, name="a"), 1)

考虑到这一点:

In the 2nd snippet using (=), when I have sess.run(a), it seems I'm running an assign op. So does "a = a+1" internally create an assignment op like assign_op = tf.assign(a, a+1)? [...]

看起来可能是这样,但事实并非如此。如上所述,这只会重新分配 Python 变量。如果没有 tf.assign 或任何其他更改变量的操作,它的值将保持为 1。每次计算 a 时,程序将始终计算 a + 1 => 1 + 1。

I'm not sure how to explain the 3rd snippet. Why the two evals increment a, but the two evals in the 2nd snippet doesn't?

这是因为在第三个片段中的赋值张量上调用 eval() 也会触发变量赋值(请注意,这与执行 session.run(a)< 没有太大区别)/code> 与当前 session )。

关于tensorflow - tf.assign 和赋值运算符 (=) 之间的区别,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45779307/

相关文章:

python - 运行时错误 : module compiled against API version 0xc but this version of numpy is 0xb

tensorflow - 如何在 TensorFlow 中交换张量的 Axis ?

tensorflow - 多输入深度学习模型中两个输入的平均值

python-2.7 - Tensorflow(python):train_step.run(…)中的“ValueError:设置具有序列的数组元素”

java - Tensorflow Java API 设置分类列的占位符

python - 如何在 tensorflow 中实现多元线性随机梯度下降算法?

python - 不使用命令行训练 Tensorflow 模型

python - Tensorflow 上的简单线性回归

python - 如何在不每次对模型充电的情况下进行预测 - tensorflow ?

python - Google 的 TensorFlow 中的 Theano Dimshuffle 等效?