python - Numba 中的 Numpy 聚合函数垫片、打字和 np.sort()

标签 python numpy sorting types numba

我正在 nopython 模式下使用 Numba (0.44) 和 Numpy。目前,Numba 不支持跨任意轴的 Numpy 聚合函数,它仅支持在整个数组上计算这些聚合。鉴于这种情况,我决定尝试一下并制作一些垫片。

在代码中:

np.min(array) # This works with Numba 0.44
np.min(array, axis = 0) # This does not work with Numba 0.44 (no axis argument allowed)

这是一个垫片示例,旨在重现np.min(array):

import numpy as np
import numba

@numba.jit(nopython = True)
def npmin (X, axis = -1):
    """
    Shim for broadcastable np.min(). 
    Allows np.min(array), np.min(array, axis = 0), and np.min(array, axis = 1)
    Note that the argument axis = -1 computes on the entire array.
    """
    if axis == 0:
        _min = np.sort(X.transpose())[:,0]
    elif axis == 1:
        _min = np.sort(X)[:,0]
    else:
        _min = np.sort(np.sort(X)[:,0])[0]
    return _min

如果没有 Numba,垫片将按预期工作,并将 np.min() 的行为概括为 2D 数组。请注意,我使用 axis = -1 作为允许对整个数组求和的方法 - 与在没有 的情况下调用 np.min(array) 的行为类似>轴 参数。

不幸的是,一旦我将 Numba 加入其中,我就会收到错误。这是跟踪:

Traceback (most recent call last):
  File "shims.py", line 81, in <module>
    _min = npmin(a)
  File "/usr/local/lib/python3.7/site-packages/numba/dispatcher.py", line 348, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/usr/local/lib/python3.7/site-packages/numba/dispatcher.py", line 315, in error_rewrite
    reraise(type(e), e, None)
  File "/usr/local/lib/python3.7/site-packages/numba/six.py", line 658, in reraise
    raise value.with_traceback(tb)
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<function sort at 0x10abd5ea0>) with argument(s) of type(s): (array(int64, 2d, F))
 * parameterized
In definition 0:
    All templates rejected
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<function sort at 0x10abd5ea0>)
[2] During: typing of call at shims.py (27)


File "shims.py", line 27:
def npmin (X, axis = -1):
    <source elided>
    if axis == 0:
        _min = np.sort(X.transpose())[:,0]
        ^

This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.

To see Python/NumPy features supported by the latest release of Numba visit:
http://numba.pydata.org/numba-doc/dev/reference/pysupported.html
and
http://numba.pydata.org/numba-doc/dev/reference/numpysupported.html

For more information about typing errors and how to debug them visit:
http://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile

If you think your code should work with Numba, please report the error message
and traceback, along with a minimal reproducer at:
https://github.com/numba/numba/issues/new

我已验证 Numba 0.44 支持我正在使用的所有函数及其各自的参数。当然,堆栈跟踪表明问题出在我对 np.sort(array) 的调用上,但我怀疑这可能是一个打字问题,因为该函数可以返回标量(不带轴参数)或二维数组(带轴参数)。

话虽如此,我有几个问题:

  • 我的实现是否存在问题?任何人都可以根据堆栈跟踪的建议查明我正在使用的不受支持的功能吗?
  • 或者更确切地说,这似乎是 Numba 的一个错误?
  • 更一般地说,目前 Numba (0.44) 是否可以使用这些类型的垫片?

最佳答案

这是二维数组的替代垫片:

@numba.jit(nopython=True)
def npmin2(X, axis=0):
    if axis == 0:
        _min = np.empty(X.shape[1])
        for i in range(X.shape[1]):
            _min[i] = np.min(X[:,i])
    elif axis == 1:
        _min = np.empty(X.shape[0])
        for i in range(X.shape[0]):
            _min[i] = np.min(X[i,:])

    return _min

尽管您必须找出针对 axis=-1 情况的解决方法,因为这将返回标量,而其他参数将返回数组,而 Numba 将无法将返回类型“统一”为一致的类型。

至少在我的机器上,性能似乎与仅调用等效的 np.min 大致相当,有时 np.min 更快,有时更快npmin2 获胜,具体取决于输入数组大小和轴。

关于python - Numba 中的 Numpy 聚合函数垫片、打字和 np.sort(),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56827566/

相关文章:

c - 如何将文件中的字符串列表读取到动态二维数组中,然后在 C 中对其进行排序

Python、ArcObjects 和 .AppRef : how to get from IAppROT to IMxDocument?

python - 无法在安装了 Canopy 的 Ubuntu 12.04 中升级 matplotlib

python - 如何删除一段 matplotlib 轴

numpy - Python - numpy mgrid 和 reshape

python - 机器 epsilon 的倍数是什么意思?

javascript - Console.log 仅显示打印对象的更新版本

javascript - JavaScript 中的多重排序

python - 如何像在 Python 2.7 上一样快地获取此 Python 3 代码?

python - 数据存储中新添加的记录未反射(reflect)在应用程序中