我有一段代码,需要在矩阵之间进行大量乘法运算。该代码旨在用于任意维度 n 的二维矩阵,原则上该矩阵可能非常大,从而使程序非常慢。 到目前为止,为了进行乘法运算,我一直使用 np.dot,如下例
def getV(csi, e, e2, k):
ktrans = k.transpose()
v = np.dot(csi, ktrans)
v = np.dot(v, e)
v = np.dot(v, k)
v = np.dot(v, csi)
v = np.dot(v, ktrans)
e2trans = e2.transpose()
v = np.dot(v, e2trans)
v = np.dot(v, k)
traceV = 2*v.trace()
return traceV
其中输出应是乘积迹线的两倍:
csi*ktrans*e*k*csi*ktrans*e2trans*k
(它们都是矩阵相乘)。 我确信有一种更快的方法来制作这么长的产品,可能在一次 channel 中。有人能解释一下怎么做吗?我已经尝试过,但似乎 np.dot 在任何单个段落中总是只需要两个矩阵。
最佳答案
因为 properties of the trace该计算可以重写如下,从而将矩阵乘法的次数从 7 次减少到 4 次:
def getV(csi, k, e, e2):
temp = k.dot(csi).dot(k.T)
trace_ = (temp.dot(e).dot(temp) * e2).sum()
return 2 * trace_
根据您当前的设置,您还可以尝试安装不同的 BLAS 库或在显卡而不是 CPU 上计算。
关于python - np.dot 用于二维矩阵之间的多个乘积,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/29148421/