我尝试用 numba 加速我的代码,但它似乎不起作用。该程序与 @jit
、@njit
或纯 Python 程序花费相同的时间(大约 10 秒)。不过我使用了 numpy 而不是 list 或 dict。
这是我的代码:
import numpy as np
from numba import njit
import random
import line_profiler
import atexit
profile = line_profiler.LineProfiler()
atexit.register(profile.print_stats)
@njit
def knapSack(W, wt, val, n):
K = np.full((n+1,W+1),0)
N = np.full((n+1,W+1,W+1),0)
M = np.full((n+1,W+1),0)
for i in range(n+1):
for w in range(W+1):
if i==0 or w==0:
K[i][w] = 0
elif wt[i-1] <= w:
if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]):
K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
c = N[i-1][w-wt[i-1]]
c[i] = i
N[i][w] = c
else:
K[i][w] = K[i-1][w]
N[i][w] = N[i-1][w]
else:
K[i][w] = K[i-1][w]
N[n][W][0] = K[n][W]
return N[n][W]
@profile
def main():
size = 1000
val = [random.randint(1, size) for i in range(0, size)]
wt = [random.randint(1, size) for i in range(0, size)]
W = 1000
n = len(val)
a = knapSack(W, wt, val, n)
main()
最佳答案
事实上,如果不改变方法本身,就不可能真正提高当前算法的性能。
您的 N
数组包含大约 10 亿个对象 ( 1001 * 1001 * 1001
)。您需要设置每个元素,因此您至少有十亿次操作。为了获得下限,我们假设设置一个数组元素需要一纳秒(实际上需要更多时间)。 10亿次操作,每次需要1纳秒,意味着需要1秒才能完成。正如我所说,每次操作可能需要比 1 纳秒长一点的时间,因此我们假设需要 10 纳秒(可能有点高,但比 1 纳秒更现实),这意味着算法总共需要 10 秒。
因此,您输入的预期运行时间将在 1 秒到 10 秒之间。因此,如果您的 Python 版本需要 10 秒,那么它可能已经达到了您选择的方法所能达到的极限,并且没有任何工具可以(显着)改进该运行时间。
可以使它更快一点的一件事是使用 np.zeros
而不是 np.full
:
K = np.zeros((n+1, W+1), dtype=int)
N = np.zeros((n+1, W+1, W+1), dtype=int)
并且不要创建 M
,因为您不会使用它。
由于您已经使用了 line-profiler,我决定看一下,并得到了这个结果:
Line # Hits Time Per Hit % Time Line Contents
==============================================================
3 def knapSack(W, wt, val, n):
4 1 19137.0 19137.0 0.0 K = np.full((n+1,W+1),0)
5 1 19408592.0 19408592.0 28.1 N = np.full((n+1,W+1,W+1),0)
6
7 1002 6412.0 6.4 0.0 for i in range(n+1):
8 1003002 4186311.0 4.2 6.1 for w in range(W+1):
9 1002001 4644031.0 4.6 6.7 if i==0 or w==0:
10 2001 19663.0 9.8 0.0 K[i][w] = 0
11 1000000 5474080.0 5.5 7.9 elif wt[i-1] <= w:
12 498365 9616406.0 19.3 13.9 if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]):
13 52596 902030.0 17.2 1.3 K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
14 52596 578740.0 11.0 0.8 c = N[i-1][w-wt[i-1]]
15 52596 295980.0 5.6 0.4 c[i] = i
16 52596 1239792.0 23.6 1.8 N[i][w] = c
17 else:
18 445769 5100917.0 11.4 7.4 K[i][w] = K[i-1][w]
19 445769 11677683.0 26.2 16.9 N[i][w] = N[i-1][w]
20 else:
21 501635 5801328.0 11.6 8.4 K[i][w] = K[i-1][w]
22 1 16.0 16.0 0.0 N[n][W][0] = K[n][W]
23 1 14.0 14.0 0.0 return N[n][W]
这表明瓶颈是 np.full
、 N[i][w] = N[i-1][w]
和 if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w])
。 Numba 不会改进前两个,因为它们已经使用了高度优化的 NumPy 代码,对于这些代码,numba 更有可能会变慢。 Numba 可能可以改进 if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w])
,但这可能不会被注意到。
如果 np.full
被 np.zeros
替换,配置文件会略有变化:
Line # Hits Time Per Hit % Time Line Contents
==============================================================
3 def knapSack(W, wt, val, n):
4 1 747.0 747.0 0.0 K = np.zeros((n+1, W+1),dtype=int)
5 1 109592.0 109592.0 0.2 N = np.zeros((n+1, W+1, W+1),dtype=int)
6
7 1002 4230.0 4.2 0.0 for i in range(n+1):
8 1003002 4414071.0 4.4 7.0 for w in range(W+1):
9 1002001 4836807.0 4.8 7.7 if i==0 or w==0:
10 2001 22282.0 11.1 0.0 K[i][w] = 0
11 1000000 5646859.0 5.6 8.9 elif wt[i-1] <= w:
12 521222 10389581.0 19.9 16.5 if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]):
13 47579 784563.0 16.5 1.2 K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
14 47579 509056.0 10.7 0.8 c = N[i-1][w-wt[i-1]]
15 47579 362796.0 7.6 0.6 c[i] = i
16 47579 1975916.0 41.5 3.1 N[i][w] = c
17 else:
18 473643 5579823.0 11.8 8.8 K[i][w] = K[i-1][w]
19 473643 22805846.0 48.1 36.1 N[i][w] = N[i-1][w]
20 else:
21 478778 5664271.0 11.8 9.0 K[i][w] = K[i-1][w]
22 1 16.0 16.0 0.0 N[n][W][0] = K[n][W]
23 1 10.0 10.0 0.0 return N[n][W]
但主要瓶颈仍然是 N[i][w] = N[i-1][w]
,使用 numba 可能比使用纯 NumPy 慢。因此,使用 numba 对代码的其他部分所获得的改进可能(再次)是不明显的。
对于第一个配置文件,我使用了您的此版本的代码(第二个配置文件只是将 np.full
更改为 np.zeros
):
import numpy as np
def knapSack(W, wt, val, n):
K = np.full((n+1,W+1),0)
N = np.full((n+1,W+1,W+1),0)
for i in range(n+1):
for w in range(W+1):
if i==0 or w==0:
K[i][w] = 0
elif wt[i-1] <= w:
if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]):
K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
c = N[i-1][w-wt[i-1]]
c[i] = i
N[i][w] = c
else:
K[i][w] = K[i-1][w]
N[i][w] = N[i-1][w]
else:
K[i][w] = K[i-1][w]
N[n][W][0] = K[n][W]
return N[n][W]
import random
size = 1000
val = [random.randint(1, size) for i in range(0, size)]
wt = [random.randint(1, size) for i in range(0, size)]
W = 1000
n = len(val)
%lprun -f knapSack knapSack(W, wt, val, n)
关于python - 为什么 numba 不提高我的背包功能的速度?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59018318/