python - 在全局上下文中使用一个 GradientTape

标签 python tensorflow tensorflow2.0

我想使用GradientTape在急切执行模式下观察梯度。是否可以创建一个 GradientTape 一次,然后记录所有内容,就好像它具有全局上下文一样?

这是我想做的一个例子:

import numpy as np
import tensorflow as tf

x = tf.Variable(np.ones((2,)))
y=2*x
z=2*y
tf.gradients(z, x) # RuntimeError, not supported in eager execution

现在,这个问题可以轻松解决:

with tf.GradientTape() as g:
    y = 2*x
    z = 2*y
    
g.gradient(y, x) # this works

但问题是我经常没有紧接着彼此的 y 和 z 的定义。例如,如果代码在 Jupyter Notebook 中执行并且它们位于不同的单元格中怎么办?

我可以定义一个 GradientTape 来监视全局的所有内容吗?

最佳答案

我找到了这个解决方法:

import numpy as np
import tensorflow as tf

# persistent is not necessary for g to work globally
# it only means that gradients can be computed more than once,
# which is important for the interactive jupyter notebook use-case
g = tf.GradientTape(persistent=True)

# this is the workaround
g.__enter__()

# you can execute this anywhere, also splitted into separate cells
x = tf.Variable(np.ones((2,)))
y = 2*x
z = 2*y

g.gradient(z, x)

关于python - 在全局上下文中使用一个 GradientTape,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58612362/

相关文章:

python-3.x - tensorflow api 2.0 张量对象仅在启用急切执行时才可迭代。要迭代此张量,请使用 tf.map_fn

python - tensorflow: "tf.gfile.GFile"中有问题, "' utf- 8' codec can' t 解码...”

Python,在 smtplib 参数中添加逗号会导致错误

tensorflow - tf.compat.v1 和 tf.compat.v2 之间的别名和区别是什么?

python - 纯 Tensorflow 中的 Gram-Schmidt 正交化 : performance for iterative solution is much slower than numpy

python - 计算损失时检查标签( tensorflow )

python - 关于 tf.function 的跟踪是什么

python - 如何从 pubkey_hash 获取比特币地址?

python - 我可以使用 IMDbPY 检索 IMDb 对给定电影的电影推荐吗?

tensorflow - 修改张量值