python - 如何沿数组维度映射克罗内克乘积?

标签 python arrays numpy jax kronecker-product

给定两个具有相同维度的张量 A 和 B (d>=2)和形状[A_{1},...,A_{d-2},A_{d-1},A_{d}][A_{1},...,A_{d-2},B_{d-1},B_{d}] (前 d-2 维度的形状相同)。

有没有办法计算最后两个维度的克罗内克积? 形状my_kron(A,B)应该是[A_{1},...,A_{d-2},A_{d-1}*B_{d-1},A_{d}*B_{d}] 。 例如 d=3 ,

A.shape=[2,3,3]
B.shape=[2,4,4]
C=my_kron(A,B)

C[0,...]应该是 A[0,...] 的克罗内克积和B[0,...]C[1,...] A[1,...]的克罗内克积和B[1,...] .

对于 d=2,这就是jnp.kron (或 np.kron )函数可以。

对于 d=3,这可以通过 jax.vmap 来实现。 jax.vmap(lambda x, y: jnp.kron(x[0, :], y[0, :]))(A, B)

但我无法找到一般(未知)维度的解决方案。 有什么建议吗?

最佳答案

numpy术语来说,我认为这就是你正在做的事情:

In [104]: A = np.arange(2*3*3).reshape(2,3,3)
In [105]: B = np.arange(2*4*4).reshape(2,4,4)

In [106]: C = np.array([np.kron(a,b) for a,b in zip(A,B)])
In [107]: C.shape
Out[107]: (2, 12, 12)

它将初始维度 2 视为一个批处理。一个明显的概括是 reshape 数组,将较高维度减少到 1,例如reshape(-1,3,3) 等。然后,将 C reshape 回所需的 n 维。

np.kron 确实接受 3d(及更高),但它在共享的 2 维度上进行某种外部:

In [108]: np.kron(A,B).shape
Out[108]: (4, 12, 12)

将 4 个维度可视化为 (2,2),我可以采用对角线并得到您的C:

In [109]: np.allclose(np.kron(A,B)[[0,3]], C)
Out[109]: True

完整的 kron 执行的计算比需要的更多,但仍然更快:

In [110]: timeit C = np.array([np.kron(a,b) for a,b in zip(A,B)])
108 µs ± 2.23 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [111]: timeit np.kron(A,B)[[0,3]]
76.4 µs ± 1.36 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

我确信可以以更直接的方式进行计算,但这样做需要更好地理解 kron 的工作原理。快速浏览一下 np.kron 代码表明它执行了 outer(A,B)

In [114]: np.outer(A,B).shape
Out[114]: (18, 32)

它具有相同数量的元素,但它随后 reshape 连接以生成kron布局。

但凭直觉,我发现这相当于你想要的:

In [123]: D = A[:,:,None,:,None]*B[:,None,:,None,:]
In [124]: np.allclose(D.reshape(2,12,12),C)
Out[124]: True
In [125]: timeit np.reshape(A[:,:,None,:,None]*B[:,None,:,None,:],(2,12,12))
14.3 µs ± 184 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

这很容易推广到更多主要维度。

def my_kron(A,B):
   D = A[...,:,None,:,None]*B[...,None,:,None,:]
   ds = D.shape
   newshape = (*ds[:-4],ds[-4]*ds[-3],ds[-2]*ds[-1])
   return D.reshape(newshape)

In [137]: my_kron(A.reshape(1,2,1,3,3),B.reshape(1,2,1,4,4)).shape
Out[137]: (1, 2, 1, 12, 12)

关于python - 如何沿数组维度映射克罗内克乘积?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/73673599/

相关文章:

python - 使用 pydrive 转换为 Google 驱动程序更新文件给出 : ApiRequestError: HttpError 500 Error

python - 显示开始和结束之间的年份

java - 检测哪个复选框被触摸?数组 Java Android Studio

Javascript:将对象从 foreach 推送到数组

python - 有没有办法通过使用 numpy 或其他资源来优化 Python 中的三重循环?

python - 我可以在 python 中为多处理创建一个共享的多数组或列表对象列表吗?

python - 我如何让 ipython 知道对自写模块所做的更改?

php - 如何将 mysql 连接结果包含到 PHP 多维数组中

python - 在 Python 2.7.9 中安装包

python - numpy 中概率被忽略