python - Numba 函数与类型参数的使用无效

标签 python numpy numba

我正在使用 Numba 非 python 模式和一些 NumPy 函数。

@njit
def invert(W, copy=True):
    '''
    Inverts elementwise the weights in an input connection matrix.
    In other words, change the from the matrix of internode strengths to the
    matrix of internode distances.

    If copy is not set, this function will *modify W in place.*

    Parameters
    ----------
    W : np.ndarray
        weighted connectivity matrix
    copy : bool

    Returns
    -------
    W : np.ndarray
        inverted connectivity matrix
    '''

    if copy:
        W = W.copy()
    E = np.where(W)
    W[E] = 1. / W[E]
    return W

在此函数中,W 是一个矩阵。但我收到以下错误。它可能与W[E] = 1./W[E]行有关。

File "/Users/xxx/anaconda3/lib/python3.7/site-packages/numba/dispatcher.py", line 317, in error_rewrite
    reraise(type(e), e, None)
  File "/Users/xxx/anaconda3/lib/python3.7/site-packages/numba/six.py", line 658, in reraise
    raise value.with_traceback(tb)
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 2d, A), tuple(array(int64, 1d, C) x 2))

那么使用 NumPy 和 Numba 的正确方法是什么?我知道 NumPy 在矩阵计算方面做得很好。在这种情况下,NumPy 是否足够快以至于 Numba 不再提供加速?

最佳答案

正如 FBruzzesi 在评论中提到的,代码无法编译的原因是您使用了“花式索引”,因为 W[E] 中的 Enp.where 的输出,是一个数组元组。 (这解释了有点神秘的错误消息:Numba 不知道如何使用 getitem,即,当输入之一是元组时,它不知道如何在括号中查找某些内容。)

Numba actually supports fancy indexing (also called "advanced indexing") on a single dimension ,只是不是多维。在您的情况下,这允许进行简单的修改:首先使用 ravel 几乎无成本地将数组变为一维,然后应用转换,然后进行廉价的 reshape 。

@njit
def invert2(W, copy=True):
    if copy:
        W = W.copy()
    Z = W.ravel()
    E = np.where(Z)
    Z[E] = 1. / Z[E]
    return Z.reshape(W.shape)

但这仍然比需要的慢,因为计算通过不必要的中间数组传递,而不是在遇到非零值时立即修改数组。简单地执行循环会更快:

@njit 
def invert3(W, copy=True): 
    if copy: 
        W = W.copy() 
    Z = W.ravel() 
    for i in range(len(Z)): 
        if Z[i] != 0: 
            Z[i] = 1/Z[i] 
    return Z.reshape(W.shape) 

无论 W 的尺寸如何,此代码都有效。如果我们知道 W 是二维的,那么我们可以直接迭代这两个维度,但由于两者具有相似的性能,我将采用更通用的路线。

在我的计算机上,计时,假设有一个 300×300 数组 W,其中大约一半的条目是 0,并且其中 invert 是您的原始函数,没有Numba 编译,有:

In [80]: %timeit invert(W)                                                                   
2.67 ms ± 49.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [81]: %timeit invert2(W)                                                                  
519 µs ± 24.5 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [82]: %timeit invert3(W)                                                                  
186 µs ± 11.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

因此,Numba 为我们提供了相当大的加速(在已经运行一次以消除编译时间之后),特别是在以 Numba 可以利用的高效循环风格重写代码之后。

关于python - Numba 函数与类型参数的使用无效,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61885520/

相关文章:

python - 在 Python 中更正用于 re.sub 替换的正则表达式语法

python - Cython 并行计算中可变数量的 worker

python-3.x - 用 NAN 逐行替换 pandas 数据帧中的最后 2 个数值

python - Numba 中的生成器参数

python - python 中的 numba CUDA 非常慢

python - 与 CPython 相比,Numba 和 Cython 没有显着提高性能,也许我使用不正确?

python - ndb 有没有 list 属性

Python Numpy 数组相等失败

python - 查询集上的函数出现类型错误

python - skimage.io.imread 与 cv2.imread