arrays - 二维数组中的前 N ​​个值,要屏蔽重复项

标签 arrays numpy sorting duplicates

我有二维 numpy 数组:

arr = np.array([[0.1, 0.1, 0.3, 0.4, 0.5], 
                [0.06, 0.1, 0.1, 0.1, 0.01], 
                [0.24, 0.24, 0.24, 0.24, 0.24], 
                [0.2, 0.25, 0.3, 0.12, 0.02]])
print (arr)
[[0.1  0.1  0.3  0.4  0.5 ]
 [0.06 0.1  0.1  0.1  0.01]
 [0.24 0.24 0.24 0.24 0.24]
 [0.2  0.25 0.3  0.12 0.02]]

我想过滤前 N 个值,所以我使用 argsort :
N = 2
arr1 = np.argsort(-arr, kind='mergesort') < N
print (arr1)
[[False False False  True  True]
 [ True False False  True False] <- first top 2 are duplicates
 [ True  True False False False]
 [False  True  True False False]]

它工作得很好,至少不是顶部重复,比如第 2 行。

预期输出:
print (arr1)
[[False False False  True  True]
 [False  True  True False False]
 [ True  True False False False]
 [False  True  True False False]]

有没有更快的方法来处理它?

最佳答案

切片以获取前 N 个索引并使用它们创建最终掩码 -

idx = np.argsort(-arr, kind='mergesort')[:,:N]
mask = np.zeros(arr.shape, dtype=bool)
np.put_along_axis(mask, idx, True, axis=-1)

关于arrays - 二维数组中的前 N ​​个值,要屏蔽重复项,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61517878/

相关文章:

python-3.x - imshow 的有效数据

datetime - 在 Pandas 中绘制 TimeDeltas

java - 使用 Lambda 解引用 Int 值

java - 在 Pivot 上对数组进行分区

java - 使用用户定义数组时出现空指针异常

c++ - 将文件读入数组

python - Numpy 从 2 个数组中选择元素

algorithm - 以递增顺序迭代数字对

php - 创建基于二维合并数组的值

javascript - 包含具有 "ing"的字符串的过滤器数组?