python - 返回元组的 np.argmax

标签 python numpy argmax

我有一个形状为(n, m, s) 的矩阵A。在第 0 个轴的每个位置,我需要对应于 (m, s) 形数组中最大值的位置。

例如:

np.random.seed(1)
A = np.random.randint(0, 10, size=[10, 3, 3])

A[0] 是:

array([[5, 8, 9],
       [5, 0, 0],
       [1, 7, 6]])

我想获取(0, 2),即这里9的位置。

我愿意做 aa = A.argmax(),这样 aa.shape = (10, 2),并且 aa[0] = [0, 2]

我怎样才能做到这一点?

最佳答案

运行 argmax 通常要快得多数组的部分扁平化版本而不是使用理解:

ind = A.reshape(A.shape[0], -1).argmax(axis=-1)

现在可以申请unravel_index获取二维索引:

coord = np.unravel_index(ind, A.shape[1:])

coord 是一对数组,每个轴一个。您可以使用常见的 Python 习惯用法转换为元组列表:

result = list(zip(*coord))

更好的方法是将数组堆叠在一起形成一个 Nx2 数组。对于所有意图和目的,这将表现为一系列对:

result = np.stack(coord, axis=-1)

强制单行:

result = np.stack(np.unravel_index(A.reshape(A.shape[0], -1).argmax(axis=-1), A.shape[1:]), axis=-1)

关于python - 返回元组的 np.argmax,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67190442/

相关文章:

python - 如何使 PyMe(Python 库)在 Windows 上的 Python 2.4 中运行?

python - 从包含特定单词的文本文件中过滤行

python - Flask-SQLAlchemy backref 函数和 backref 参数

python - scipy.optimize.minimize 返回零维数组?

python - numpy 和 statsmodels 在计算相关性时给出不同的值,如何解释?

python - 计算不同长度的 3 维数组的平均值

python - 如何将 RFC822 转换为 python 日期时间对象?

java - Java 8 流中的 arg 最大值?

python - 在 Pandas 数据框的每一行中查找第一个和最后一个非零列

python - 第一次出现大于 numpy 数组中给定值的值