给定两个具有相同维度的张量 A 和 B (d>=2)
和形状[A_{1},...,A_{d-2},A_{d-1},A_{d}]
和[A_{1},...,A_{d-2},B_{d-1},B_{d}]
(前 d-2 维度的形状相同)。
有没有办法计算最后两个维度的克罗内克积?
形状my_kron(A,B)
应该是[A_{1},...,A_{d-2},A_{d-1}*B_{d-1},A_{d}*B_{d}]
。
例如 d=3
,
A.shape=[2,3,3]
B.shape=[2,4,4]
C=my_kron(A,B)
C[0,...]
应该是 A[0,...]
的克罗内克积和B[0,...]
和C[1,...]
A[1,...]
的克罗内克积和B[1,...]
.
对于 d=2,这就是jnp.kron
(或 np.kron
)函数可以。
对于 d=3,这可以通过 jax.vmap
来实现。
jax.vmap(lambda x, y: jnp.kron(x[0, :], y[0, :]))(A, B)
但我无法找到一般(未知)维度的解决方案。 有什么建议吗?
最佳答案
用numpy
术语来说,我认为这就是你正在做的事情:
In [104]: A = np.arange(2*3*3).reshape(2,3,3)
In [105]: B = np.arange(2*4*4).reshape(2,4,4)
In [106]: C = np.array([np.kron(a,b) for a,b in zip(A,B)])
In [107]: C.shape
Out[107]: (2, 12, 12)
它将初始维度 2 视为一个批处理
。一个明显的概括是 reshape 数组,将较高维度减少到 1,例如reshape(-1,3,3)
等。然后,将 C
reshape 回所需的 n 维。
np.kron
确实接受 3d(及更高),但它在共享的 2 维度上进行某种外部
:
In [108]: np.kron(A,B).shape
Out[108]: (4, 12, 12)
将 4 个维度可视化为 (2,2),我可以采用对角线
并得到您的C
:
In [109]: np.allclose(np.kron(A,B)[[0,3]], C)
Out[109]: True
完整的 kron
执行的计算比需要的更多,但仍然更快:
In [110]: timeit C = np.array([np.kron(a,b) for a,b in zip(A,B)])
108 µs ± 2.23 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [111]: timeit np.kron(A,B)[[0,3]]
76.4 µs ± 1.36 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
我确信可以以更直接的方式进行计算,但这样做需要更好地理解 kron
的工作原理。快速浏览一下 np.kron
代码表明它执行了 outer(A,B)
In [114]: np.outer(A,B).shape
Out[114]: (18, 32)
它具有相同数量的元素,但它随后 reshape
并连接
以生成kron
布局。
但凭直觉,我发现这相当于你想要的:
In [123]: D = A[:,:,None,:,None]*B[:,None,:,None,:]
In [124]: np.allclose(D.reshape(2,12,12),C)
Out[124]: True
In [125]: timeit np.reshape(A[:,:,None,:,None]*B[:,None,:,None,:],(2,12,12))
14.3 µs ± 184 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
这很容易推广到更多主要维度。
def my_kron(A,B):
D = A[...,:,None,:,None]*B[...,None,:,None,:]
ds = D.shape
newshape = (*ds[:-4],ds[-4]*ds[-3],ds[-2]*ds[-1])
return D.reshape(newshape)
In [137]: my_kron(A.reshape(1,2,1,3,3),B.reshape(1,2,1,4,4)).shape
Out[137]: (1, 2, 1, 12, 12)
关于python - 如何沿数组维度映射克罗内克乘积?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/73673599/