有时我需要修改 pytorch 张量中的一些值。例如,给定一个张量 x,我需要将其正数部分乘以 2,并将其负数部分乘以 3:
import torch
x = torch.randn(1000, requires_grad=True)
x[x>0] = 2 * x[x>0]
x[x<0] = 3 * x[x<0]
y = x.sum()
y.backward()
但是,此类就地操作总是会破坏 autograd 的图表:
Traceback (most recent call last):
File "test_rep.py", line 4, in <module>
x[x>0] = 2 * x[x>0]
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
因此,到目前为止我一直在使用以下解决方法:
import torch
x = torch.randn(1000, requires_grad=True)
y = torch.zeros_like(x, device=x.device)
y[x>0] = 2 * x[x>0]
y[x<0] = 3 * x[x<0]
z = y.sum()
z.backward()
这会导致手动创建新的张量。我想知道是否有更好的方法来做到这一点。
最佳答案
关注怎么样?
import torch
x = torch.randn(1000, requires_grad=True)
x = torch.where(x>0, x*2, x)
x = torch.where(x<0, x*3, x)
y = x.sum()
y.backward()
关于python - 有没有一种好方法可以在保留 autograd 功能的同时修改 pytorch 张量中的某些值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66363241/