当我重新创建 JAX np 数组并将其重新分配给相同的变量名称时,由于某种原因,GPU 内存几乎使第一次重新创建时翻倍,然后在后续重新创建/重新分配时保持稳定。
为什么会发生这种情况?这是 JAX 数组通常预期的行为吗?
完全可运行的最小示例:https://colab.research.google.com/drive/1piUvyVylRBKm1xb1WsocsSVXJzvn5bdI?usp=sharing .
对于后代,以防 Colab 宕机:
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
import jax
from jax import numpy as jnp
from jax import random
# First creation of jnp array
x = jnp.ones(shape=(int(1e8),), dtype=float)
get_gpu_memory() # the memory usage from the first call is 618 MB
# Second creation of jnp array, reassigning it to the same variable name
x = jnp.ones(shape=(int(1e8),), dtype=float)
get_gpu_memory() # the memory usage is now 1130 MB - almost double!
# Third creation of jnp array, reassigning it to the same variable name
x = jnp.ones(shape=(int(1e8),), dtype=float)
get_gpu_memory() # the memory usage is stable at 1130 MB.
谢谢!
最佳答案
这种行为的原因来自于几个事物的相互作用:
如果不进行预分配,GPU 内存使用量将根据需要增加,但在删除缓冲区时不会收缩。
当你重新分配一个Python变量时,旧值仍然存在于内存中,直到Python垃圾收集器注意到它不再被引用,并将其删除。这将需要少量时间在后台发生(您可以调用
import gc; gc.collect()
强制在任何时候发生)。JAX 异步向 GPU 发送指令,这意味着一旦 Python 垃圾收集了 GPU 支持的值,Python 脚本可能会继续运行一小段时间,然后实际从设备中删除相应的缓冲区。
所有这些意味着在取消分配先前的 x
值和释放设备上的内存之间存在一些延迟,如果您立即分配新值,设备可能会扩展其内存在删除旧数组之前分配内存以适应新数组。
那么为什么第三次调用时内存使用量保持不变?好吧,此时第一个分配已被删除,因此已经有空间用于第三个分配,而不会增加内存占用。
考虑到这些事情,您可以通过在删除旧值和创建新值之间设置延迟来保持分配不变;即替换它:
x = jnp.ones(shape=(int(1e8),), dtype=float)
这样:
del x
time.sleep(1)
x = jnp.ones(shape=(int(1e8),), dtype=float)
当我以这种方式运行时,我看到内存使用量恒定为 618MiB。
关于python - 为什么重新创建 JAX numpy 数组并将其重新分配给相同的变量名称时 GPU 内存会增加?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/74628777/