python - 知道矩阵是对称和正半定的矩阵求逆的更有效方法

标签 python numpy scipy linear-algebra

我在 python 中使用 numpy 反转协方差矩阵。协方差矩阵是对称的半正定矩阵。

我想知道是否存在针对对称半正定矩阵优化的算法,比 numpy.linalg.inv() 更快(当然,如果可以从 python 轻松访问它的实现! ).我没有设法在 numpy.linalg 中找到任何东西,也没有在网上搜索。

编辑:

正如@yixuan 所观察到的,半正定矩阵通常不是严格可逆的。我检查过在我的情况下我只是得到了正定矩阵,所以我接受了一个适用于正定性的答案。无论如何,在 LAPACK 低级例程中,我发现 DSY* 例程仅针对 simmetric/hermitian 矩阵进行了优化,尽管它们似乎在 scipy 中缺失(也许是只是安装版本的问题)。

最佳答案

我尝试了@percusse 的答案,但是当我对其执行计时时,我发现它比 np.linalg.inv 慢了大约 33%(使用 100K 随机正定 4x4 np.float64 矩阵)。这是我的实现:

from scipy.linalg import lapack

def upper_triangular_to_symmetric(ut):
    ut += np.triu(ut, k=1).T

def fast_positive_definite_inverse(m):
    cholesky, info = lapack.dpotrf(m)
    if info != 0:
        raise ValueError('dpotrf failed on input {}'.format(m))
    inv, info = lapack.dpotri(cholesky)
    if info != 0:
        raise ValueError('dpotri failed on input {}'.format(cholesky))
    upper_triangular_to_symmetric(inv)
    return inv

我试着分析它,令我惊讶的是,它花费了大约 82% 的时间调用 upper_triangular_to_symmetric(这不是“困难”部分)!我认为发生这种情况是因为它正在进行浮点加法以组合矩阵,而不是简单的复制。

我尝试了一个快了大约 87% 的 upper_triangular_to_symmetric 实现(参见 this question and answer):

from scipy.linalg import lapack

inds_cache = {}

def upper_triangular_to_symmetric(ut):
    n = ut.shape[0]
    try:
        inds = inds_cache[n]
    except KeyError:
        inds = np.tri(n, k=-1, dtype=np.bool)
        inds_cache[n] = inds
    ut[inds] = ut.T[inds]


def fast_positive_definite_inverse(m):
    cholesky, info = lapack.dpotrf(m)
    if info != 0:
        raise ValueError('dpotrf failed on input {}'.format(m))
    inv, info = lapack.dpotri(cholesky)
    if info != 0:
        raise ValueError('dpotri failed on input {}'.format(cholesky))
    upper_triangular_to_symmetric(inv)
    return inv

这个版本比 np.linalg.inv 快大约 68%,并且只花费大约 42% 的时间调用 upper_triangular_to_symmetric

关于python - 知道矩阵是对称和正半定的矩阵求逆的更有效方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40703042/

相关文章:

python - TensorFlow - 在读取和写入 TFRecords 文件时设置图像的形状?

python - 从一个十六进制字符串构造一个 Numpy 数组

python:在顶部绘制带有函数线的直方图

python - Matplotlib 不显示数字

python - 确定 numpy 数组的 2 个(垂直或水平)相邻元素是否具有相同值的最快方法

python - pandas 中删除的列重新出现

Python UnicodeDecodeError : 'utf-8' codec can't decode byte 0x8c in position 2: invalid start byte

python - Scikit-learn χ²(卡方)统计量和相应的列联表

python - 根据需要的值将其他行插入到数据框中

python - 为什么准确率和损失在训练时保持完全相同?