python - 当使用 .clamp 而不是 torch.relu 时,Pytorch Autograd 会给出不同的渐变

标签 python pytorch backpropagation autograd relu

我仍在努力理解 PyTorch autograd 系统。我正在努力的一件事是理解为什么 .clamp(min=0)nn.functional.relu()似乎有不同的向后传球。

尤其令人困惑的是 .clamp相当于 relu在 PyTorch 教程中,例如 https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-nn .

我在分析具有一个隐藏层和一个 relu 激活(输出层中的线性)的简单全连接网络的梯度时发现了这一点。

据我了解,以下代码的输出应该只是零。我希望有人能告诉我我缺少什么。

import torch
dtype = torch.float

x = torch.tensor([[3,2,1],
                  [1,0,2],
                  [4,1,2],
                  [0,0,1]], dtype=dtype)

y = torch.ones(4,4)

w1_a = torch.tensor([[1,2],
                     [0,1],
                     [4,0]], dtype=dtype, requires_grad=True)
w1_b = w1_a.clone().detach()
w1_b.requires_grad = True



w2_a = torch.tensor([[-1, 1],
                     [-2, 3]], dtype=dtype, requires_grad=True)
w2_b = w2_a.clone().detach()
w2_b.requires_grad = True


y_hat_a = torch.nn.functional.relu(x.mm(w1_a)).mm(w2_a)
y_a = torch.ones_like(y_hat_a)
y_hat_b = x.mm(w1_b).clamp(min=0).mm(w2_b)
y_b = torch.ones_like(y_hat_b)

loss_a = (y_hat_a - y_a).pow(2).sum()
loss_b = (y_hat_b - y_b).pow(2).sum()

loss_a.backward()
loss_b.backward()

print(w1_a.grad - w1_b.grad)
print(w2_a.grad - w2_b.grad)

# OUT:
# tensor([[  0.,   0.],
#         [  0.,   0.],
#         [  0., -38.]])
# tensor([[0., 0.],
#         [0., 0.]])
# 

最佳答案

原因是clamprelu0 处产生不同的梯度.用标量张量检查 x = 0两个版本:(x.clamp(min=0) - 1.0).pow(2).backward()(relu(x) - 1.0).pow(2).backward() .由此产生的 x.grad0relu版本但它是 -2clamp版本。这意味着 relu选择 x == 0 --> grad = 0clamp选择 x == 0 --> grad = 1 .

关于python - 当使用 .clamp 而不是 torch.relu 时,Pytorch Autograd 会给出不同的渐变,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60618346/

相关文章:

python - python 中的 Johansen 协整检验

编写多个正则表达式测试的 Pythonic 方式

Pytorch:为什么在 nn.modules.loss 和 nn.functional 模块中都实现了损失函数?

parallel-processing - Pytorch softmax沿着不同的掩码没有for循环

python - pytorch:来自两个网络时损失的表现如何?

python - opencv 在 osx 上安装没有合适的图像错误

python - 在 Python Google App Engine 中,如何模拟或子类化 File 类,以便编写用于访问文件的软件不会抛出异常?

Python 神经网络 : 'numpy.ndarray' object has no attribute 'dim'

python - 反向传播中 sigmoid 导数输入的困惑

matlab - 误差反向传播 - 神经网络