我是java初学者。
最近,我正在编写一个计算矩阵乘法的程序。所以我写了一个类来做到这一点。
public class MultiThreadsMatrixMultipy{
public int[][] multipy(int[][] matrix1,int[][] matrix2) {
if(!utils.CheckDimension(matrix1,matrix2)){
return null;
}
int row1 = matrix1.length;
int col1 = matrix1[0].length;
int row2 = matrix2.length;
int col2 = matrix2[0].length;
int[][] ans = new int[row1][col2];
Thread[][] threads = new SingleRowMultipy[row1][col2];
for(int i=0;i<row1;i++){
for(int j=0;j<col2;j++){
threads[i][j] = new SingleRowMultipy(i,j,matrix1,matrix2,ans));
threads[i][j].start();
}
}
return ans;
}
}
public class SingleRowMultipy extends Thread{
private int row;
private int col;
private int[][] A;
private int[][] B;
private int[][] ans;
public SingleRowMultipy(int row,int col,int[][] A,int[][] B,int[][] C){
this.row = row;
this.col = col;
this.A = A;
this.B = B;
this.ans = C;
}
public void run(){
int sum =0;
for(int i=0;i<A[row].length;i++){
sum+=(A[row][i]*B[i][col]);
}
ans[row][col] = sum;
}
}
我想用一个线程来计算matrix1[i][:] * matrix2[:][j]
,矩阵的大小为1000*5000
和5000*1000
,所以线程数为1000*1000。当我运行这个程序时,速度非常慢,大约需要38s
。如果我只是用单线程
来计算结果,需要17s
。单线程代码如下:
public class SimpleMatrixMultipy
{
public int[][] multipy(int[][] matrix1,int[][] matrix2){
int row1 = matrix1.length;
int col1 = matrix1[0].length;
int row2 = matrix2.length;
int col2 = matrix2[0].length;
int[][] ans = new int[row1][col2];
for(int i=0;i<row1;i++){
for(int j=0;j<col2;j++){
for(int k=0;k<col1;k++){
ans[i][j] += matrix1[i][k]*matrix2[k][j];
}
}
}
return ans;
}
}
我可以做什么来加速程序?
最佳答案
正如@Turing85所说,需要管理线程计数。有两种方法,要么使用 Executors.newFixedThreadPool
固定数量的线程,要么使用 Executors.newCachedThreadPool
使用现有的可用线程。
其他重要的一点是避免直接继承Thread
类,而是实现runnable。
import java.util.ArrayList;
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
public class MultiThreadsMatrixMultipy {
public static void main(final String[] args) {
}
public int[][] multipy(final int[][] matrix1, final int[][] matrix2) {
if(!utils.CheckDimension(matrix1,matrix2)){
return null;
}
final int row1 = matrix1.length;
final int col2 = matrix2[0].length;
final int[][] ans = new int[row1][col2];
// final Executor executor = Executors.newCachedThreadPool(new CustomThreadFactory("Multiplier"));
final Executor executor = Executors.newFixedThreadPool(20, new CustomThreadFactory("Multiplier"));
for (int i = 0; i < row1; i++) {
for (int j = 0; j < col2; j++) {
executor.execute(new SingleRowMultipy(i, j, matrix1, matrix2, ans));
}
}
return ans;
}
}
class CustomThreadFactory implements ThreadFactory {
private int counter;
private final String name;
private final List<String> stats;
public CustomThreadFactory(final String name) {
counter = 1;
this.name = name;
stats = new ArrayList<>();
}
@Override
public Thread newThread(final Runnable runnable) {
final Thread t = new Thread(runnable, name + "-Thread_" + counter);
counter++;
stats.add(String.format("Created thread %d with name %s on %s \n", t.getId(), t.getName(), new Date()));
return t;
}
public String getStats() {
final StringBuffer buffer = new StringBuffer();
final Iterator<String> it = stats.iterator();
while (it.hasNext()) {
buffer.append(it.next());
}
return buffer.toString();
}
}
class SingleRowMultipy implements Runnable {
private final int row;
private final int col;
private final int[][] A;
private final int[][] B;
private final int[][] ans;
public SingleRowMultipy(final int row, final int col, final int[][] A, final int[][] B, final int[][] C) {
this.row = row;
this.col = col;
this.A = A;
this.B = B;
this.ans = C;
}
@Override
public void run() {
int sum = 0;
for (int i = 0; i < A[row].length; i++) {
sum += (A[row][i] * B[i][col]);
}
ans[row][col] = sum;
}
}
关于java - 为什么线程数增加时程序变慢,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59466020/