问题的后续:
Is it possible to improve python performance for this code?
当使用已接受答案中的函数时,无论是否有 jax(bar 或 jit_bar):
T = np.random.rand(5000, 566, 3)
@jax.jit
def jit_bar(Y):
u, v = jnp.triu_indices(Y.shape[0], 1)
return jnp.sqrt((3 * (Y[u] - Y[v]) ** 2).mean(axis=(-1, -2)))
msd = jit_bar(T)
向函数发送一个 (10000x566x3) 数组,使用 python3.6 可以稳定使用 1.5 GB 的内存,而使用 python = 3.8、3.9、3.10、3.11 时,内存会飙升至 +50 GB。
编辑:
经过一些试验,它似乎仅与 jax 有关,此代码可以正常运行:
python3.6、jax(0.2.17)、jaxlib(0.1.68)、numpy(1.19.2)
但不包括:
python3.11、jax(0.4.1)、jaxlib(0.4.1)、numpy(1.24.1)
最佳答案
如果 Y
的形状为 (10000, 566, 3)
则 triu_indices
返回长度为 (10000 * 10001) 的数组/2
,因此 Y[u]
和 Y[v]
的大小均为 (50005000, 566, 3)
。如果它们是 float32 值,则每个大小约为 316 GB。我不希望这段代码在任何地方都运行良好!
我怀疑旧的 JAX 版本可能有一些额外的优化,但在更高的版本中被删除了;考虑到计算的形式,唯一可能的是平方差的因式分解,以避免实例化完整的矩阵和,我隐约记得这是以前的 XLA 优化,但由于数值不稳定而被删除。
但是如果您愿意,您可以手动进行此类优化;这是一种似乎有效的方法,它为原始输入生成的最大中间数组的形状为 [10000, 10000]
,在 float32 中大约约 380MB:
@jax.jit
def jit_bar2(Y):
u, v = jnp.triu_indices(Y.shape[0], 1)
Y = Y.reshape(Y.shape[0], -1)
Y2m = (Y ** 2).mean(-1)
YYTm = (Y @ Y.T) / Y.shape[1]
return jnp.sqrt(3 * (Y2m[u] + Y2m[v] - 2 * YYTm[u, v]))
T = np.random.rand(50, 6, 3) # test with a smaller input
np.testing.assert_allclose(jit_bar(T), jit_bar2(T), atol=1E-5)
关于python - JAX 0.2.17 和 JAX 0.4.1 之间巨大的内存需求差异,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/75471289/