python - 您如何实现可从 Numba 调用的 C 以与 nquad 高效集成?

标签 python numpy scipy numba

我需要在 python 中进行 6D 数值积分。因为 scipy.integrate.nquad 函数很慢,我目前正在尝试通过将被积函数定义为带有 Numba 的 scipy.LowLevelCallable 来加快速度。

通过复制给定的示例,我能够使用 scipy.integrate.quad 在 1D 中执行此操作 here :

import numpy as np
from numba import cfunc
from scipy import integrate

def integrand(t):
    return np.exp(-t) / t**2

nb_integrand = cfunc("float64(float64)")(integrand)

# regular integration
%timeit integrate.quad(integrand, 1, np.inf)

10000 次循环,最好的 3 次:每次循环 128 微秒

# integration with compiled function
%timeit integrate.quad(nb_integrand.ctypes, 1, np.inf)

100000 次循环,最好的 3 次:每次循环 7.08 微秒

当我现在想用 nquad 执行此操作时,nquad 文档说:

If the user desires improved integration performance, then f may be a scipy.LowLevelCallable with one of the signatures:

double func(int n, double *xx)
double func(int n, double *xx, void *user_data)

where n is the number of extra parameters and args is an array of doubles of the additional parameters, the xx array contains the coordinates. The user_data is the data contained in the scipy.LowLevelCallable.

但是下面的代码给我一个错误:

from numba import cfunc
import ctypes

def func(n_arg,x):
    xe = x[0]
    xh = x[1]
    return np.sin(2*np.pi*xe)*np.sin(2*np.pi*xh)

nb_func = cfunc("float64(int64,CPointer(float64))")(func)

integrate.nquad(nb_func.ctypes, [[0,1],[0,1]], full_output=True)

错误:quad:第一个参数是签名不正确的 ctypes 函数指针

是否可以直接在代码中使用 numba 编译一个可以与 nquad 一起使用的函数,而无需在外部文件中定义该函数?

非常感谢您!

最佳答案

将函数包装在 scipy.LowLevelCallable 中让 nquad 快乐:

si.nquad(sp.LowLevelCallable(nb_func.ctypes), [[0,1],[0,1]], full_output=True)
# (-2.3958561404687756e-19, 7.002641250699693e-15, {'neval': 1323})

关于python - 您如何实现可从 Numba 调用的 C 以与 nquad 高效集成?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45823212/

相关文章:

python - 将数据拟合到所有可能的分布并返回最佳拟合

python - 获取 numpy 数组中元素对的总和

Python SciPy 使用 pip install scipy 给出错误

python - 为什么即使协方差是半正定的,bivariate_normal 也会返回 NaN?

python - 为什么 Ubuntu 上的 PhantomJS 会被 Google map 注册为触摸设备?

python - 具有多个主键的 SQLAlchemy 不会自动设置任何

python - 在列 block 中展平或分组数组 - NumPy/Python

numpy - 将图像格式从32FC1转换为16UC1

python - 遇到从 Dataflow 管道向 BigQuery 进行缓慢流式写入的问题?

python - 如果文件中的行等于用户输入