python - NumPy 数组下三角区域中 n 个最大值的索引

标签 python arrays numpy distance

我有一个 numpy 余弦相似度矩阵。我想找到 n 个最大值的索引,但排除对角线上的 1.0,并且仅排除其中的下三角区域。

similarities = [[ 1.          0.18898224  0.16903085]
 [ 0.18898224  1.          0.67082039]
 [ 0.16903085  0.67082039  1.        ]]

在这种情况下,如果我想要两个最高值,我希望它返回 [1, 0][2, 1]

我尝试使用argpartition,但这不会返回我正在寻找的内容

n_select = 1
most_similar = (-similarities).argpartition(n_select, axis=None)[:n_select]

如何获得排除对角线 1 的 n 个最高值并排除上三角元素?

最佳答案

方法#1

使用np.tril_indices的一种方法 -

def n_largest_indices_tril(a, n=2):
    m = a.shape[0]
    r,c = np.tril_indices(m,-1)
    idx = a[r,c].argpartition(-n)[-n:]
    return zip(r[idx], c[idx])

示例运行 -

In [39]: a
Out[39]: 
array([[ 1.  ,  0.4 ,  0.59,  0.15,  0.29],
       [ 0.4 ,  1.  ,  0.03,  0.57,  0.57],
       [ 0.59,  0.03,  1.  ,  0.9 ,  0.52],
       [ 0.15,  0.57,  0.9 ,  1.  ,  0.37],
       [ 0.29,  0.57,  0.52,  0.37,  1.  ]])

In [40]: n_largest_indices_tril(a, n=2)
Out[40]: [(2, 0), (3, 2)]

In [41]: n_largest_indices_tril(a, n=3)
Out[41]: [(4, 1), (2, 0), (3, 2)]

方法#2

为了性能,我们可能希望避免生成所有下三角索引,而是使用掩码,为我们提供第二种方法来解决我们的情况,就像这样 -

def n_largest_indices_tril_v2(a, n=2):
    m = a.shape[0]
    r = np.arange(m)
    mask = r[:,None] > r
    idx = a[mask].argpartition(-n)[-n:]

    clens = np.arange(m).cumsum()    
    grp_start = clens[:-1]
    grp_stop = clens[1:]-1    

    rows = np.searchsorted(grp_stop, idx)+1    
    cols  = idx - grp_start[rows-1]
    return zip(rows, cols)

运行时测试

In [143]: # Setup symmetric array 
     ...: N = 1000
     ...: a = np.random.rand(N,N)*0.9
     ...: np.fill_diagonal(a,1)
     ...: m = a.shape[0]
     ...: r,c = np.tril_indices(m,-1)
     ...: a[r,c] = a[c,r]

In [144]: %timeit n_largest_indices_tril(a, n=2)
100 loops, best of 3: 12.5 ms per loop

In [145]: %timeit n_largest_indices_tril_v2(a, n=2)
100 loops, best of 3: 7.85 ms per loop

对于n个最小索引

要获取 n 个最小的,只需使用 ndarray.argpartition(n)[:n] 代替这两种方法。

关于python - NumPy 数组下三角区域中 n 个最大值的索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47289086/

相关文章:

python - 在 django 中重置密码?

python - 如何在 Apple M1 芯片的 Mac 上使用 Tensorflow 检查 GPU 可访问性?

python - 寻找图中的负循环

c - 当您写入内存超出数组范围时会发生什么?

python - 根据其他数据帧的条件创建数据帧

python - 使用 matplotlib 创建每周时间表

Python list.append 作为参数

将一个结构体数组复制到另一个更小的结构体数组

java - 如何访问 JSONObject、JSONObject 和 JSONArray 中的内容? Java/安卓

python - 尝试在不使用 for(或类似)循环的情况下对 numpy 数组内的所有子数组执行操作