python - 努巴。如何使用元组轴参数编写 np.sum ?

标签 python numpy numba

此代码:

@jit(nopython=True)
def foo(x):
    return x.sum(axis=(1,2))

x=np.linspace(0,1)
x=x.reshape(5,5,-1)
print(foo(x))

返回此错误:

NotImplementedError: No definition for lowering array.sum(array(float64, 3d, C), Tuple(Literal[int](1), Literal[int](2))) -> array(float64, 2d, C)

axis 参数只是一个整数,而不是整数元组( https://numpy.org/doc/stable/reference/generated/numpy.ndarray.sum.html#numpy.ndarray.sum )时,numba 似乎支持 np.sum 。 因此,我使用此解决方法 return x.sum(axis=1).sum(axis=1) 但如果我考虑代码优化,这不是一个好的解决方案。

是否存在其他解决方案或者我必须等待 future 的 numba 版本?

最佳答案

只需将数组重新调整为比所需结果多一维即可。在您的示例中:

x.sum(axis=(1,2))

可以替换为:

x.reshape(5,-1).sum(axis=1)

它产生相同的结果并且可以由 Numba 执行。

关于python - 努巴。如何使用元组轴参数编写 np.sum ?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70410031/

相关文章:

python - 重命名 pandas Dataframe 列及其下的数据

opencv - 如何找到图像像素值的众数(统计数据)?

python - numba @jit 比纯 python 慢?

python - Django Python 与 Gspread : 'choices' must be an iterable containing (actual value, 人类可读名称)元组

python - 如何在web2py中创建用户定义的类

python - 我的 python opencv 绑定(bind)中似乎缺少 Stitcher API

python - Numpy 矩阵减法混淆

python - 使用 Numpy 进行外减法

python - 为什么 numba 不提高我的背包功能的速度?

python - 列表理解用 0 替换数组中的前 n 项