python - 加速将函数作为 numba 参数的函数

标签 python python-3.x numba

我正在尝试使用 numba 来加速将另一个函数作为参数的函数。一个最小的例子如下:

import numba as nb

def f(x):
    return x*x

@nb.jit(nopython=True)
def call_func(func,x):
    return func(x)

if __name__ == '__main__':
    print(call_func(f,5))

然而,这不起作用,因为显然 numba 不知道如何处理该函数参数。回溯很长:

Traceback (most recent call last):
  File "numba_function.py", line 15, in <module>
    print(call_func(f,5))
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 330, in _compile_for_args
    raise e
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 307, in _compile_for_args
    return self.compile(tuple(argtypes))
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 579, in compile
    cres = self._compiler.compile(args, return_type)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 80, in compile
    flags=flags, locals=self.locals)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 740, in compile_extra
    return pipeline.compile_extra(func)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 360, in compile_extra
    return self._compile_bytecode()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 699, in _compile_bytecode
    return self._compile_core()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 686, in _compile_core
    res = pm.run(self.status)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 246, in run
    raise patched_exception
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 238, in run
    stage()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 452, in stage_nopython_frontend
    self.locals)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 841, in type_inference_stage
    infer.propagate()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 773, in propagate
    raise errors[0]
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 129, in propagate
    constraint(typeinfer)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 380, in __call__
    self.resolve(typeinfer, typevars, fnty)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 402, in resolve
    raise TypingError(msg, loc=self.loc)
numba.errors.TypingError: Failed at nopython (nopython frontend)
Invalid usage of pyobject with parameters (int64)
No type info available for pyobject as a callable.
File "numba_function.py", line 10
[1] During: resolving callee type: pyobject
[2] During: typing of call at numba_function.py (10)

This error may have been caused by the following argument(s):
- argument 0: cannot determine Numba type of <class 'function'>

有办法解决这个问题吗?

最佳答案

这取决于您传递给call_funcfunc 是否可以在nopython 模式下编译。

如果它不能在 nopython 模式下编译,那么它是不可能的,因为 numba 不支持 nopython 函数内的 python 调用(这就是它被称为 nopython 的原因)。

然而,如果它可以在 nopython 模式下编译,您可以使用闭包:

import numba as nb

def f(x):
    return x*x

def call_func(func, x):
    func = nb.njit(func)   # compile func in nopython mode!
    @nb.njit
    def inner(x):
        return func(x)
    return inner(x)

if __name__ == '__main__':
    print(call_func(f,5))

这种方法有一些明显的缺点,因为它需要在每次调用 call_func 时编译 funcinner。这意味着它只有在通过编译函数的加速大于编译成本时才可行。如果多次使用相同的函数调用 call_func,则可以减轻这种开销:

import numba as nb

def f(x):
    return x*x

def call_func(func):  # only take func
    func = nb.njit(func)   # compile func in nopython mode!
    @nb.njit
    def inner(x):
        return func(x)
    return inner  # return the closure

if __name__ == '__main__':
    call_func_with_f = call_func(f)   # compile once
    print(call_func_with_f(5))        # call the compiled version
    print(call_func_with_f(5))        # call the compiled version
    print(call_func_with_f(5))        # call the compiled version
    print(call_func_with_f(5))        # call the compiled version
    print(call_func_with_f(5))        # call the compiled version

只是一般说明:我不会创建带有函数参数的 numba 函数。如果您不能对函数进行硬编码,则 numba 无法生成真正快速的函数,并且如果您还包括闭包的编译成本,那基本上是不值得的。

关于python - 加速将函数作为 numba 参数的函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45976662/

相关文章:

python - wxPython - Ubuntu 中的空白框

python - numba @vectorize 目标 ='parallel' TypeError

python - CUDA-Python : How to launch CUDA kernel in Python (Numba 0. 25)?

python - 在pygame中射击子弹

python - Numba 并行代码比顺序代码慢

python - 如何计算元素 x[i+1] 和 x[i-1] 之间的差异?

python : write to COM in ISO 8859-1

python - 与 0.13.1 相比,Pandas 0.15 中的索引速度非常慢

python - 如何确定 struct.unpack 的格式(因为我没有用 Python 打包)?

python - 如何让敌人跟随玩家? pygame