python - 处理 jax numpy 数组中的不同形状(jit 兼容)

标签 python arrays shapes jit jax

重要说明:我需要这里的所有内容都与 jit 兼容,否则我的问题就微不足道了:)

我有一个 jax numpy 数组,例如:

a = jnp.array([1,5,3,4,5,6,7,2,9])

首先,我根据一个值对其进行过滤,假设我只保留 < 5

的值
a = jnp.where((a < 5), x=a, y=jnp.nan)
# a is now [ 1. nan  3.  4. nan nan nan  2. nan]

我只想保留非 nan 值:[ 1. 3. 4. 2.] 然后我将使用此数组进行其他操作。

但更重要的是,在我的程序执行期间,这段代码将执行多次,阈值会发生变化(即它不会总是 5)。

因此,最终数组的形状也会发生变化。这是我的 jit 编译问题,我不知道如何使其与 jit 兼容,因为形状取决于有多少元素符合阈值条件。

最佳答案

JAX 的 JIT 目前与动态(数据相关)形状的数组不兼容,因此无法完成您的问题。

有一些关于在 JAX 转换(如 JIT)中处理动态形状的实验性工作正在进行中(请参阅 https://github.com/google/jax/pull/9335),但我不确定它何时可以使用。

通常的解决方法是根据具有填充值的静态形状数组重新表达您的计算;例如,你可以使用这样的东西:

a = jnp.where((a < 5), size=len(a), fill_value=np.nan)

这将创建一个长度与 a 相同的数组,在前面有非 nan 值,并在末尾填充 nan 值。

关于python - 处理 jax numpy 数组中的不同形状(jit 兼容),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71692885/

相关文章:

python - 如何获取或初始化我想要的变量?

python - 创建运算矩阵张量

java - 在java中反转数组并将其分配给新数组

C# for 组合框中的循环

python - 从 DF 中删除包含重复单词的短语(Pandas、Python3)

python - 解析 Redis MONITOR 消息

javascript - 如何使用 angularjs ng-repeat 将数组中的 html 字符串转换为 html?

java - 如何使用 fillPolygon 在 java swing 中绘制星星

java - java中的三角形

java - 使用 relocate 在 javafx 中移动形状后获取元素坐标