java - 三路归并排序问题排序不正确

标签 java mergesort

我一直在研究这种三路合并排序算法,该算法是基于我的正常合并排序代码;但是,它的排序不正确,所以我相信我的代码中可能存在一个小错误。有什么帮助吗?我已经研究了 3 个小时的代码,试图找到问题所在,但事实证明这很困难。

public class TriMergeSort {

    void merge(int arr[], int low, int mid1, int mid2, int high) { 
        int sizeA = mid1 - low + 1; 
        int sizeB =  mid2 - mid1;
        int sizeC = high - mid2;

        int A[] = new int[sizeA]; 
        int B[] = new int[sizeB]; 
        int C[] = new int[sizeC];

        for (int i = 0; i < sizeA; i++) 
            A[i] = arr[low + i]; 
        for (int j = 0; j < sizeB; j++) 
            B[j] = arr[mid1 + j + 1]; 
        for (int x = 0; x < sizeC; x++) 
            C[x] = arr[mid2 + x + 1];

        int i = 0, j = 0, x = 0; 
        int k = low; 
        
        while (i < sizeA && j < sizeB && x < sizeC) {
            
            if (A[i] < B[j] && A[i] < C[x]) { 
                arr[k] = A[i]; 
                i++; 
            } else
            if (A[i] >= B[j] && B[j] < C[x]) { 
                arr[k] = B[j]; 
                j++; 
            } else
            if (A[i] > C[x] && B[j] >= C[x]) { 
                arr[k] = C[x]; 
                x++; 
            } 
            k++; 
        } 

        while (i < sizeA) { 
            arr[k] = A[i]; 
            i++; 
            k++; 
        } 

        while (j < sizeB) { 
            arr[k] = B[j]; 
            j++; 
            k++; 
        } 
        
        while (x < sizeC) { 
            arr[k] = C[x]; 
            x++; 
            k++; 
        }
    } 

    void sort(int arr[], int low, int high) { 
        
        if (low < high) {  
            int mid1 = low + ((high - low) / 3); 
            int mid2 = low + 2 * ((high - low) / 3) + 1;

            sort(arr, low, mid1); 
            sort(arr, mid1 + 1, mid2); 
            sort(arr, mid2 + 1, high);

            merge(arr, low, mid1, mid2, high); 
        } 
    } 

    static void print(int arr[]) { 
        int n = arr.length; 
        for (int i = 0; i < n; ++i) 
            System.out.print(arr[i] + " "); 
        System.out.println(); 
    } 

    public static void main(String args[]) { 
        int arr[] = { 15, 2, 6, 7, 55, 0, 28, 41, 12 }; 

        TriMergeSort test = new TriMergeSort(); 
        test.sort(arr, 0, arr.length - 1); 

        print(arr); 
    }
} 

最佳答案

问题中发布的代码运行良好。您没有发布您遇到问题的 3 路合并代码。

请注意,不应将 high 作为要排序的切片中最后一项的索引传递,而应传递切片之外第一个元素的索引。这允许更简单的代码,而无需进行困惑且容易出错的 +1/-1 调整。

这是修改后的版本:

public class MergeSort { 

    void merge(int arr[], int low, int mid, int high) { 
        int sizeA = mid - low; 
        int sizeB = high - mid; 

        int A[] = new int[sizeA]; 
        int B[] = new int[sizeB]; 

        for (int i = 0; i < sizeA; i++) 
            A[i] = arr[low + i]; 
        for (int j = 0; j < sizeB; j++) 
            B[j] = arr[mid + j]; 

        int i = 0, j = 0; 
        int k = low; 
        
        while (i < sizeA && j < sizeB) { 
            if (A[i] <= B[j]) { 
                arr[k++] = A[i++]; 
            } else { 
                arr[k++] = B[j++]; 
            } 
        } 

        while (i < sizeA) {
            arr[k++] = A[i++];
        } 

        while (j < sizeB) { 
            arr[k++] = B[j++];
        } 
    } 

    void sort(int arr[], int low, int high) { 
        if (high - low >= 2) {  
            int mid = low + (high - low) / 2; 
            sort(arr, low, mid); 
            sort(arr, mid, high); 
            merge(arr, low, mid, high); 
        } 
    } 

    static void print(int arr[]) { 
        int n = arr.length; 
        for (int i = 0; i < n; ++i) {
            System.out.print(arr[i] + " ");
        }
        System.out.println(); 
    } 

    public static void main(String args[]) { 
        int arr[] = { 15, 2, 6, 7, 55, 0, 28, 41, 12, 10, 59 }; 
        MergeSort test = new MergeSort(); 
        test.sort(arr, 0, arr.length); 
        print(arr); 
    } 
}

要将其转换为 3 路合并版本,sort3 必须遵循以下步骤:

  • 将范围分为 3 个切片,而不是 2 个。第一个切片从 lowmid1 = low + (high - low)/3 排除,第二个切片从mid1mid2 = low + (high - low)*2/3 排除,第三个从 mid2high 排除。
  • 对 3 个子切片中的每一个进行递归排序
  • 调用merge3(arr, low, mid1, mid2, high)
    • 复制 3 个子切片
    • 为 3 个索引值编写一个循环,运行 3 个切片,直到其中一个耗尽
    • 为剩下的 2 个切片(A 和 B)或(B 和 C)或(A 和 C)编写 3 个循环,
    • 编写 3 个循环以从剩余切片 A、B 或 C 中复制剩余元素

编辑: TriMergeSort 类中的 merge 函数缺少 3 个循环,一旦 3 个初始切片之一被合并,就会合并 2 个切片。筋疲力尽的。这解释了为什么数组没有正确排序。在三路合并循环之后,您应该:

    while (i < sizeA && j < sizeB) {
        ...
    }
    while (i < sizeA && x < sizeC) {
        ...
    }
    while (j < sizeB && x < sizeC) {
        ...
    }

为了避免所有这些重复循环,您可以将对索引值的测试合并到单个循环体中:

public class TriMergeSort {

    void merge(int arr[], int low, int mid1, int mid2, int high) { 
        int sizeA = mid1 - low; 
        int sizeB = mid2 - mid1;
        int sizeC = high - mid2;

        int A[] = new int[sizeA]; 
        int B[] = new int[sizeB]; 
        int C[] = new int[sizeC];

        for (int i = 0; i < sizeA; i++) 
            A[i] = arr[low + i]; 
        for (int j = 0; j < sizeB; j++) 
            B[j] = arr[mid1 + j]; 
        for (int k = 0; k < sizeC; k++) 
            C[k] = arr[mid2 + k];

        int i = 0, j = 0, k = 0;
        
        while (low < high) {
            if (i < sizeA && (j >= sizeB || A[i] <= B[j])) {
                if (k >= sizeC || A[i] <= C[k]) {
                    arr[low++] = A[i++];
                } else {
                    arr[low++] = C[k++];
                }
            } else {
                if (j < sizeB && (k >= sizeC || B[j] <= C[k])) {
                    arr[low++] = B[j++];
                } else {
                    arr[low++] = C[k++];
                }
            }
        } 
    } 

    void sort(int arr[], int low, int high) { 
        if (high - low >= 2) {  
            int mid1 = low + (high - low) / 3; 
            int mid2 = low + (high - low) * 2 / 3;
            sort(arr, low, mid1); 
            sort(arr, mid1, mid2); 
            sort(arr, mid2, high);
            merge(arr, low, mid1, mid2, high); 
        } 
    } 

    static void print(int arr[]) { 
        int n = arr.length; 
        for (int i = 0; i < n; ++i) {
            System.out.print(arr[i] + " ");
        }
        System.out.println(); 
    } 

    public static void main(String args[]) { 
        int arr[] = { 15, 2, 6, 7, 55, 0, 28, 41, 12 }; 
        TriMergeSort test = new TriMergeSort(); 
        test.sort(arr, 0, arr.length); 
        print(arr); 
    }
}

上面的 while 循环可以进一步简化,但可读性稍差:

    while (low < high) {
        if (i < sizeA && (j >= sizeB || A[i] <= B[j])) {
            arr[low++] = (k >= sizeC || A[i] <= C[k]) ? A[i++] : C[k++];
        } else {
            arr[low++] = (j < sizeB && (k >= sizeC || B[j] <= C[k])) ? B[j++] : C[k++];
        }
    } 

甚至更进一步:

    while (low < high) {
        arr[low++] = (i < sizeA && (j >= sizeB || A[i] <= B[j])) ?
            ((k >= sizeC || A[i] <= C[k]) ? A[i++] : C[k++]) :
            (j < sizeB && (k >= sizeC || B[j] <= C[k])) ? B[j++] : C[k++];
    } 

关于java - 三路归并排序问题排序不正确,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64123996/

相关文章:

java - 合并排序,Java 代码不起作用?

javascript - Javascript 的原生排序方法是如何工作的?

c++ - 为什么我在这里遇到段错误?

java - java中归并排序的问题

java - 递归方法返回星号中的值

java - RabbitListener 不会接收使用 AsyncRabbitTemplate 发送的每条消息

java - android中以不同格式显示具有一个字符串值的日期的问题

java - 运算符 != 未定义参数类型 boolean, int

java - 使用 jsoup 解析表格

algorithm - 检查大小为 N 的 ArrayList 中是否有两个数字的总和为 N