我想在下一批维度为 [batch, 4,4] = [2,4,4] 的采样数组中找到 '1' 的位置。
import jax
import jax.numpy as jnp
a = jnp.array([[[0., 0., 0., 1.],
[0., 0., 0., 0.],
[0., 1., 0., 1.],
[0., 0., 1., 1.]],
[[1., 0., 1., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 1., 0., 1.]]])
我尝试遍历批处理的维度(使用 vmap)并使用 jax 函数查找坐标
b = jax.vmap(jnp.where)(a)
print('b', b)
但我收到一个错误,但我不知道如何修复:
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
This Tracer was created on line /home/imi/Desktop/Backflow/backflow/src/debug.py:17 (<module>)
我期望以下输出:
b = [[[0,3], [2,1],[2,3],[3,2],[3,3]],
[[0,0],[0,2],[1,0],[3,1],[3,3]]
第一行 [x,y] 坐标对应于第一批中存在“1”的位置,第二行对应于第二批中存在“1”的位置。
最佳答案
像 vmap
这样的 JAX 转换需要静态调整数组大小,因此无法精确执行您想要的计算(因为 1
条目的数量,因此输出数组的大小取决于数据)。
但是,如果您先验知道每批有五个条目,您可以执行以下操作:
from functools import partial
indices = jax.vmap(partial(jnp.where, size=5))(a)
print(jnp.stack(indices, axis=2))
[[[0 3]
[2 1]
[2 3]
[3 2]
[3 3]]
[[0 0]
[0 2]
[1 0]
[3 1]
[3 3]]]
如果您先验不知道有多少个1
条目,那么您有几个选择:一是避免JAX转换并调用un-在每个批处理上转换jnp.where
:
result = [jnp.column_stack(jnp.where(b)) for b in a]
print(result)
[DeviceArray([[0, 3],
[2, 1],
[2, 3],
[3, 2],
[3, 3]], dtype=int32), DeviceArray([[0, 0],
[0, 2],
[1, 0],
[3, 1],
[3, 3]], dtype=int32)]
请注意,对于这种情况,通常不可能将结果存储在单个数组中,因为每个批处理中可能有不同数量的 1
条目,并且 JAX 不支持不规则数组.
另一个选项是将大小
设置为某个最大值,并输出填充结果:
max_size = a[0].size # size of slice is the upper bound
fill_value = a[0].shape # fill with out-of-bound indices
indices = jax.vmap(partial(jnp.where, size=max_size, fill_value=fill_value))(a)
print(jnp.stack(indices, axis=2))
[[[0 3]
[2 1]
[2 3]
[3 2]
[3 3]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]]
[[0 0]
[0 2]
[1 0]
[3 1]
[3 3]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]
[4 4]]]
通过填充结果,您可以编写其余代码来预测这些填充值。
关于python - 给定批处理的样本找到 x,y 位置 - python, jax,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/75134947/