python - RNN 的 tf.clip_by_value 和 tf.clip_by_global_norm 之间的区别以及如何确定要剪辑的最大值?

标签 python tensorflow deep-learning

想了解 TensorFlow 中梯度裁剪实现过程中 tf.clip_by_valuetf.clip_by_global_norm 的作用差异。首选哪一个以及如何确定要剪辑的最大值?

最佳答案

TL;DR:使用tf.clip_by_global_norm进行渐变裁剪。

按值剪辑

tf.clip_by_value 剪辑一个张量内的每个值,无论张量中的其他值如何。例如,

tf.clip_by_value([-1, 2, 10], 0, 3)  -> [0, 2, 3]  # Only the values below 0 or above 3 are changed

因此,它可以改变张量的方向,因此如果张量中的值相互去相关(梯度裁剪不是这种情况),或者避免张量中的零/无限值,则应该使用它。张量可能会在其他地方导致 Nan/无限值(例如通过使用最小值 epsilon=1e-8 和非常大的最大值进行裁剪)。

按规范剪辑

tf.clip_by_norm 如有必要,会重新调整一个张量,使其 L2 范数不超过某一阈值。通常,避免一个张量上的梯度爆炸很有用,因为您可以保持梯度方向。例如:

tf.clip_by_norm([-2, 3, 6], 5)  -> [-2, 3, 6]*5/7  # The original L2 norm is 7, which is >5, so the final one is 5
tf.clip_by_norm([-2, 3, 6], 9)  -> [-2, 3, 6]  # The original L2 norm is 7, which is <9, so it is left unchanged

但是,clip_by_norm 仅适用于一个梯度,因此,如果您在所有梯度张量上使用它,您将使它们不平衡(有些将被重新缩放,其他则不会,并且并非所有梯度都具有相同的缩放比例)规模)。

请注意,前两个仅适用于一个张量,而最后一个则适用于张量列表。

clip_by_global_norm

tf.clip_by_global_norm 重新缩放张量列表,以便所有其范数的向量的总范数不超过阈值。目标与clip_by_norm相同(避免梯度爆炸,保持梯度方向),但它同时作用于所有梯度,而不是单独作用于每个梯度(也就是说,所有梯度都通过如有必要,使用相同的因子,或者不重新调整它们)。这更好,因为保持了不同梯度之间的平衡。

例如:

tf.clip_by_global_norm([tf.constant([-2, 3, 6]),tf.constant([-4, 6, 12])] , 14.5)

会将两个张量重新缩放一个因子 14.5/sqrt(49 + 196),因为第一个张量的 L2 范数为 7,第二个张量的 L2 范数为 14,而 sqrt(7^ 2+ 14^2)>14.5

这个 (tf.clip_by_global_norm) 是您应该用于渐变裁剪的那个。请参阅this例如了解更多信息。

选择值

选择最大值是最难的部分。您应该使用最大值,这样就不会出现梯度爆炸(其影响可能是张量中出现的 Nan无限 值,在训练步骤很少)。 tf.clip_by_global_norm 的值应该比其他值更大,因为由于隐含的张量数量,全局 L2 范数在机械上会比其他范数更大。

关于python - RNN 的 tf.clip_by_value 和 tf.clip_by_global_norm 之间的区别以及如何确定要剪辑的最大值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44796793/

相关文章:

python - ABC ----> ABC、CAB、BCA 仅顺序重要但不是真正的 itertools 排列

python - 限制 MNIST 训练数据的大小

python - 如何将输入 dim 从 fit 方法传递到 skorch 包装器?

machine-learning - tensorflow word2vec 示例中权重和偏差的目的是什么?

python - 如何在Python中生成具有指定均值、方差、偏度、峰度的数据?

python - Python 中的 Formlets 实现

gpu - 为 Tensorflow 推荐的 GPU

python - 如何在 Unet 架构 PyTorch 中处理奇数分辨率

python - 初始化和销毁​​ Python 多处理 worker

python-3.x - tensorflow gpu docker 图像中的 Python 3.6