python - numpy.where 是如何工作的?

标签 python numpy where-clause

我可以理解以下 numpy 行为。

>>> a
array([[ 0. ,  0. ,  0. ],
       [ 0. ,  0.7,  0. ],
       [ 0. ,  0.3,  0.5],
       [ 0.6,  0. ,  0.8],
       [ 0.7,  0. ,  0. ]])
>>> argmax_overlaps = a.argmax(axis=1)
>>> argmax_overlaps
array([0, 1, 2, 2, 0])
>>> max_overlaps = a[np.arange(5),argmax_overlaps]
>>> max_overlaps
array([ 0. ,  0.7,  0.5,  0.8,  0.7])
>>> gt_argmax_overlaps = a.argmax(axis=0)
>>> gt_argmax_overlaps
array([4, 1, 3])
>>> gt_max_overlaps = a[gt_argmax_overlaps,np.arange(a.shape[1])]
>>> gt_max_overlaps
array([ 0.7,  0.7,  0.8])
>>> gt_argmax_overlaps = np.where(a == gt_max_overlaps)
>>> gt_argmax_overlaps
(array([1, 3, 4]), array([1, 2, 0]))

我知道 0.7、0.7 和 0.8 是 a[1,1]、a[3,2] 和 a[4,0] 所以我得到元组 (array[1,3,4] 和 array [1,2,0]) 每个数组由这三个元素的第 0 和第 1 个索引组成。然后我尝试了其他示例以查看我的理解是否正确。

>>> np.where(a == [0.3])
(array([2]), array([1]))

0.3 在 a[2,1] 中,所以结果看起来和我预期的一样。然后我试了一下

>>> np.where(a == [0.3, 0.5])
(array([], dtype=int64),)

??我希望看到 (array([2,2]),array([2,3]))。为什么我会看到上面的输出?

>>> np.where(a == [0.7, 0.7, 0.8])
(array([1, 3, 4]), array([1, 2, 0]))
>>> np.where(a == [0.8,0.7,0.7])
(array([1]), array([1]))

第二个结果我也看不懂。有人可以向我解释一下吗?谢谢。

最佳答案

首先要意识到的是 np.where(a == [whatever]) 只是向您显示 a == [whatever] 为 True 的索引.因此,您可以通过查看 a == [whatever] 的值来获得提示。在你的情况下“有效”:

>>> a == [0.7, 0.7, 0.8]
array([[False, False, False],
       [False,  True, False],
       [False, False, False],
       [False, False,  True],
       [ True, False, False]], dtype=bool)

你并没有得到你认为的那样。您认为这是分别请求每个元素的索引,但它获取的是值匹配的位置在行中的相同位置。这个比较基本上是在说“对于每一行,告诉我第一个元素是否为 0.7,第二个元素是否为 0.7,第三个元素是否为 0.8”。然后它返回那些匹配位置的索引。换句话说,比较是在整行之间进行的,而不仅仅是单个值。对于你的最后一个例子:

>>> a == [0.8,0.7,0.7]
array([[False, False, False],
       [False,  True, False],
       [False, False, False],
       [False, False, False],
       [False, False, False]], dtype=bool)

您现在得到不同的结果。它不要求“a 值为 0.8 的索引”,它只要求行开头有 0.8 的索引——同样还有 a 0.7 在后两个位置中的任何一个。

只有当您比较的值与 a 的单行形状匹配时,才能进行这种类型的逐行比较。因此,当您尝试使用双元素列表时,它会返回一个空集,因为它试图将列表作为标量值与数组中的各个值进行比较。

结果是您不能在值列表上使用 == 并期望它只告诉您任何值出现的位置。相等性将按值 和位置 进行匹配(如果您比较的值与数组的一行形状相同),或者它将尝试将整个列表作为标量进行比较(如果形状不匹配)。如果您想独立搜索值,则需要执行 Khris 在评论中建议的操作:

np.where((a==0.3)|(a==0.5))

也就是说,您需要对单独的值进行两次(或更多次)单独比较,而不是对值列表进行一次比较。

关于python - numpy.where 是如何工作的?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41973222/

相关文章:

mysql - OR 和 AND 与 mysql 中 where 的区别

python - 如何限制 python 中的输入,使其只允许 0 或 1?

python - 如何在 Mac 上将 openCV 安装到 Enthought python 发行版中

python - 将 CSV 文件转换为 pandas 的 'flat file"

wpf - 在 WPF 应用程序中何处放置和配置 IoC 容器?

mysql - 在 where 子句上查找空值时查询不起作用

python - 无法在 while 循环中从列表中排除特定范围内的项目

python - 如何一键运行Python脚本?

python - 在 np.select 中使用字符串条件时出现问题

python - 将解决方案应用于实际数据时结果不正确