python - 如何将 jax vmap 用于嵌套循环?

标签 python performance vectorization jit jax

我想使用 vmap 来矢量化此代码以提高性能。

def matrix(dataA, dataB):
    return jnp.array([[func(a, b) for b in dataB] for a in dataA])
matrix(data, data)
我试过这个:
def f(x, y):
    return func(x, y)
mapped = jax.vmap(f)
mapped(data, data)
但这仅给出对角线条目。
基本上我有一个向量 data = [1,2,3,4,5] (示例),我想得到一个矩阵,使得每个条目 (i, j)矩阵是 f(data[i], data[j]) .因此,生成的矩阵形状将为 (len(data), len(data)) .

最佳答案

jax.vmap一次映射一组轴。如果要映射两个独立的轴集,可以通过嵌套两个 vmap 来实现。转换:

mapped = jax.vmap(jax.vmap(f, in_axes=(None, 0)), in_axes=(0, None))
result = mapped(data, data)

关于python - 如何将 jax vmap 用于嵌套循环?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69429846/

相关文章:

python - 如何将文件从 Nodejs 发送到 Flask Python?

java - 在 Python 中复制 Java 密码哈希代码 (PBKDF2WithHmacSHA1)

python - 从套接字客户端向套接字服务器发送PIL图像对象

python - 移动矩阵行以使最大值位于中间

numpy - 如何向量化 for 循环,就像下面在 numpy 中提到的那样?

Python 到 Json 的转换

performance - 更改 Julia DEPOT_PATH 时包加载时间会急剧增加

mysql - 如何优化运行速度太慢的查询(LEFT JOIN)

performance - R: tm Textmining 包:文档级元数据生成速度慢

python - Numpy:将行值广播到 channel