python - 在 numba.jit 装饰器中使用并行选项使函数给出错误的结果

标签 python python-3.x numpy jit numba

给定矩形的两个对角 (x1, y1)(x2, y2) 以及两个半径 r1 r2,找到位于由半径 r1r2 定义的圆之间的点与矩形中点总数的比率。

简单的 NumPy 方法:

def func_1(x1,y1,x2,y2,r1,r2,n):
     x11,y11 = np.meshgrid(np.linspace(x1,x2,n),np.linspace(y1,y2,n))
     z1 = np.sqrt(x11**2+y11**2)
     a = np.where((z1>(r1)) & (z1<(r2)))
     fill_factor = len(a[0])/(n*n)
     return fill_factor

接下来我尝试使用 numba 的 jit 装饰器优化这个函数。当我使用时:

nopython = True

函数速度更快,输出正确。但是当我还添加:

parallel = True

函数速度更快但给出了错误的结果。 我知道这与我的 z 矩阵有关,因为它没有正确更新。

@jit(nopython=True,parallel=True)
def func_2(x1,y1,x2,y2,r1,r2,n):
    x_ = np.linspace(x1,x2,n)
    y_ = np.linspace(y1,y2,n)
    z1 = np.zeros((n,n))
    for i in range(n):
        for j in range(n):
            z1[i][j] = np.sqrt((x_[i]*x_[i]+y_[j]*y_[j]))
    a = np.where((z1>(r1)) & (z1<(r2)))
    fill_factor = len(a[0])/(n*n)
    return fill_factor

测试值:

x1 = 1.0
x2 = -1.0
y1 = 1.0
y2 = -1.0
r1 = 0.5
r2 = 0.75
n = 25000

附加信息:Python 版本:3.6.1,Numba 版本:0.34.0+5.g1762237,NumPy 版本:1.13.1

最佳答案

parallel=True 的问题在于它是一个黑盒子。 Numba 甚至不保证它实际上会并行化任何东西。它使用启发式方法来确定它是否可并行化以及哪些可以并行完成。这些可能会失败,在您的示例中它们确实会失败,就像在 my experiments with parallel and numba 中一样.这使得 parallel 变得不可信,我建议反对使用它!

在较新的版本 (0.34) 中,添加了 prange,您可能会更幸运。它不能在这种情况下应用,因为 prange 的工作方式类似于 range 并且不同于 np.linspace...

请注意:您可以完全避免在函数中构建 z 和执行 np.where,您可以明确地进行检查:

import numpy as np
import numba as nb

@nb.njit   # equivalent to "jit(nopython=True)".
def func_2(x1,y1,x2,y2,r1,r2,n):
    x_ = np.linspace(x1,x2,n)
    y_ = np.linspace(y1,y2,n)
    cnts = 0
    for i in range(n):
        for j in range(n):
            z = np.sqrt(x_[i] * x_[i] + y_[j] * y_[j])
            if r1 < z < r2:
                cnts += 1
    fill_factor = cnts/(n*n)
    return fill_factor

与您的函数相比,这也应该提供一些加速,甚至可能比使用 parallel=True(如果它能正常工作)。

关于python - 在 numba.jit 装饰器中使用并行选项使函数给出错误的结果,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46009368/

相关文章:

python - 在 python 中处理可能是整数或字符的方法的参数的正确方法是什么?

python - 如何检查单词是否以字母表范围内开头

datetime - 在 Pandas 中绘制 TimeDeltas

windows - OpenCV - python 3.x 和 windows - Numpy 的版本是什么?

python - 如何通过在 Python 中调用相同的字典值每次都返回新的对象?

python-3.x - 对 numpy 的同情导致 AttributeError : 'Symbol' object has no attribute 'cos'

python - 在输出文件 [Python] 中写入超过 80 个字符的行

python - 在 pybind11 中混合 Python 和 C++ 源文件

python - 如何将kubernetes实现的 secret 环境变量获取到python中?

python - 如何内置对 python 调用的跟踪功能?