java - 如何在 Java 中实现多线程 MergeSort

标签 java multithreading sorting mergesort fork-join

<分区>

我发现的大多数合并排序示例都在单个线程中运行。这首先破坏了使用合并排序算法的一些优势。有人可以展示使用多线程在 Java 中编写合并排序算法的正确方法。

该解决方案应使用最新版本的 java 的功能(如适用)。 Stackoverflow 上已有的许多解决方案都使用纯线程。我正在寻找一个演示 ForkJoin 与 RecursiveTask 的解决方案,这似乎是 RecursiveTask 类的主要用例。

重点应放在展示具有卓越性能特征的算法上,包括可能的时间和空间复杂度。

注意:所提出的重复问题都不适用,因为它们都没有提供使用递归任务的解决方案,而这正是该问题所要求的。

最佳答案

合并排序最方便的多线程范例是 fork-join 范例。这是从 Java 8 及更高版本提供的。以下代码演示了使用 fork 连接的合并排序。

import java.util.*;
import java.util.concurrent.*;

public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
    private List<N> elements;

    public MergeSort(List<N> elements) {
        this.elements = new ArrayList<>(elements);
    }

    @Override
    protected List<N> compute() {
        if(this.elements.size() <= 1)
            return this.elements;
        else {
            final int pivot = this.elements.size() / 2;
            MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
            MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));

            leftTask.fork();
            rightTask.fork();

            List<N> left = leftTask.join();
            List<N> right = rightTask.join();

            return merge(left, right);
        }
    }

    private List<N> merge(List<N> left, List<N> right) {
        List<N> sorted = new ArrayList<>();
        while(!left.isEmpty() || !right.isEmpty()) {
            if(left.isEmpty())
                sorted.add(right.remove(0));
            else if(right.isEmpty())
                sorted.add(left.remove(0));
            else {
                if( left.get(0).compareTo(right.get(0)) < 0 )
                    sorted.add(left.remove(0));
                else
                    sorted.add(right.remove(0));
            }
        }

        return sorted;
    }

    public static void main(String[] args) {
        ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
        List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(Arrays.asList(7,2,9,10,1)));
        System.out.println("result: " + result);
    }
}

虽然不那么直接,但以下代码变体消除了对 ArrayList 的过度复制。初始未排序列表仅创建一次,对子列表的调用本身不需要执行任何复制。在每次算法 fork 时我们都会复制数组列表。此外,现在,当合并列表而不是创建新列表并在每次重用左侧列表并将值插入其中时复制其中的值。通过避免额外的复制步骤,我们提高了性能。我们在这里使用 LinkedList 是因为与 ArrayList 相比,插入的成本相当低。我们还消除了对 remove 的调用,这在 ArrayList 上也是昂贵的。

import java.util.*;
import java.util.concurrent.*;

public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
    private List<N> elements;

    public MergeSort(List<N> elements) {
        this.elements = elements;
    }

    @Override
    protected List<N> compute() {
        if(this.elements.size() <= 1)
            return new LinkedList<>(this.elements);
        else {
            final int pivot = this.elements.size() / 2;
            MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
            MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));

            leftTask.fork();
            rightTask.fork();

            List<N> left = leftTask.join();
            List<N> right = rightTask.join();

            return merge(left, right);
        }
    }

    private List<N> merge(List<N> left, List<N> right) {
        int leftIndex = 0;
        int rightIndex = 0;
        while(leftIndex < left.size() || rightIndex < right.size()) {
            if(leftIndex >= left.size())
                left.add(leftIndex++, right.get(rightIndex++));
            else if(rightIndex >= right.size())
                return left;
            else {
                if( left.get(leftIndex).compareTo(right.get(rightIndex)) < 0 )
                    leftIndex++;
                else
                    left.add(leftIndex++, right.get(rightIndex++));
            }
        }

        return left;
    }

    public static void main(String[] args) {
        ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
        List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(Arrays.asList(7,2,9,-7,777777,10,1)));
        System.out.println("result: " + result);
    }
}

我们还可以通过使用迭代器而不是在执行合并时直接调用 get 来进一步改进代码。这样做的原因是通过索引获取 LinkedList 的时间性能较差(线性),因此通过使用迭代器,我们消除了在每次获取时内部迭代链表所导致的减速。迭代器上对 next 的调用是常数时间,而不是调用 get 的线性时间。以下代码被修改为使用迭代器。

import java.util.*;
import java.util.concurrent.*;

public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
    private List<N> elements;

    public MergeSort(List<N> elements) {
        this.elements = elements;
    }

    @Override
    protected List<N> compute() {
        if(this.elements.size() <= 1)
            return new LinkedList<>(this.elements);
        else {
            final int pivot = this.elements.size() / 2;
            MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
            MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));

            leftTask.fork();
            rightTask.fork();

            List<N> left = leftTask.join();
            List<N> right = rightTask.join();

            return merge(left, right);
        }
    }

    private List<N> merge(List<N> left, List<N> right) {
        ListIterator<N> leftIter = left.listIterator();
        ListIterator<N> rightIter = right.listIterator();
        while(leftIter.hasNext() || rightIter.hasNext()) {
            if(!leftIter.hasNext()) {
                leftIter.add(rightIter.next());
                rightIter.remove();
            }
            else if(!rightIter.hasNext())
                return left;
            else {
                N rightElement = rightIter.next();
                if( leftIter.next().compareTo(rightElement) < 0 )
                    rightIter.previous();
                else {
                    leftIter.previous();
                    leftIter.add(rightElement);
                }
            }
        }

        return left;
    }

    public static void main(String[] args) {
        ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
        List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(Arrays.asList(7,2,9,-7,777777,10,1)));
        System.out.println("result: " + result);
    }
}

最后是最复杂的代码版本,这次迭代使用了完全就地操作。仅创建初始 ArrayList,并且不会创建其他集合。因此逻辑特别难以遵循(所以我把它留到最后)。但应该尽可能接近理想的实现。

import java.util.*;
import java.util.concurrent.*;

public class MergeSort<N extends Comparable<N>> extends RecursiveTask<List<N>> {
    private List<N> elements;

    public MergeSort(List<N> elements) {
        this.elements = elements;
    }

    @Override
    protected List<N> compute() {
        if(this.elements.size() <= 1)
            return this.elements;
        else {
            final int pivot = this.elements.size() / 2;
            MergeSort<N> leftTask = new MergeSort<N>(this.elements.subList(0, pivot));
            MergeSort<N> rightTask = new MergeSort<N>(this.elements.subList(pivot, this.elements.size()));

            leftTask.fork();
            rightTask.fork();

            List<N> left = leftTask.join();
            List<N> right = rightTask.join();

            merge(left, right);
            return this.elements;
        }
    }

    private void merge(List<N> left, List<N> right) {
        int leftIndex = 0;
        int rightIndex = 0;
        while(leftIndex < left.size() ) {
            if(rightIndex == 0) {
                if( left.get(leftIndex).compareTo(right.get(rightIndex)) > 0 ) {
                    swap(left, leftIndex++, right, rightIndex++);
                } else {
                    leftIndex++;
                }
            } else {
                if(rightIndex >= right.size()) {
                    if(right.get(0).compareTo(left.get(left.size() - 1)) < 0 )
                        merge(left, right);
                    else
                        return;
                }
                else if( right.get(0).compareTo(right.get(rightIndex)) < 0 ) {
                    swap(left, leftIndex++, right, 0);
                } else {
                    swap(left, leftIndex++, right, rightIndex++);
                }
            }
        }

        if(rightIndex < right.size() && rightIndex != 0)
            merge(right.subList(0, rightIndex), right.subList(rightIndex, right.size()));
    }

    private void swap(List<N> left, int leftIndex, List<N> right, int rightIndex) {
        //N leftElement = left.get(leftIndex);
        left.set(leftIndex, right.set(rightIndex, left.get(leftIndex)));
    }

    public static void main(String[] args) {
        ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
        List<Integer> result = forkJoinPool.invoke(new MergeSort<Integer>(new ArrayList<>(Arrays.asList(5,9,8,7,6,1,2,3,4))));
        System.out.println("result: " + result);
    }
}

关于java - 如何在 Java 中实现多线程 MergeSort,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50036207/

相关文章:

java - Java Web 应用程序和 Web 服务之间的线程间通信

java - 用户输入的数字不重复

java - 在 hibernate 中使用合并后刷新实体实例?

java - 尝试执行 java.class 时出现 java.lang.NoClassDefFoundError

java - 从容器获取线程?

python - 使用用户定义的规则对项目进行排序

java - libGDX 中的枚举出现 ExceptionInInitializerError

java - 例如,调用 notifyAll 的顺序如何影响 Java 中的执行?

java - 正则表达式 * 或其他方法 * 对所有文本的整个 URL 进行排序

c# - Datagrid列排序生成错误