python - Numpy:如何使用 argmax 结果来获得实际最大值?

标签 python numpy

<分区>

假设我有一个 3D 数组:

>>> a
array([[[7, 0],
        [3, 6]],

       [[2, 4],
        [5, 1]]])

我可以使用 axis=1 获取它的 argmax

>>> m = np.argmax(a, axis=1)
>>> m
array([[0, 1],
       [1, 0]])

如何使用 m 作为 a 的索引,以便结果等同于简单地使用 max

>>> a.max(axis=1)
array([[7, 6],
       [5, 4]])

(这在m应用于其他相同形状的数组时很有用)

最佳答案

您可以使用 advanced indexing 执行此操作和 numpy broadcasting :

m = np.argmax(a, axis=1)
a[np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2])]

#array([[7, 6],
#       [5, 4]])

m = np.argmax(a, axis=1)

创建第一、第二和第三维索引的数组:

ind1, ind2, ind3 = np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2])
​

因为维度不匹配,三个数组会广播,导致每个数组如下:

for x in np.broadcast_arrays(ind1, ind2, ind3):
    print(x, '\n')

#[[0 0]
# [1 1]] 

#[[0 1]
# [1 0]] 

#[[0 1]
# [0 1]] 

并且由于所有索引都是整数数组,它会触发 advanced indexing ,因此索引为 (0, 0, 0), (0, 1, 1), (1, 1, 0), (1, 0, 1) 的元素被拾取,即一个元素来自每个组合的数组作为索引。

关于python - Numpy:如何使用 argmax 结果来获得实际最大值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46840848/

相关文章:

python - 为什么这个正则表达式会经历灾难性的回溯?

python - PyQt:setStyleSheet url() 中的特殊字符

python - 在没有 root 访问权限的 Unix 服务器上安装 pip 和 numpy

python - 根据另一个列表或数组查找数组/列表中元素的索引

python - 获取 numpy 二维数组中包含非屏蔽值的第一行和最后一行和列的索引

python - numpy 掩码数组的高效内存使用

python - ApiTemporaryUnavailableError

python - 张量板中 enqueue_input 和 Globe_step/sec 的含义是什么以及如何解释该数字?

python - 剥离 numpy 数组中的空格

python - 如何获取元组列表并将元组从字符串更改为整数