python - 在 NumPy 数组中沿轴获取 N 个最大值和索引

标签 python numpy matrix

我认为这对于有经验的 numpy 用户来说是一个简单的问题。

我有一个分数矩阵。原始索引对应样本,列索引对应项目。例如,

score_matrix = 
  [[ 1. ,  0.3,  0.4],
   [ 0.2,  0.6,  0.8],
   [ 0.1,  0.3,  0.5]]

我想为每个样本获取项目的前 M 个索引。我也想获得最高分。例如,

top2_ind = 
  [[0, 2],
   [2, 1],
   [2, 1]]

top2_score = 
  [[1. , 0.4],
   [0,8, 0.6],
   [0.5, 0.3]]

使用 numpy 执行此操作的最佳方法是什么?

最佳答案

这是一种使用 np.argpartition 的方法-

idx = np.argpartition(a,range(M))[:,:-M-1:-1] # topM_ind
out = a[np.arange(a.shape[0])[:,None],idx]    # topM_score

sample 运行-

In [343]: a
Out[343]: 
array([[ 1. ,  0.3,  0.4],
       [ 0.2,  0.6,  0.8],
       [ 0.1,  0.3,  0.5]])

In [344]: M = 2

In [345]: idx = np.argpartition(a,range(M))[:,:-M-1:-1]

In [346]: idx
Out[346]: 
array([[0, 2],
       [2, 1],
       [2, 1]])

In [347]: a[np.arange(a.shape[0])[:,None],idx]
Out[347]: 
array([[ 1. ,  0.4],
       [ 0.8,  0.6],
       [ 0.5,  0.3]])

或者,使用 np.argsort -

获取 idx 的代码可能更慢,但更短一些
idx = a.argsort(1)[:,:-M-1:-1]

这是一个 post包含一些比较 np.argsortnp.argpartition 类似问题的运行时测试。

关于python - 在 NumPy 数组中沿轴获取 N 个最大值和索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40574071/

相关文章:

python - Ubuntu 通过 apt-get 安装 apache spark

Python3 中缺少 Python 2's ` exceptions` 模块,它的内容去哪儿了?

python - 如何使用 csv 数据作为变量将其应用于公式?

matlab - 从矩阵列中减去相应的向量值

python - 如何用Python中的另一个字符串替换函数中的字符串?

php - 图像处理(带 PHP 的 OpenCV)- exec 命令问题

python - 在 numpy 中用一个数组索引另一个数组

对 pandas 数据框进行 Python 曲线拟合,然后将 coef 添加到新列

javascript - 缩放旋转的矩形并找到新的控制点和中心

在单个目标映射到多个源的情况下定位最接近源的目标并解决冲突的算法