python - (稀疏)2D numpy 数组每行/列的快速非零索引

标签 python arrays numpy scipy sparse-matrix

我正在寻找最快的方法来获取二维数组每行每列的非零索引列表。以下是一段工作代码:

preds = [matrix[:,v].nonzero()[0] for v in range(matrix.shape[1])]
descs = [matrix[v].nonzero()[0] for v in range(matrix.shape[0])]

示例输入:

matrix = np.array([[0,0,0,0],[1,0,0,0],[1,1,0,0],[1,1,1,0]])

示例输出

preds = [array([1, 2, 3]), array([2, 3]), array([3]), array([], dtype=int64)]
descs = [array([], dtype=int64), array([0]), array([0, 1]), array([0, 1, 2])]

(这些列表称为 preds 和 descs,因为当矩阵被解释为邻接矩阵时,它们指的是 DAG 中的前辈和后裔,但这对问题来说不是必需的。)

时序示例: 出于计时目的,以下矩阵是一个很好的代表:

test_matrix = np.zeros(shape=(4096,4096),dtype=np.float32)
for k in range(16):
    test_matrix[256*(k+1):256*(k+2),256*k:256*(k+1)]=1

背景:在我的代码中,对于一个 4000x4000 的矩阵,这两行代码占用了 75% 的时间,而随后的拓扑排序和 DP 算法只占用了四分之一的剩余时间。矩阵中大约 5% 的值是非零值,因此稀疏矩阵解决方案可能适用。

谢谢。

(关于此处发布的建议:https://scicomp.stackexchange.com/questions/35242/fast-nonzero-indices-per-row-column-for-sparse-2d-numpy-array 那里也有答案,我将在评论中提供时间安排。 此链接包含一个可接受的答案,速度是原来的两倍。)

最佳答案

如果您有足够的动力,Numba 可以做出惊人的事情。 这是您需要的逻辑的快速实现。 简而言之,它计算了 np.nonzero() 的等价物,但它包含了稍后将索引分派(dispatch)为您需要的格式的信息。 这些信息的灵感来自 sparse.csr.indptrsparse.csc.indptr

import numpy as np
import numba as nb


@nb.jit
def cumsum(arr):
    result = np.empty_like(arr)
    cumsum = result[0] = arr[0]
    for i in range(1, len(arr)):
        cumsum += arr[i]
        result[i] = cumsum
    return result


@nb.jit
def count_nonzero(arr):
    arr = arr.ravel()
    n = 0
    for x in arr:
        if x != 0:
            n += 1
    return n


@nb.jit
def row_col_nonzero_nb(arr):
    n, m = arr.shape
    max_k = count_nonzero(arr)
    indices = np.empty((2, max_k), dtype=np.uint32)
    i_offset = np.zeros(n + 1, dtype=np.uint32)
    j_offset = np.zeros(m + 1, dtype=np.uint32)
    n, m = arr.shape
    k = 0
    for i in range(n):
        for j in range(m):
            if arr[i, j] != 0:
                indices[:, k] = i, j
                i_offset[i + 1] += 1
                j_offset[j + 1] += 1
                k += 1
    return indices, cumsum(i_offset), cumsum(j_offset)


def row_col_idx_nonzero_nb(arr):
    (ii, jj), jj_split, ii_split = row_col_nonzero_nb(arr)
    ii_ = np.argsort(jj)
    ii = ii[ii_]
    return np.split(ii, ii_split[1:-1]), np.split(jj, jj_split[1:-1])

与您的方法(下面的 row_col_idx_sep())和其他一些方法相比,根据 @hpaulj answer (row_col_idx_sparse_lil()) 和 @knl answer from scicomp.stackexchange.com (row_col_idx_sparse_coo()):

def row_col_idx_sep(arr):
    return (
        [arr[:, j].nonzero()[0] for j in range(arr.shape[1])],
        [arr[i, :].nonzero()[0] for i in range(arr.shape[0])],)
def row_col_idx_zip(arr):
    n, m = arr.shape
    ii = [[] for _ in range(n)]
    jj = [[] for _ in range(m)]
    x, y = np.nonzero(arr)
    for i, j in zip(x, y):
        ii[i].append(j)
        jj[j].append(i)
    return jj, ii
import scipy as sp
import scipy.sparse


def row_col_idx_sparse_coo(arr):
    coo_mat = sp.sparse.coo_matrix(arr)
    csr_mat = coo_mat.tocsr()
    csc_mat = coo_mat.tocsc()
    return (
        np.split(csc_mat.indices, csc_mat.indptr)[1:-1],
        np.split(csr_mat.indices, csr_mat.indptr)[1:-1],)
def row_col_idx_sparse_lil(arr):
    lil_mat = sp.sparse.lil_matrix(arr)
    return lil_mat.T.rows, lil_mat.rows

对于使用以下方法生成的输入:

def gen_input(n, density=0.1, dtype=np.float32):
    arr = np.zeros(shape=(n, n), dtype=dtype)
    indices = tuple(np.random.randint(0, n, (2, int(n * n * density))).tolist())
    arr[indices] = 1.0
    return arr

一个人会得到(你的 test_matrix 有大约 0.06 的非零密度):

m = gen_input(4096, density=0.06)
%timeit row_col_idx_sep(m)
# 1 loop, best of 3: 767 ms per loop
%timeit row_col_idx_zip(m)
# 1 loop, best of 3: 660 ms per loop
%timeit row_col_idx_sparse_coo(m)
# 1 loop, best of 3: 205 ms per loop
%timeit row_col_idx_sparse_lil(m)
# 1 loop, best of 3: 498 ms per loop
%timeit row_col_idx_nonzero_nb(m)
# 10 loops, best of 3: 130 ms per loop

表明这接近于最快的基于 scipy.sparse 的方法的两倍。

关于python - (稀疏)2D numpy 数组每行/列的快速非零索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62065793/

相关文章:

Python:numpy数组大于和小于一个值

python - 从 1D 数组 numpy 的 where 子句创建 2D 数组

python - 尽管是分开的,但在更改另一个数组时意外更改了一个 numpy 数组

python - python 线程问题 (dht22)

python - 将字典写入 csv 文件

python - `self[key] += value` 的魔术方法?

javascript - Angular $scope.$watch - 根据 ng-model/length 创建数组和总数

python - 如果满足条件,从数组中减去一个数字python

python - Nose 工具中的 assert_raises() 并没有真正起作用

c - 用指针编译c代码时出现段错误