我的目标是有效地计算以下嵌套循环,
Ab = np.random.randn(1000, 100)
Tb = np.zeros((100, 100, 100))
for i in range(d):
for j in range(d):
for k in range(d):
Tb[i, j, k] = np.sum(Ab[:, i] * Ab[:, j] * Ab[:, k])
我找到了一种更快的方法来通过仅循环组合来执行嵌套循环:
for i,j,k in itertools.combinations_with_replacement(np.arange(100), 3):
Abijk = np.sum(Ab[:, i] * Ab[:, j] * Ab[:, k])
Tb[i, j, k] = Abijk
Tb[i, k, j] = Abijk
Tb[j, i, k] = Abijk
Tb[j, k, i] = Abijk
Tb[k, j, i] = Abijk
Tb[k, i, j] = Abijk
有没有更有效的方法来做到这一点?
我希望有一种方法可以利用 Numpy 的 Blas、Numba 的 JIT 或 Pytorch GPU 实现。
最佳答案
方法#1
我们可以直接使用迭代器 einsum
string notation与 NumPy 的内置 np.einsum
。因此,解决方案是使用单个 einsum
调用 -
Tb = np.einsum('ai,aj,ak->ijk',Ab,Ab,Ab)
方法#2
我们可以使用广播元素乘法
的组合,然后使用np.tensordot
或np.matmul
来处理所有总和减少
。
因此,再次使用einsum
或显式维度扩展和广播
来获取广播的元素乘法 -
parte1 = np.einsum('ai,aj->aij',Ab,Ab)
parte1 = (Ab[:,None,:]*Ab[:,:,None]
然后,tensordot
或 np.matmul
-
Tb = np.tensordot(parte1,Ab,axes=((0),(0)))
Tb = np.matmul(parte1.T, Ab) # Or parte1.T @ Ab on Python 3.x
因此,第二种方法总共有四种可能的变体。
运行时测试
In [140]: d = 100
...: m = 1000
...: Ab = np.random.randn(m,d)
In [148]: %%timeit # original faster method
...: d = 100
...: Tb = np.zeros((d,d,d))
...: for i,j,k in itertools.combinations_with_replacement(np.arange(100), 3):
...: Abijk = np.sum(Ab[:, i] * Ab[:, j] * Ab[:, k])
...:
...: Tb[i, j, k] = Abijk
...: Tb[i, k, j] = Abijk
...:
...: Tb[j, i, k] = Abijk
...: Tb[j, k, i] = Abijk
...:
...: Tb[k, j, i] = Abijk
...: Tb[k, i, j] = Abijk
1 loop, best of 3: 2.08 s per loop
In [141]: %timeit np.einsum('ai,aj,ak->ijk',Ab,Ab,Ab)
1 loop, best of 3: 3.08 s per loop
In [142]: %timeit np.tensordot(np.einsum('ai,aj->aij',Ab,Ab),Ab,axes=((0),(0)))
...: %timeit np.tensordot(Ab[:,None,:]*Ab[:,:,None],Ab,axes=((0),(0)))
...: %timeit np.matmul(np.einsum('ai,aj->ija',Ab,Ab), Ab)
...: %timeit np.matmul(Ab.T[None,:,:]*Ab.T[:,None,:], Ab)
10 loops, best of 3: 56.8 ms per loop
10 loops, best of 3: 59.2 ms per loop
1 loop, best of 3: 673 ms per loop
1 loop, best of 3: 670 ms per loop
最快的似乎是基于 tensordot
的。因此,与基于更快的单循环 itertools
的方法相比,获得了 35x+
加速。
关于python - Numpy:更快地计算涉及求和的三重嵌套循环,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47620881/