我有一个形状为 (M,N) 的数组 A,现在我想进行运算
R = (A[:,newaxis,:] * A[newaxis,:,:]).sum(2)
这应该产生一个 (MxM) 数组。现在的问题是数组非常大,我收到内存错误,因为 MxMxN 数组放不下内存。
完成这项工作的最佳策略是什么? C? map ()?还是有专门的功能?
谢谢你 大卫
最佳答案
我不确定你的数组有多大,但下面是等价的:
R = np.einsum('ij,kj',A,A)
而且速度会快很多,内存占用也少得多:
In [7]: A = np.random.random(size=(500,400))
In [8]: %timeit R = (A[:,np.newaxis,:] * A[np.newaxis,:,:]).sum(2)
1 loops, best of 3: 1.21 s per loop
In [9]: %timeit R = np.einsum('ij,kj',A,A)
10 loops, best of 3: 54 ms per loop
如果我将 A
的大小增加到 (500,4000)
,np.einsum
会在大约 2 秒内完成计算,而原始公式由于必须创建的临时数组的大小而使我的机器停止运行。
更新:
正如@Jaime 在评论中指出的那样,np.dot(A,A.T)
也是该问题的等效表述,甚至可以比 np.einsum
解决方案。完全感谢他指出这一点,但如果他没有将其作为正式解决方案发布,我想将其拉到主要答案中。
关于python - 在一个太大的数组内部进行乘法和加法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/16986317/