python - Pytorch:更新 numpy 数组而不更新相应的张量

标签 python numpy deep-learning pytorch tensor

当我运行以下代码时,

import numpy as np
a = np.ones(3)
b = torch.from_numpy(a)
np.add(a, 1, out=a)
print(a)
print(b)

a 和 b 都是 2。

但是,当我运行时:

import numpy as np
a = np.ones(3)
b = torch.from_numpy(a)
a = a+1
print(a)
print(b)

b 仍为 1s,而 a 已更新为 2s。

这是预期的行为吗?

最佳答案

是的,如@hpaulj pointed out in his comment 、操作

a = a + 1

创建原始数组a的副本并使用broadcasting加1 。并且在加法之后,由于我们将其分配给a,因此a被更新为加法运算的结果。但是,b仍然共享原始数组a的内存(即更新之前创建的数组a。)

所以,我们看到的结果是这样的:

In [75]: a = np.ones(3)
    ...: b = torch.from_numpy(a)
    ...: a = a+1     # <========= creates copy of `a` and modifies it
    ...: print(a)
    ...: print(b)
    ...: 
[ 2.  2.  2.]

 1
 1
 1
[torch.DoubleTensor of size 3]

但是,看看当您愿意这样做时会发生什么:

In [72]: a = np.ones(3)
    ...: b = torch.from_numpy(a)
    ...: a += 1      # <========== in-place modification of `a`
    ...: print(a)
    ...: print(b)
    ...:

[ 2.  2.  2.]

 2
 2
 2
[torch.DoubleTensor of size 3]

观察 += 操作如何就地对原始数组进行修改,而 somearr = somearr + 1 创建数组的副本 somearray,然后对其进行修改。

关于python - Pytorch:更新 numpy 数组而不更新相应的张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48370286/

相关文章:

python - keras中的Flatten()和GlobalAveragePooling2D()有什么区别

python - 有没有更有效的方法将多行字符串转换为 numpy 数组?

python - 使用 numpy 广播/矢量化从其他数组构建新数组

python - 如何将时区偏移量添加到 pandas datetime?

python - 在预测期间,数据规范化如何在 keras 中工作?

python - openCV:使用findContours检测圆

python - TensorFlow:带轴选项的 bincount

machine-learning - 将长一维矢量数据、一维矢量标签输入 Caffe

python - 基准测试 : does python have a faster way of walking a network folder?

python - 无法让 INSERT 在 SPARQLWrapper 中工作