基于此处提供的解释 1 ,我正在尝试使用相同的想法来加速以下积分:
import scipy.integrate as si
from scipy.optimize import root, fsolve
import numba
from numba import cfunc
from numba.types import intc, CPointer, float64
from scipy import LowLevelCallable
def integrand(t, *args):
a = args[0]
c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
return c * np.exp(- (t / (a * c))**2)
def do_integrate(func, a):
return si.quad(func, 0, 1, args=(a,))
print(do_integrate(integrand, 2.)[0])
根据前面的引用资料,我尝试使用numba/jit,并按照以下方式修改前面的 block :
import numpy as np
import scipy.integrate as si
from scipy.optimize import root
import numba
from numba import cfunc
from numba.types import intc, CPointer, float64
from scipy import LowLevelCallable
def jit_integrand_function(integrand_function):
jitted_function = numba.jit(integrand_function, nopython=True)
@cfunc(float64(intc, CPointer(float64)))
def wrapped(n, xx):
return jitted_function(xx[0], xx[1])
return LowLevelCallable(wrapped.ctypes)
@jit_integrand_function
def integrand(t, *args):
a = args[0]
c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
return c * np.exp(- (t / (a * c))**2)
def do_integrate(func, a):
return si.quad(func, 0, 1, args=(a,))
do_integrate(integrand, 2.)
但是,这个实现给了我错误
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: convert make_function into JIT functions)
Cannot capture the non-constant value associated with variable 'a' in a function that will escape.
File "<ipython-input-16-3d98286a4be7>", line 20:
def integrand(t, *args):
<source elided>
a = args[0]
c = fsolve(lambda x: a * x**2 - np.exp(-x**2 / a), 1)[0]
^
During: resolving callee type: type(CPUDispatcher(<function integrand at 0x11a949d08>))
During: typing of call at <ipython-input-16-3d98286a4be7> (14)
During: resolving callee type: type(CPUDispatcher(<function integrand at 0x11a949d08>))
During: typing of call at <ipython-input-16-3d98286a4be7> (14)
错误是因为我在被积函数中使用了 scipy.optimize 中的 fsolve。
我想知道是否有解决此错误的方法,以及是否可以在此上下文中将 scipy.optimize.fsolve 与 numba 一起使用。
最佳答案
我为 Minpack 编写了一个小的 python 包装器,称为 NumbaMinpack
,它可以在 numba 编译函数中调用:https://github.com/Nicholaswogan/NumbaMinpack .你可以用它来 @njit
被积函数:
import scipy.integrate as si
from NumbaMinpack import hybrd, minpack_sig
from numba import njit, cfunc
import numpy as np
@cfunc(minpack_sig)
def f(x, fvec, args):
a = args[0]
fvec[0] = a * x[0]**2.0 - np.exp(-x[0]**2.0 / a)
funcptr = f.address # pointer to function
@njit
def integrand(t, *args):
a = args[0]
args_ = np.array(args)
x_init = np.array([1.0])
sol = hybrd(funcptr,x_init,args_)
c = sol[0][0]
return c * np.exp(- (t / (a * c))**2)
def do_integrate(func, a):
return si.quad(func, 0, 1, args=(a,))
print(do_integrate(integrand, 2.)[0])
在我的电脑上,上述代码耗时 87 µs,而纯 python 版本耗时 2920 µs
关于scipy - 有没有办法将 scipy.optimize.fsolve 与 jit_integrand_function 和 scipy.integrate.quad 一起使用?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62924588/