tensorflow - 被 `tf.cond`的行为所迷惑

标签 tensorflow

我的图形中需要一个条件控制流。如果predTrue,则图应调用一个操作,该操作会更新变量,然后将其返回,否则它将返回不变的变量。简化的版本是:

pred = tf.constant(True)
x = tf.Variable([1])
assign_x_2 = tf.assign(x, [2])
def update_x_2():
  with tf.control_dependencies([assign_x_2]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval())

但是,我发现pred=Truepred=False都导致相同的结果y=[2],这意味着当update_x_2未选择tf.cond时,也会调用assign op。怎么解释呢?以及如何解决这个问题?

最佳答案

TL; DR:如果您希望 tf.cond() 在其中一个分支中执行副作用(如赋值),则必须在内部创建执行副作用的操作,该操作将传递给tf.cond()
tf.cond()的行为有点不直观。由于TensorFlow图中的执行在整个图中向前流动,因此在评估条件之前,必须在分支中引用的所有操作都必须执行。这意味着true和false分支都接收对tf.assign() op的控制依赖项,因此y始终设置为2,即使pred is False`。

解决方案是在定义true分支的函数内部创建tf.assign() op。例如,您可以按以下方式组织代码:

pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
def update_x_2():
  with tf.control_dependencies([tf.assign(x, [2])]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval(feed_dict={pred: False}))  # ==> [1]
  print(y.eval(feed_dict={pred: True}))   # ==> [2]

关于tensorflow - 被 `tf.cond`的行为所迷惑,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37063952/

相关文章:

python - Keras : 'can not import name ' abs' 导入错误

tensorflow - 深度学习回归——巨大的均方误差和损失

python - 张量到变量的切片赋值

python-3.x - 如何在 while 循环中打印张量的值?

random - tensorflow 梯度更新中的确定性?

python - 了解 CNN 超参数

python - 张量板显示语法错误 : can't assign to operator

docker - 当我尝试连接到 docker 镜像时,主机没有运行

machine-learning - 输入队列没有响应 TensorFlow 程序挂起

python - Tensorflow:对数据进行切片并对每个切片应用卷积