python - Cython 函数指针解引用时间(与直接调用函数相比)

标签 python function numpy pointers cython

我有一些 Cython 代码,涉及对以下形式的 Numpy 数组(表示 BGR 图像)进行极其重复的像素操作:

ctypedef double (*blend_type)(double, double) # function pointer
@cython.boundscheck(False)  # Deactivate bounds checking
@cython.wraparound(False)   # Deactivate negative indexing.
cdef cnp.ndarray[cnp.float_t, ndim=3] blend_it(const double[:, :, :] array_1, const double[:, :, :] array_2, const blend_type blendfunc, const double opacity):
  # the base layer is a (array_1)
  # the blend layer is b (array_2)
  # base layer is below blend layer
  cdef Py_ssize_t y_len = array_1.shape[0]
  cdef Py_ssize_t x_len = array_1.shape[1]
  cdef Py_ssize_t a_channels = array_1.shape[2]
  cdef Py_ssize_t b_channels = array_2.shape[2]
  cdef cnp.ndarray[cnp.float_t, ndim=3] result = np.zeros((y_len, x_len, a_channels), dtype = np.float_)
  cdef double[:, :, :] result_view = result
  cdef Py_ssize_t x, y, c

  for y in range(y_len):
    for x in range(x_len):
      for c in range(3): # iterate over BGR channels first
        # calculate channel values via blend mode
        a = array_1[y, x, c]
        b = array_2[y, x, c]
        result_view[y, x, c] = blendfunc(a, b)
        # many other operations involving result_view...
  return result;

其中blendfunc指的是另一个cython函数,例如下面的overlay_pix:

cdef double overlay_pix(double a, double b):
  if a < 0.5:
    return 2*a*b
  else:
    return 1 - 2*(1 - a)*(1 - b)

使用函数指针的目的是避免为每种混合模式(其中有很多)一遍又一遍地重写大量困惑的重复代码。因此,我为每种混合模式创建了一个这样的界面,省去了麻烦:

def overlay(double[:, :, :] array_1, double[:, :, :] array_2, double opacity = 1.0):
  return blend_it(array_1, array_2, overlay_pix, opacity)

但是,这似乎花费了我一些时间!我注意到,对于非常大的图像(例如 8K 图像及更大图像),在 blend_it 函数中使用 blendfunc 而不是直接调用会浪费大量时间到 blend_it 中的 overlay_pix。我认为这是因为 blend_it 必须在迭代中每次都取消引用函数指针,而不是让函数立即可用,但我不确定。

时间损失并不理想,但我当然不想为每种混合模式一遍又一遍地重写blend_it。有什么办法可以避免时间损失吗?有没有办法将函数指针转换为循环外的本地函数,然后在循环内更快地访问它?

最佳答案

@ead's answer说了两件有趣的事情:

  1. C 也许能够将其优化为直接调用。我认为除了相当简单的情况之外,这通常是不正确的,并且对于 OP 正在使用的编译器和代码来说似乎也不正确。
  2. 在 C++ 中,您可以使用模板 - 这绝对是正确的,因为模板类型在编译时始终已知,优化通常很容易。

Cython 和 C++ 模板有点困惑,所以我认为您不想在这里使用它们。然而 Cython 确实有一个类似模板的功能,称为 fused types 。您可以使用融合类型来获得编译时优化,如下所示。代码的大致轮廓是:

  1. 定义一个 cdef 类,其中包含用于执行您想要执行的所有操作的 staticmethod cdef 函数。
  2. 定义一个包含所有cdef 类的融合类型。 (这是这种方法的局限性 - 它不容易扩展,因此如果您想添加操作,则必须编辑代码)
  3. 定义一个采用融合类型的虚拟参数的函数。使用此类型调用静态方法
  4. 定义包装函数 - 您需要使用显式的 [type] 语法才能使其正常工作。

代码:

import cython

cdef class Plus:
    @staticmethod
    cdef double func(double x):
        return x+1    

cdef class Minus:
    @staticmethod
    cdef double func(double x):
        return x-1

ctypedef fused pick_func:
    Plus
    Minus

cdef run_func(double [::1] x, pick_func dummy):
    cdef int i
    with cython.boundscheck(False), cython.wraparound(False):
        for i in range(x.shape[0]):
            x[i] = cython.typeof(dummy).func(x[i])
    return x.base

def run_func_plus(x):
    return run_func[Plus](x,Plus())

def run_func_minus(x):
    return run_func[Minus](x,Minus())

为了比较,使用函数指针的等效代码是

cdef double add_one(double x):
    return x+1

cdef double minus_one(double x):
    return x-1

cdef run_func_ptr(double [::1] x, double (*f)(double)):
    cdef int i
    with cython.boundscheck(False), cython.wraparound(False):
        for i in range(x.shape[0]):
            x[i] = f(x[i])
    return x.base

def run_func_ptr_plus(x):
    return run_func_ptr(x,add_one)

def run_func_ptr_minus(x):
    return run_func_ptr(x,minus_one)

使用 timeit 与使用函数指针相比​​,我获得了大约 2.5 倍的加速。这表明函数指针没有针对我进行优化(但是我还没有尝试更改编译器设置来尝试改进)

import numpy as np
import example

# show the two methods give the same answer
print(example.run_func_plus(np.ones((10,))))
print(example.run_func_minus(np.ones((10,))))

print(example.run_func_ptr_plus(np.ones((10,))))
print(example.run_func_ptr_minus(np.ones((10,))))

from timeit import timeit

# timing comparison
print(timeit("""run_func_plus(x)""",
             """from example import run_func_plus
from numpy import zeros
x = zeros((10000,))
""",number=10000))

print(timeit("""run_func_ptr_plus(x)""",
             """from example import run_func_ptr_plus
from numpy import zeros
x = zeros((10000,))
""",number=10000))

关于python - Cython 函数指针解引用时间(与直接调用函数相比),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54965864/

相关文章:

python - 如何为星形参数设置默认值?

python - 我可以使用 try/except 重新启动我的 python 代码吗?

Python:如何从列表中删除价低于特定值的值

c - arm 链接器在函数调用中使用的 'veneer' 是什么?

python - Python 中的语句和函数有什么区别?

python - 为什么子进程中的简单回显不起作用

sql - PostgreSQL 函数返回表

python - 如何按面积排序opencvconnectedComponentewithStat?

python - 为什么这个函数返回全零

python - 从 2 个法线创建旋转矩阵