Java 将两个 double[][] 与并行流相加

标签 java arrays java-8 java-stream

假设我有这两个矩阵:

double[][] a = new double[2][2]
a[0][0] = 1
a[0][1] = 2
a[1][0] = 3
a[1][1] = 4

double[][] b = new double[2][2]
b[0][0] = 1
b[0][1] = 2
b[1][0] = 3
b[1][1] = 4

以传统方式,为了对这些矩阵求和,我会做一个嵌套的 for 循环:

int rows = a.length;
int cols = a[0].length;
double[][] res = new double[rows][cols];
for(int i = 0; i < rows; i++){
    for(int j = 0; j < cols; j++){
        res[i][j] = a[i][j] + b[i][j];
    }
}

我对流 API 相当陌生,但我认为这非常适合与 parallelStream 一起使用,所以我的问题是是否有办法做到这一点并利用并行处理?

编辑:不确定这是否是正确的地方,但我们开始吧: 使用一些建议,我对 Stream 进行了测试。设置是这样的: 经典方法:

public class ClassicMatrix {

    private final double[][] components;
    private final int cols;
    private final int rows;




    public ClassicMatrix(final double[][] components){
    this.components = components;
    this.rows = components.length;
    this.cols = components[0].length;
    }


    public ClassicMatrix addComponents(final ClassicMatrix a) {
    final double[][] res = new double[rows][cols];
    for (int i = 0; i < rows; i++) {
        for (int j = 0; j < rows; j++) {
        res[i][j] = components[i][j] + a.components[i][j];
        }
    }
    return new ClassicMatrix(res);
    }

}

使用@dkatzel 建议:

public class MatrixStream1 {

    private final double[][] components;
    private final int cols;
    private final int rows;

    public MatrixStream1(final double[][] components){
    this.components = components;
    this.rows = components.length;
    this.cols = components[0].length;
    }

    public MatrixStream1 addComponents(final MatrixStream1 a) {
    final double[][] res = new double[rows][cols];
    IntStream.range(0, rows*cols).parallel().forEach(i -> {
               int x = i/rows;
               int y = i%rows;

               res[x][y] = components[x][y] + a.components[x][y];
           });
    return new MatrixStream1(res);
    }
}

使用@Eugene 建议:

public class MatrixStream2 {

    private final double[][] components;
    private final int cols;
    private final int rows;

    public MatrixStream2(final double[][] components) {
    this.components = components;
    this.rows = components.length;
    this.cols = components[0].length;
    }

    public MatrixStream2 addComponents(final MatrixStream2 a) {
    final double[][] res = new double[rows][cols];
    IntStream.range(0, rows)
        .forEach(i -> Arrays.parallelSetAll(res[i], j -> components[i][j] * a.components[i][j]));
    return new MatrixStream2(res);
    }
}

和一个测试类,为每个方法运行 3 次独立的时间(只需替换 main() 中的方法名称):

public class MatrixTest {

    private final static String path = "/media/manuel/workspace/data/";

    public static void main(String[] args) {
    final List<Double[]> lst = new ArrayList<>();
    for (int i = 100; i < 8000; i = i + 400) {
        final Double[] d = testClassic(i); 
        System.out.println(d[0] + " : " + d[1]);
        lst.add(d);
    }
    IOUtils.saveToFile(path + "classic.csv", lst);
    }

    public static Double[] testClassic(final int i) {

    final ClassicMatrix a = new ClassicMatrix(rand(i));
    final ClassicMatrix b = new ClassicMatrix(rand(i));

    final long start = System.currentTimeMillis();
    final ClassicMatrix mul = a.addComponents(b);
    final long now = System.currentTimeMillis();
    final double elapsed = (now - start);

    return new Double[] { (double) i, elapsed };

    }

    public static Double[] testStream1(final int i) {

    final MatrixStream1 a = new MatrixStream1(rand(i));
    final MatrixStream1 b = new MatrixStream1(rand(i));

    final long start = System.currentTimeMillis();
    final MatrixStream1 mul = a.addComponents(b);
    final long now = System.currentTimeMillis();
    final double elapsed = (now - start);

    return new Double[] { (double) i, elapsed };

    }

    public static Double[] testStream2(final int i) {

    final MatrixStream2 a = new MatrixStream2(rand(i));
    final MatrixStream2 b = new MatrixStream2(rand(i));

    final long start = System.currentTimeMillis();
    final MatrixStream2 mul = a.addComponents(b);
    final long now = System.currentTimeMillis();
    final double elapsed = (now - start);

    return new Double[] { (double) i, elapsed };

    }

    private static double[][] rand(final int size) {
    final double[][] rnd = new double[size][size];
    for (int i = 0; i < size; i++) {
        for (int j = 0; j < size; j++) {
        rnd[i][j] = Math.random();
        }
    }
    return rnd;
    }
}

结果:

Classic Matrix size, Time (ms)
100.0,1.0
500.0,5.0
900.0,5.0
1300.0,43.0
1700.0,94.0
2100.0,26.0
2500.0,33.0
2900.0,46.0
3300.0,265.0
3700.0,71.0
4100.0,87.0
4500.0,380.0
4900.0,432.0
5300.0,215.0
5700.0,238.0
6100.0,577.0
6500.0,677.0
6900.0,609.0
7300.0,584.0
7700.0,592.0

Stream1, Time(ms)
100.0,86.0
500.0,13.0
900.0,9.0
1300.0,47.0
1700.0,92.0
2100.0,29.0
2500.0,33.0
2900.0,46.0
3300.0,253.0
3700.0,71.0
4100.0,90.0
4500.0,352.0
4900.0,373.0
5300.0,497.0
5700.0,485.0
6100.0,579.0
6500.0,711.0
6900.0,800.0
7300.0,780.0
7700.0,902.0

Stream2, Time(ms)
100.0,111.0
500.0,42.0
900.0,12.0
1300.0,54.0
1700.0,97.0
2100.0,110.0
2500.0,177.0
2900.0,71.0
3300.0,250.0
3700.0,106.0
4100.0,359.0
4500.0,143.0
4900.0,233.0
5300.0,261.0
5700.0,289.0
6100.0,406.0
6500.0,814.0
6900.0,830.0
7300.0,828.0
7700.0,911.0

为了更好的比较,我做了一个图: Performance Test

根本没有改善。缺陷在哪里?矩阵是否很小(7700 x 7700)?比这更严重的是它会炸毁我的计算机内存。

最佳答案

一种方法是使用 Arrays.parallelSetAll :

int rows = a.length;
int cols = a[0].length;
double[][] res = new double[rows][cols];

Arrays.parallelSetAll(res, i -> {
    Arrays.parallelSetAll(res[i], j -> a[i][j] + b[i][j]);
    return res[i];
});

我不是 100% 确定,但我认为对 Arrays.parallelSetAll 的内部调用可能不值得为每一行的列生成内部并行化的开销。也许仅将每一行的总和并行化就足够了:

Arrays.parallelSetAll(res, i -> {
    Arrays.setAll(res[i], j -> a[i][j] + b[i][j]);
    return res[i];
});

无论如何,在将并行化添加到算法之前,您应该仔细测量,因为很多时候开销太大以至于不值得使用它。

关于Java 将两个 double[][] 与并行流相加,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45304215/

相关文章:

java - 我如何从理论上计算某些方法的执行时间?

c - sizeof(数组)/sizeof(int)

java - 用 Iterable.forEach 替换 foreach 循环纯粹是为了美观吗?

datetime - ZonedDateTime 中的 XMLGregorianCalendar 格式日期

java - WildFly jdbc 与 Oracle 的连接

java - 自动打字机错误

php - php中删除数组中的一个元素

arrays - 如何在 native react 中将json数据与数组映射

java - 为什么 System.getProperty 在 eclipse + Java 8 中不起作用?

java - 使用 protostuff 反序列化数组