当我运行以下代码时,
import numpy as np
a = np.ones(3)
b = torch.from_numpy(a)
np.add(a, 1, out=a)
print(a)
print(b)
a 和 b 都是 2。
但是,当我运行时:
import numpy as np
a = np.ones(3)
b = torch.from_numpy(a)
a = a+1
print(a)
print(b)
b 仍为 1s,而 a 已更新为 2s。
这是预期的行为吗?
最佳答案
是的,如@hpaulj pointed out in his comment 、操作
a = a + 1
创建原始数组a
的副本并使用broadcasting加1 。并且在加法之后,由于我们将其分配给a
,因此a
被更新为加法运算的结果。但是,b
仍然共享原始数组a
的内存(即更新之前创建的数组a
。)
所以,我们看到的结果是这样的:
In [75]: a = np.ones(3)
...: b = torch.from_numpy(a)
...: a = a+1 # <========= creates copy of `a` and modifies it
...: print(a)
...: print(b)
...:
[ 2. 2. 2.]
1
1
1
[torch.DoubleTensor of size 3]
但是,看看当您愿意这样做时会发生什么:
In [72]: a = np.ones(3)
...: b = torch.from_numpy(a)
...: a += 1 # <========== in-place modification of `a`
...: print(a)
...: print(b)
...:
[ 2. 2. 2.]
2
2
2
[torch.DoubleTensor of size 3]
观察 +=
操作如何就地对原始数组进行修改,而 somearr = somearr + 1
创建数组的副本 somearray
,然后对其进行修改。
关于python - Pytorch:更新 numpy 数组而不更新相应的张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48370286/