python - 在一个太大的数组内部进行乘法和加法

标签 python numpy

我有一个形状为 (M,N) 的数组 A,现在我想进行运算

R = (A[:,newaxis,:] * A[newaxis,:,:]).sum(2)

这应该产生一个 (MxM) 数组。现在的问题是数组非常大,我收到内存错误,因为 MxMxN 数组放不下内存。

完成这项工作的最佳策略是什么? C? map ()?还是有专门的功能?

谢谢你 大卫

最佳答案

我不确定你的数组有多大,但下面是等价的:

R = np.einsum('ij,kj',A,A)

而且速度会快很多,内存占用也少得多:

In [7]: A = np.random.random(size=(500,400))

In [8]: %timeit R = (A[:,np.newaxis,:] * A[np.newaxis,:,:]).sum(2)
1 loops, best of 3: 1.21 s per loop

In [9]: %timeit R = np.einsum('ij,kj',A,A)
10 loops, best of 3: 54 ms per loop

如果我将 A 的大小增加到 (500,4000)np.einsum 会在大约 2 秒内完成计算,而原始公式由于必须创建的临时数组的大小而使我的机器停止运行。

更新:

正如@Jaime 在评论中指出的那样,np.dot(A,A.T) 也是该问题的等效表述,甚至可以比 np.einsum解决方案。完全感谢他指出这一点,但如果他没有将其作为正式解决方案发布,我想将其拉到主要答案中。

关于python - 在一个太大的数组内部进行乘法和加法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/16986317/

相关文章:

python - discord.py 发送机器人消息然后固定(重写)

python - PyTorch Softmax 尺寸错误

python - 使用 numba jit,Python 的段间距离

python - 计算一段时间内值增加和减少的最大总计

具有更新值的python for循环

android - 如何从android接收和返回webapp2中的字符串?

python - 如果消息是回复,机器人应该获取用户的 ID

python-3.x - 这种 numpy.where() 方法可以适应两种情况而不是一种吗?

python - pandas 中的并行 read_table

python - 如何在 Python/SQLAlchemy 中使用 "class"作为枚举值?