python - 与函数相比,类方法的 cython 优化令人沮丧

标签 python cython

我正在尝试在 cython 中实现一个简单的滚动平均值。作为算法模拟的一部分,我正在通过非常大的数据集运行它,因此使用 pandas.rolling 等不是一个选项。

但是,我面临着一个非常令人烦恼的情况,cython 在类方法上确实表现不佳。以下是代码。

class RollingAverage:
    def __init__(self, length):
        self.current = 0
        self.ma = 0
        self.window_length = length
        self.window = np.zeros(length, dtype=np.float32)

    def mean(self):
        return self.ma

    @cython.boundscheck(False)
    @cython.wraparound(False)
    def update(self, value):
        self.ma += (value - self.window[self.current % self.window_length]) / self.window_length
        self.window[self.current % self.window_length] = value
        self.current += 1
        return self.ma

    def update2(self,value):
        self.ma = __update_impl( self.ma,
                                 self.current,
                                 self.window_length,
                                 self.window,
                                 value )
        self.current += 1
        return self.ma

@cython.boundscheck(False)
@cython.wraparound(False)
def __update_impl(ma, current, window_length, window, value):
    ma += (value - window[current % window_length]) / window_length
    window[current % window_length] = value
    return ma

我还有一个 pxd 文件,其中定义了以下内容:

cdef class RollingAverage:
    cdef int current
    cdef float ma
    cdef int window_length
    cdef np.ndarray window
    cpdef update(self, float value)
    cpdef update2(self, float value)
    cpdef mean(self)


cdef float __update_impl(float ma,
                  int current,
                  int window_length,
                  np.ndarray[float] window,
                  float value)

编辑:这是 update__update_impl 中热代码的 cython 注释差异:

def update(self, value):

+0127:      self.ma += (value - self.window[self.current % self.window_length]) / self.window_length

  __pyx_t_1 = PyFloat_FromDouble(__pyx_v_self->ma); if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 127, __pyx_L1_error)
  __Pyx_GOTREF(__pyx_t_1);
  __pyx_t_2 = PyFloat_FromDouble(__pyx_v_value); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 127, __pyx_L1_error)
  __Pyx_GOTREF(__pyx_t_2);
  if (unlikely(__pyx_v_self->window_length == 0)) {
    PyErr_SetString(PyExc_ZeroDivisionError, "integer division or modulo by zero");
    __PYX_ERR(0, 127, __pyx_L1_error)
  }
  __pyx_t_8 = __Pyx_mod_int(__pyx_v_self->current, __pyx_v_self->window_length);
  __pyx_t_4 = __Pyx_GetItemInt(((PyObject *)__pyx_v_self->window), __pyx_t_8, int, 1, __Pyx_PyInt_From_int, 0, 0, 0); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 127, __pyx_L1_error)
  __Pyx_GOTREF(__pyx_t_4);
  __pyx_t_6 = PyNumber_Subtract(__pyx_t_2, __pyx_t_4); if (unlikely(!__pyx_t_6)) __PYX_ERR(0, 127, __pyx_L1_error)
  __Pyx_GOTREF(__pyx_t_6);
  __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
  __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
  __pyx_t_4 = __Pyx_PyInt_From_int(__pyx_v_self->window_length); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 127, __pyx_L1_error)
  __Pyx_GOTREF(__pyx_t_4);
  __pyx_t_2 = __Pyx_PyNumber_Divide(__pyx_t_6, __pyx_t_4); if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 127, __pyx_L1_error)
  __Pyx_GOTREF(__pyx_t_2);
  __Pyx_DECREF(__pyx_t_6); __pyx_t_6 = 0;
  __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
  __pyx_t_4 = PyNumber_InPlaceAdd(__pyx_t_1, __pyx_t_2); if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 127, __pyx_L1_error)
  __Pyx_GOTREF(__pyx_t_4);
  __Pyx_DECREF(__pyx_t_1); __pyx_t_1 = 0;
  __Pyx_DECREF(__pyx_t_2); __pyx_t_2 = 0;
  __pyx_t_7 = __pyx_PyFloat_AsFloat(__pyx_t_4); if (unlikely((__pyx_t_7 == (float)-1) && PyErr_Occurred())) __PYX_ERR(0, 127, __pyx_L1_error)
  __Pyx_DECREF(__pyx_t_4); __pyx_t_4 = 0;
  __pyx_v_self->ma = __pyx_t_7;

+0129:      self.current += 1
+0130:      return self.ma


def __update_impl(ma, current, window_length, window, value):

+0146:  ma += (value - window[current % window_length]) / window_length

  if (unlikely(__pyx_v_window_length == 0)) {
    PyErr_SetString(PyExc_ZeroDivisionError, "integer division or modulo by zero");
    __PYX_ERR(0, 146, __pyx_L1_error)
  }
  __pyx_t_1 = __Pyx_mod_int(__pyx_v_current, __pyx_v_window_length);
  __pyx_t_2 = (__pyx_v_value - (*__Pyx_BufPtrStrided1d(float *, __pyx_pybuffernd_window.rcbuffer->pybuffer.buf, __pyx_t_1, __pyx_pybuffernd_window.diminfo[0].strides)));
  if (unlikely(__pyx_v_window_length == 0)) {
    PyErr_SetString(PyExc_ZeroDivisionError, "float division");
    __PYX_ERR(0, 146, __pyx_L1_error)
  }
  __pyx_v_ma = (__pyx_v_ma + (__pyx_t_2 / ((float)__pyx_v_window_length)));

+0147:  window[current % window_length] = value
+0148:  return ma

方法update几乎比update2慢一个数量级。

%%timeit
ma.update(1000)
The slowest run took 11.68 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 3 µs per loop

%%timeit
ma.update2(1000)
The slowest run took 22.90 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 416 ns per loop

此外,当我在源代码上运行“cython -a”时,我在方法中得到了一些严重虚假的内容,而该函数中没有一丝黄色,几乎完全编译为 C 代码。

我还尝试将所有 self. 变量替换为我在 cython 中声明的局部变量,但无济于事。由于某种原因,在方法中包含代码会强制进行大量测试和转换。

我错过了什么?

最佳答案

主要区别是ndarray[float]在快速版本中 vs ndarray在慢速版本中。如果不知道数组数据类型,Cython 就不可能对访问进行任何真正的优化。

您必须这样做,因为 fully typed ndarrays aren't allowed as class members 。最好的解决方案是使用类型化内存 View float [:]其作用大致相同。 (例如,如果您特别需要访问 ndarray 来调用其方法之一,那么您可以通过内存 View 的 .base 属性获取它)

关于python - 与函数相比,类方法的 cython 优化令人沮丧,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45397362/

相关文章:

python - 在 Python 中将 click.progressbar 与多处理一起使用

python - 为什么 Cython Decorator 版本比 Cython Pyx 版本慢?

python - 字符串的快速串联

python-2.7 - 使用 cython 加速数以千计的集合操作

python - 如何使用 Eclipse 调试 pserve?

python - 有没有办法让一个程序暂停一段时间,但仍然有另一段代码在运行?

python - 在 Python 中查找多个子字符串之一的最有效方法是什么?

python - pickle Cython 修饰函数导致 PicklingError

python - 简化语句 '.' .join( string.split ('.' )[0 :3] )

python - 获取列与某个值匹配的行的索引