python - Numpy/Scipy 稀疏与密集乘法

标签 python numpy scipy sparse-matrix matrix-multiplication

scipy 稀疏矩阵类型和普通 numpy 矩阵类型之间似乎存在一些差异

import scipy.sparse as sp
A = sp.dia_matrix(tri(3,4))
vec = array([1,2,3,4])

print A * vec                        #array([ 1.,  3.,  6.])

print A * (mat(vec).T)               #matrix([[ 1.],
                                     #        [ 3.],
                                     #        [ 6.]])

print A.todense() * vec              #ValueError: matrices are not aligned

print A.todense() * (mat(vec).T)     #matrix([[ 1.],
                                     #        [ 3.],
                                     #        [ 6.]])

为什么稀疏矩阵可以计算出数组应该被解释为列向量,而普通矩阵不能?

最佳答案

spmatrix 类(您可以在 scipy/sparse/base.py 中查看)中的 __mul__() 有一组“ifs”可以回答你的问题:

class spmatrix(object):
    ...
    def __mul__(self, other):
        ...
        M,N = self.shape
        if other.__class__ is np.ndarray:
            # Fast path for the most common case
            if other.shape == (N,):
                return self._mul_vector(other)
            elif other.shape == (N, 1):
                return self._mul_vector(other.ravel()).reshape(M, 1)
            elif other.ndim == 2  and other.shape[0] == N:
                return self._mul_multivector(other)

对于一维数组,它将始终从 compressed.py 转到 _mul_vector(),在类 _cs_matrix 中,代码如下:

def _mul_vector(self, other):
    M,N = self.shape

    # output array
    result = np.zeros(M, dtype=upcast_char(self.dtype.char,
                                           other.dtype.char))

    # csr_matvec or csc_matvec
    fn = getattr(sparsetools,self.format + '_matvec')
    fn(M, N, self.indptr, self.indices, self.data, other, result)

    return result

请注意,它假设输出具有稀疏矩阵的行数。基本上,它将您的输入一维数组视为适合稀疏数组的列数(没有转置或非转置)。但是对于 ndim==2 的 ndarray 它不能做这样的假设,所以如果你尝试:

vec = np.array([[1,2,3,4],
                [1,2,3,4]])

A * vec.T 将是唯一可行的选项。

对于一维矩阵,稀疏模块也不假定它适合列数。要检查您是否可以尝试:

A * mat(vec)
#ValueError: dimension mismatch

A * mat(vec).T 将是您唯一的选择。

关于python - Numpy/Scipy 稀疏与密集乘法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/16839840/

相关文章:

javascript - 重访 Python 私有(private)实例数据

python - 将一个数组的数据替换为第二个数组的 2 个值

python - 如何计算每个 bin 中的点数?

python - SciPy 中的 spearmanr 使用什么显着性检验?

python - 为什么我的 AWS SQS 消息没有被删除?

python - 更改 python 解释器中间脚本

python - 名称错误 : free variable 'd' referenced before assignment in enclosing scope

python setuptools 和 easy_install numpy

python - 使用 curvefit 拟合双曲函数和调和函数

python - 具有广播的稀疏 Scipy 矩阵和向量的元素最大值