因此,我尝试使用 numpy.ma.where
为我创建一个数组,就像 numpy.where
函数一样。 where
函数广播我的列数组,然后用零替换一些元素。我得到以下信息:
>>> import numpy
>>> condition = numpy.array([True,False, True, True, False, True]).reshape((3,2))
>>> print (condition)
[[ True False]
[ True True]
[False True]]
>>> broadcast_column = numpy.array([1,2,3]).reshape((-1,1)) # Column to be broadcast
>>> print (broadcast_column)
[[1]
[2]
[3]]
>>> numpy.where(condition, broadcast_column, 0) \
... # Yields the expected output, column is broadcast then condition applied
array([[1, 0],
[2, 2],
[0, 3]])
>>> numpy.ma.where(condition, broadcast_column, 0).data \
... # using the ma.where function yields a *different* array! Why?
array([[1, 0],
[3, 1],
[0, 3]], dtype=int32)
>>> numpy.ma.where(condition, broadcast_column.repeat(2,axis=1), 0).data \
... # The problem doesn't occur if broadcasting isnt used
array([[1, 0],
[2, 2],
[0, 3]], dtype=int32)
非常感谢您的帮助!
我的numpy版本是1.6.2
最佳答案
np.ma.where 的核心是这样的语句: (在 Ubuntu 上,请参阅/usr/share/pyshared/numpy/ma/core.py)
np.putmask(_data, fc, xv.astype(ndtype))
_data
是要返回的屏蔽数组中的数据。
fc
是 bool 数组,当条件为 True 时,该数组为 True。
xv.astype(ndtype)
是要插入的值,例如广播_列
。
In [90]: d = np.empty(fc.shape, dtype=ndtype).view(np.ma.MaskedArray)
In [91]: _data = d._data
In [92]: _data
Out[92]:
array([[5772360, 5772360],
[ 0, 17],
[5772344, 5772344]])
In [93]: fc
Out[93]:
array([[ True, False],
[ True, True],
[False, True]], dtype=bool)
In [94]: xv.astype(ndtype)
Out[94]:
array([[1],
[2],
[3]])
In [95]: np.putmask(_data, fc, xv.astype(ndtype))
In [96]: _data
Out[96]:
array([[ 1, 5772360],
[ 3, 1],
[5772344, 3]])
注意数组中间行的 3 和 1。
问题是 np.putmask
不广播值,而是重复它们:
来自 np.putmask
的文档字符串:
putmask(a, mask, values)
Sets
a.flat[n] = values[n]
for each n wheremask.flat[n]==True
.If
values
is not the same size asa
andmask
then it will repeat. This gives behavior different froma[mask] = values
.
当您显式广播时,flat
返回所需的扁平值:
In [97]: list(broadcast_column.repeat(2,axis=1).flat)
Out[97]: [1, 1, 2, 2, 3, 3]
但是如果你不广播,
In [99]: list(broadcast_column.flat) + list(broadcast_column.flat)
Out[99]: [1, 2, 3, 1, 2, 3]
正确的值不在所需的位置。
PS。在最新版本的 numpy 中,the code reads
np.copyto(_data, xv.astype(ndtype), where=fc)
我不确定这会对行为产生什么影响;我没有足够新的 numpy 版本来测试。
关于arrays - Numpy "ma.where"与 "where"具有不同的行为...我怎样才能使其行为相同?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/13099350/