python - Python 中每个线程使用 Numba 并行时间

标签 python multithreading parallel-processing numba

当我使用 numba 中的 njit 并行运行该程序时,我注意到使用许多线程并没有什么区别。事实上,从 1-5 个线程开始,时间会更快(这是预期的),但之后时间会变慢。为什么会发生这种情况?

from numba import njit,prange,set_num_threads,get_num_threads
import numpy as np
@njit(parallel=True)
def test(x,y):
    z=np.empty((x.shape[0],x.shape[0]),dtype=np.float64)
    for i in prange(x.shape[0]):
        for j in range(x.shape[0]):
            z[i,j]=x[i,j]*y[i,j]
    return z
x=np.random.rand(10000,10000)
y=np.random.rand(10000,10000)
for i in range(16):   
    set_num_threads(i+1)
    print("Number of threads :",get_num_threads())
    %timeit -r 1 -n 10 test(x,y)
Number of threads : 1
234 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)
Number of threads : 2
178 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)
Number of threads : 3
168 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)
Number of threads : 4
161 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)
Number of threads : 5
148 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)
Number of threads : 6
152 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)
Number of threads : 7
152 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)
Number of threads : 8
153 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)
Number of threads : 9
154 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)
Number of threads : 10
156 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)
Number of threads : 11
158 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)
Number of threads : 12
157 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)
Number of threads : 13
158 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)
Number of threads : 14
160 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)
Number of threads : 15
160 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)
Number of threads : 16
161 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 10 loops each)

我在 Jupyter Notebook (anaconda) 的 8 核 16 线程 CPU 中对此进行了测试。

最佳答案

代码受内存限制,因此 RAM 在只有少数核心的情况下就饱和了

事实上,z[i,j]=x[i,j]*y[i,j] 导致两次 8 字节的内存加载、一次 8 字节的存储和一次额外的加载由于 x86-64 处理器上的写入分配缓存策略,因此需要 8 个字节(在这种情况下必须读取已写入的缓存行)。这意味着每次循环迭代加载/存储 32 个字节,而只需要执行 1 次乘法。现代主流 (x86-64) 处理器可以执行 2x4 double FP 乘法/周期,并在 3-5 GHz 下运行(事实上,英特尔服务器处理器可以执行 2x8 DP FP 乘法/周期)。而好的主流PC只能达到40-60 GiB/s,高性能服务器只能达到200-350 GiB/s。

在 Numba 中没有办法加速像这样的内存绑定(bind)代码。 C/C++ 代码可以通过避免写入分配来稍微改进这一点(最多快 1.33 倍)。最好的解决方案是尽可能在较小的 block 上进行操作,并合并计算步骤,以便每步应用更多的 FP 操作。

事实上,众所周知,与处理器的计算能力相比,RAM 的速度增长缓慢。这个问题几十年前就已经被发现,而且随着时间的推移,两者之间的差距仍然越来越大。这个问题被称为“内存墙”。 future 情况不会更好(至少不太可能是这种情况)。

关于python - Python 中每个线程使用 Numba 并行时间,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/72756574/

相关文章:

python - 在新进程中执行 Bash 命令

parallel-processing - Julia - 模块和并行

r - 并行处理的负载平衡

python - 在 Python 中使用类的骰子生成器

python - 如何返回python给定文件路径中的错误行?

python - 允许用户在 Heroku 上为 Django 应用程序使用自定义域

python - 尝试修复tkinter GUI卡住(使用线程)

multithreading - 如何使用安全异常库捕获异步异常?

python - 使用占位符创建字符串

Java:类完全在第二个线程/IllegalMonitorStateException 中运行