python - JAX vmap 行为

标签 python vectorization jax

我试图了解 JAX vmap 的行为,所以我编写了以下代码:

import jax.numpy as jnp
from jax import vmap

def what(a,b,c):
  z = jnp.dot(a,b)
  return z + c

v_what = vmap(what, in_axes=(None,0,None))

a = jnp.array([1,1,3])
b = jnp.array([2,2])
c = 1.0

v_what(a,b,c)

输出是:

DeviceArray([[3., 3., 7.],
             [3., 3., 7.]], dtype=float32)

我知道唯一被更改的输入是 b,但是有人可以解释一下为什么会出现这样的结果吗?对函数进行向量化后,点积的表现如何?

最佳答案

您已指定转换后的函数应映射到 b 的第一个轴上,并且不映射到 a 的任何轴上或c 。粗略地说,您已经创建了一个执行此操作的映射函数:

def v_what(a, b, c):
  return jnp.stack([what(a, b_i, c) for b_i in b], axis=0)

对于您的输入,每行内的点积看起来像 jnp.dot(a, 2) ,结果相当于a * 2 .

关于python - JAX vmap 行为,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66548897/

相关文章:

python - 将 C 方法与 python 脚本连接起来

python - 使用索引以递归方式快速获取目录中的所有文件

python - 加快深度的 numpy 整数数组索引

python - 过滤图像中补丁位置的最优化方法

python - JAX:jit 函数的时间随着函数访问的内存而超线性增长

python - 给定批处理的样本找到 x,y 位置 - python, jax

python - 如果列表中没有则显示/隐藏字段

python - 使用 Python(Raspbian 操作系统)发送电子邮件时出现 SSL 错误

r - 何时在 R 中使用 for 循环

python - vmap 遍历 jax 中的列表