python - np.dot 用于二维矩阵之间的多个乘积

标签 python performance numpy matrix multiplication

我有一段代码,需要在矩阵之间进行大量乘法运算。该代码旨在用于任意维度 n 的二维矩阵,原则上该矩阵可能非常大,从而使程序非常慢。 到目前为止,为了进行乘法运算,我一直使用 np.dot,如下例

def getV(csi, e, e2, k):
    ktrans = k.transpose()
    v = np.dot(csi, ktrans)
    v = np.dot(v, e)
    v = np.dot(v, k)
    v = np.dot(v, csi)
    v = np.dot(v,  ktrans)
    e2trans = e2.transpose()
    v = np.dot(v, e2trans)
    v = np.dot(v, k)
    traceV = 2*v.trace()
    return traceV

其中输出应是乘积迹线的两倍:

csi*ktrans*e*k*csi*ktrans*e2trans*k

(它们都是矩阵相乘)。 我确信有一种更快的方法来制作这么长的产品,可能在一次 channel 中。有人能解释一下怎么做吗?我已经尝试过,但似乎 np.dot 在任何单个段落中总是只需要两个矩阵。

最佳答案

因为 properties of the trace该计算可以重写如下,从而将矩阵乘法的次数从 7 次减少到 4 次:

def getV(csi, k, e, e2):
    temp = k.dot(csi).dot(k.T)
    trace_ = (temp.dot(e).dot(temp) * e2).sum()
    return 2 * trace_

根据您当前的设置,您还可以尝试安装不同的 BLAS 库或在显卡而不是 CPU 上计算。

关于python - np.dot 用于二维矩阵之间的多个乘积,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/29148421/

相关文章:

python - matplotlib streamplot 上的箭头数

mysql - 从单表性能问题创建多对多

python - 如何拟合一些系数受限的多项式?

javascript - Javascript 驱动动画的每秒帧数?

database - i18n 性能 : resx vs. 数据库?

python - Numpy 来自数组,为每个元素创建一个矩阵 N*M,其中所有值都设置为元素,无需 for 循环

python - 在pytorch中分解一批会导致不同的结果,为什么?

python - 使用 s3boto 的 Django 存储忽略 MEDIA_URL

python - pytest-cov 在 Tavern 测试中始终显示 0 覆盖率

python - 从 pandas 数据框中提取多列的组合