python - ndarray 中最大值的索引列表

标签 python numpy indexing

我有一个 ndarray。我需要从这个数组中选择具有最大值的 N 个数字的列表。我发现 heapq.nlargest 来查找 N 个最大的条目,但我需要提取索引。 我想构建一个新数组,其中只有第一列中权重最大的 N 行才能生存。其余行将被随机值替换

import numpy as np
import heapq   # For choosing list of max values
a = [[1.1,2.1,3.1], [2.1,3.1,4.1], [5.1,0.1,7.1],[0.1,1.1,1.1],[4.1,3.1,9.1]]
a = np.asarray(a)
maxVal = heapq.nlargest(2,a[:,0])

if __name__ == '__main__':
    print a
    print maxVal

我的输出是:

[[ 1.1  2.1  3.1]
[ 2.1  3.1  4.1]
[ 5.1  0.1  7.1]
[ 0.1  1.1  1.1]
[ 4.1  3.1  9.1]]

[5.0999999999999996, 4.0999999999999996]

但我需要的是 [2,4] 作为构建新数组的索引。索引是行,因此如果在此示例中我想将其余部分替换为 0,我需要以以下内容结束:

[[0.0  0.0  0.0]
[ 0.0  0.0  0.0]
[ 5.1  0.1  7.1]
[ 0.0  0.0  0.0]
[ 4.1  3.1  9.1]]

我陷入了需要索引的地步。原始数组有 1000 行和 100 列。权重是标准化浮点,我不想做类似 if a[:,1] == maxVal[0]: 的事情,因为有时我的权重非常接近,可以用更多来完成值maxVal[0]比我原来的N。

是否有任何简单的方法可以提取此设置上的索引以替换数组的其余部分?

最佳答案

如果只有 1000 行,我会忘记堆并在第一列上使用 np.argsort:

>>> np.argsort(a[:,0])[::-1][:2]
array([2, 4])

如果你想把它们放在一起,它看起来像:

def trim_rows(a, n) :
    idx = np.argsort(a[:,0])[:-n]
    a[idx] = 0

>>> a = np.random.rand(10, 4)
>>> a

array([[ 0.34416425,  0.89021968,  0.06260404,  0.0218131 ],
       [ 0.72344948,  0.79637177,  0.70029863,  0.20096129],
       [ 0.27772833,  0.05372373,  0.00372941,  0.18454153],
       [ 0.09124461,  0.38676351,  0.98478492,  0.72986697],
       [ 0.84789887,  0.69171688,  0.97718206,  0.64019977],
       [ 0.27597241,  0.26705301,  0.62124467,  0.43337711],
       [ 0.79455424,  0.37024814,  0.93549275,  0.01130491],
       [ 0.95113795,  0.32306471,  0.47548887,  0.20429272],
       [ 0.3943888 ,  0.61586129,  0.02776393,  0.2560126 ],
       [ 0.5934556 ,  0.23093912,  0.12550062,  0.58542137]])
>>> trim_rows(a, 3)
>>> a

array([[ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.84789887,  0.69171688,  0.97718206,  0.64019977],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.79455424,  0.37024814,  0.93549275,  0.01130491],
       [ 0.95113795,  0.32306471,  0.47548887,  0.20429272],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ]])

对于您的数据大小来说,它可能足够快:

In [7]: a = np.random.rand(1000, 100)

In [8]: %timeit -n1 -r1 trim_rows(a, 50)
1 loops, best of 1: 7.65 ms per loop

关于python - ndarray 中最大值的索引列表,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/15124318/

相关文章:

Python错误列表索引超出范围

python - 引用过滤数据框的列的可扩展方式

python - 如何修改此正则表达式以不匹配不间断空格?

python - 使用 merge 和 groupby 将 DF 引入新方案

python - 从图的顶部到底部更改 matshow xticklabel 位置

python - 创建一个距离中心有欧氏距离的二维 Numpy 数组

python - 替换 Django 模板中的字符

python - 从 str 或 int 继承

python - SymPy 表达式的 `subs` 方法的 `evalf` 参数到底是做什么用的?像 `s.evalf(subs={...})` 吗?

sql - 用于提高 SQL Server 上多个连接性能的索引 View