python - 使用 numba 索引 numpy 数组时出现类型错误

标签 python numba

我需要根据另一个包含类成员资格信息的数组(标签)对一维 numpy 数组(如下:data)中的元素求和>)。我在下面的代码中使用 numba 来加快速度。但是,如果我没有在 ret[int(find(labels, g))] += y 行中显式使用 int() 进行转换,我会收到一条错误消息:

类型错误:不支持的数组索引类型?int64

是否有比显式转换更好的解决方法?

import numpy as np
from numba import jit

labels = np.array([45, 85, 99, 89, 45, 86, 348, 764])
n = int(1e3)
data = np.random.random(n)
groups = np.random.choice(a=labels, size=n, replace=True)

@jit(nopython=True)
def find(seq, value):
    for ct, x in enumerate(seq):
        if x == value:
            return ct

@jit(nopython=True)
def subsumNumba(data, groups, labels):
    ret = np.zeros(len(labels))
    for y, g in zip(data, groups):
        # not working without casting with int()
        ret[int(find(labels, g))] += y
    return ret

最佳答案

问题是 find 可以返回一个 intNone 如果它没有找到任何东西,因此我认为 >?int64 错误。为了避免强制转换,当 find 退出而没有找到所需的值时,您需要提供一个 int 返回值,然后在调用者中处理它。

关于python - 使用 numba 索引 numpy 数组时出现类型错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39316939/

相关文章:

Python 类型错误 : expected string or buffer

python - SQLAlchemy 连接挂起

python - 如何修复 vim 以正确缩进包含 Python 注释行的折叠?

python - 过滤二维 numpy 数组的最快方法

python - Numba - 字符串类型

python - 使用 Numba 对每一行应用多个函数

python - Numba 通过影响就地损坏数据

Python - 在分组后将行转换为列并为不匹配的行填充零

python - 如何在决策树中对数据集的连续变量列进行类别划分?

python - 这段代码有什么问题? Module 类型的未知属性 'array'(<module 'numpy' from filename __init__.py'>