python - 如何在 numpy ndarray 中获取每行的 N 个最大值?

标签 python numpy

我们知道当 N = 1 时该怎么做

import numpy as np

m = np.arange(15).reshape(3, 5)
m[xrange(len(m)), m.argmax(axis=1)]    # array([ 4,  9, 14])

当 N > 1 时,获得前 N 个的最佳方法是什么? (比如说,5)

最佳答案

使用 np.partition 进行部分排序可以比完整排序便宜得多:

gen = np.random.RandomState(0)
x = gen.permutation(100)

# full sort
print(np.sort(x)[-10:])
# [90 91 92 93 94 95 96 97 98 99]

# partial sort such that the largest 10 items are in the last 10 indices
print(np.partition(x, -10)[-10:])
# [90 91 93 92 94 96 98 95 97 99]

如果您需要对最大的 N 个项目进行排序,您可以在部分排序的最后 N 个元素上调用 np.sort数组:

print(np.sort(np.partition(x, -10)[-10:]))
# [90 91 92 93 94 95 96 97 98 99]

如果您的数组足够大,这仍然比对整个数组进行完整排序要快得多。


要对二维数组的每一行进行排序,您可以将 axis= 参数用于 np.partition 和/或 np.sort:

y = np.repeat(np.arange(100)[None, :], 5, 0)
gen.shuffle(y.T)

# partial sort, followed by a full sort of the last 10 elements in each row
print(np.sort(np.partition(y, -10, axis=1)[:, -10:], axis=1))
# [[90 91 92 93 94 95 96 97 98 99]
#  [90 91 92 93 94 95 96 97 98 99]
#  [90 91 92 93 94 95 96 97 98 99]
#  [90 91 92 93 94 95 96 97 98 99]
#  [90 91 92 93 94 95 96 97 98 99]]

基准:

In [1]: %%timeit x = np.random.permutation(10000000)
   ...: np.sort(x)[-10:]
   ...: 
1 loop, best of 3: 958 ms per loop

In [2]: %%timeit x = np.random.permutation(10000000)
np.partition(x, -10)[-10:]
   ....: 
10 loops, best of 3: 41.3 ms per loop

In [3]: %%timeit x = np.random.permutation(10000000)
np.sort(np.partition(x, -10)[-10:])
   ....: 
10 loops, best of 3: 78.8 ms per loop

关于python - 如何在 numpy ndarray 中获取每行的 N 个最大值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37125495/

相关文章:

python - 在Python数据框中复制Excel计算

python - Cython 指定固定长度字符串的 numpy 数组

python - 为什么 scipy.linalg 的方法在 9x9 矩阵上运行速度较慢?

python-3.x - Python Numpy 在爆炸性传播中跟踪所有 0

python - Python中微分方程的并行求解

python - 如何沿给定轴取元素,由它们的索引给出?

python - 如何避免 mako %def 中的重复过滤器规范?

python - 如何让 Tornado websocket 客户端接收服务器通知?

python - 删除 pandas 数据框中的列会删除父数据框中的列

python - 将 Plumbr 与在 Python 脚本中使用 R 制作图表的其他选项进行比较