python - 强制依赖变量更新

标签 python tensorflow

假设我有一些变量 x 的函数 f:

x = tf.Variable(1.0)
fx = x*x

以及更新x的操作:

new_x = x.assign(2.0)

我想获取更新后的 x 产生的 f 值。我本来以为

with tf.control_dependencies([new_x,]):
    new_fx = tf.identity(fx)    

将强制new_fx依赖于更新new_x,但情况似乎并非如此:

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

# prints 1.0, expected 4.0
print "new fx", sess.run(new_fx)

还有其他方法可以定义 fx 的更新值吗?

显然,我可以通过编写类似 new_fx = new_x * new_x 的内容来创建一个新的独立副本,但这会增大图形大小,并且还需要访问 fx 的定义>,我更愿意将其视为黑匣子。

编辑:为了激发这一点,这里是我要编写的代码的草图:

# Hamiltonian Monte Carlo update, simplified
def hmc_step(x, momentum, logpdf, n_steps=50): 
    # x and momentum are Variables
    # logpdf is a Tensor with potentially complicated dependence on x

    grad = tf.gradients(logpdf, x)[0]

    # initial position        
    new_x = x

    for i in range(n_steps):
        # update position
        new_x = x.assign(new_x + momentum)

        # update momentum using gradient at *current* position
        with tf.control_dependencies([new_x]):
             momentum = momentum + grad # DOESN'T WORK

        # DOES WORK BUT IS UGLY
        # new_logpdf = define_logpdf(new_x)
        # new_grad = tf.gradients(new_logpdf, new_x)[0]
        # momentum = momentum + new_grad

    # (do some stuff to accept/reject the new x)
    # ....

    return new_x

每次循环定义一个新的 logpdf 副本并重新导出梯度感觉真的很不优雅:它需要访问 Define_logpdf() 并将图形大小放大 50 倍。有没有更好的方法来做到这一点(除了 theano.scan 的等价物)?

最佳答案

with tf.control_dependency([op]) block 强制对 op 对 with block 内创建的其他操作的控制依赖。在您的情况下,x*x是在外部创建的,而tf.identity只是获取旧值。这就是您想要的:

with tf.control_dependencies([new_x,]):
  new_fx = x*x

关于python - 强制依赖变量更新,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35693687/

相关文章:

python - 如何在图像中找到签名?

python - 如何打印列表而不是列表列表

python - 我在理解 round() 函数时缺少什么?

tensorflow - 操作链的自定义渐变

python - 如何使用 TF1.3 中的新数据集 api 映射具有附加参数的函数?

python - Flask-Security 无法正确初始化

python - 使用 numpy/scipy 将 3D 点的最近点投影到 3D 三角形

neural-network - tensorflow.equal() op 上的不兼容形状用于正确的预测评估

python - 从 Keras 检查点加载

python - 为什么我们将图像归一化为mean=0.5, std=0.5?