java - 欧氏距离,Scipy、纯Python、Java结果不同

标签 java python scipy knn euclidean-distance

我正在研究欧几里得距离度量的不同实现,我注意到我得到了 Scipy、纯 Python 和 Java 的不同结果。

下面是我如何使用 Scipy(= 选项 1)计算距离:

distance = scipy.spatial.distance.euclidean(sample, training_vector)

这是我在论坛中找到的 Python 实现(选项 2):

distance = math.sqrt(sum([(a - b) ** 2 for a, b in zip(training_vector, sample)]))

最后,这是我在 Java 中的实现(选项 3):

public double distance(int[] a, int[] b) {
    assert a.length == b.length;
    double squaredDistance = 0.0;
    for(int i=0; i<a.length; i++){
        squaredDistance += Math.pow(a[i] - b[i], 2.0);
    }
    return Math.sqrt(squaredDistance);
}

sampletraining_vector 都是长度为 784 的一维数组,取自 MNIST 数据集。我用相同的 sampletraining_vector 尝试了所有三种方法。问题是三种不同的方法导致了三个明显不同的距离(即选项 1 大约在 1936 年左右,选项 2 大约在 1914 年左右,选项 3 大约在 1382 年左右)。有趣的是,当我在选项 1 和 2 中对 sampletraining_vector 使用相同的参数顺序时(即将参数翻转到选项 1),我得到了相同的结果两个选项。但是距离度量应该是对称的,对吧...?

同样有趣的是:我将这些指标用于 MNIST 数据集的 k-NN 分类器。对于 100 个测试样本和 2700 个训练样本,我的 Java 实现产生了大约 94% 的准确率。然而,使用选项 1 的 Python 实现只能产生大约 75% 的准确率......

关于为什么我会得到这些不同的结果,您有什么想法吗?如果您有兴趣,我可以在线发布两个阵列的 CSV,并在此处发布链接。

我正在使用 Java 8、Python 2.7 和 Scipy 1.0.0。

编辑: 将选项 2 更改为

distance = math.sqrt(sum([(float(a) - float(b)) ** 2 for a, b in zip(training_vector, sample)]))

这有以下效果:

  • 它摆脱了 ubyte 溢出警告(我之前一定错过了这个警告......)
  • 更改选项 1 和 2 的参数顺序不再产生影响。
  • 选项 2(纯 Python)和选项 3(Java)的结果现在相等

因此,这只会留下以下问题:为什么使用 SciPy 时结果不同(即错误?)?

最佳答案

好的,我找到了解决方案:我使用带有 dtype=np.uint8 的 pandas 导入了训练和测试数据集。因此,sampletraining_vector 都是 uint8 类型的 numpy 数组。我将数据类型更改为 np.float32,现在我的所有三个选项都给出了相同的结果。我还尝试了 np.uint32,它也能正常工作。

我不太清楚为什么,但显然,在使用 uint8 时,SciPy 没有给出“预期”的结果。也许 SciPy 中有一些内部溢出?不太确定,但至少它现在有效。感谢所有提供帮助的人!

关于java - 欧氏距离,Scipy、纯Python、Java结果不同,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49029995/

相关文章:

Java浅拷贝子类实例的父类(super class)实例

python - scrapy、splash、lua、按钮点击

python - 使用最小/最有效路径(Scipy distance.cdist)3D 浏览坐标?

java - 从 Spring 渲染 View 时出现 StackOverFlowError

java - 计算相似性度量

java - 我将GWT依赖项添加到vaadin项目时无法启动tomcat

python - tensorflow MNIST : terminate called after throwing an instance of 'std::bad_alloc'

python - 如何在 Python3 中将数据框转换为字典

python - 使用 RDKit 计算 sdf 文件和结构 SMILE 之间的 Tanimoto 相似度?

python - 确定 scipy.optimize 的合理初始猜测的函数?