java - 为什么线程数增加时程序变慢

标签 java multithreading performance

我是java初学者。

最近,我正在编写一个计算矩阵乘法的程序。所以我写了一个类来做到这一点。

public class MultiThreadsMatrixMultipy{
 public   int[][] multipy(int[][] matrix1,int[][] matrix2) {
     if(!utils.CheckDimension(matrix1,matrix2)){
         return null;
     }
     int row1 = matrix1.length;
     int col1 = matrix1[0].length;
     int row2 = matrix2.length;
     int col2 = matrix2[0].length;
     int[][] ans = new int[row1][col2];
     Thread[][]  threads = new SingleRowMultipy[row1][col2];

     for(int i=0;i<row1;i++){
         for(int j=0;j<col2;j++){
             threads[i][j] = new SingleRowMultipy(i,j,matrix1,matrix2,ans));
             threads[i][j].start();
         }
     }
     return ans;
 }
}
public class SingleRowMultipy extends Thread{
        private int row;
        private int col;
        private int[][] A;
        private int[][] B;
        private int[][] ans;
        public SingleRowMultipy(int row,int col,int[][] A,int[][] B,int[][] C){
            this.row = row;
            this.col = col;
            this.A = A;
            this.B = B;
            this.ans = C;
        }
        public void run(){
            int sum =0;
            for(int i=0;i<A[row].length;i++){
                 sum+=(A[row][i]*B[i][col]);
            }
            ans[row][col] = sum;
        }
}

我想用一个线程来计算matrix1[i][:] * matrix2[:][j],矩阵的大小为1000*50005000*1000,所以线程数为1000*1000。当我运行这个程序时,速度非常慢,大约需要38s。如果我只是用单线程来计算结果,需要17s。单线程代码如下:

public class SimpleMatrixMultipy
{
    public int[][] multipy(int[][] matrix1,int[][] matrix2){
        int row1 = matrix1.length;
        int col1 = matrix1[0].length;
        int row2 = matrix2.length;
        int col2 = matrix2[0].length;
        int[][] ans = new int[row1][col2];
        for(int i=0;i<row1;i++){
            for(int j=0;j<col2;j++){
                for(int k=0;k<col1;k++){
                    ans[i][j] += matrix1[i][k]*matrix2[k][j];
                }
            }
        }
        return ans;
    }

}

我可以做什么来加速程序?

最佳答案

正如@Turing85所说,需要管理线程计数。有两种方法,要么使用 Executors.newFixedThreadPool 固定数量的线程,要么使用 Executors.newCachedThreadPool 使用现有的可用线程。

其他重要的一点是避免直接继承Thread类,而是实现runnable。

import java.util.ArrayList;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;

public class MultiThreadsMatrixMultipy {

    public static void main(final String[] args) {

    }

    public int[][] multipy(final int[][] matrix1, final int[][] matrix2) {
        if(!utils.CheckDimension(matrix1,matrix2)){
            return null;
        }
        final int row1 = matrix1.length;
        final int col2 = matrix2[0].length;
        final int[][] ans = new int[row1][col2];
        // final Executor executor = Executors.newCachedThreadPool(new CustomThreadFactory("Multiplier"));
        final Executor executor = Executors.newFixedThreadPool(20, new CustomThreadFactory("Multiplier"));

        for (int i = 0; i < row1; i++) {
            for (int j = 0; j < col2; j++) {
                executor.execute(new SingleRowMultipy(i, j, matrix1, matrix2, ans));
            }
        }
        return ans;
    }
}

class CustomThreadFactory implements ThreadFactory {
    private int counter;
    private final String name;
    private final List<String> stats;

    public CustomThreadFactory(final String name) {
        counter = 1;
        this.name = name;
        stats = new ArrayList<>();
    }

    @Override
    public Thread newThread(final Runnable runnable) {
        final Thread t = new Thread(runnable, name + "-Thread_" + counter);
        counter++;
        stats.add(String.format("Created thread %d with name %s on %s \n", t.getId(), t.getName(), new Date()));
        return t;
    }

    public String getStats() {
        final StringBuffer buffer = new StringBuffer();
        final Iterator<String> it = stats.iterator();
        while (it.hasNext()) {
            buffer.append(it.next());
        }
        return buffer.toString();
    }
}

class SingleRowMultipy implements Runnable {
    private final int row;
    private final int col;
    private final int[][] A;
    private final int[][] B;
    private final int[][] ans;

    public SingleRowMultipy(final int row, final int col, final int[][] A, final int[][] B, final int[][] C) {
        this.row = row;
        this.col = col;
        this.A = A;
        this.B = B;
        this.ans = C;
    }

    @Override
    public void run() {
        int sum = 0;
        for (int i = 0; i < A[row].length; i++) {
            sum += (A[row][i] * B[i][col]);
        }
        ans[row][col] = sum;
    }
}

关于java - 为什么线程数增加时程序变慢,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59466020/

相关文章:

java - 在类构造函数中调用线程的替代方法

c# - 打开文件和读取内容的最可靠方法是什么

c++ - R 调用 C 代码比 C++ 函数调用 C 代码更快?

java - 使用 JNA 和 Delphi DLL 在 PointerByReference 中获取值时出错

java - 我应该如何使用 jax rs/jersey 从 JSONArray 输出 json?

c# - 在 .NET 中如何最有效地利用多核进行短计算?

c++ - 如何从另一个类的函数调用std::async时传递另一个类的函数?

c++ - 更好的可读性和简单性与更高的复杂性和编程速度,该选择什么?

java - 使用此路径检索 SD 卡上的文件 : "/document/primary:..."

java - 是否可以将 chrome webdriver 文件设置为 URL?