scipy - 有没有办法将 scipy.optimize.fsolve 与 jit_integrand_function 和 scipy.integrate.quad 一起使用?

标签 scipy jit numba quad

基于此处提供的解释 1 ,我正在尝试使用相同的想法来加速以下积分:

import scipy.integrate as si
from scipy.optimize import root, fsolve
import numba
from numba import cfunc
from numba.types import intc, CPointer, float64
from scipy import LowLevelCallable

def integrand(t, *args):
    a = args[0]
    c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
    return c * np.exp(- (t / (a * c))**2) 

def do_integrate(func, a):
    return si.quad(func, 0, 1, args=(a,))

print(do_integrate(integrand, 2.)[0]) 

根据前面的引用资料,我尝试使用numba/jit,并按照以下方式修改前面的 block :

import numpy as np
import scipy.integrate as si
from scipy.optimize import root
import numba
from numba import cfunc
from numba.types import intc, CPointer, float64
from scipy import LowLevelCallable

def jit_integrand_function(integrand_function):
    jitted_function = numba.jit(integrand_function, nopython=True)  
    @cfunc(float64(intc, CPointer(float64)))
    def wrapped(n, xx):
        return jitted_function(xx[0], xx[1])
    return LowLevelCallable(wrapped.ctypes)

@jit_integrand_function
def integrand(t, *args):
    a = args[0]
    c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
    return c * np.exp(- (t / (a * c))**2)

def do_integrate(func, a):
    return si.quad(func, 0, 1, args=(a,))

do_integrate(integrand, 2.)

但是,这个实现给了我错误


TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: convert make_function into JIT functions)
Cannot capture the non-constant value associated with variable 'a' in a function that will escape.

File "<ipython-input-16-3d98286a4be7>", line 20:
def integrand(t, *args):
    <source elided>
    a = args[0]
    c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
    ^

During: resolving callee type: type(CPUDispatcher(<function integrand at 0x11a949d08>))
During: typing of call at <ipython-input-16-3d98286a4be7> (14)

During: resolving callee type: type(CPUDispatcher(<function integrand at 0x11a949d08>))
During: typing of call at <ipython-input-16-3d98286a4be7> (14)

错误是因为我在被积函数中使用了 scipy.optimize 中的 fsolve。

我想知道是否有解决此错误的方法,以及是否可以在此上下文中将 scipy.optimize.fsolve 与 numba 一起使用。

最佳答案

我为 Minpack 编写了一个小的 python 包装器,称为 NumbaMinpack,它可以在 numba 编译函数中调用:https://github.com/Nicholaswogan/NumbaMinpack .你可以用它来 @njit 被积函数:

import scipy.integrate as si
from NumbaMinpack import hybrd, minpack_sig
from numba import njit, cfunc
import numpy as np

@cfunc(minpack_sig)
def f(x, fvec, args):
    a = args[0]
    fvec[0] = a * x[0]**2.0 - np.exp(-x[0]**2.0 / a)

funcptr = f.address # pointer to function  

@njit
def integrand(t, *args):
    a = args[0]
    args_ = np.array(args)
    x_init = np.array([1.0])
    sol = hybrd(funcptr,x_init,args_)
    c = sol[0][0]
    return c * np.exp(- (t / (a * c))**2) 

def do_integrate(func, a):
    return si.quad(func, 0, 1, args=(a,))

print(do_integrate(integrand, 2.)[0]) 

在我的电脑上,上述代码耗时 87 µs,而纯 python 版本耗时 2920 µs

关于scipy - 有没有办法将 scipy.optimize.fsolve 与 jit_integrand_function 和 scipy.integrate.quad 一起使用?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62924588/

相关文章:

python - 处理 jax numpy 数组中的不同形状(jit 兼容)

matlab - (Matlab)奇怪的精度损失,同时将复杂矩阵分配给局部变量

rust - 如何解析基于 LLVM MCJIT 的 JIT 中的当前进程符号?

python - Numba - 如何并行填充二维数组

python - Numba 的 prange 给出了错误的结果

python - 导入 scipy.stats 错误

python - 从安装了新 ArcGIS 10.1 的 scipy- 新计算机导入统计信息时出错

Python Pandas : how to vectorize this function

python - 使用 scipy.optimize.curve_fit 执行加权线性拟合

python - scipy curve_fit 和局部最小值 : get to global minima as fast as possible