python - 使用内联其他函数编译带有函数的 Numba 模块时出错

标签 python numpy numba

Numba documentation specifies that other compiled functions can be inlined and called from other compiled functions.编译时这似乎不是真的 ahead of time .

例如:这里有两个函数计算两个向量数组之间的内点积,其中一个计算实际乘积,另一个在循环内进行内联调用:

# Module test.py
import numpy as np
from numba import njit, float64

@njit(float64(float64[:], float64[:]))
def product(a, b):
    prod = 0
    for i in range(a.size):
        prod += a[i] * b[i]
    return prod

@njit(float64[:](float64[:,:], float64[:,:]))
def n_inner1d(a, b):
    prod = np.empty(a.shape[0])    
    for i in range(a.shape[0]):
        prod[i] = product(a[i], b[i])

    return prod

照原样,我可以执行 import test 并使用 test.n_inner1d 非常好。现在让我们做一些修改,以便可以将其编译为 .pyd

# Module test.py
import numpy as np
from numba import float64
from numba.pycc import CC

cc = CC('test')
cc.verbose = True

@cc.export('product','float64(float64[:], float64[:])')
def product(a, b):
    prod = 0
    for i in range(a.size):
        prod += a[i] * b[i]
    return prod

@cc.export('n_inner1d','float64[:](float64[:,:], float64[:,:])')
def n_inner1d(a, b):
    prod = np.empty(a.shape[0])    
    for i in range(a.shape[0]):
        prod[i] = product(a[i], b[i])

    return prod

if __name__ == "__main__":
    cc.compile()

尝试编译时,出现以下错误:

# python test.py
Failed at nopython (nopython frontend)
Untyped global name 'product': cannot determine Numba type of <type 'function'>
File "test.py", line 20

问题

对于编译的模块 ahead of time , 内部定义的函数是否可以相互调用并内联使用?

最佳答案

我联系了 numba 开发人员,他们友好地回答说,在 @cc.export 之后添加 @njit 装饰器将使函数调用类型解析工作并解析。

例如:

@cc.export('product','float64(float64[:], float64[:])')
@njit
def product(a, b):
    prod = 0
    for i in range(a.size):
        prod += a[i] * b[i]
    return prod

将使其他人可以使用product 功能。需要注意的是,在某些情况下,内联函数完全有可能以与声明的 AOT 不同的类型签名结束。

关于python - 使用内联其他函数编译带有函数的 Numba 模块时出错,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49326937/

相关文章:

Python 求和并组织 csv 数据

python - 使 while 解析函数递归

python - NATS WEBSOCKET Python WebSocket 连接

python - 如何使用 Python 的 ast 创建类型化返回? - `def f() -> bool: return True`

python - 使用 scipy.signal.lti 从状态矩阵在 Python 中创建 LTI 系统

python - 生成二维数组的梯度图

python - Numba 通过影响就地损坏数据

python - 遍历 3D 数组时 Numba 降低错误

python - 值错误 : Failed to convert a NumPy array to a Tensor (Unsupported object type numpy. ndarray)。试图预测特斯拉股票

python - numba 中的 jit 和 autojit 有什么区别?