我编写了一个脚本,用于评估 arr
的某些条目是否在 check_elements
中。我的方法不比较单个条目,而是比较arr
内的整个向量。因此,脚本检查 [8, 3]
、[4, 5]
, ... 是否在 check_elements
中。
这是一个例子:
import numpy as np
# arr.shape -> (2, 3, 2)
arr = np.array([[[8, 3],
[4, 5],
[6, 2]],
[[9, 0],
[1, 10],
[7, 11]]])
# check_elements.shape -> (3, 2)
# generally: (n, 2)
check_elements = np.array([[4, 5], [9, 0], [7, 11]])
# rslt.shape -> (2, 3)
rslt = np.zeros((arr.shape[0], arr.shape[1]), dtype=np.bool)
for i, j in np.ndindex((arr.shape[0], arr.shape[1])):
if arr[i, j] in check_elements: # <-- condition is checked against
# the whole last dimension
rslt[i, j] = True
else:
rslt[i, j] = False
现在:
print(rslt)
...将打印:
[[False True False]
[ True False True]]
为了获取我使用的索引:
print(np.transpose(np.nonzero(rslt)))
...打印以下内容:
[[0 1] # arr[0, 1] -> [4, 5] -> is in check_elements
[1 0] # arr[1, 0] -> [9, 0] -> is in check_elements
[1 2]] # arr[1, 2] -> [7, 11] -> is in check_elements
如果我检查单个值的条件(例如 arr > 3
或 np.where(...)
),此任务将变得简单且高效,但我我对单一值(value)观不感兴趣。我想检查整个最后一个维度(或其片段)的条件。
我的问题是:是否有更快的方法来达到相同的结果?我是否正确,矢量化尝试和诸如 np.where 之类的东西不能用于解决我的问题,因为它们总是对单个值进行操作,而不是对整个维度或切片进行操作那个维度?
最佳答案
这是使用 broadcasting 的 Numpythonic 方法:
>>> (check_elements == arr[:,:,None]).reshape(2, 3, 6).any(axis=2)
array([[False, True, False],
[ True, False, True]], dtype=bool)
关于python - Numpy ndarray 中的成员资格检查,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39164636/