multithreading - Prange 减慢 Cython 循环速度

标签 multithreading openmp cython

考虑两种计​​算随机数的方法,一种是单线程,一种是使用 cython prange 和 openmp 的多线程:

def rnd_test(long size1):
    cdef long i
    for i in range(size1):
        rand()
    return 1

def rnd_test_par(long size1):
    cdef long i
    with nogil, parallel():
        for i in prange(size1, schedule='static'):
             rand()
    return 1

函数 rnd_test 首先使用以下 setup.py 进行编译

from distutils.core import setup
from Cython.Build import cythonize

setup(
  name = 'Hello world app',
  ext_modules = cythonize("cython_test.pyx"),
)

rnd_test(100_000_000) 运行时间为 0.7 秒。

然后,使用以下 setup.py 编译 rnd_test_par

from distutils.core import setup
from distutils.extension import Extension
from Cython.Build import cythonize

ext_modules = [
    Extension(
        "cython_test_openmp",
        ["cython_test_openmp.pyx"],
        extra_compile_args=["-O3", '-fopenmp'],
        extra_link_args=['-fopenmp'],
    )

]

setup(
    name='hello-parallel-world',
    ext_modules=cythonize(ext_modules),
)

rnd_test_par(100_000_000) 运行时间为 10 秒!!!

在 ipython 中使用 cython 获得类似的结果:

%%cython
import cython
from cython.parallel cimport parallel, prange
from libc.stdlib cimport rand

def rnd_test(long size1):
    cdef long i
    for i in range(size1):
        rand()
    return 1

%%timeit
rnd_test(100_000_000)

1 个循环,3 次最佳:每个循环 1.5 秒

%%cython --compile-args=-fopenmp --link-args=-fopenmp --force
import cython
from cython.parallel cimport parallel, prange
from libc.stdlib cimport rand

def rnd_test_par(long size1):
    cdef long i
    with nogil, parallel():
        for i in prange(size1, schedule='static'):
                rand()
    return 1

%%timeit
rnd_test_par(100_000_000)

1 次循环,3 次最佳:每次循环 8.42 秒

我做错了什么?我对 cython 完全陌生,这是我第二次使用它。我上次有很好的经验,所以我决定用于一个带有蒙特卡罗模拟的项目(因此使用兰特)。

这是预期的吗?阅读了所有文档后,我认为 prange 在像这样的令人尴尬的并行情况下应该可以很好地工作。我不明白为什么这无法加快循环速度,甚至使它变得如此慢。

一些附加信息:

  • 我正在运行 python 3.6、cython 0.26。
  • gcc版本为“gcc (Ubuntu 5.4.0-6ubuntu1~16.04.4) 5.4.0 20160609”
  • CPU 使用情况确认并行版本实际上使用了许多内核 (系列案例的 90% 与 25%)

感谢您提供的任何帮助。我首先尝试使用 numba,它确实加快了计算速度,但它还有其他问题让我想避免它。我希望 Cython 在这种情况下工作。

谢谢!!!

最佳答案

通过 DavidW 的有用反馈和链接,我有了一个用于随机数生成的多线程解决方案。 然而,与单线程(向量化)Numpy 解决方案相比,节省的时间并没有那么大。 numpy 方法在 1.2 秒内生成 1 亿个数字(内存为 5GB),而多线程方法则为 0.7 秒。考虑到复杂性的增加(例如使用 c++ 库),我想知道这是否值得。也许我会将随机数生成保留为单线程,并致力于并行执行此步骤之后的计算。 然而,这个练习对于理解随机数生成器的问题非常有用。最终,我希望拥有可以在分布式环境中工作的框架,而且我现在可以看到,由于随机数生成器本质上具有不可忽略的状态,因此随机数生成器面临的挑战会更大。

%%cython --compile-args=-fopenmp --link-args=-fopenmp --force
# distutils: language = c++
# distutils: extra_compile_args = -std=c++11
import cython
cimport numpy as np
import numpy as np
from cython.parallel cimport parallel, prange, threadid
cimport openmp

cdef extern from "<random>" namespace "std" nogil:
    cdef cppclass mt19937:
        mt19937() # we need to define this constructor to stack allocate classes in Cython
        mt19937(unsigned int seed) # not worrying about matching the exact int type for seed

    cdef cppclass uniform_real_distribution[T]:
        uniform_real_distribution()
        uniform_real_distribution(T a, T b)
        T operator()(mt19937 gen) # ignore the possibility of using other classes for "gen"

@cython.boundscheck(False)
@cython.wraparound(False)        
def test_rnd_par(long size):
    cdef:
        mt19937 gen
        uniform_real_distribution[double] dist = uniform_real_distribution[double](0.0,1.0)
        narr = np.empty(size, dtype=np.dtype("double"))
        double [:] narr_view = narr
        long i

    with nogil, parallel():
        gen = mt19937(openmp.omp_get_thread_num())
        for i in prange(size, schedule='static'):
            narr_view[i] = dist(gen)
    return narr

关于multithreading - Prange 减慢 Cython 循环速度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46257321/

相关文章:

cython - 在 Cython 中获取结构元素

c++ - 需要帮助来理解 "ABA"问题

c# - UWP CreateFileAsync 线程在 Release模式下编译并选中 Compile with .NET Native 工具链时退出

具有共享对象的 C++11 多线程

java - ContentProvider insert() 总是在 UI 线程上运行?

c - OpenMP 循环并行化

c - 在 Visual Studio 2010 (OpenMP) 中并行化 for 循环

c - OpenMP 并行区域线程关联

cython - 诗歌+狮身人面像+Cython

python - 从具有可变列数的 ASCII 文件中读取浮点值