python - JAX 0.2.17 和 JAX 0.4.1 之间巨大的内存需求差异

标签 python numpy jax

问题的后续:

Is it possible to improve python performance for this code?

当使用已接受答案中的函数时,无论是否有 jax(bar 或 jit_bar):

T = np.random.rand(5000, 566, 3)
@jax.jit
def jit_bar(Y):
   u, v = jnp.triu_indices(Y.shape[0], 1)
   return jnp.sqrt((3 * (Y[u] - Y[v]) ** 2).mean(axis=(-1, -2)))
msd = jit_bar(T)

向函数发送一个 (10000x566x3) 数组,使用 python3.6 可以稳定使用 1.5 GB 的内存,而使用 python = 3.8、3.9、3.10、3.11 时,内存会飙升至 +50 GB。

编辑:

经过一些试验,它似乎仅与 jax 有关,此代码可以正常运行:

python3.6、jax(0.2.17)、jaxlib(0.1.68)、numpy(1.19.2)

但不包括:

python3.11、jax(0.4.1)、jaxlib(0.4.1)、numpy(1.24.1)

最佳答案

如果 Y 的形状为 (10000, 566, 3)triu_indices 返回长度为 (10000 * 10001) 的数组/2,因此 Y[u]Y[v] 的大小均为 (50005000, 566, 3) 。如果它们是 float32 值,则每个大小约为 316 GB。我不希望这段代码在任何地方都运行良好!

我怀疑旧的 JAX 版本可能有一些额外的优化,但在更高的版本中被删除了;考虑到计算的形式,唯一可能的是平方差的因式分解,以避免实例化完整的矩阵和,我隐约记得这是以前的 XLA 优化,但由于数值不稳定而被删除。

但是如果您愿意,您可以手动进行此类优化;这是一种似乎有效的方法,它为原始输入生成的最大中间数组的形状为 [10000, 10000],在 float32 中大约约 380MB:

@jax.jit
def jit_bar2(Y):
   u, v = jnp.triu_indices(Y.shape[0], 1)
   Y = Y.reshape(Y.shape[0], -1)
   Y2m = (Y ** 2).mean(-1)
   YYTm = (Y @ Y.T) / Y.shape[1]
   return jnp.sqrt(3 * (Y2m[u] + Y2m[v] - 2 * YYTm[u, v]))

T = np.random.rand(50, 6, 3)  # test with a smaller input
np.testing.assert_allclose(jit_bar(T), jit_bar2(T), atol=1E-5)

关于python - JAX 0.2.17 和 JAX 0.4.1 之间巨大的内存需求差异,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/75471289/

相关文章:

python - etree xml解析和删除

python - 仅针对特定轴的克罗内克乘法

python - 获取 numpy ndarray 的一部分(对于任意维度)

installation - 无法安装特定的 JAX jaxlib GPU 版本

tensorflow - 基于秩的计算的自动微分

python - 在给定值之后屏蔽 numpy 数组

python - Pyramid 页面加载动画

python - 从文件中读取数据,将其拆分为列表,然后获取该数据并将其放入函数中

python - 根据另一个 Pandas 数据框的值填充一个 Pandas 数据框的最快方法是什么?

python - 如何清除 Tkinter ListBox Python