我正在通过 numpyro 使用 JAX。特别是,我想使用 B 样条函数(例如在 scipy.interpolate.BSpline 中实现)将不同的点转换为样条曲线,其中输入取决于模型中的某些参数。因此,我需要能够区分 JAX 中的 B 样条曲线(仅在输入参数中,而不是在结或整数顺序中(当然!))。
我可以轻松使用jax.custom_vjp
,但当像在numpyro中那样使用JIT时就不行了。我查看了以下内容:
- https://github.com/google/jax/issues/1142
- https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
看起来最好的希望是使用回调。 不过,我无法完全弄清楚这是如何运作的。 在https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html#using-call-to-call-a-jax-function-on-another-device-with-reverse-mode-autodiff-support
具有反向模式自动差异的 TensorFlow 示例似乎没有使用 JIT。
示例
以下是无需 JIT 即可运行的 Python 代码(请参阅 b_spline_basis()
函数):
from scipy.interpolate import BSpline
import numpy as np
from numpy import typing as npt
from functools import partial
import jax
doubleArray = npt.NDArray[np.double]
# see
# https://stackoverflow.com/q/74699053/5861244
# https://en.wikipedia.org/wiki/B-spline#Derivative_expressions
def _b_spline_deriv_inner(spline: BSpline, deriv_basis: doubleArray) -> doubleArray: # type: ignore[no-any-unimported]
out = np.zeros((deriv_basis.shape[0], deriv_basis.shape[1] - 1))
for col_index in range(out.shape[1] - 1):
scale = spline.t[col_index + spline.k + 1] - spline.t[col_index + 1]
if scale != 0:
out[:, col_index] = -deriv_basis[:, col_index + 1] / scale
for col_index in range(1, out.shape[1]):
scale = spline.t[col_index + spline.k] - spline.t[col_index]
if scale != 0:
out[:, col_index] += deriv_basis[:, col_index] / scale
return float(spline.k) * out
def _b_spline_eval(spline: BSpline, x: doubleArray, deriv: int) -> doubleArray: # type: ignore[no-any-unimported]
if deriv == 0:
return spline.design_matrix(x=x, t=spline.t, k=spline.k).todense()
elif spline.k <= 0:
return np.zeros((x.shape[0], spline.t.shape[0] - spline.k - 1))
return _b_spline_deriv_inner(
spline=spline,
deriv_basis=_b_spline_eval(
BSpline(t=spline.t, k=spline.k - 1, c=np.zeros(spline.c.shape[0] + 1)), x=x, deriv=deriv - 1
),
)
@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2))
def b_spline_basis(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> doubleArray:
return _b_spline_eval(spline=BSpline(t=knots, k=order, c=np.zeros((order + knots.shape[0] - 1))), x=x, deriv=deriv)[
:, 1:
]
def b_spline_basis_fwd(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> tuple[doubleArray, doubleArray]:
spline = BSpline(t=knots, k=order, c=np.zeros(order + knots.shape[0] - 1))
return (
_b_spline_eval(spline=spline, x=x, deriv=deriv)[:, 1:],
_b_spline_eval(spline=spline, x=x, deriv=deriv + 1)[:, 1:],
)
def b_spline_basis_bwd(
knots: doubleArray, order: int, deriv: int, partials: doubleArray, grad: doubleArray
) -> tuple[doubleArray]:
return (jax.numpy.sum(partials * grad, axis=1),)
b_spline_basis.defvjp(b_spline_basis_fwd, b_spline_basis_bwd)
if __name__ == "__main__":
# tests
knots = np.array([0, 0, 0, 0, 0.25, 1, 1, 1, 1])
x = np.array([0.1, 0.5, 0.9])
order = 3
def test_jax(basis: doubleArray, partials: doubleArray, deriv: int) -> None:
weights = jax.numpy.arange(1, basis.shape[1] + 1)
def test_func(x: doubleArray) -> doubleArray:
return jax.numpy.sum(jax.numpy.dot(b_spline_basis(knots=knots, order=order, deriv=deriv, x=x), weights)) # type: ignore[no-any-return]
assert np.allclose(test_func(x), np.sum(np.dot(basis, weights)))
assert np.allclose(jax.grad(test_func)(x), np.dot(partials, weights))
deriv0 = np.transpose(
np.array(
[
0.684,
0.166666666666667,
0.00133333333333333,
0.096,
0.444444444444444,
0.0355555555555555,
0.004,
0.351851851851852,
0.312148148148148,
0,
0.037037037037037,
0.650962962962963,
]
).reshape(-1, 3)
)
deriv1 = np.transpose(
np.array(
[
2.52,
-1,
-0.04,
1.68,
-0.666666666666667,
-0.666666666666667,
0.12,
1.22222222222222,
-2.29777777777778,
0,
0.444444444444444,
3.00444444444444,
]
).reshape(-1, 3)
)
test_jax(deriv0, deriv1, deriv=0)
deriv2 = np.transpose(
np.array(
[
-69.6,
4,
0.8,
9.6,
-5.33333333333333,
5.33333333333333,
2.4,
-2.22222222222222,
-15.3777777777778,
0,
3.55555555555556,
9.24444444444445,
]
).reshape(-1, 3)
)
test_jax(deriv1, deriv2, deriv=1)
deriv3 = np.transpose(
np.array(
[
504,
-8,
-8,
-144,
26.6666666666667,
26.6666666666667,
24,
-32.8888888888889,
-32.8888888888889,
0,
14.2222222222222,
14.2222222222222,
]
).reshape(-1, 3)
)
test_jax(deriv2, deriv3, deriv=2)
最佳答案
实现此目的的最佳方法可能是结合使用 custom_jvp
和 jax.pure_callback
.
不幸的是,pure_callback
相对较新,还没有很好的文档,但您可以在 JAX 用户论坛中找到其使用示例(例如 here )。
复制到这里供后代使用,这是一个通过 numpy 回调在 jit 兼容代码中计算正弦和余弦的示例,并使用自定义 JVP 规则进行自动差分。
import jax
import numpy as np
jax.config.update('jax_enable_x64', True)
@jax.custom_jvp
def np_sin(x):
# Compute the sine by calling-back to np.sin on the host.
return jax.pure_callback(np.sin, jax.ShapeDtypeStruct(np.shape(x), np.float64), x)
@np_sin.defjvp
def _np_sin_jvp(primals, tangents):
x, = primals
dx, = tangents
return np_sin(x), np_cos(x) * dx #d sin(x) = cos(x) dx
@jax.custom_jvp
def np_cos(x):
# Compute the cosine by calling-back to np.cos on the host.
return jax.pure_callback(np.cos, jax.ShapeDtypeStruct(np.shape(x), np.float64), x)
@np_cos.defjvp
def _np_cos_jvp(primals, tangents):
x, = primals
dx, = tangents
return np_cos(x), -np_sin(x) * dx # d cos(x) = -sin(x) dx
print(np_sin(1.0))
# 0.8414709848078965
print(np_cos(1.0))
# 0.5403023058681398
print(jax.jit(jax.grad(np_sin))(1.0))
# 0.5403023058681398
请注意,由于 pure_callback
通过将数据发送回主机来进行操作,因此通常会对 GPU 和 TPU 等加速器产生大量开销,尽管在单 CPU 设置中,这种方法可以表现良好。
关于python - JAX 与 JIT 和自定义差异化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/74719636/