python - 给定批处理的样本找到 x,y 位置 - python, jax

标签 python numpy coordinates jax

我想在下一批维度为 [batch, 4,4] = [2,4,4] 的采样数组中找到 '1' 的位置。

import jax
import jax.numpy as jnp

a = jnp.array([[[0., 0., 0., 1.],
                [0., 0., 0., 0.],
                [0., 1., 0., 1.],
                [0., 0., 1., 1.]],
             
               [[1., 0., 1., 0.],
                [1., 0., 0., 0.],
                [0., 0., 0., 0.],
                [0., 1., 0., 1.]]])

我尝试遍历批处理的维度(使用 vmap)并使用 jax 函数查找坐标

b = jax.vmap(jnp.where)(a)
print('b', b)

但我收到一个错误,但我不知道如何修复:

The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
This Tracer was created on line /home/imi/Desktop/Backflow/backflow/src/debug.py:17 (<module>)

我期望以下输出:

b = [[[0,3], [2,1],[2,3],[3,2],[3,3]],

     [[0,0],[0,2],[1,0],[3,1],[3,3]]

第一行 [x,y] 坐标对应于第一批中存在“1”的位置,第二行对应于第二批中存在“1”的位置。

最佳答案

vmap 这样的 JAX 转换需要静态调整数组大小,因此无法精确执行您想要的计算(因为 1 条目的数量,因此输出数组的大小取决于数据)。

但是,如果您先验知道每批有五个条目,您可以执行以下操作:

from functools import partial
indices = jax.vmap(partial(jnp.where, size=5))(a)
print(jnp.stack(indices, axis=2))
[[[0 3]
  [2 1]
  [2 3]
  [3 2]
  [3 3]]

 [[0 0]
  [0 2]
  [1 0]
  [3 1]
  [3 3]]]

如果您先验不知道有多少个1条目,那么您有几个选择:一是避免JAX转换并调用un-在每个批处理上转换jnp.where:

result = [jnp.column_stack(jnp.where(b)) for b in a]
print(result)
[DeviceArray([[0, 3],
             [2, 1],
             [2, 3],
             [3, 2],
             [3, 3]], dtype=int32), DeviceArray([[0, 0],
             [0, 2],
             [1, 0],
             [3, 1],
             [3, 3]], dtype=int32)]

请注意,对于这种情况,通常不可能将结果存储在单个数组中,因为每个批处理中可能有不同数量的 1 条目,并且 JAX 不支持不规则数组.

另一个选项是将大小设置为某个最大值,并输出填充结果:

max_size = a[0].size  # size of slice is the upper bound
fill_value = a[0].shape  # fill with out-of-bound indices
indices = jax.vmap(partial(jnp.where, size=max_size, fill_value=fill_value))(a)
print(jnp.stack(indices, axis=2))
[[[0 3]
  [2 1]
  [2 3]
  [3 2]
  [3 3]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]]

 [[0 0]
  [0 2]
  [1 0]
  [3 1]
  [3 3]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]]]

通过填充结果,您可以编写其余代码来预测这些填充值。

关于python - 给定批处理的样本找到 x,y 位置 - python, jax,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/75134947/

相关文章:

python - Python中的字符串如何格式化 boolean 值?

python - 在 plone 中,如何使上传的文件不可下载?

javascript - 如何从发出警报的页面恢复 "text",尤其是在人类用户在页面的 **警报** 上 *单击* 后?

python - numpy 数组和 pymssql

python - 混合类型的 NumPy 数组/矩阵

python - 来自单独数组中纬度/经度坐标的(纬度,经度)的所有可能点

Android OpenGL ES 触摸添加对象

ios - 是否可以将颜色分配给目标坐标 x 到 y,而无需上传 UIimage

python - Python/Jython 中低效的随机掷骰子

python - 寻找二阶多项式的回归 python