我正在使用 NumPy 处理一些大型数据矩阵(大小约为 50GB)。我运行这段代码的机器有 128GB 的 RAM,所以做这种量级的简单线性运算在内存方面应该不是问题。
但是,当我在 Python 中计算以下代码时,我目睹了巨大的内存增长(超过 100GB):
import numpy as np
# memory allocations (everything works fine)
a = np.zeros((1192953, 192, 32), dtype='f8')
b = np.zeros((1192953, 192), dtype='f8')
c = np.zeros((192, 32), dtype='f8')
a[:] = b[:, :, np.newaxis] - c[np.newaxis, :, :] # memory explodes here
请注意,初始内存分配没有任何问题。但是,当我尝试通过广播执行减法运算时,内存增长到 100GB 以上。我一直认为广播可以避免额外的内存分配,但现在我不确定是否总是这样。
因此,有人可以详细说明为什么会发生这种内存增长,以及如何使用更高效的内存结构重写以下代码?
我在 IPython Notebook 中运行 Python 2.7 中的代码。
最佳答案
@rth 建议小批量进行操作是一个很好的建议。您也可以尝试使用函数 np.subtract
并将其指定为目标数组以避免创建额外的临时数组。我还认为您不需要将 c
索引为 c[np.newaxis, :, :]
,因为它已经是一个 3-d 数组。
所以不是
a[:] = b[:, :, np.newaxis] - c[np.newaxis, :, :] # memory explodes here
试试
np.subtract(b[:, :, np.newaxis], c, a)
np.subtract
的第三个参数是目标数组。
关于python - NumPy 中广播操作的内存增长,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/31536504/