python - 如何获取稀疏矩阵数据数组的对角线元素的索引

标签 python scipy sparse-matrix

我有一个 csr 格式的稀疏矩阵,例如:

>>> a = sp.random(3, 3, 0.6, format='csr')  # an example
>>> a.toarray()  # just to see how it looks like
array([[0.31975333, 0.88437035, 0.        ],
       [0.        , 0.        , 0.        ],
       [0.14013856, 0.56245834, 0.62107962]])
>>> a.data  # data array
array([0.31975333, 0.88437035, 0.14013856, 0.56245834, 0.62107962])

对于这个特定的示例,我想要获取 [0, 4],它们是非零对角线元素 0.31975333 的数据数组索引0.62107962

执行此操作的简单方法如下:

ind = []
seen = set()
for i, val in enumerate(a.data):
    if val in a.diagonal() and val not in seen:
        ind.append(i)
        seen.add(val)

但实际上矩阵非常大,所以我不想使用 for 循环或使用 toarray() 方法转换为 numpy 数组。有没有更有效的方法来做到这一点?

编辑:我刚刚意识到,当非对角线元素等于并在某些对角线元素之前时,上面的代码会给出错误的结果:它返回该非对角线的索引元素。此外,它不返回重复对角线元素的索引。例如:

a = np.array([[0.31975333, 0.88437035, 0.        ],
              [0.62107962, 0.31975333, 0.        ],
              [0.14013856, 0.56245834, 0.62107962]])
a = sp.csr_matrix(a)

>>> a.data
array([0.31975333, 0.88437035, 0.62107962, 0.31975333, 0.14013856,
       0.56245834, 0.62107962])

我的代码返回ind = [0, 2],但它应该是[0, 3, 6]。 Andras Deak 提供的代码(他的 get_rowwise 函数)返回正确的结果。

最佳答案

我找到了一个可能更有效的解决方案,尽管它仍然循环。但是,它循环遍历矩阵的行而不是元素本身。根据矩阵的稀疏模式,这可能会更快,也可能不会更快。对于具有 N 行的稀疏矩阵,这保证会花费 N 次迭代。

我们只是循环遍历每一行,通过 a.indices 和 a.indptr 获取填充的列索引,并且如果给定行的对角线元素存在于填充的值然后我们计算它的索引:

import numpy as np
import scipy.sparse as sp

def orig_loopy(a):
    ind = []
    seen = set()
    for i, val in enumerate(a.data):
        if val in a.diagonal() and val not in seen:
            ind.append(i)
            seen.add(val)
    return ind

def get_rowwise(a):
    datainds = []
    indices = a.indices # column indices of filled values
    indptr = a.indptr   # auxiliary "pointer" to data indices
    for irow in range(a.shape[0]):
        rowinds = indices[indptr[irow]:indptr[irow+1]] # column indices of the row
        if irow in rowinds:
            # then we've got a diagonal in this row
            # so let's find its index
            datainds.append(indptr[irow] + np.flatnonzero(irow == rowinds)[0])
    return datainds

a = sp.random(300, 300, 0.6, format='csr')
orig_loopy(a) == get_rowwise(a) # True

对于具有相同密度的 (300,300) 形状的随机输入,原始版本在 3.7 秒内运行,新版本在 5.5 毫秒内运行。

关于python - 如何获取稀疏矩阵数据数组的对角线元素的索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52839461/

相关文章:

python - 从 csv 中提取奇怪排列的数据并使用 python 转换为另一个 csv 文件

python - Argparse 将命令行字符串传递给变量 python3

python - PyQtGraph 图形布局小部件问题

python - Pandas 如何使用 read_fwf 读取填充为 0 的数字?

python - 如何解决稀疏矩阵的缓慢 groupby 问题?

c++ - 包含 LU 分解的矩阵

python - 将图例或背景图像添加到 Igraph 0.6 (for python) 图中

python - scipy.interpolate.lagrange 在某些数据上失败

python - : import scipy as sp/sc 的官方缩写

python - 从另一个 csr_matrix 的一行创建 csr_matrix 的平铺操作