python - 使用 cython 加速 numpy 数组的类(class)速度

标签 python numpy vectorization cython

我有以下代码:

class _Particles:
    def __init__(self, num_particle, dim, fun, lower_bound, upper_bound):
        self.lower_bound = lower_bound   # np.array of shape (dim,)
        self.upper_bound = upper_bound   # np.array of shape (dim,)
        self.num_particle = num_particle   # a scalar
        self.dim = dim   # dimension, a scalar
        self.fun = fun   # a function

        self.pos = np.empty((num_particle,dim))
        self.val = np.empty(num_particle)
        self.randomize()


    def randomize(self):
        self.pos = np.random.rand(self.num_particle, self.dim)*(self.upper_bound\
                -self.lower_bound)+self.lower_bound
        self.val = self.fun(np.transpose(self.pos))
        self.best_idx = np.argmin(self.val)
        self.best_val = self.val[self.best_idx]
        self.best_pos = self.pos[self.best_idx]


    def move(self, displacement, idx='all', check_bound=True):
        if idx is 'all':
            self.pos += displacement
        elif isinstance(idx,(tuple,list,np.ndarray)):
            self.pos[idx] += displacement
        else:
            raise TypeError('Check the type of idx!',type(idx))

        self.pos = np.maximum(self.pos, self.lower_bound[np.newaxis,:])
        self.pos = np.minimum(self.pos, self.upper_bound[np.newaxis,:])
        self.val = self.fun(np.transpose(self.pos))
        self.best_idx = np.argmin(self.val)
        self.best_val = self.val[self.best_idx]
        self.best_pos = self.pos[self.best_idx]

我想看看是否可以加快上述代码的速度,并且我正在考虑使用 cython,但我不确定是否可能,因为它主要使用 numpy 数组,并且大多数执行都是通过矢量化完成的。我尝试这样的事情:

# the .pyx file that will be compiled
cdef class _Particles(object):
    cdef int num_particle
    cdef int dim
    cdef fun
    cdef np.ndarray lower_bound
    cdef np.ndarray upper_bound
    cdef np.ndarray pos
    cdef np.ndarray val
    cdef int best_idx
    cdef double best_val
    cdef np.ndarray[np.float64_t, ndim=1] best_pos

    def __init__(self, int num_particle, int dim, fun,
                 np.ndarray lower_bound, np.ndarray upper_bound):
        self.num_particle = num_particle
        self.dim = dim
        self.fun = fun
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound

        self.pos = np.empty((num_particle,dim))
        self.val = np.empty(num_particle)
        self.randomize()

    def randomize(self):
        self.pos = npr.rand(self.num_particle,self.dim)*(self.upper_bound\
                -self.lower_bound)+self.lower_bound

        self.val = self.fun(np.transpose(self.pos))
        self.best_idx = np.argmin(self.val)
        self.best_val = self.val[self.best_idx]
        self.best_pos = self.pos[self.best_idx]

它的速度更快,但只快了一点,这在意料之中,因为它仍然主要是 python 代码。那么有没有什么方法可以使用 cython 加速上述代码(或者向我指出其他一些完全方法)?特别是如何加快self.fun(self.pos)np.argmin(self.val)等代码?

谢谢。

最佳答案

实际上,上面的代码恐怕没有太多需要优化的地方。 为了使 argmin 更快,我建议您获取(或以其他方式自行编译)具有多线程支持的 NumPy(或者您可以自己重新实现一些多线程 argmin)。

就 Cython 而言,当您开始使用 C 类型时,您会得到真正的好处,但我不会看到您发布的代码有很大的改进。 这主要是粘合代码,不涉及数字运算。

我希望数字运算发生在函数fun中,这可能是实际手动优化可能产生影响的唯一地方,只要它不那么容易矢量化(阅读:有一个 for 或其他手动循环)。然后,我将从 numba 开始,如果它有效的话,这是一个更简单的直接加速代码的方法。如果没有,那么开始研究 Cython 可能是合适的。

关于python - 使用 cython 加速 numpy 数组的类(class)速度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52114107/

相关文章:

python - 我的 Flask-Admin ModelView 的 __init__ 没有应用程序上下文——它通常什么时候获得一个?

php - Python SOAP 客户端、授权

python - 无法解析网页中所有可用的 asin

Python 比较数组为零比 np.any(array) 更快

python - Sublime Text 3、Python 3 和 UTF-8 彼此不喜欢

python - Numpy where 返回空数组

python - 锐化图像以检测纸上标记为 "X"的对象中的边缘/线条

python - 使用 numpy 数组将值分配给另一个数组

python - 如何使用 pandas fillna 快速填写大量数据?

r - 将 data.frames(n x 2 data.frames)列表 reshape 为单个 data.frame(n x 3 列)