python - 计算平方欧氏距离的可能优化

标签 python c performance numpy

我在一个Python项目中每天需要做几亿次欧式距离计算。

这是我的开始:

def euclidean_dist_square(x, y):
    diff = np.array(x) - np.array(y)
    return np.dot(diff, diff)

这非常快,我已经放弃了 sqrt 计算,因为我只需要对项目进行排名(最近邻搜索)。尽管如此,它仍然是脚本的瓶颈。因此我写了一个 C 扩展,它计算距离。计算始终使用 128 维 vector 完成。

#include "euclidean.h"
#include <math.h>

double euclidean(double x[128], double y[128])
{
    double Sum;
    for(int i=0;i<128;i++)
    {
        Sum = Sum + pow((x[i]-y[i]),2.0);
    }
    return Sum;
}

扩展的完整代码在这里:https://gist.github.com/herrbuerger/bd63b73f3c5cf1cd51de

与 numpy 版本相比,现在这提供了很好的加速。

但是有没有办法进一步加快速度(这是我的第一个 C 扩展,所以我假设有)?随着每天使用此功能的次数,每一微秒实际上都会带来好处。

你们中的一些人可能会建议将它从 Python 完全移植到另一种语言,不幸的是这是一个更大的项目而不是一个选项:(

谢谢。

编辑

我已经在 CodeReview 上发布了这个问题:https://codereview.stackexchange.com/questions/52218/possible-optimizations-for-calculating-squared-euclidean-distance

我会在一个小时内删除这个问题,以防有人开始写答案。

最佳答案

据我所知,在 NumPy 中计算欧氏距离的最快方法是 the one in scikit-learn , 可以概括为

def squared_distances(X, Y):
    """Return a distance matrix for each pair of rows i, j in X, Y."""
    # http://stackoverflow.com/a/19094808/166749
    X_row_norms = np.einsum('ij,ij->i', X, X)
    Y_row_norms = np.einsum('ij,ij->i', Y, Y)
    distances = np.dot(X, Y)
    distances *= -2
    distances += X_row_norms
    distances += Y_row_norms

    np.maximum(distances, 0, distances)  # get rid of negatives; optional
    return distances

这段代码的瓶颈是矩阵乘法 (np.dot),因此请确保您的 NumPy 与良好的 BLAS 实现相关联;在多核机器上使用多线程 BLAS 和足够大的输入矩阵,它应该比你在 C 中能想到的任何东西都快。注意它依赖于二项式公式

||x - y||² = ||x||² + ||y||² - 2 x⋅y    

X_row_normsY_row_norms 可以在 k-NN 用例的调用中缓存。

(我是这段代码的合著者,我花了很多时间来优化它和 SciPy 实现;scikit-learn 以牺牲一些准确性为代价更快,但对于 k-NN 来说这应该无关紧要很多。scipy.spatial.distance 中可用的 SciPy 实现实际上是您刚刚编写的代码的优化版本,并且更准确。)

关于python - 计算平方欧氏距离的可能优化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/23983748/

相关文章:

c - LPSTREAM 将整个 Stream 内容读入 unsigned char* 数组

c - 将二维动态数组的一维传递给函数

performance - 为什么在 Go 中交换 []float64 的元素比在 Rust 中交换 Vec<f64> 的元素更快?

PHP文件上传。发布与放置?

python - 与 Arduino Edison/Galileo 的语言兼容性

python - 错误 "AttributeError: ' unicode' object has no attribute 'read' "on file upload

c - xv6源码中cga_init()中的0xa55a是什么意思?

c# - 减少纬度和经度点数量的最快方法

python - cv2 videowriter,写得很好,但在应用程序关闭时抛出错误

python - "Merging"具有共同维度的 numpy 数组