python - 为什么 numba 不提高我的背包功能的速度?

标签 python python-3.x performance jit numba

我尝试用 numba 加速我的代码,但它似乎不起作用。该程序与 @jit@njit 或纯 Python 程序花费相同的时间(大约 10 秒)。不过我使用了 numpy 而不是 list 或 dict。

这是我的代码:

import numpy as np
from numba import njit
import random
import line_profiler
import atexit
profile = line_profiler.LineProfiler()
atexit.register(profile.print_stats)

@njit
def knapSack(W, wt, val, n):
    K = np.full((n+1,W+1),0)
    N =  np.full((n+1,W+1,W+1),0)
    M =  np.full((n+1,W+1),0)

    for i in range(n+1):
        for w in range(W+1):
            if i==0 or w==0:
                K[i][w] = 0
            elif wt[i-1] <= w:
                if(val[i-1] + K[i-1][w-wt[i-1]] >  K[i-1][w]):
                    K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
                    c = N[i-1][w-wt[i-1]]
                    c[i] = i
                    N[i][w] = c
                else:
                    K[i][w] = K[i-1][w]
                    N[i][w] = N[i-1][w]
            else:
                K[i][w] = K[i-1][w]
    N[n][W][0] = K[n][W]
    return N[n][W]

@profile
def main():

    size = 1000
    val = [random.randint(1, size) for i in range(0, size)]
    wt = [random.randint(1, size) for i in range(0, size)]
    W = 1000
    n = len(val)
    a = knapSack(W, wt, val, n)
main()

最佳答案

事实上,如果不改变方法本身,就不可能真正提高当前算法的性能。

您的 N 数组包含大约 10 亿个对象 ( 1001 * 1001 * 1001 )。您需要设置每个元素,因此您至少有十亿次操作。为了获得下限,我们假设设置一个数组元素需要一纳秒(实际上需要更多时间)。 10亿次操作,每次需要1纳秒,意味着需要1秒才能完成。正如我所说,每次操作可能需要比 1 纳秒长一点的时间,因此我们假设需要 10 纳秒(可能有点高,但比 1 纳秒更现实),这意味着算法总共需要 10 秒。

因此,您输入的预期运行时间将在 1 秒到 10 秒之间。因此,如果您的 Python 版本需要 10 秒,那么它可能已经达到了您选择的方法所能达到的极限,并且没有任何工具可以(显着)改进该运行时间。


可以使它更快一点的一件事是使用 np.zeros 而不是 np.full :

K = np.zeros((n+1, W+1), dtype=int)
N = np.zeros((n+1, W+1, W+1), dtype=int)

并且不要创建 M,因为您不会使用它。


由于您已经使用了 line-profiler,我决定看一下,并得到了这个结果:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     3                                           def knapSack(W, wt, val, n):
     4         1      19137.0  19137.0      0.0      K = np.full((n+1,W+1),0)
     5         1   19408592.0 19408592.0     28.1      N = np.full((n+1,W+1,W+1),0)
     6                                           
     7      1002       6412.0      6.4      0.0      for i in range(n+1):
     8   1003002    4186311.0      4.2      6.1          for w in range(W+1):
     9   1002001    4644031.0      4.6      6.7              if i==0 or w==0:
    10      2001      19663.0      9.8      0.0                  K[i][w] = 0
    11   1000000    5474080.0      5.5      7.9              elif wt[i-1] <= w:
    12    498365    9616406.0     19.3     13.9                  if(val[i-1] + K[i-1][w-wt[i-1]] >  K[i-1][w]):
    13     52596     902030.0     17.2      1.3                      K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
    14     52596     578740.0     11.0      0.8                      c = N[i-1][w-wt[i-1]]
    15     52596     295980.0      5.6      0.4                      c[i] = i
    16     52596    1239792.0     23.6      1.8                      N[i][w] = c
    17                                                           else:
    18    445769    5100917.0     11.4      7.4                      K[i][w] = K[i-1][w]
    19    445769   11677683.0     26.2     16.9                      N[i][w] = N[i-1][w]
    20                                                       else:
    21    501635    5801328.0     11.6      8.4                  K[i][w] = K[i-1][w]
    22         1         16.0     16.0      0.0      N[n][W][0] = K[n][W]
    23         1         14.0     14.0      0.0      return N[n][W]

这表明瓶颈是 np.fullN[i][w] = N[i-1][w]if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]) 。 Numba 不会改进前两个,因为它们已经使用了高度优化的 NumPy 代码,对于这些代码,numba 更有可能会变慢。 Numba 可能可以改进 if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]),但这可能不会被注意到。

如果 np.fullnp.zeros 替换,配置文件会略有变化:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     3                                           def knapSack(W, wt, val, n):
     4         1        747.0    747.0      0.0      K = np.zeros((n+1, W+1),dtype=int)
     5         1     109592.0 109592.0      0.2      N = np.zeros((n+1, W+1, W+1),dtype=int)
     6                                           
     7      1002       4230.0      4.2      0.0      for i in range(n+1):
     8   1003002    4414071.0      4.4      7.0          for w in range(W+1):
     9   1002001    4836807.0      4.8      7.7              if i==0 or w==0:
    10      2001      22282.0     11.1      0.0                  K[i][w] = 0
    11   1000000    5646859.0      5.6      8.9              elif wt[i-1] <= w:
    12    521222   10389581.0     19.9     16.5                  if(val[i-1] + K[i-1][w-wt[i-1]] >  K[i-1][w]):
    13     47579     784563.0     16.5      1.2                      K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
    14     47579     509056.0     10.7      0.8                      c = N[i-1][w-wt[i-1]]
    15     47579     362796.0      7.6      0.6                      c[i] = i
    16     47579    1975916.0     41.5      3.1                      N[i][w] = c
    17                                                           else:
    18    473643    5579823.0     11.8      8.8                      K[i][w] = K[i-1][w]
    19    473643   22805846.0     48.1     36.1                      N[i][w] = N[i-1][w]
    20                                                       else:
    21    478778    5664271.0     11.8      9.0                  K[i][w] = K[i-1][w]
    22         1         16.0     16.0      0.0      N[n][W][0] = K[n][W]
    23         1         10.0     10.0      0.0      return N[n][W]

但主要瓶颈仍然是 N[i][w] = N[i-1][w],使用 numba 可能比使用纯 NumPy 慢。因此,使用 numba 对代码的其他部分所获得的改进可能(再次)是不明显的。


对于第一个配置文件,我使用了您的此版本的代码(第二个配置文件只是将 np.full 更改为 np.zeros ):

import numpy as np

def knapSack(W, wt, val, n):
    K = np.full((n+1,W+1),0)
    N = np.full((n+1,W+1,W+1),0)

    for i in range(n+1):
        for w in range(W+1):
            if i==0 or w==0:
                K[i][w] = 0
            elif wt[i-1] <= w:
                if(val[i-1] + K[i-1][w-wt[i-1]] >  K[i-1][w]):
                    K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
                    c = N[i-1][w-wt[i-1]]
                    c[i] = i
                    N[i][w] = c
                else:
                    K[i][w] = K[i-1][w]
                    N[i][w] = N[i-1][w]
            else:
                K[i][w] = K[i-1][w]
    N[n][W][0] = K[n][W]
    return N[n][W]

import random
size = 1000
val = [random.randint(1, size) for i in range(0, size)]
wt = [random.randint(1, size) for i in range(0, size)]
W = 1000
n = len(val)

%lprun -f knapSack knapSack(W, wt, val, n)

关于python - 为什么 numba 不提高我的背包功能的速度?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59018318/

相关文章:

python - 从命令行运行 python 脚本,然后进入交互模式

python - 该平台缺少功能性的 sem_open 实现,因此,需要所需的同步原语

Python/Bash - 获取带有转义字符的文件名

python-3.x - 在 Python 中使用 kwargs 添加参数

java - 变量性能 - java

java - 测试对象是否为空的有效方法

python - 如何解析字符串查找特定单词/数字并在找到时显示它们

python - invoke.context.Context 因缺少位置参数而出现奇怪的错误

python - Python Bug 中的合并排序

javascript - Electron 中 app.ready() 之前调用了哪个方法