我对使用 numba
比较陌生,我想用它来使我的数组计算尽可能高效。该函数是 numba 文档中几个概念的组合。
我正在使用 Scipy 库中的单一函数
scipy.special.eval_laguerre(n, x, out=None) = <ufunc 'eval_laguerre'>
计算点 n 处的拉盖尔多项式 L_n(x)。
问题 1:Numba 文档清楚地说明了如何使用装饰器 @vectorize
优化用户编写的 ufunc。 http://numba.pydata.org/numba-doc/0.12/ufuncs.html#generalized-ufuncs
是否有标准程序可以使用 python 库提供的 ufunc 来执行此操作?
问题 2:我想针对数组中包含 n 个值的数组,计算矩阵的每个条目的 L_n(x) 。然后我必须使用表达式对这些值求和:
result = np.sum( [eval_laguerre(n, matrix) for n in array], axis=0)
我用过的地方import numpy as np
.
如果我要使用广播,我会评估:
result = np.sum( eval_laguerre( array[:, None, None], matrix ), axis=0)
其中 axis=0
表示要求和的维度。
我想使用“@jit”来编译此部分,但我不确定'numpy.sum()
的程序是什么。 。目前,上面的表达式带有 @jit
表达式给出语法错误。
result = np.sum( eval_laguerre( array[:, None, None], matrix ), axis=0)
^
SyntaxError: invalid syntax
正确的使用方法是什么@jit
和np.sum()
?
编辑:回应@hpaulj:
我的想法是numba
可以优化 for 循环,即
for n in array:
eval_laguerre(n, matrix)
这可能吗?如果没有numba
,然后用什么? Pythran
?
最佳答案
让我们更具体一点:
一个示例数组,我将其用于 n
和 x
(您可以选择更实际的值):
In [782]: A=np.arange(12.).reshape(3,4)
该版本充分利用了ufunc
广播能力
In [790]: special.eval_laguerre(A[:,None,:],A[None,:,:]).shape
Out[790]: (3, 3, 4)
或者求和:
In [784]: np.sum(special.eval_laguerre(A[:,None,:],A[None,:,:]),0)
Out[784]:
array([[ 3.00000000e+00, -1.56922399e-01, -4.86843034e-01,
7.27719156e-02],
[ 1.37460317e+00, -4.47492284e+00, 5.77714286e+00,
-9.71780654e-01],
[ -1.76222222e+01, 7.00178571e+00, 5.55396825e+01,
-1.32810866e+02]])
相当于 sum
中的列表压缩:
In [785]: np.sum([special.eval_laguerre(n,A) for n in A],0)
Out[785]:
array([[ 3.00000000e+00, -1.56922399e-01, -4.86843034e-01,
7.27719156e-02],
[ 1.37460317e+00, -4.47492284e+00, 5.77714286e+00,
-9.71780654e-01],
[ -1.76222222e+01, 7.00178571e+00, 5.55396825e+01,
-1.32810866e+02]])
或者显式循环:
In [786]: x=np.zeros_like(A)
In [787]: for n in A:
x += special.eval_laguerre(n, A)
最后一个版本有机会使用 numba
进行编译。
在简单的时间测试中,ufunc 广播速度更快:
In [791]: timeit np.sum([special.eval_laguerre(n,A) for n in A],axis=0)
10000 loops, best of 3: 84.8 µs per loop
In [792]: timeit np.sum(special.eval_laguerre(A[:,None,:],A[None,:,:]),0)
10000 loops, best of 3: 43.9 µs per loop
我的猜测是,numba 版本将改进理解版本和显式循环,但可能不会比广播版本更快。
关于python - Numba:矢量化标准 SciPy ufunc 和 numpy.sum() 语法错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/30800863/