python - 使用 Cython 优化 NumPy

标签 python numpy cython matrix-multiplication optimization

我目前正在尝试优化我用纯 Python 编写的代码。此代码使用 NumPy非常重,因为我正在使用 NumPy 数组。下面你可以看到我转换为 Cython 的最简单的类.它只做两个 Numpy 数组的乘法。这里:

bendingForces = self.matrixPrefactor * membraneHeight

我的问题是,我是否以及如何优化它,因为当我查看“cython -a”生成的 C 代码时,它有很多 NumPy 调用,看起来效率不高。

import numpy as np
cimport numpy as np
ctypedef np.float64_t dtype_t
ctypedef np.complex128_t cplxtype_t
ctypedef Py_ssize_t index_t

    cdef class bendingForcesClass( object ):
        cdef dtype_t bendingRigidity
        cdef np.ndarray matrixPrefactor
        cdef np.ndarray bendingForces

        def __init__( self, dtype_t bendingRigidity, np.ndarray[dtype_t, ndim=2] waveNumbersNorm ):
            self.bendingRigidity = bendingRigidity
            self.matrixPrefactor = -self.bendingRigidity * waveNumbersNorm**2

        cpdef np.ndarray calculate( self, np.ndarray membraneHeight ) :
            cdef np.ndarray bendingForces
            bendingForces = self.matrixPrefactor * membraneHeight
            return bendingForces

我的想法是使用两个 for 循环并遍历数组的条目。也许我可以使用编译器通过 SIMD 操作来优化它?!我试过了,我可以编译它,但它给出了奇怪的结果并且花了很长时间。这是替代函数的代码:

cpdef np.ndarray calculate( self, np.ndarray membraneHeight ) :

    cdef index_t index1, index2 # corresponds to: cdef Py_ssize_t index1, index2
    for index1 in range( self.matrixSize ):
        for index2 in range( self.matrixSize ):
            self.bendingForces[ index1, index2 ] = self.matrixPrefactor.data[ index1, index2 ] * membraneHeight.data[ index1, index2 ]
    return self.bendingForces

然而,正如我所说,这段代码确实很慢,并且没有按预期运行。那我做错了什么?优化它并删除 NumPy 调用操作的最佳方法是什么?

最佳答案

对于简单的矩阵乘法,NumPy 代码已经在本地只进行循环和乘法运算,因此在 Cython 中很难超越它。 Cython 非常适合将 Python 中的循环替换为 Cython 中的循环的情况。您的代码比 NumPy 慢的原因之一是因为每次您在数组中进行索引查找时,

self.bendingForces[ index1, index2 ] = self.matrixPrefactor.data[ index1, index2 ] * membraneHeight.data[ index1, index2 ]

它会进行更多计算,例如边界检查(索引有效)。如果将索引转换为无符号整数,则可以在函数之前使用装饰器 @cython.boundscheck(False)

查看此 tutorial有关加速 Cython 代码的更多详细信息。

关于python - 使用 Cython 优化 NumPy,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/5331275/

相关文章:

python - 根据特定列的条件将一组行的数据帧值分配给另一组行

python - Elasticsearch 滚动(扫描)到 Pandas DataFrame

python - N 个最高值的 Torch argmax

python - 如何使用 numpy 加载没有固定列大小的数据

python - 如何使 Cython 可见我的 Python 模块?

c++ - 如何通过 Cython 处理 C++ 包装中的双指针

带有分割定义的 Cython 编译

python - 为什么 __debug__ 是关键字时不在关键字列表中?

python - 连接python中的两个数组,交替numpy中的列

python - 为什么需要 Visual C++ Installer 来为 Python 安装 numpy 包?