python - 为什么重新创建 JAX numpy 数组并将其重新分配给相同的变量名称时 GPU 内存会增加?

标签 python memory gpu nvidia jax

当我重新创建 JAX np 数组并将其重新分配给相同的变量名称时,由于某种原因,GPU 内存几乎使第一次重新创建时翻倍,然后在后续重新创建/重新分配时保持稳定。

为什么会发生这种情况?这是 JAX 数组通常预期的行为吗?

完全可运行的最小示例:https://colab.research.google.com/drive/1piUvyVylRBKm1xb1WsocsSVXJzvn5bdI?usp=sharing .

对于后代,以防 Colab 宕机:

%env XLA_PYTHON_CLIENT_PREALLOCATE=false
import jax
from jax import numpy as jnp
from jax import random

# First creation of jnp array
x = jnp.ones(shape=(int(1e8),), dtype=float)
get_gpu_memory() # the memory usage from the first call is 618 MB

# Second creation of jnp array, reassigning it to the same variable name
x = jnp.ones(shape=(int(1e8),), dtype=float)
get_gpu_memory() # the memory usage is now 1130 MB - almost double!

# Third creation of jnp array, reassigning it to the same variable name
x = jnp.ones(shape=(int(1e8),), dtype=float)
get_gpu_memory() # the memory usage is stable at 1130 MB.

谢谢!

最佳答案

这种行为的原因来自于几个事物的相互作用:

  1. 如果不进行预分配,GPU 内存使用量将根据需要增加,但在删除缓冲区时不会收缩。

  2. 当你重新分配一个Python变量时,旧值仍然存在于内存中,直到Python垃圾收集器注意到它不再被引用,并将其删除。这将需要少量时间在后台发生(您可以调用 import gc; gc.collect() 强制在任何时候发生)。

  3. JAX 异步向 GPU 发送指令,这意味着一旦 Python 垃圾收集了 GPU 支持的值,Python 脚本可能会继续运行一小段时间,然后实际从设备中删除相应的缓冲区。

所有这些意味着在取消分配先前的 x 值和释放设备上的内存之间存在一些延迟,如果您立即分配新值,设备可能会扩展其内存在删除旧数组之前分配内存以适应新数组。

那么为什么第三次调用时内存使用量保持不变?好吧,此时第一个分配已被删除,因此已经有空间用于第三个分配,而不会增加内存占用。

考虑到这些事情,您可以通过在删除旧值和创建新值之间设置延迟来保持分配不变;即替换它:

x = jnp.ones(shape=(int(1e8),), dtype=float)

这样:

del x
time.sleep(1)
x = jnp.ones(shape=(int(1e8),), dtype=float)

当我以这种方式运行时,我看到内存使用量恒定为 618MiB。

关于python - 为什么重新创建 JAX numpy 数组并将其重新分配给相同的变量名称时 GPU 内存会增加?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/74628777/

相关文章:

c++ - 执行三个嵌套 for 循环的最快方法是什么?

OpenGL 和多 GPU

python - AWS 上的 Hadoop 流 - 情绪分析示例

python - 正则表达式选择和替换双括号内的空格

python - 在 Python 或 SQL 中像求解器一样使用 Excel

Python 与 SQL 外连接给出不同的结果。为什么?

c++ - 使用内存映射文件时如何释放物理内存?

python - 在不填充内存的情况下读取 Python 中的特定文件行

c++ - C++中的对象大小是不可预测的

c++ - 获取或生成 "C++"中的系统信息