我的图形中需要一个条件控制流。如果pred
是True
,则图应调用一个操作,该操作会更新变量,然后将其返回,否则它将返回不变的变量。简化的版本是:
pred = tf.constant(True)
x = tf.Variable([1])
assign_x_2 = tf.assign(x, [2])
def update_x_2():
with tf.control_dependencies([assign_x_2]):
return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
session.run(tf.initialize_all_variables())
print(y.eval())
但是,我发现
pred=True
和pred=False
都导致相同的结果y=[2]
,这意味着当update_x_2
未选择tf.cond
时,也会调用assign op。怎么解释呢?以及如何解决这个问题?
最佳答案
TL; DR:如果您希望 tf.cond()
在其中一个分支中执行副作用(如赋值),则必须在内部创建执行副作用的操作,该操作将传递给tf.cond()
。tf.cond()
的行为有点不直观。由于TensorFlow图中的执行在整个图中向前流动,因此在评估条件之前,必须在或分支中引用的所有操作都必须执行。这意味着true和false分支都接收对tf.assign()
op的控制依赖项,因此y
始终设置为2
,即使pred is
False`。
解决方案是在定义true分支的函数内部创建tf.assign()
op。例如,您可以按以下方式组织代码:
pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
def update_x_2():
with tf.control_dependencies([tf.assign(x, [2])]):
return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
session.run(tf.initialize_all_variables())
print(y.eval(feed_dict={pred: False})) # ==> [1]
print(y.eval(feed_dict={pred: True})) # ==> [2]
关于tensorflow - 被 `tf.cond`的行为所迷惑,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37063952/