python - vmap 遍历 jax 中的列表

标签 python jax

使用 jax,我尝试计算每个样本的梯度,对其进行处理,然后将它们采用标准形式来计算标准参数更新。 我的工作代码看起来像

differentiate_per_sample = jit(vmap(grad(loss), in_axes=(None, 0, 0)))
gradients = differentiate_per_sample(params, x, y)

# some code

gradients_summed_over_samples = []
    for layer in gradients:
        (dw, db) = layer
        (dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
        gradients_summed_over_samples.append((dw, db))

其中 gradients 的形式为 list(tuple(DeviceArray(...), DeviceArray(...)), ...)

现在我尝试将循环重写为vmap(不确定它最终是否会带来加速)

def sum_samples(layer):
    (dw, db) = layer
    (dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))

vmap(sum_samples)(gradients)

但是 sum_samples 仅调用一次,而不是针对列表中的每个元素。

是列表有问题还是我理解有其他错误?

最佳答案

jax.vmap 将仅映射到 jax 数组输入,而不是数组或元组列表的输入。此外,vmapped 函数无法就地修改输入;函数应该返回一个值,并且该返回值将与其他返回值堆叠以构造输出

例如,您可以修改您定义的函数并按如下方式使用它:

import jax.numpy as np
from jax import random

def sum_samples(layer):
    (dw, db) = layer
    (dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
    return np.array([dw, db])

key = random.PRNGKey(1701)
data = random.uniform(key, (10, 2, 20))

result = vmap(sum_samples)(data)
print(result.shape)
# (10, 2)

旁注:如果您使用这种方法,上面的 vmapped 函数可以更简洁地表示为:

def sum_samples(layer):
    return layer.sum(1)

关于python - vmap 遍历 jax 中的列表,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61786831/

相关文章:

python - 将变量分配给Python对象

python - 如何在日期时间索引上加入两个数据帧,使用 nan 自动填充不匹配的行

python - Django 模型 : Get elements considering their presence as a foreign key in another table

python - 在给定值之后屏蔽 numpy 数组

python - 将 tf.data.Dataset 转换为 jax.numpy 迭代器

python - 如何让 ibm_db 或 PyDB2 python 模块与 Mac OS X 10.7 Lion 中的 DB2 配合使用?

python - 在 mac os 上为 python3 安装 mysqlclient for mariadb

python - 如何在 Jax 中将函数有效地映射到多个参数(元组/列表),每个参数的形状不一致

python - JAX 中 Pytree 的加权和