python - 可以推断或暗示 numba 本地人的类型吗?

标签 python numba

我正在尝试使用 numba 作为我们的一些 cython 代码的替代品。

这里:

@jit
def unique_in_sorted(arr):
    result = np.empty_like(arr)
    result[0] = arr[0]
    last_out = arr[0]
    count_out = 1
    for i in range(len(arr)):
        a = arr[i]
        if a<last_out:
            raise Exception("Input not sorted: {} .. {}".format(last_out, a))
        if last_out!=a:
            result[count_out] = a
            last_out = a

    return result.resize(count_out)

...我发现上面的运行速度与 vanilla python 大致相同。

def test_unique_in_sorted(self):
    unsorted = np.random.random_integers(0, 200, 1000000)
    a = np.sort(unsorted)
    for x in range(10):
        with timed("unique_in_sorted"):
            r = unique_in_sorted(a)
    unique_in_sorted.inspect_types()

查看inspect_types的输出,似乎所有局部变量都是Python对象,而不是类型化的C值。例如,在 inspect_types 输出中,我们有:

#   $0.12 = getitem(index=$const0.11, value=arr)  :: pyobject
#   last_out = $0.12  :: pyobject

last_out = arr[0]

这让我认为last_out被视为pyobject,而不是arr中的int64。

是否有办法优化上述内容,使其以与等效 cython 实现相似的速度运行?

这是inspect_types的完整输出。

unique_in_sorted (array(int64, 1d, C),)
--------------------------------------------------------------------------------
# File: /home/user1/py/fast_ops.py
# --- LINE 142 --- 
# label 0
#   del $0.1
#   del $0.2
#   del $0.4
#   del $const0.6
#   del $const0.9
#   del $0.7
#   del $const0.11
#   del $0.12
#   del $const0.13

@jit

# --- LINE 143 --- 

def unique_in_sorted(arr):

    # --- LINE 144 --- 
    #   arr = arg(0, name=arr)  :: pyobject
    #   $0.1 = global(np: <module 'numpy' from '/usr/local/lib/python2.7/dist-packages/numpy/__init__.pyc'>)  :: pyobject
    #   $0.2 = getattr(attr=empty_like, value=$0.1)  :: pyobject
    #   $0.4 = call $0.2(arr, kws=[], args=[Var(arr, /home/user1/py/fast_ops.py (144))], func=$0.2, vararg=None)  :: pyobject
    #   result = $0.4  :: pyobject

    result = np.empty_like(arr)

    # --- LINE 145 --- 
    #   $const0.6 = const(int, 0)  :: pyobject
    #   $0.7 = getitem(index=$const0.6, value=arr)  :: pyobject
    #   $const0.9 = const(int, 0)  :: pyobject
    #   result[$const0.9] = $0.7  :: pyobject

    result[0] = arr[0]

    # --- LINE 146 --- 
    #   $const0.11 = const(int, 0)  :: pyobject
    #   $0.12 = getitem(index=$const0.11, value=arr)  :: pyobject
    #   last_out = $0.12  :: pyobject

    last_out = arr[0]

    # --- LINE 147 --- 
    #   $const0.13 = const(int, 1)  :: pyobject
    #   count_out = $const0.13  :: pyobject
    #   jump 45
    # label 45

    count_out = 1

    # --- LINE 148 --- 
    #   jump 48
    # label 48
    #   $48.1 = global(range: <built-in function range>)  :: pyobject
    #   $48.2 = global(len: <built-in function len>)  :: pyobject
    #   $48.4 = call $48.2(arr, kws=[], args=[Var(arr, /home/user1/py/fast_ops.py (144))], func=$48.2, vararg=None)  :: pyobject
    #   del $48.2
    #   $48.5 = call $48.1($48.4, kws=[], args=[Var($48.4, /home/user1/py/fast_ops.py (148))], func=$48.1, vararg=None)  :: pyobject
    #   del $48.4
    #   del $48.1
    #   $48.6 = getiter(value=$48.5)  :: pyobject
    #   del $48.5
    #   $phi64.1 = $48.6  :: pyobject
    #   del $48.6
    #   jump 64
    # label 64
    #   $64.2 = iternext(value=$phi64.1)  :: pyobject
    #   $64.3 = pair_first(value=$64.2)  :: pyobject
    #   $64.4 = pair_second(value=$64.2)  :: pyobject
    #   del $64.2
    #   $phi153.2 = $phi64.1  :: pyobject
    #   del $phi153.2
    #   $phi153.1 = $64.3  :: pyobject
    #   del $phi153.1
    #   $phi67.1 = $64.3  :: pyobject
    #   del $64.3
    #   branch $64.4, 67, 153
    # label 67
    #   del $64.4
    #   i = $phi67.1  :: pyobject
    #   del $phi67.1
    #   del i
    #   del $67.4
    # label 155
    #   del a
    #   del $119.3
    #   jump 64

    for i in range(len(arr)):

        # --- LINE 149 --- 
        #   $67.4 = getitem(index=i, value=arr)  :: pyobject
        #   a = $67.4  :: pyobject

        a = arr[i]

        # --- LINE 150 --- 
        #   $67.7 = a < last_out  :: pyobject
        #   branch $67.7, 92, 119
        # label 92
        #   del result
        #   del count_out
        #   del arr
        #   del $phi64.1
        #   del $67.7
        #   del $const92.2
        #   del last_out
        #   del a
        #   del $92.3
        #   del $92.6
        #   del $92.1

        if a<last_out:

            # --- LINE 151 --- 
            #   $92.1 = global(Exception: <type 'exceptions.Exception'>)  :: pyobject
            #   $const92.2 = const(str, Input not sorted: {} .. {})  :: pyobject
            #   $92.3 = getattr(attr=format, value=$const92.2)  :: pyobject
            #   $92.6 = call $92.3(last_out, a, kws=[], args=[Var(last_out, /home/user1/py/fast_ops.py (146)), Var(a, /home/user1/py/fast_ops.py (149))], func=$92.3, vararg=None)  :: pyobject
            #   $92.7 = call $92.1($92.6, kws=[], args=[Var($92.6, /home/user1/py/fast_ops.py (151))], func=$92.1, vararg=None)  :: pyobject
            #   raise $92.7
            # label 119
            #   del $67.7

            raise Exception("Input not sorted: {} .. {}".format(last_out, a))

        # --- LINE 152 --- 
        #   $119.3 = last_out != a  :: pyobject
        #   branch $119.3, 131, 155
        # label 131
        #   del $119.3
        #   del a

        if last_out!=a:

            # --- LINE 153 --- 
            #   result[count_out] = a  :: pyobject

            result[count_out] = a

            # --- LINE 154 --- 
            #   last_out = a  :: pyobject
            #   jump 155
            # label 153
            #   del last_out
            #   del arr
            #   del $phi67.1
            #   del $phi64.1
            #   del $64.4
            #   jump 154
            # label 154
            #   del result
            #   del count_out
            #   del $154.2
            #   del $154.4

            last_out = a

        # --- LINE 155 --- 



    # --- LINE 156 --- 
    #   $154.2 = getattr(attr=resize, value=result)  :: pyobject
    #   $154.4 = call $154.2(count_out, kws=[], args=[Var(count_out, /home/user1/py/fast_ops.py (147))], func=$154.2, vararg=None)  :: pyobject
    #   $154.5 = cast(value=$154.4)  :: pyobject
    #   return $154.5

    return result.resize(count_out)

最佳答案

事实证明,异常处理将其置于对象模式。

修正如下:

@jit
def unique_in_sorted(arr):
    result = np.empty_like(arr)
    count_out =_unique_in_sorted(arr, result)
    result.resize(count_out)
    return result

@jit(nopython=True)
def _unique_in_sorted(arr, result):
    result[0] = arr[0]
    last_out = arr[0]
    count_out = 1
    for i in range(len(arr)):
        a = arr[i]
        # if a<last_out:
        #     raise Exception("Input not sorted: {} .. {}".format(last_out, a))
        if last_out!=a:
            result[count_out] = a
            last_out = a
    return count_out

关于python - 可以推断或暗示 numba 本地人的类型吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49057801/

相关文章:

python - Numba - nopython 模式是否支持元组列表?

python - 无法再打开用于 Python 编程的 Spyder IDE

python - 非类型化全局名称 'pvc_sim' : cannot determine Numba type of <class 'skfuzzy.control.controlsystem.ControlSystemSimulation' >

python - 将一个以小写字母开头的元素连接到列表的前一个元素

python - 如何在 Django 中获取整个当前 URL

python - 这个递归函数能否变成具有类似性能的迭代函数?

python-3.x - Numba 在 np.astype 上无效使用 BoundFunction

python - 如何使用多线程加速嵌套for循环计算?

python - 求两个数之和的最有效方法

python - 使用 Pandas 执行一次热编码