python - 使用 numba jit,Python 的段间距离

标签 python numpy jit numba

在过去的一周里,我一直在询问有关此堆栈的相关问题,以尝试找出我不了解的关于在 Python 中将 @jit 装饰器与 Numba 结合使用的问题。但是,我碰壁了,所以我只写下整个问题。

当前的问题是计算成对大量 段之间的最小距离。这些段由它们的 3D 起点和终点表示。在数学上,每个段都被参数化为 [AB] = A + (B-A)*s,其中 s 在 [0,1] 中,A 和 B 是段的起点和终点。对于两个这样的线段,可以计算出最小距离并给出公式here .

我已经在另一个thread上暴露了这个问题,并且给出的答案涉及通过向量化问题来替换我的代码的双循环,但是这会遇到大量段的内存问题。因此,我决定坚持使用循环,并改用 numba 的 jit

由于问题的解需要很多点积,而numpy的点积是not supported by numba , 我首先实现了自己的 3D 点积。

import numpy as np
from numba import jit, autojit, double, float64, float32, void, int32

def my_dot(a,b):
    res = a[0]*b[0]+a[1]*b[1]+a[2]*b[2]
    return res

dot_jit = jit(double(double[:], double[:]))(my_dot)    #I know, it's not of much use here.

计算N段中所有对的最小距离的函数以Nx6数组(6个坐标)作为输入

def compute_stuff(array_to_compute):
    N = len(array_to_compute)
    con_mat = np.zeros((N,N))
    for i in range(N):
        for j in range(i+1,N):

            p0 = array_to_compute[i,0:3]
            p1 = array_to_compute[i,3:6]
            q0 = array_to_compute[j,0:3]
            q1 = array_to_compute[j,3:6]

            s = ( dot_jit((p1-p0),(q1-q0))*dot_jit((q1-q0),(p0-q0)) - dot_jit((q1-q0),(q1-q0))*dot_jit((p1-p0),(p0-q0)))/( dot_jit((p1-p0),(p1-p0))*dot_jit((q1-q0),(q1-q0)) - dot_jit((p1-p0),(q1-q0))**2 )
            t = ( dot_jit((p1-p0),(p1-p0))*dot_jit((q1-q0),(p0-q0)) -dot_jit((p1-p0),(q1-q0))*dot_jit((p1-p0),(p0-q0)))/( dot_jit((p1-p0),(p1-p0))*dot_jit((q1-q0),(q1-q0)) - dot_jit((p1-p0),(q1-q0))**2 )

            con_mat[i,j] = np.sum( (p0+(p1-p0)*s-(q0+(q1-q0)*t))**2 ) 

return con_mat

fast_compute_stuff = jit(double[:,:](double[:,:]))(compute_stuff)

因此,compute_stuff(arg) 将二维 np.array (double[:,:]) 作为参数,执行一系列 numba 支持的 (?) 操作,并返回另一个二维 np.array (double[:, :]).然而,

v = np.random.random( (100,6) )
%timeit compute_stuff(v)
%timeit fast_compute_stuff(v)

每个循环我得到 134 和 123 毫秒。你能解释一下为什么我不能加速我的功能吗?任何反馈将不胜感激。

最佳答案

这是我的代码版本,速度明显更快:

@jit(nopython=True)
def dot(a,b):
    res = a[0]*b[0]+a[1]*b[1]+a[2]*b[2]
    return res

@jit
def compute_stuff2(array_to_compute):
    N = array_to_compute.shape[0]
    con_mat = np.zeros((N,N))

    p0 = np.zeros(3)
    p1 = np.zeros(3)
    q0 = np.zeros(3)
    q1 = np.zeros(3)

    p0m1 = np.zeros(3)
    p1m0 = np.zeros(3)
    q0m1 = np.zeros(3)
    q1m0 = np.zeros(3)
    p0mq0 = np.zeros(3)

    for i in range(N):
        for j in range(i+1,N):

            for k in xrange(3):
                p0[k] = array_to_compute[i,k]
                p1[k] = array_to_compute[i,k+3]
                q0[k] = array_to_compute[j,k]
                q1[k] = array_to_compute[j,k+3]

                p0m1[k] = p0[k] - p1[k]
                p1m0[k] = -p0m1[k]

                q0m1[k] = q0[k] - q1[k]
                q1m0[k] = -q0m1[k]

                p0mq0[k] = p0[k] - q0[k]

            s = ( dot(p1m0, q1m0)*dot(q1m0, p0mq0) - dot(q1m0, q1m0)*dot(p1m0, p0mq0))/( dot(p1m0, p1m0)*dot(q1m0, q1m0) - dot(p1m0, q1m0)**2 )
            t = ( dot(p1m0, p1m0)*dot(q1m0, p0mq0) - dot(p1m0, q1m0)*dot(p1m0, p0mq0))/( dot(p1m0, p1m0)*dot(q1m0, q1m0) - dot(p1m0, q1m0)**2 )


            for k in xrange(3):
                con_mat[i,j] += (p0[k]+(p1[k]-p0[k])*s-(q0[k]+(q1[k]-q0[k])*t))**2 

    return con_mat

时间安排:

In [38]:

v = np.random.random( (100,6) )
%timeit compute_stuff(v)
%timeit fast_compute_stuff(v)
%timeit compute_stuff2(v)

np.allclose(compute_stuff2(v), compute_stuff(v))

#10 loops, best of 3: 107 ms per loop
#10 loops, best of 3: 108 ms per loop
#10000 loops, best of 3: 114 µs per loop
#True

我加快速度的基本策略是:

  • 摆脱所有数组表达式并显式展开矢量化(尤其是因为您的数组非常小)
  • 在调用 dot 方法时摆脱冗余计算(减去两个向量)。
  • 将所有数组创建移至嵌套 for 循环之外,以便 numba 可以执行一些操作 loop lifting .这也避免了创建许多昂贵的小阵列。最好分配一次并重用内存。

另一件需要注意的事情是,对于最近版本的 numba,过去被称为 autojit(即让 numba 对输入进行类型推断)并且现在只是没有类型提示的装饰器通常是就像根据我的经验指定您的输入类型一样快。

此外,在 OS X 上使用带有 Python 2.7.9 的 Anaconda python 发行版使用 numba 0.17.0 运行计时。

关于python - 使用 numba jit,Python 的段间距离,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/28970883/

相关文章:

python - 接近 1 的纹理坐标表现异常

python - 如何使用 pandas 行形成新列

java - 使用 JIT 编译器的 Collections.emptyList 和空 ArrayList 的性能

python - Python/Django 中的 URL 映射

python - 尝试解析嵌入在网页中的列表中的项目

python - 在 python 中替换 if elif block

python - 如何使用季度和年份的日期时间索引过滤 Pandas 系列

python - 如何读取Python中的polyfit函数?

javascript - 编写高性能 Javascript 代码而不去优化

c++ - "Live"代码和使用 C++ 和 LLVM JIT 的快速原型(prototype)制作?