假设我有一个 3D 数组:
>>> a
array([[[7, 0],
[3, 6]],
[[2, 4],
[5, 1]]])
我可以使用 axis=1
获取它的 argmax
>>> m = np.argmax(a, axis=1)
>>> m
array([[0, 1],
[1, 0]])
如何使用 m
作为 a
的索引,以便结果等同于简单地使用 max
?
>>> a.max(axis=1)
array([[7, 6],
[5, 4]])
(这在m
应用于其他相同形状的数组时很有用)
您可以使用 advanced indexing 执行此操作和 numpy broadcasting :
m = np.argmax(a, axis=1)
a[np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2])]
#array([[7, 6],
# [5, 4]])
m = np.argmax(a, axis=1)
创建第一、第二和第三维索引的数组:
ind1, ind2, ind3 = np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2])
因为维度不匹配,三个数组会广播,导致每个数组如下:
for x in np.broadcast_arrays(ind1, ind2, ind3):
print(x, '\n')
#[[0 0]
# [1 1]]
#[[0 1]
# [1 0]]
#[[0 1]
# [0 1]]
并且由于所有索引都是整数数组,它会触发 advanced indexing ,因此索引为 (0, 0, 0), (0, 1, 1), (1, 1, 0), (1, 0, 1)
的元素被拾取,即一个元素来自每个组合的数组作为索引。