我有一个形状为(n, m, s)
的矩阵A
。在第 0 个轴的每个位置,我需要对应于 (m, s)
形数组中最大值的位置。
例如:
np.random.seed(1)
A = np.random.randint(0, 10, size=[10, 3, 3])
A[0]
是:
array([[5, 8, 9],
[5, 0, 0],
[1, 7, 6]])
我想获取(0, 2)
,即这里9
的位置。
我愿意做
aa = A.argmax()
,这样 aa.shape = (10, 2)
,并且 aa[0] = [0, 2]
我怎样才能做到这一点?
最佳答案
运行 argmax
通常要快得多数组的部分扁平化版本而不是使用理解:
ind = A.reshape(A.shape[0], -1).argmax(axis=-1)
现在可以申请unravel_index
获取二维索引:
coord = np.unravel_index(ind, A.shape[1:])
coord
是一对数组,每个轴一个。您可以使用常见的 Python 习惯用法转换为元组列表:
result = list(zip(*coord))
更好的方法是将数组堆叠在一起形成一个 Nx2 数组。对于所有意图和目的,这将表现为一系列对:
result = np.stack(coord, axis=-1)
强制单行:
result = np.stack(np.unravel_index(A.reshape(A.shape[0], -1).argmax(axis=-1), A.shape[1:]), axis=-1)
关于python - 返回元组的 np.argmax,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67190442/