numpy - Numpy `where` 子句的奇怪行为

标签 numpy pythonanywhere array-broadcasting

我发现 ufunc where 有奇怪的行为Numpy 1.15.3 的条款.

In [1]: import numpy as np

In [2]: x = np.array([[1,2],[3,4]])

In [3]: y = np.ones(x.shape) * 2

In [4]: print(x, "\n", y)
[[1 2]
 [3 4]]
 [[2. 2.]
 [2. 2.]]

In [5]: np.add(x, y, where=x==3)
Out[5]:
array([[2., 2.],     #<=========== where do these 2s come from???
       [5., 2.]])

In [6]: np.add(x, y, where=x==3, out=np.zeros(x.shape))
Out[6]:
array([[0., 0.],
       [5., 0.]])

In [7]: np.add(x, y, where=x==3, out=np.ones(x.shape))
Out[7]:
array([[1., 1.],
       [5., 1.]])

In [8]: np.add(x, y, where=x==3)
Out[8]:
array([[1., 1.], # <========= it seems these 1s are remembered from last computation.
       [5., 1.]])

添加1

看来我只能用 out 得到合理的结果参数。

下面没有 out参数:

import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

x = np.linspace(-2,2,60)
y = np.linspace(-2,2,60)

xx, yy = np.meshgrid(x,y)

r= np.ones((60,60), dtype=float) * 2
z = np.sqrt(r**2 - xx**2 - yy**2, where=(r**2 - xx**2 - yy**2)>=0) # <==== HERE!!

surf = ax.plot_surface(xx, yy, z, cmap="viridis")

这会生成一个荒谬的图像:

enter image description here

如果我添加 out参数如下,一切正常。

z = np.zeros(xx.shape)
np.sqrt(r**2 - xx**2 - yy**2, where=(r**2 - xx**2 - yy**2)>=0, out=z)

enter image description here

最佳答案

由于使用 where,您的输出中最终会出现垃圾数据。正如您所说,修复方法是初始化您自己的输出并将其传递给 out

来自 docs about the out arg :

If ‘out’ is None (the default), a uninitialized return array is created. The output array is then filled with the results of the ufunc in the places that the broadcast ‘where’ is True. If ‘where’ is the scalar True (the default), then this corresponds to the entire output being filled. Note that outputs not explicitly filled are left with their uninitialized values.

因此,您跳过的 out 值(即 whereFalse 的索引)将保留为之前的值在他们之前。这就是为什么 numpy 看起来“记住”了之前计算中的值,例如第一个示例代码块末尾的 1

正如 @WarrenWeckesser 在他的评论中指出的那样,这也意味着当 out 留空时,同一内存块将被重新用于输出,至少在某些情况下是这样。有趣的是,您可以通过将每个输出分配给变量来更改获得的结果:

x = np.array([[1,2],[3,4]])
y = np.ones(x.shape) * 2

arr0 = np.add(x, y, where=x==3)
arr1 = np.add(x, y, where=x==3, out=np.zeros(x.shape))
arr2 = np.add(x, y, where=x==3, out=np.ones(x.shape))
arr3 = np.add(x, y, where=x==3)
print(arr3)

现在您可以清楚地看到输出中的垃圾数据:

[[-2.68156159e+154 -2.68156159e+154]
 [ 5.00000000e+000  2.82470645e-309]]

关于numpy - Numpy `where` 子句的奇怪行为,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53312747/

相关文章:

python - 使用 numpy 方法计算核矩阵

python - 为什么 Numpy 中的数组索引会产生这个结果?

python - 在 pythonanywhere 的 linux 云上对 pymssql 与 Azure mssql 的连接进行故障排除

javascript - 在 PythonAnywhere 上托管基于 Tornado 的 Python 应用程序时出现 "Error running WSGI application"

mysql - 在 python 中使用 pythonanywhere 的 MySQL 数据库

python - Numpy 3d 数组索引

python-3.x - 如何将使用cv2.imread ('img.png',cv2.IMREAD_UNCHANGED)读取的图像转换为cv2.imread ('img.png',cv2.IMREAD_COLOR)的格式

python - 为什么我必须使用 np.string_?为什么我不能使用没有下划线的 np.string?

python - 运行我的系统时具有相同的输出

python - 需要有效的方法将较小的 Numpy 数组广播到较大的数组中