python - JAX:避免对沿一个轴使用不同数量的元素评估的函数进行即时重新编译

标签 python python-3.x jit jax

当 JIT 函数的输入结构基本保持不变(除了一个轴具有不同数量的元素之外)时,是否可以避免重新编译 JIT 函数?

import jax

@jax.jit
def f(x):
    print('recompiling')
    return (x + 10) * 100

a = f(jax.numpy.arange(300000000).reshape((-1, 2, 2)).block_until_ready()) # recompiling
b = f(jax.numpy.arange(300000000).reshape((-1, 2, 2)).block_until_ready())
c = f(jax.numpy.arange(450000000).reshape((-1, 2, 2)).block_until_ready()) # recompiling. It would be nice if it weren't

要求:pip安装jax、jaxlib

最佳答案

不,当您调用具有不同形状数组的函数时,无法避免重新编译。从根本上说,JAX 为静态形状的输入和输出编译函数,并且使用新形状的数组调用 JIT 编译的函数总是会触发重新编译。

目前正在进行一些放宽此要求的工作(在 JAX 的 github 存储库中搜索“动态形状”),但目前没有此类 API 可用。

关于python - JAX:避免对沿一个轴使用不同数量的元素评估的函数进行即时重新编译,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70126391/

相关文章:

python - 在 Python 中生成数字列表及其负数

java - 处理来自 Throwable catch 的 NullPointerException 的最佳方法? (安卓)

c# - JIT 编译器的 IL 优化

python - 获取带有 token 的用户 Django Rest Framework

python - Hashlib Python 模块方法更新中的最大字节数限制

python - 从列表中删除具有重复键元素的元组

python - 在同一轴python上绘制for循环内生成的多个图

interpreter - llvm/工具 : lli REPL compared to LuaJIT

python - 从编辑 View 自定义(覆盖)Flask-Admin 的提交方法

python - PyQt:事件未触发,我的代码有什么问题?