Java 与 Python 特定的代码片段性能改进

标签 java python algorithm performance

我对 Java 和 Python 中一段特定代码的性能有疑问。

算法:
我正在生成随机 N 维点,然后对彼此之间一定距离阈值以下的所有点进行一些处理。处理本身在这里并不重要,因为它不会影响总执行时间。在这两种情况下生成点也只需要几分之一秒,所以我只对进行比较的部分感兴趣。

执行时间:
对于 3000 个点和二维的固定输入,Java 在 2 到 4 秒 内完成,而 Python 需要 15 到 200 秒

我对 Python 的执行时间有点怀疑。这段 Python 代码中有什么我遗漏的吗?是否有任何算法改进建议(例如预分配/重用内存、降低 Big-Oh 复杂性的方法等)?


Java

double random_points[][] = new double[number_of_points][dimensions];
for(i = 0; i < number_of_points; i++)
  for(d = 0; d < dimensions; d++)
    random_points[i][d] = Math.random();

double p1[], p2[];
for(i = 0; i < number_of_points; i++)
{
  p1 = random_points[i];
  for(j = i + 1; j < number_of_points; j++)
  {
    p2 = random_points[j];

    double sum_of_squares = 0;
    for(d = 0; d < DIM_; d++)
      sum_of_squares += (p2[d] - p1[d]) * (p2[d] - p1[d]);

    double distance = Math.sqrt(ss);
    if(distance > SOME_THRESHOLD) continue;

    //...else do something with p1 and p2

  }
}

python 3.2

random_points = [[random.random() for _d in range(0,dimensions)] for _n in range(0,number_of_points)]

for i, p1 in enumerate(random_points):
  for j, p2 in enumerate(random_points[i+1:]):
    distance = math.sqrt(sum([(p1[d]-p2[d])**2 for d in range(0,dimensions)]))
    if distance > SOME_THRESHOLD: continue

    #...else do something with p1 and p2

最佳答案

您可能要考虑使用 numpy .

我刚刚尝试了以下方法:

import numpy
from scipy.spatial.distance import pdist
D=2
N=3000
p=numpy.random.uniform(size=(N,D))
dist=pdist(p, 'euclidean')

最后一行计算距离矩阵(这相当于在您的代码中为每对点计算 distance)。在我的电脑上大约需要 0.07 秒。

此方法的主要缺点是距离矩阵需要 O(n^2) 内存。如果这是一个问题,以下可能是更好的选择:

for i in xrange(1, N):
  v = p[:N-i] - p[i:]
  dist = numpy.sqrt(numpy.sum(numpy.square(v), axis=1))
  for j in numpy.nonzero(dist > 1.4)[0]:
    print j, i+j

对于 N=3000,这在我的电脑上需要大约 0.33 秒。

关于Java 与 Python 特定的代码片段性能改进,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/5705378/

相关文章:

Java - 使用 Scanner 的编译错误

python - 特定行出现了多少次?

algorithm - 以 3 的有符号幂计数

java - 避免在目标/生成源的循环中构建项目

java - Eclipse 中 win32com.dll 错误

java - Hibernate 持久错误,id 字段为空

algorithm - 检查 10 亿个手机号码是否重复

python - 如何使用 selenium python 读取表数据?

python - 如何使用 selenium 和 python 从下拉列表中获取值列表

算法校正数字线上数据点之间的噪声距离测量