我试图了解 JAX vmap 的行为,所以我编写了以下代码:
import jax.numpy as jnp
from jax import vmap
def what(a,b,c):
z = jnp.dot(a,b)
return z + c
v_what = vmap(what, in_axes=(None,0,None))
a = jnp.array([1,1,3])
b = jnp.array([2,2])
c = 1.0
v_what(a,b,c)
输出是:
DeviceArray([[3., 3., 7.],
[3., 3., 7.]], dtype=float32)
我知道唯一被更改的输入是 b
,但是有人可以解释一下为什么会出现这样的结果吗?对函数进行向量化后,点积的表现如何?
最佳答案
您已指定转换后的函数应映射到 b
的第一个轴上,并且不映射到 a
的任何轴上或c
。粗略地说,您已经创建了一个执行此操作的映射函数:
def v_what(a, b, c):
return jnp.stack([what(a, b_i, c) for b_i in b], axis=0)
对于您的输入,每行内的点积看起来像 jnp.dot(a, 2)
,结果相当于a * 2
.
关于python - JAX vmap 行为,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66548897/