python - Numba 中的生成器参数

标签 python generator numba

此问题的跟进:function types in numba .

我正在编写一个需要将生成器作为其参数之一的函数。粘贴到这里太复杂了,所以考虑这个玩具示例:

def take_and_sum(gen):
    @numba.jit(nopython=False)
    def inner(n):
        s = 0
        for _ in range(n):
            s += next(gen)
        return s
    return inner

它返回生成器的前 n 个元素的总和。用法示例:

@numba.njit()
def odd_numbers():
    n = 1
    while True:
        yield n
        n += 2

take_and_sum(odd_numbers())(3) # prints 9

它是柯里化(Currying)的,因为我想用 nopython=True 编译,然后我不能将 gen(一个 pyobject)作为一个争论。不幸的是,使用 nopython=True 我得到一个错误:

TypingError: Failed at nopython (nopython frontend)
Untyped global name 'gen'

即使我 nopython 编译了我的生成器。

真正令人困惑的是硬编码输入有效:

def take_and_sum():
    @numba.njit()
    def inner(n):
        gen = odd_numbers()
        s = 0.0
        for _ in range(n):
            s += next(gen)
        return s
    return inner

take_and_sum()(3)

我还尝试将我的生成器变成一个类:

@numba.jitclass({'n': numba.uint})
class Odd:
    def __init__(self):
        self.n = 1
    def next(self):
        n = self.n
        self.n += 2
        return n

同样,这在对象模式下有效,但在 nopython 模式下我得到了不可搜索的:

LoweringError: Failed at nopython (nopython mode backend)
Internal error:
NotImplementedError: instance.jitclass.Odd#4aa9758<n:uint64> as constant unsupported

最佳答案

我实际上无法解决您的问题,因为据我所知根本不可能。我只是强调一些方面(适用于 numba 0.30):

不能创建一个 numba-jitclass 生成器:

import numba

@numba.jitclass({'n': numba.uint})
class Odd:
    def __init__(self):
        self.n = 1

    def __iter__(self):
        return self

    def __next__(self):
        n = self.n
        self.n += 2
        return n

试试看:

>>> next(Odd())
TypeError: 'Odd' object is not an iterator

当您删除 numba.jitclass 时,它会起作用:

>>> next(Odd())
1

您使用硬编码生成器的示例并不等效。您最初的尝试创建了一个生成器对象,将其传递给 numba 函数并修改了生成器。您可能希望它更新生成器的状态

>>> t = odd_numbers()
>>> take_and_sum(t)(3)
9
>>> next(t)   # State has been updated, unfortunatly that requires nopython=False!
7

但这对于 numba 来说根本不可能(目前)。

第二个示例不同,因为每次调用函数时都会创建生成器,因此函数外部没有需要更新的状态:

>>> take_and_sum()(3) # using your hardcoded version
9.0
>>> take_and_sum()(3) # no updated state so this returns the same:
9.0

绝对可以更改它,但不能选择使用任意函数:

@numba.jitclass({'n': numba.uint})
class Odd:
    def __init__(self):
        self.n = 1

    def calculate(self, n):
        s = 0.0
        for _ in range(n):
            s += self.n
            self.n += 2
        return s

>>> x = Odd()
>>> x.calculate(3)
9.0
>>> x.calculate(3)
27.0

我知道这不是你想要的,但至少它在某种程度上是有效的:-)

关于python - Numba 中的生成器参数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41860989/

相关文章:

python - 使用 flow_from_directory 将图像增强拟合到训练数据

python - 应用窗函数后的 FFT 意外频移

python - 通量变异性分析仅适用于隔间之间的传输 react ?

javascript - 如何使用 Redux Saga 测试 API 请求失败?

javascript - 如何返回到随机文本生成器函数的开头?

python - 混合数值和分类数据观测值之间成对距离计算的有效实现

python - NumPy 数组下三角区域中 n 个最大值的索引

javascript - 在 JS 中的生成器上调用 join()

performance - 使用条件语句加速 Python 嵌套循环