python - 有没有一种好方法可以在保留 autograd 功能的同时修改 pytorch 张量中的某些值?

标签 python pytorch tensor

有时我需要修改 pytorch 张量中的一些值。例如,给定一个张量 x,我需要将其正数部分乘以 2,并将其负数部分乘以 3:

import torch

x = torch.randn(1000, requires_grad=True)
x[x>0] = 2 * x[x>0]
x[x<0] = 3 * x[x<0]

y = x.sum()
y.backward()

但是,此类就地操作总是会破坏 autograd 的图表:

Traceback (most recent call last):
  File "test_rep.py", line 4, in <module>
    x[x>0] = 2 * x[x>0]
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

因此,到目前为止我一直在使用以下解决方法:

import torch

x = torch.randn(1000, requires_grad=True)
y = torch.zeros_like(x, device=x.device)

y[x>0] = 2 * x[x>0]
y[x<0] = 3 * x[x<0]

z = y.sum()
z.backward()

这会导致手动创建新的张量。我想知道是否有更好的方法来做到这一点。

最佳答案

关注怎么样?

import torch

x = torch.randn(1000, requires_grad=True)
x = torch.where(x>0, x*2, x)
x = torch.where(x<0, x*3, x)

y = x.sum()
y.backward()

关于python - 有没有一种好方法可以在保留 autograd 功能的同时修改 pytorch 张量中的某些值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66363241/

相关文章:

python - 如何将 txt 读取为没有符号的数据框

python - 子进程在 Linux 上不起作用

pytorch - 为什么 nn.CrossEntropyLoss 在我验证输入为非 0 维时抛出 "TypeError: iteration over a 0-d tensor"?

lstm - 如何正确地为 PyTorch 中的嵌入、LSTM 和线性层提供输入?

python - Python 输入与 input_signature 不兼容是什么意思

python - Tkinter Canvas 不在无限循环中更新

python - 有没有办法让这段代码更简洁?

pytorch - 在 pytorch 中使用数据加载器进行替换采样

python - 如何将整数的pytorch张量转换为 boolean 值的张量?

python - 如何加速将tensorflow_datasets中的张量转换为numpy数组的代码?