python - Tensorflow 的 gradient_override_map 函数

标签 python tensorflow

谁能给我解释一下 TensorFlow 中的 gradient_override_map 函数? 我无法准确理解它的用法。

我看到的代码用法是:

with G.gradient_override_map({"Floor": "Identity"}):
    return tf.reduce_mean(SomeVals) * SomeOtherVal

这里究竟发生了什么?什么是身份

最佳答案

“Floor”和“Identity”都是操作类型串,前者对应tf.floor,后者对应tf.identity所以你的代码的功能,我猜,是用tf.identity的back-propagated gradient(简称BPG)计算机制代替tf.floor的BPG计算机制 图 G 中的操作,同时传递 tf.reduce_mean 的正向输出。这似乎有点奇怪,因为在我迄今为止发现的 gradient_override_map 的所有应用程序中,op_type_map 的键始终与用于在上下文中生成输出的操作的类型字符串相同。我的意思是我更熟悉返回 tf.floor(SomeVals) 的场景,而不是 tf.reduce_mean(SomeVals)

gradient_override_map({op_A_type: op_B_type})的作用是将op_A的BPG计算机制替换为op_B,同时保留op_A_type的前向传播计算机制。 lahwran 的回答中显示了 gradient_override_map 的常见应用。

@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
    return 5.0 * grad

g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
    output = tf.identity(input, name="Identity")

通过

@tf.RegisterGradient("CustomGrad")
def _const_mul_grad(unused_op, grad):
    return 5.0 * grad

装饰器 tf.RegisterGradient("CustomGrad") 为自定义操作类型注册由 _const_mul_grad(unused_op, grad) 定义的梯度函数 -- "CustomGrad",

同时

g = tf.get_default_graph()
with g.gradient_override_map({"Identity": "CustomGrad"}):
    output = tf.identity(input, name="Identity") 

确保字符串类型为“Identity”(tf.identity) 的所有操作(在图 g 中)的输出保持原样,而 tf.identity 的 BPG 计算机制s 替换为字符串类型“CustomGrad”的BPG计算操作机制。

附言

  1. 操作的类型字符串对应于定义操作的原型(prototype)的 OpDef.name 字段。要查找操作的 OpDef.name ,请引用 MingXing 在 this question 下的回答。

  2. 没有必要声明tf.identity 操作的名称,因为 tf.identity 中的 arg 'name' 是可选的。

关于python - Tensorflow 的 gradient_override_map 函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41391718/

相关文章:

python - 如何替换pandas中的部分数据框

python - 如何根据条件替换 Panda 数据框列中的单元格

python - Clojure 中的 map 是有序的吗?

python - 如何动态更新 tf.ones_like() 的形状?

python - 尝试写入文件时出现 "FileNotFoundError: [Errno 2] No such file or directory"

python - 更改类对象中的字典键

python - 使用 Keras : All layer names should be unique for discriminator 在 GPU 上训练 GAN

python - 如何将参数传递给 Tensorflow 中 tf.cond 中的函数?

tensorflow - tensorflow 中数据集管道中的高斯模糊图像

python - 在 tf.nn.top_k 中加入 torch.topk 的 dim 参数