python - JAX 与 JIT 和自定义差异化

标签 python jax numpyro

我正在通过 numpyro 使用 JAX。特别是,我想使用 B 样条函数(例如在 scipy.interpolate.BSpline 中实现)将不同的点转换为样条曲线,其中输入取决于模型中的某些参数。因此,我需要能够区分 JAX 中的 B 样条曲线(仅在输入参数中,而不是在结或整数顺序中(当然!))。

我可以轻松使用jax.custom_vjp,但当像在numpyro中那样使用JIT时就不行了。我查看了以下内容:

  1. https://github.com/google/jax/issues/1142
  2. https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html

看起来最好的希望是使用回调。 不过,我无法完全弄清楚这是如何运作的。https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html#using-call-to-call-a-jax-function-on-another-device-with-reverse-mode-autodiff-support

具有反向模式自动差异的 TensorFlow 示例似乎没有使用 JIT。

示例

以下是无需 JIT 即可运行的 Python 代码(请参阅 b_spline_basis() 函数):

from scipy.interpolate import BSpline
import numpy as np
from numpy import typing as npt
from functools import partial
import jax

doubleArray = npt.NDArray[np.double]

# see
#   https://stackoverflow.com/q/74699053/5861244
#   https://en.wikipedia.org/wiki/B-spline#Derivative_expressions
def _b_spline_deriv_inner(spline: BSpline, deriv_basis: doubleArray) -> doubleArray:  # type: ignore[no-any-unimported]
    out = np.zeros((deriv_basis.shape[0], deriv_basis.shape[1] - 1))

    for col_index in range(out.shape[1] - 1):
        scale = spline.t[col_index + spline.k + 1] - spline.t[col_index + 1]
        if scale != 0:
            out[:, col_index] = -deriv_basis[:, col_index + 1] / scale

    for col_index in range(1, out.shape[1]):
        scale = spline.t[col_index + spline.k] - spline.t[col_index]
        if scale != 0:
            out[:, col_index] += deriv_basis[:, col_index] / scale

    return float(spline.k) * out


def _b_spline_eval(spline: BSpline, x: doubleArray, deriv: int) -> doubleArray:  # type: ignore[no-any-unimported]
    if deriv == 0:
        return spline.design_matrix(x=x, t=spline.t, k=spline.k).todense()
    elif spline.k <= 0:
        return np.zeros((x.shape[0], spline.t.shape[0] - spline.k - 1))

    return _b_spline_deriv_inner(
        spline=spline,
        deriv_basis=_b_spline_eval(
            BSpline(t=spline.t, k=spline.k - 1, c=np.zeros(spline.c.shape[0] + 1)), x=x, deriv=deriv - 1
        ),
    )


@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2))
def b_spline_basis(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> doubleArray:
    return _b_spline_eval(spline=BSpline(t=knots, k=order, c=np.zeros((order + knots.shape[0] - 1))), x=x, deriv=deriv)[
        :, 1:
    ]


def b_spline_basis_fwd(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> tuple[doubleArray, doubleArray]:
    spline = BSpline(t=knots, k=order, c=np.zeros(order + knots.shape[0] - 1))
    return (
        _b_spline_eval(spline=spline, x=x, deriv=deriv)[:, 1:],
        _b_spline_eval(spline=spline, x=x, deriv=deriv + 1)[:, 1:],
    )


def b_spline_basis_bwd(
    knots: doubleArray, order: int, deriv: int, partials: doubleArray, grad: doubleArray
) -> tuple[doubleArray]:
    return (jax.numpy.sum(partials * grad, axis=1),)


b_spline_basis.defvjp(b_spline_basis_fwd, b_spline_basis_bwd)

if __name__ == "__main__":
    # tests

    knots = np.array([0, 0, 0, 0, 0.25, 1, 1, 1, 1])
    x = np.array([0.1, 0.5, 0.9])
    order = 3

    def test_jax(basis: doubleArray, partials: doubleArray, deriv: int) -> None:
        weights = jax.numpy.arange(1, basis.shape[1] + 1)

        def test_func(x: doubleArray) -> doubleArray:
            return jax.numpy.sum(jax.numpy.dot(b_spline_basis(knots=knots, order=order, deriv=deriv, x=x), weights))  # type: ignore[no-any-return]

        assert np.allclose(test_func(x), np.sum(np.dot(basis, weights)))
        assert np.allclose(jax.grad(test_func)(x), np.dot(partials, weights))

    deriv0 = np.transpose(
        np.array(
            [
                0.684,
                0.166666666666667,
                0.00133333333333333,
                0.096,
                0.444444444444444,
                0.0355555555555555,
                0.004,
                0.351851851851852,
                0.312148148148148,
                0,
                0.037037037037037,
                0.650962962962963,
            ]
        ).reshape(-1, 3)
    )

    deriv1 = np.transpose(
        np.array(
            [
                2.52,
                -1,
                -0.04,
                1.68,
                -0.666666666666667,
                -0.666666666666667,
                0.12,
                1.22222222222222,
                -2.29777777777778,
                0,
                0.444444444444444,
                3.00444444444444,
            ]
        ).reshape(-1, 3)
    )
    test_jax(deriv0, deriv1, deriv=0)

    deriv2 = np.transpose(
        np.array(
            [
                -69.6,
                4,
                0.8,
                9.6,
                -5.33333333333333,
                5.33333333333333,
                2.4,
                -2.22222222222222,
                -15.3777777777778,
                0,
                3.55555555555556,
                9.24444444444445,
            ]
        ).reshape(-1, 3)
    )
    test_jax(deriv1, deriv2, deriv=1)

    deriv3 = np.transpose(
        np.array(
            [
                504,
                -8,
                -8,
                -144,
                26.6666666666667,
                26.6666666666667,
                24,
                -32.8888888888889,
                -32.8888888888889,
                0,
                14.2222222222222,
                14.2222222222222,
            ]
        ).reshape(-1, 3)
    )
    test_jax(deriv2, deriv3, deriv=2)

最佳答案

实现此目的的最佳方法可能是结合使用 custom_jvpjax.pure_callback .

不幸的是,pure_callback 相对较新,还没有很好的文档,但您可以在 JAX 用户论坛中找到其使用示例(例如 here )。

复制到这里供后代使用,这是一个通过 numpy 回调在 jit 兼容代码中计算正弦和余弦的示例,并使用自定义 JVP 规则进行自动差分。

import jax
import numpy as np
jax.config.update('jax_enable_x64', True)

@jax.custom_jvp
def np_sin(x):
  # Compute the sine by calling-back to np.sin on the host.
  return jax.pure_callback(np.sin, jax.ShapeDtypeStruct(np.shape(x), np.float64), x)

@np_sin.defjvp
def _np_sin_jvp(primals, tangents):
  x, = primals
  dx, = tangents
  return np_sin(x), np_cos(x) * dx  #d sin(x) = cos(x) dx

@jax.custom_jvp
def np_cos(x):
  # Compute the cosine by calling-back to np.cos on the host.
  return jax.pure_callback(np.cos, jax.ShapeDtypeStruct(np.shape(x), np.float64), x)

@np_cos.defjvp
def _np_cos_jvp(primals, tangents):
  x, = primals
  dx, = tangents
  return np_cos(x), -np_sin(x) * dx  # d cos(x) = -sin(x) dx


print(np_sin(1.0))
# 0.8414709848078965
print(np_cos(1.0))
# 0.5403023058681398
print(jax.jit(jax.grad(np_sin))(1.0))
# 0.5403023058681398

请注意,由于 pure_callback 通过将数据发送回主机来进行操作,因此通常会对 GPU 和 TPU 等加速器产生大量开销,尽管在单 CPU 设置中,这种方法可以表现良好。

关于python - JAX 与 JIT 和自定义差异化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/74719636/

相关文章:

python - 切片 jax.numpy 数组时性能下降

pytorch - NumPyro 与 Pyro : Why is former 100x faster and when should I use the latter?

python - MongoDB - 显示数据库中的所有文件

python - 为什么重新创建 JAX numpy 数组并将其重新分配给相同的变量名称时 GPU 内存会增加?

python - 如果从 pandas 转换为 numpy 后数组包含 nan,则删除 'nan' 或减少 numpy 数组的长度

python - Jax 找不到静态参数

python - Google Colab 无法更改 Python 版本

python - 如何使用谷歌应用引擎更新数据存储中的列名

python - Python 中带有可选参数的命令行选项