Python - 查找一个矩阵的每一行中的 K 个最大值并与二进制矩阵进行比较

标签 python numpy matrix

我需要确定矩阵 a 中 k 个最大值的位置(索引)是否与二进制指示矩阵 b 位于同一位置。

import numpy as np
a = np.matrix([[.8,.2,.6,.4],[.9,.3,.8,.6],[.2,.6,.8,.4],[.3,.3,.1,.8]])
b = np.matrix([[1,0,0,1],[1,0,1,1],[1,1,1,0],[1,0,0,1]])
print "a:\n", a
print "b:\n", b

d = argsort(a)
d[:,2:] # Return whether these indices are in 'b'

返回:

a:
[[ 0.8  0.2  0.6  0.4]
 [ 0.9  0.3  0.8  0.6]
 [ 0.2  0.6  0.8  0.4]
 [ 0.3  0.3  0.1  0.8]]
b:
[[1 0 0 1]
 [1 0 1 1]
 [1 1 1 0]
 [1 0 0 1]]

matrix([[2, 0],
        [2, 0],
        [1, 2],
        [1, 3]])

我想比较从最后结果返回的索引,如果 b 在这些位置有索引,则返回计数。 对于此示例,最终期望的结果是:

1
2
2
1

换句话说,在a的第一行中,前2个值仅对应于b中的一个值,依此类推。

有什么想法可以有效地做到这一点吗?也许 argsort 在这里是错误的方法。 谢谢。

最佳答案

当您采用argsort时,您会得到从最小0到最大3的结果,因此您可以通过[反转它: :-1] 获取最大 0 和最小 3:

s = np.argsort(a, axis=1)[:,::-1]   
#array([[0, 2, 3, 1],
#       [0, 2, 3, 1],
#       [2, 1, 3, 0],
#       [3, 1, 0, 2]])

现在您可以使用 np.take 获取最大值所在的 0 和第二最大值所在的 1:

s2 = s + (np.arange(s.shape[0])*s.shape[1])[:,None]
s = np.take(s.flatten(),s2)
#array([[0, 3, 1, 2],
#       [0, 3, 1, 2],
#       [3, 1, 0, 2],
#       [2, 1, 3, 0]])

b 中,0 值应替换为 np.nan,以便 0==np.nan 给出 False:

b = np.float_(b)
b[b==0] = np.nan
#array([[  1.,  nan,  nan,   1.],
#       [  1.,  nan,   1.,   1.],
#       [  1.,   1.,   1.,  nan],
#       [  1.,  nan,  nan,   1.]])

下面的比较将为您提供所需的结果:

print np.logical_or(s==b-1, s==b).sum(axis=1)
#[[1]
# [2]
# [2]
# [1]]

一般情况下,将 an 个最大值与二进制 b 进行比较:

def check_a_b(a,b,n=2):
    b = np.float_(b)
    b[b==0] = np.nan
    s = np.argsort(a, axis=1)[:,::-1]
    s2 = s + (np.arange(s.shape[0])*s.shape[1])[:,None]
    s = np.take(s.flatten(),s2)
    ans = s==(b-1)
    for i in range(n-1):
        ans = np.logical_or( ans, s==b+i )
    return ans.sum(axis=1)

这将在逻辑或中进行成对比较。

关于Python - 查找一个矩阵的每一行中的 K 个最大值并与二进制矩阵进行比较,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/18115783/

相关文章:

python - 对数据框进行操作以将行转换为单独的列

vector - 将向量添加到numpy中的矩阵行

matrix - 在 gnuplot 中使用带有图像的矩阵的文件条目中的多个调色板和空标签

python - 解释 SciPy 的层次聚类树状图的输出? (也许发现了一个错误...)

Python 类方法无法识别输入

python - 如何访问模型的恢复权重?

python - 使用 polyfit 求 30 个点的多项式函数

java - 使用 Spring 框架在矩阵中直接注入(inject) double 值

c - 使用 b^n 次方求矩阵的逆

python - 使用 Scikit-Learn 在 Python 中为随机森林绘制树