我正在尝试使用 Numba 来并行化一个带有两个 numpy ndarrays 的 Python 函数,alpha
和 beta,
作为论据。它们分别具有形式 (a,m,n)
的形状和 (b,m,n)
并且因此可以在后面的维度上广播。该函数计算参数的 2D 切片的矩阵点积(Frobenius 乘积),并找到使每个 alpha 切片的乘积最大化的 beta 切片。在代码中:
@njit(parallel=True)
def parallel_value(alpha,beta):
values=np.empty(alpha.shape[0])
indices=np.empty(alpha.shape[0])
for i in prange(alpha.shape[0]):
dot=np.sum(alpha[i,:,:]*beta,axis=(1,2))
index=np.argmax(dot)
values[i]=dot[index]
indices[i]=index
return values,indices
这在没有 njit 装饰器的情况下运行良好,但 Numba 编译器提示:No implementation of function Function(<built-in function setitem>) found for signature:
>>>setitem(array(float64, 1d, C), int64, array(float64, 1d, C))
违规行显然是values[i]=dot[index]
.我不知道为什么这是有问题的。此问题的原因是什么,我该如何解决?另外,添加
nogil=True
有什么好处吗?到 @njit
的参数?
最佳答案
我设法重现了您的问题。运行代码时:
import numpy as np
from numba import njit, prange
@njit(parallel=True)
def parallel_value(alpha,beta):
values=np.empty(alpha.shape[0])
indices=np.empty(alpha.shape[0])
for i in prange(alpha.shape[0]):
dot=np.sum(alpha[i,:,:]*beta,axis=(1,2))
index=np.argmax(dot)
values[i]=dot[index]
indices[i]=index
return values,indices
a, b, m, n = 6, 5, 4, 3
parallel_value(np.random.rand(a, m, n), np.random.rand(b, m, n))
我收到错误消息:TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:
>>> setitem(array(float64, 1d, C), int64, array(float64, 1d, C))
There are 16 candidate implementations:
- Of which 16 did not match due to:
Overload of function 'setitem': File: <numerous>: Line N/A.
With argument(s): '(array(float64, 1d, C), int64, array(float64, 1d, C))':
No match.
During: typing of setitem at <ipython-input-41-44518cf5219f> (11)
File "<ipython-input-41-44518cf5219f>", line 11:
def parallel_value(alpha,beta):
<source elided>
index=np.argmax(dot)
values[i]=dot[index]
^
根据 this issue在GitHub页面中,numba中的点操作可能存在问题。当我使用显式循环重写代码时,它似乎有效:
import numpy as np
from numba import njit, prange
@njit(parallel=True)
def parallel_value_numba(alpha,beta):
values = np.empty(alpha.shape[0])
indices = np.empty(alpha.shape[0])
for i in prange(alpha.shape[0]):
dot = np.zeros(beta.shape[0])
for j in prange(beta.shape[0]):
for k in prange(beta.shape[1]):
for l in prange(beta.shape[2]):
dot[j] += alpha[i,k,l]*beta[j, k, l]
index=np.argmax(dot)
values[i]=dot[index]
indices[i]=index
return values,indices
def parallel_value_nonumba(alpha,beta):
values=np.empty(alpha.shape[0])
indices=np.empty(alpha.shape[0])
for i in prange(alpha.shape[0]):
dot=np.sum(alpha[i,:,:]*beta,axis=(1,2))
index=np.argmax(dot)
values[i]=dot[index]
indices[i]=index
return values,indices
a, b, m, n = 6, 5, 4, 3
np.random.seed(42)
A = np.random.rand(a, m, n)
B = np.random.rand(b, m, n)
res_num = parallel_value_numba(A, B)
res_nonum = parallel_value_nonumba(A, B)
print(f'res_num = {res_num}')
print(f'res_nonum = {res_nonum}')
输出:res_num = (array([3.52775653, 2.49947515, 3.33824146, 2.9669794 , 3.78968905,
3.43988156]), array([1., 3., 1., 1., 1., 1.]))
res_nonum = (array([3.52775653, 2.49947515, 3.33824146, 2.9669794 , 3.78968905,
3.43988156]), array([1., 3., 1., 1., 1., 1.]))
据我所知,显式循环似乎不会妨碍性能。尽管我无法将其与没有它们的情况下运行相同的代码进行比较,因为这是 numba,但我的猜测是这无关紧要:%timeit res_num = parallel_value_numba(A, B)
%timeit res_nonum = parallel_value_nonumba(A, B)
输出:The slowest run took 1472.03 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 5: 4.92 µs per loop
10000 loops, best of 5: 76.9 µs per loop
最后,您可以通过对正在使用的代码进行矢量化来更有效地使用 numpy。它几乎和带有显式循环的 numba 一样快,而且你不会有那个初始编译时间。您可以这样做:def parallel_value_np(alpha,beta):
alpha = alpha.reshape(alpha.shape[0], 1, alpha.shape[1], alpha.shape[2])
beta = beta.reshape(1, beta.shape[0], beta.shape[1], beta.shape[2])
dot = np.sum(alpha*beta, axis=(2,3))
indices = np.argmax(dot, axis = 1)
values = dot[np.arange(len(indices)), indices]
return values,indices
res_np = parallel_value_np(A, B)
print(f'res_num = {res_np}')
%timeit res_num = parallel_value_numba(A, B)
输出:res_num = (array([3.52775653, 2.49947515, 3.33824146, 2.9669794 , 3.78968905,
3.43988156]), array([1, 3, 1, 1, 1, 1]))
The slowest run took 5.46 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 5: 16.1 µs per loop
关于python - 使用 Numba 在 nd-array 上并行化最大值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68106434/