java - for each loops as streams in Java8 - k-means

标签 java java-8 multiprocessing java-stream

我已经实现了 k-means 算法,我想通过使用 Java 8 流和多核处理来加快我的过程。

我在 Java 7 中得到了这段代码:

//Step 2: For each point p:
//find nearest clusters c
//assign the point p to the closest cluster c
for (Point p : points) {
   double minDst = Double.MAX_VALUE;
   int minClusterNr = 1;
   for (Cluster c : clusters) {
      double tmpDst = determineDistance(p, c);
      if (tmpDst < minDst) {
         minDst = tmpDst;
         minClusterNr = c.clusterNumber;
      }
   }
   clusters.get(minClusterNr - 1).points.add(p);
}
//Step 3: For each cluster c
//find the central point of all points p in c
//set c to the center point
ArrayList<Cluster> newClusters = new ArrayList<Cluster>();
for (Cluster c : clusters) {
   double newX = 0;
   double newY = 0;
   for (Point p : c.points) {
      newX += p.x;
      newY += p.y;
   }
   newX = newX / c.points.size();
   newY = newY / c.points.size();
   newClusters.add(new Cluster(newX, newY, c.clusterNumber));
}

我想使用带有并行流的 Java 8 来加速这个过程。 我尝试了一下并提出了这个解决方案:

points.stream().forEach(p -> {
   minDst = Double.MAX_VALUE; //<- THESE ARE GLOBAL VARIABLES NOW
   minClusterNr = 1;          //<- THESE ARE GLOBAL VARIABLES NOW
   clusters.stream().forEach(c -> {
      double tmpDst = determineDistance(p, c);
      if (tmpDst < minDst) {
         minDst = tmpDst;
         minClusterNr = c.clusterNumber;
      }
   });
   clusters.get(minClusterNr - 1).points.add(p);
});
ArrayList<Cluster> newClusters = new ArrayList<Cluster>();
clusters.stream().forEach(c -> {
   newX = 0;  //<- THESE ARE GLOBAL VARIABLES NOW
   newY = 0;  //<- THESE ARE GLOBAL VARIABLES NOW
   c.points.stream().forEach(p -> {
      newX += p.x;
      newY += p.y;
   });
   newX = newX / c.points.size();
   newY = newY / c.points.size();
   newClusters.add(new Cluster(newX, newY, c.clusterNumber));
});

这种带有流的解决方案比没有流的解决方案要快得多。我想知道这是否已经使用了多核处理?否则为什么它会突然快两倍?

without streams : Elapsed time: 202 msec & with streams : Elapsed time: 116 msec

另外,在任何这些方法中使用 parallelStream 来加速它们是否有用?当我将流更改为 stream().parallel().forEach(CODE) 时,它现在所做的就是导致 ArrayOutOfBounce 和 NullPointer 异常

---- 编辑(按要求添加源代码,以便您自己尝试)----

--- 聚类.java ---

package algo;

import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.util.ArrayList;
import java.util.Random;
import java.util.function.BiFunction;

import graphics.SimpleColorFun;

/**
 * An implementation of the k-means-algorithm.
 * <p>
 * Step 0: Determine the max size of the canvas
 * <p>
 * Step 1: Place clusters at random
 * <p>
 * Step 2: For each point p:<br>
 * find nearest clusters c<br>
 * assign the point p to the closest cluster c
 * <p>
 * Step 3: For each cluster c<br>
 * find the central point of all points p in c<br>
 * set c to the center point
 * <p>
 * Stop when none of the cluster x,y values change
 * @author makt
 *
 */
public class Clustering {

   private BiFunction<Integer, Integer, Color> colorFun = new SimpleColorFun();
   //   private BiFunction<Integer, Integer, Color> colorFun = new GrayScaleColorFun();

   public Random rngGenerator = new Random();

   public double max_x;
   public double max_y;
   public double max_xy;

   //---------------------------------
   //TODO: IS IT GOOD TO HAVE THOUSE VALUES UP HERE?
   double minDst = Double.MAX_VALUE;
   int minClusterNr = 1;

   double newX = 0;
   double newY = 0;
   //----------------------------------

   public boolean workWithStreams = false;

   public ArrayList<ArrayList<Cluster>> allGeneratedClusterLists = new ArrayList<ArrayList<Cluster>>();
   public ArrayList<BufferedImage> allGeneratedImages = new ArrayList<BufferedImage>();

   public Clustering(int seed) {
      rngGenerator.setSeed(seed);
   }

   public Clustering(Random rng) {
      rngGenerator = rng;
   }

   public void setup(int centroidCount, ArrayList<Point> points, int maxIterations) {

      //Step 0: Determine the max size of the canvas
      determineSize(points);

      ArrayList<Cluster> clusters = new ArrayList<Cluster>();
      //Step 1: Place clusters at random
      for (int i = 0; i < centroidCount; i++) {
         clusters.add(new Cluster(rngGenerator.nextInt((int) max_x), rngGenerator.nextInt((int) max_y), i + 1));
      }

      int iterations = 0;

      if (workWithStreams) {
         allGeneratedClusterLists.add(doClusteringWithStreams(points, clusters));
      } else {
         allGeneratedClusterLists.add(doClustering(points, clusters));
      }

      iterations += 1;

      //do until maxIterations is reached or until none of the cluster x and y values change anymore
      while (iterations < maxIterations) {
         //Step 2: happens inside doClustering
         if (workWithStreams) {
            allGeneratedClusterLists.add(doClusteringWithStreams(points, allGeneratedClusterLists.get(iterations - 1)));
         } else {
            allGeneratedClusterLists.add(doClustering(points, allGeneratedClusterLists.get(iterations - 1)));
         }

         if (!didPointsChangeClusters(allGeneratedClusterLists.get(iterations - 1), allGeneratedClusterLists.get(iterations))) {
            break;
         }

         iterations += 1;
      }

      System.out.println("Finished with " + iterations + " out of " + maxIterations + " max iterations");
   }

   /**
    * checks if the cluster x and y values changed compared to the previous x and y values
    * @param previousCluster
    * @param currentCluster
    * @return true if any cluster x or y values changed, false if all of them they are the same
    */
   private boolean didPointsChangeClusters(ArrayList<Cluster> previousCluster, ArrayList<Cluster> currentCluster) {
      for (int i = 0; i < previousCluster.size(); i++) {
         if (previousCluster.get(i).x != currentCluster.get(i).x || previousCluster.get(i).y != currentCluster.get(i).y) {
            return true;
         }
      }
      return false;
   }

   /**
    * 
    * @param points - all given points
    * @param clusters - its point list gets filled in this method
    * @return a new Clusters Array which has an <b> empty </b> point list.
    */
   private ArrayList<Cluster> doClustering(ArrayList<Point> points, ArrayList<Cluster> clusters) {
      //Step 2: For each point p:
      //find nearest clusters c
      //assign the point p to the closest cluster c

      for (Point p : points) {
         double minDst = Double.MAX_VALUE;
         int minClusterNr = 1;
         for (Cluster c : clusters) {
            double tmpDst = determineDistance(p, c);
            if (tmpDst < minDst) {
               minDst = tmpDst;
               minClusterNr = c.clusterNumber;
            }
         }
         clusters.get(minClusterNr - 1).points.add(p);
      }

      //Step 3: For each cluster c
      //find the central point of all points p in c
      //set c to the center point
      ArrayList<Cluster> newClusters = new ArrayList<Cluster>();
      for (Cluster c : clusters) {
         double newX = 0;
         double newY = 0;
         for (Point p : c.points) {
            newX += p.x;
            newY += p.y;
         }
         newX = newX / c.points.size();
         newY = newY / c.points.size();
         newClusters.add(new Cluster(newX, newY, c.clusterNumber));
      }

      allGeneratedImages.add(createImage(clusters));

      return newClusters;
   }

   /**
    * Does the same as doClustering but about twice as fast!<br>
    * Uses Java8 streams to achieve this
    * @param points
    * @param clusters
    * @return
    */
   private ArrayList<Cluster> doClusteringWithStreams(ArrayList<Point> points, ArrayList<Cluster> clusters) {
      points.stream().forEach(p -> {
         minDst = Double.MAX_VALUE;
         minClusterNr = 1;
         clusters.stream().forEach(c -> {
            double tmpDst = determineDistance(p, c);
            if (tmpDst < minDst) {
               minDst = tmpDst;
               minClusterNr = c.clusterNumber;
            }
         });
         clusters.get(minClusterNr - 1).points.add(p);
      });

      ArrayList<Cluster> newClusters = new ArrayList<Cluster>();

      clusters.stream().forEach(c -> {
         newX = 0;
         newY = 0;
         c.points.stream().forEach(p -> {
            newX += p.x;
            newY += p.y;
         });
         newX = newX / c.points.size();
         newY = newY / c.points.size();
         newClusters.add(new Cluster(newX, newY, c.clusterNumber));
      });

      allGeneratedImages.add(createImage(clusters));

      return newClusters;
   }

   //draw all centers from clusters
   //draw all points
   //color points according to cluster value
   private BufferedImage createImage(ArrayList<Cluster> clusters) {
      //add 10% of the max size left and right to the image bounds
      //BufferedImage bi = new BufferedImage((int) (max_xy * 1.05), (int) (max_xy * 1.05), BufferedImage.TYPE_BYTE_INDEXED);
      BufferedImage bi = new BufferedImage((int) (max_xy * 1.05), (int) (max_xy * 1.05), BufferedImage.TYPE_INT_ARGB); // support 32-bit RGBA values
      Graphics2D g2d = bi.createGraphics();

      int numClusters = clusters.size();
      for (Cluster c : clusters) {
         //color points according to cluster value
         Color col = colorFun.apply(c.clusterNumber, numClusters);
         //draw all points
         g2d.setColor(col);
         for (Point p : c.points) {
            g2d.fillRect((int) p.x, (int) p.y, (int) (max_xy * 0.02), (int) (max_xy * 0.02));
         }
         //draw all centers from clusters
         g2d.setColor(new Color(160, 80, 80, 200)); // use RGBA: transparency=200
         g2d.fillOval((int) c.x, (int) c.y, (int) (max_xy * 0.03), (int) (max_xy * 0.03));
      }

      return bi;
   }

   /**
    * Calculates the euclidean distance without square root
    * @param p
    * @param c
    * @return
    */
   private double determineDistance(Point p, Cluster c) {
      //math.sqrt not needed because the relative distance does not change by applying the square root
      //        return Math.sqrt(Math.pow((p.x - c.x), 2)+Math.pow((p.y - c.y),2));

      return Math.pow((p.x - c.x), 2) + Math.pow((p.y - c.y), 2);
   }

   //TODO: What if coordinates can also be negative?
   private void determineSize(ArrayList<Point> points) {
      for (Point p : points) {
         if (p.x > max_x) {
            max_x = p.x;
         }
         if (p.y > max_y) {
            max_y = p.y;
         }
      }
      if (max_x > max_y) {
         max_xy = max_x;
      } else {
         max_xy = max_y;
      }
   }

}

--- 点.java ---

package algo;

public class Point {

    public double x;
    public double y;

    public Point(int x, int y) {
        this.x = x;
        this.y = y;
    }

    public Point(double x, double y) {
        this.x = x;
        this.y = y;
    }


}

--- 集群.java ---

package algo;

import java.util.ArrayList;

public class Cluster {

    public double x;
    public double y;

    public int clusterNumber;

    public ArrayList<Point> points = new ArrayList<Point>();

    public Cluster(double x, double y, int clusterNumber) {
        this.x = x;
        this.y = y;
        this.clusterNumber = clusterNumber;
    }

}

--- SimpleColorFun.java ---

package graphics;

import java.awt.Color;
import java.util.function.BiFunction;

/**
 * Simple function for selection a color for a specific cluster identified with an integer-ID.
 * 
 * @author makl, hese
 */
public class SimpleColorFun implements BiFunction<Integer, Integer, Color> {

   /**
    * Selects a color value.
    * @param n current index
    * @param numCol number of color-values possible
    */
   @Override
   public Color apply(Integer n, Integer numCol) {
      Color col = Color.BLACK;
      //color points according to cluster value
      switch (n) {
         case 1:
            col = Color.RED;
            break;
         case 2:
            col = Color.GREEN;
            break;
         case 3:
            col = Color.BLUE;
            break;
         case 4:
            col = Color.ORANGE;
            break;
         case 5:
            col = Color.MAGENTA;
            break;
         case 6:
            col = Color.YELLOW;
            break;
         case 7:
            col = Color.CYAN;
            break;
         case 8:
            col = Color.PINK;
            break;
         case 9:
            col = Color.LIGHT_GRAY;
            break;
         default:
            break;
      }
      return col;
   }

}

--- Main.java ---(用一些时间记录机制替换秒表 - 我从我们的工作环境中得到这个)

package main;

import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Random;
import java.util.concurrent.TimeUnit;

import javax.imageio.ImageIO;

import algo.Clustering;
import algo.Point;
import eu.lbase.common.util.Stopwatch;
// import persistence.DataHandler;

public class Main {

   private static final String OUTPUT_DIR = (new File("./output/withoutStream")).getAbsolutePath() + File.separator;
   private static final String OUTPUT_DIR_2 = (new File("./output/withStream")).getAbsolutePath() + File.separator;

   public static void main(String[] args) {
      Random rng = new Random();
      int numPoints = 300;
      int seed = 2;

      ArrayList<Point> points = new ArrayList<Point>();
      rng.setSeed(rng.nextInt());
      for (int i = 0; i < numPoints; i++) {
         points.add(new Point(rng.nextInt(1000), rng.nextInt(1000)));
      }

      Stopwatch stw = Stopwatch.create(TimeUnit.MILLISECONDS);
      {
         // Stopwatch start
         System.out.println("--- Started without streams ---");
         stw.start();

         Clustering algo = new Clustering(seed);
         algo.setup(8, points, 25);

         // Stopwatch stop
         stw.stop();
         System.out.println("--- Finished without streams ---");
         System.out.printf("Elapsed time: %d msec%n%n", stw.getElapsed());

         System.out.printf("Writing images to '%s' ...%n", OUTPUT_DIR);

         deleteOldFiles(new File(OUTPUT_DIR));
         makeImages(OUTPUT_DIR, algo);

         System.out.println("Finished writing.\n");
      }

      {
         System.out.println("--- Started with streams ---");
         stw.start();

         Clustering algo = new Clustering(seed);
         algo.workWithStreams = true;
         algo.setup(8, points, 25);

         // Stopwatch stop
         stw.stop();
         System.out.println("--- Finished with streams ---");
         System.out.printf("Elapsed time: %d msec%n%n", stw.getElapsed());

         System.out.printf("Writing images to '%s' ...%n", OUTPUT_DIR_2);

         deleteOldFiles(new File(OUTPUT_DIR_2));
         makeImages(OUTPUT_DIR_2, algo);

         System.out.println("Finished writing.\n");
      }
   }

   /**
    * creates one image for each iteration in the given directory
    * @param algo
    */
   private static void makeImages(String dir, Clustering algo) {
      int i = 1;
      for (BufferedImage img : algo.allGeneratedImages) {
         try {
            String filename = String.format("%03d.png", i);
            ImageIO.write(img, "png", new File(dir + filename));
         } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
         }
         i++;
      }
   }

   /**
    * deletes old files from the target directory<br>
    * Does <b>not</b> delete directories!
    * @param dir - directory to delete files from
    * @return
    */
   private static boolean deleteOldFiles(File file) {
      File[] allContents = file.listFiles();
      if (allContents != null) {
         for (File f : allContents) {
            deleteOldFiles(f);
         }
      }
      if (!file.isDirectory()) {
         return file.delete();
      }
      return false;
   }

}

最佳答案

当你想高效地使用 Streams 时,你应该停止使用 forEach 来基本上编写与循环相同的代码,而是了解 aggregate operations .另见综合package documentation .

线程安全的解决方案可能看起来像

points.stream().forEach(p -> {
    Cluster min = clusters.stream()
        .min(Comparator.comparingDouble(c -> determineDistance(p, c))).get();
    // your original code used the custerNumber to lookup the Cluster in
    // the list, don't know whether this is this really necessary
    min = clusters.get(min.clusterNumber - 1);

    // didn't find a better way considering your current code structure
    synchronized(min) {
        min.points.add(p);
    }
 });
 List<Cluster> newClusters = clusters.stream()
    .map(c -> new Cluster(
       c.points.stream().mapToDouble(p -> p.x).sum()/c.points.size(),
       c.points.stream().mapToDouble(p -> p.y).sum()/c.points.size(),
       c.clusterNumber))
    .collect(Collectors.toList());
}

但是您没有提供足够的上下文来测试它。有一些悬而未决的问题,例如您使用 Cluster 实例的 clusterNumber 来查看 clusters 列表;我不知道 clusterNumber 是否代表我们已经拥有的 Cluster 实例的实际列表索引,也就是说,如果这是不必要的冗余,或者具有不同的含义。

我也不知道比同步特定 Cluster 以使其列表线程安全的操作更好的解决方案(给定您当前的代码结构)。仅当您决定使用并行流时才需要这样做,即 points.parallelStream().forEach(p -> …),其他操作不受影响。

您现在有几个流可以并行和顺序尝试,以找出您从哪里获益。通常,只有其他流才能带来显着的好处,如果有的话……

关于java - for each loops as streams in Java8 - k-means,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48704263/

相关文章:

python - 如何将python并行计算(ipython并行或多处理)的结果输出到pandas数据帧?

java - 韩文未出现在Apache Netbeans中

java - 为什么我的 android 程序在 4.4.3 版本上运行良好,但在 10.0 版本上不起作用

java - 有什么方法可以降低以下java代码的复杂性吗?

lambda - 收集后获取java8上的索引

python - 转换为共享字符串数组的 Numpy 字符串矩阵会导致类型不匹配

python - 用于多处理的共享内存中的大型 numpy 数组 : Is something wrong with this approach?

java - 在 Java 中启动线程而不等待它们(使用 .join())是危险的吗?

java - 使用 JDBCCategoryDataset 折线图中的多个系列问题

java - 为什么 javac 需要引用类的接口(interface)而 ECJ 不需要?