我想形成一个数组,其中包含数组中 k 个最小值的索引:
import heapq
import numpy as np
a= np.array([[1, 3, 5, 2, 3],
[7, 6, 5, 2, 4],
[2, 0, 5, 6, 4]])
[t[0] for t in heapq.nsmallest(2,enumerate(a[1]),lambda(t):t[1])]
===[3, 4]
但这失败了:
[t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]
Traceback (most recent call last):
File "<pyshell#19>", line 1, in <module>
[t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]
TypeError: 'numpy.bool_' object is not iterable
最佳答案
您的问题出在a.all()
中:
[t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]
这会检查数组中所有元素的真实性,即False
(因为您有一个 0)。
如果数组与 k 相比不是太大,您可以使用 .argsort
获取值。这里我将为每行选择两个最大的位置:
print a.argsort()[:,:2]
array([[0, 3],
[3, 4],
[1, 0]])
如果您想要全局最小值的位置,请首先展平数组:
a.flatten().argsort()[:2]
如果数组非常大,您可以使用 np.argpartition
获得更好的性能,这将仅执行部分排序。
关于python - 如何在多维数组中找到 k 个最小数字的索引?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/24344976/