重要说明:我需要这里的所有内容都与 jit 兼容,否则我的问题就微不足道了:)
我有一个 jax numpy 数组,例如:
a = jnp.array([1,5,3,4,5,6,7,2,9])
首先,我根据一个值对其进行过滤,假设我只保留 < 5
的值a = jnp.where((a < 5), x=a, y=jnp.nan)
# a is now [ 1. nan 3. 4. nan nan nan 2. nan]
我只想保留非 nan 值:[ 1. 3. 4. 2.]
然后我将使用此数组进行其他操作。
但更重要的是,在我的程序执行期间,这段代码将执行多次,阈值会发生变化(即它不会总是 5)。
因此,最终数组的形状也会发生变化。这是我的 jit 编译问题,我不知道如何使其与 jit 兼容,因为形状取决于有多少元素符合阈值条件。
最佳答案
JAX 的 JIT 目前与动态(数据相关)形状的数组不兼容,因此无法完成您的问题。
有一些关于在 JAX 转换(如 JIT)中处理动态形状的实验性工作正在进行中(请参阅 https://github.com/google/jax/pull/9335),但我不确定它何时可以使用。
通常的解决方法是根据具有填充值的静态形状数组重新表达您的计算;例如,你可以使用这样的东西:
a = jnp.where((a < 5), size=len(a), fill_value=np.nan)
这将创建一个长度与 a
相同的数组,在前面有非 nan 值,并在末尾填充 nan
值。
关于python - 处理 jax numpy 数组中的不同形状(jit 兼容),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71692885/