我需要对两个 4D 数组 (m & n) 执行矩阵乘法,m 和 n 的维度分别为 2x2x2x2 和 2x3x2x2,这应该会产生一个 2x3x2x2 数组。经过大量研究(主要是在本网站上),这似乎可以通过 np.einsum 或 np.tensordot 有效地完成,但我无法复制答案我从 Matlab 得到(手工验证)。我了解这些方法(einsum 和 tensordot)在对 2D 数组执行矩阵乘法时如何工作(清楚地解释了 here ),但我无法获得正确的 4D 数组轴索引。显然我错过了什么!我的实际问题涉及两个 23x23x3x3 复数数组,但我的测试数组是:
a = np.array([[1, 7], [4, 3]])
b = np.array([[2, 9], [4, 5]])
c = np.array([[3, 6], [1, 0]])
d = np.array([[2, 8], [1, 2]])
e = np.array([[0, 0], [1, 2]])
f = np.array([[2, 8], [1, 0]])
m = np.array([[a, b], [c, d]]) # (2,2,2,2)
n = np.array([[e, f, a], [b, d, c]]) # (2,3,2,2)
我意识到复数可能会带来更多问题,但现在,我只是想了解索引如何与 einsum 和 tensordot 一起工作。我要寻找的答案是这个 2x3x2x2 数组:
+----+-----------+-----------+-----------+
| | 0 | 1 | 2 |
+====+===========+===========+===========+
| 0 | [[47 77] | [[22 42] | [[44 40] |
| | [31 67]] | [27 74]] | [33 61]] |
+----+-----------+-----------+-----------+
| 1 | [[42 70] | [[24 56] | [[41 51] |
| | [10 19]] | [ 6 20]] | [ 6 13]] |
+----+-----------+-----------+-----------+
我最接近的尝试是使用 np.tensordot:
mn = np.tensordot(m,n, axes=([1,3],[0,2]))
这给了我一个 2x2x3x2 数组,数字正确但顺序不正确:
+----+-----------+-----------+
| | 0 | 1 |
+====+===========+===========+
| 0 | [[47 77] | [[31 67] |
| | [22 42] | [24 74] |
| | [44 40]] | [33 61]] |
+----+-----------+-----------+
| 1 | [[42 70] | [[10 19] |
| | [24 56] | [ 6 20] |
| | [41 51]] | [ 6 13]] |
+----+-----------+-----------+
我还尝试实现了 here 中的一些解决方案但没有任何运气。
任何关于我如何改进这一点的想法将不胜感激,谢谢
最佳答案
您可以简单地交换 tensordot
结果上的轴,这样我们仍然可以利用 BLAS
和 tensordot
-
np.tensordot(m,n, axes=((1,3),(0,2))).swapaxes(1,2)
或者,我们可以在 tensordot
调用中交换 m
和 n
的位置并转置以重新排列所有轴 -
np.tensordot(n,m, axes=((0,2),(1,3))).transpose(2,0,3,1)
使用 reshape 和交换轴的人工,我们也可以引入2D
矩阵乘法与np.dot
,就像这样-
m0,m1,m2,m3 = m.shape
n0,n1,n2,n3 = n.shape
m2D = m.swapaxes(1,2).reshape(-1,m1*m3)
n2D = n.swapaxes(1,2).reshape(n0*n2,-1)
out = m2D.dot(n2D).reshape(m0,m2,n1,n3).swapaxes(1,2)
运行时测试-
将输入数组缩放到 10x
形状:
In [85]: m = np.random.rand(20,20,20,20)
In [86]: n = np.random.rand(20,30,20,20)
# @Daniel F's soln with einsum
In [87]: %timeit np.einsum('ijkl,jmln->imkn', m, n)
10 loops, best of 3: 136 ms per loop
In [126]: %timeit np.tensordot(m,n, axes=((1,3),(0,2))).swapaxes(1,2)
100 loops, best of 3: 2.31 ms per loop
In [127]: %timeit np.tensordot(n,m, axes=((0,2),(1,3))).transpose(2,0,3,1)
100 loops, best of 3: 2.37 ms per loop
In [128]: %%timeit
...: m0,m1,m2,m3 = m.shape
...: n0,n1,n2,n3 = n.shape
...: m2D = m.swapaxes(1,2).reshape(-1,m1*m3)
...: n2D = n.swapaxes(1,2).reshape(n0*n2,-1)
...: out = m2D.dot(n2D).reshape(m0,m2,n1,n3).swapaxes(1,2)
100 loops, best of 3: 2.36 ms per loop
关于python - 4D numpy 数组上的矩阵乘法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47752324/