python - 为什么这个函数在 JAX 和 numpy 中变慢?

标签 python performance numpy optimization jax

我有以下 numpy 函数,如下所示,我正在尝试使用 JAX 进行优化,但无论出于何种原因,它都变慢了。
有人可以指出我可以做些什么来提高这里的性能吗?我怀疑这与 Cg_new 发生的列表理解有关,但将其分开并不会在 JAX 中产生任何进一步的性能提升。

import numpy as np 

def testFunction_numpy(C, Mi, C_new, Mi_new):
    Wg_new = np.zeros((len(Mi_new[:,0]), len(Mi[0])))
    Cg_new = np.zeros((1, len(Mi[0])))
    invertCsensor_new = np.linalg.inv(C_new)

    Wg_new = np.dot(invertCsensor_new, Mi_new)
    Cg_new = [np.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))] 

    return C_new, Mi_new, Wg_new, Cg_new

C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)

%timeit testFunction_numpy(C, Mi, C_new, Mi_new)
#1000 loops, best of 3: 1.73 ms per loop
这是 JAX 等效项:
import jax.numpy as jnp
import numpy as np
import jax

def testFunction_JAX(C, Mi, C_new, Mi_new):
    Wg_new = jnp.zeros((len(Mi_new[:,0]), len(Mi[0])))
    Cg_new = jnp.zeros((1, len(Mi[0])))
    invertCsensor_new = jnp.linalg.inv(C_new)

    Wg_new = jnp.dot(invertCsensor_new, Mi_new)
    Cg_new = [jnp.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))] 

    return C_new, Mi_new, Wg_new, Cg_new

C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)

C = jnp.asarray(C)
Mi = jnp.asarray(Mi)
C_new = jnp.asarray(C_new)
Mi_new = jnp.asarray(Mi_new)

jitter = jax.jit(testFunction_JAX) 

%timeit jitter(C, Mi, C_new, Mi_new)
#1 loop, best of 3: 4.96 ms per loop

最佳答案

当 JAX jit 编译遇到 Python 控制流(包括列表推导式)时,它会有效地展平循环并分阶段执行完整的操作序列。这会导致 jit 编译时间变慢和代码不理想。幸运的是,您的函数中的列表理解很容易用原生 numpy 广播来表达。此外,您还可以进行其他两项改进:

  • 无需转发申报Wg_newCg_new在计算它们之前
  • 计算时 dot(inv(A), B) ,使用 np.linalg.solve 更加高效和精确而不是显式计算逆。

  • 对 numpy 和 JAX 版本进行这三项改进会导致以下结果:
    def testFunction_numpy_v2(C, Mi, C_new, Mi_new):
        Wg_new = np.linalg.solve(C_new, Mi_new)
        Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
        return C_new, Mi_new, Wg_new, Cg_new
    
    @jax.jit
    def testFunction_JAX_v2(C, Mi, C_new, Mi_new):
        Wg_new = jnp.linalg.solve(C_new, Mi_new)
        Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
        return C_new, Mi_new, Wg_new, Cg_new
    
    %timeit testFunction_numpy_v2(C, Mi, C_new, Mi_new)
    # 1000 loops, best of 3: 1.11 ms per loop
    %timeit testFunction_JAX_v2(C_jax, Mi_jax, C_new_jax, Mi_new_jax)
    # 1000 loops, best of 3: 1.35 ms per loop
    
    由于改进的实现,这两个函数都比以前快了一点。但是,您会注意到,这里的 JAX 仍然比 numpy 慢;这有点在意料之中,因为对于这种简单程度的函数,JAX 和 numpy 都有效地生成了在 CPU 架构上执行的相同的短系列 BLAS 和 LAPACK 调用。与 numpy 的引用实现相比,根本没有太大的改进空间,而且对于如此小的数组,JAX 的开销是显而易见的。

    关于python - 为什么这个函数在 JAX 和 numpy 中变慢?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64517793/

    相关文章:

    python - 将 abc.abstractmethod 与其他装饰器相结合

    MySQL如何优化子查询性能

    python - 通过全批量训练将字母图像训练到神经网络

    Python - 在没有 python 解释器的情况下运行 numpy

    Python、redis : How do I set multiple key-value pairs at once

    python - 终止从已结束的进程启动的进程 - Python

    python - 在Python中读取带有空行的文本

    java - 检查对象是否为空的最快方法

    python - 如何加速 numpy 代码

    python - 将字符串列表转换为 pandas 中的 float 列表