python - TensorFlow 传递梯度不变

标签 python numpy tensorflow

假设我在神经网络中使用了一些自定义操作二值化器。该操作采用一个Tensor并构造一个新的Tensor。我想修改该操作,使其仅在前向传递中使用。在向后传递中,计算梯度时,应该只通过到达它的梯度。

更具体地说,二值化器是:

def binarizer(input):
    prob = tf.truediv(tf.add(1.0, input), 2.0)
    bernoulli = tf.contrib.distributions.Bernoulli(p=prob, dtype=tf.float32)
    return 2 * bernoulli.sample() - 1

我设置了我的网络:

# ...

h1_before_my_op = tf.nn.tanh(tf.matmul(x, W) + bias_h1)
h1 = binarizer(h1_before_b)

# ...

loss = tf.reduce_mean(tf.square(y - y_true))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)

如何告诉 TensorFlow 在后向传递中跳过梯度计算?

<小时/>

我尝试定义自定义操作,如 this answer 中所述。 ,但是:py_func无法返回Tensor,这不是它的用途 - 我明白:

UnimplementedError (see above for traceback): Unsupported object type Tensor

最佳答案

您正在寻找tf.stop_gradient(input, name=None) :

Stops gradient computation.

When executed in a graph, this op outputs its input tensor as-is.

h1 = binarizer(h1_before_b)
h1 = tf.stop_gradient(h1)

关于python - TensorFlow 传递梯度不变,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40305498/

相关文章:

python - ctypes.pointer、ctypes.POINTER 和 ctypes.byref 之间有什么区别?

python - 在 Python 中从 URL 下载文本

python - Numpy 从 np 数组中删除一个维度

python - reshape 张量是否保留原始张量的特征?

tensorflow - 您可以在我的 GCP VM 上使用 Jupyter notebook 在 Google Cloud 中运行 TPU 训练吗?

python - 如何使用面向对象的 OpenCV

Python:从 strftime 打印时区

python - 如何从 numpy 数组中获取字符串值?

python - 查找一组数据的中位数并应用于该组的成员

machine-learning - 如何创建包含多个与 CIFAR10 格式相同的图像的数据集?