我一直在研究这种三路合并排序算法,该算法是基于我的正常合并排序代码;但是,它的排序不正确,所以我相信我的代码中可能存在一个小错误。有什么帮助吗?我已经研究了 3 个小时的代码,试图找到问题所在,但事实证明这很困难。
public class TriMergeSort {
void merge(int arr[], int low, int mid1, int mid2, int high) {
int sizeA = mid1 - low + 1;
int sizeB = mid2 - mid1;
int sizeC = high - mid2;
int A[] = new int[sizeA];
int B[] = new int[sizeB];
int C[] = new int[sizeC];
for (int i = 0; i < sizeA; i++)
A[i] = arr[low + i];
for (int j = 0; j < sizeB; j++)
B[j] = arr[mid1 + j + 1];
for (int x = 0; x < sizeC; x++)
C[x] = arr[mid2 + x + 1];
int i = 0, j = 0, x = 0;
int k = low;
while (i < sizeA && j < sizeB && x < sizeC) {
if (A[i] < B[j] && A[i] < C[x]) {
arr[k] = A[i];
i++;
} else
if (A[i] >= B[j] && B[j] < C[x]) {
arr[k] = B[j];
j++;
} else
if (A[i] > C[x] && B[j] >= C[x]) {
arr[k] = C[x];
x++;
}
k++;
}
while (i < sizeA) {
arr[k] = A[i];
i++;
k++;
}
while (j < sizeB) {
arr[k] = B[j];
j++;
k++;
}
while (x < sizeC) {
arr[k] = C[x];
x++;
k++;
}
}
void sort(int arr[], int low, int high) {
if (low < high) {
int mid1 = low + ((high - low) / 3);
int mid2 = low + 2 * ((high - low) / 3) + 1;
sort(arr, low, mid1);
sort(arr, mid1 + 1, mid2);
sort(arr, mid2 + 1, high);
merge(arr, low, mid1, mid2, high);
}
}
static void print(int arr[]) {
int n = arr.length;
for (int i = 0; i < n; ++i)
System.out.print(arr[i] + " ");
System.out.println();
}
public static void main(String args[]) {
int arr[] = { 15, 2, 6, 7, 55, 0, 28, 41, 12 };
TriMergeSort test = new TriMergeSort();
test.sort(arr, 0, arr.length - 1);
print(arr);
}
}
最佳答案
问题中发布的代码运行良好。您没有发布您遇到问题的 3 路合并代码。
请注意,不应将 high
作为要排序的切片中最后一项的索引传递,而应传递切片之外第一个元素的索引。这允许更简单的代码,而无需进行困惑且容易出错的 +1
/-1
调整。
这是修改后的版本:
public class MergeSort {
void merge(int arr[], int low, int mid, int high) {
int sizeA = mid - low;
int sizeB = high - mid;
int A[] = new int[sizeA];
int B[] = new int[sizeB];
for (int i = 0; i < sizeA; i++)
A[i] = arr[low + i];
for (int j = 0; j < sizeB; j++)
B[j] = arr[mid + j];
int i = 0, j = 0;
int k = low;
while (i < sizeA && j < sizeB) {
if (A[i] <= B[j]) {
arr[k++] = A[i++];
} else {
arr[k++] = B[j++];
}
}
while (i < sizeA) {
arr[k++] = A[i++];
}
while (j < sizeB) {
arr[k++] = B[j++];
}
}
void sort(int arr[], int low, int high) {
if (high - low >= 2) {
int mid = low + (high - low) / 2;
sort(arr, low, mid);
sort(arr, mid, high);
merge(arr, low, mid, high);
}
}
static void print(int arr[]) {
int n = arr.length;
for (int i = 0; i < n; ++i) {
System.out.print(arr[i] + " ");
}
System.out.println();
}
public static void main(String args[]) {
int arr[] = { 15, 2, 6, 7, 55, 0, 28, 41, 12, 10, 59 };
MergeSort test = new MergeSort();
test.sort(arr, 0, arr.length);
print(arr);
}
}
要将其转换为 3 路合并版本,sort3
必须遵循以下步骤:
- 将范围分为 3 个切片,而不是 2 个。第一个切片从
low
到mid1 = low + (high - low)/3
排除,第二个切片从mid1
到mid2 = low + (high - low)*2/3
排除,第三个从mid2
到high
排除。 - 对 3 个子切片中的每一个进行递归排序
- 调用
merge3(arr, low, mid1, mid2, high)
- 复制 3 个子切片
- 为 3 个索引值编写一个循环,运行 3 个切片,直到其中一个耗尽
- 为剩下的 2 个切片(A 和 B)或(B 和 C)或(A 和 C)编写 3 个循环,
- 编写 3 个循环以从剩余切片 A、B 或 C 中复制剩余元素
编辑: TriMergeSort
类中的 merge
函数缺少 3 个循环,一旦 3 个初始切片之一被合并,就会合并 2 个切片。筋疲力尽的。这解释了为什么数组没有正确排序。在三路合并循环之后,您应该:
while (i < sizeA && j < sizeB) {
...
}
while (i < sizeA && x < sizeC) {
...
}
while (j < sizeB && x < sizeC) {
...
}
为了避免所有这些重复循环,您可以将对索引值的测试合并到单个循环体中:
public class TriMergeSort {
void merge(int arr[], int low, int mid1, int mid2, int high) {
int sizeA = mid1 - low;
int sizeB = mid2 - mid1;
int sizeC = high - mid2;
int A[] = new int[sizeA];
int B[] = new int[sizeB];
int C[] = new int[sizeC];
for (int i = 0; i < sizeA; i++)
A[i] = arr[low + i];
for (int j = 0; j < sizeB; j++)
B[j] = arr[mid1 + j];
for (int k = 0; k < sizeC; k++)
C[k] = arr[mid2 + k];
int i = 0, j = 0, k = 0;
while (low < high) {
if (i < sizeA && (j >= sizeB || A[i] <= B[j])) {
if (k >= sizeC || A[i] <= C[k]) {
arr[low++] = A[i++];
} else {
arr[low++] = C[k++];
}
} else {
if (j < sizeB && (k >= sizeC || B[j] <= C[k])) {
arr[low++] = B[j++];
} else {
arr[low++] = C[k++];
}
}
}
}
void sort(int arr[], int low, int high) {
if (high - low >= 2) {
int mid1 = low + (high - low) / 3;
int mid2 = low + (high - low) * 2 / 3;
sort(arr, low, mid1);
sort(arr, mid1, mid2);
sort(arr, mid2, high);
merge(arr, low, mid1, mid2, high);
}
}
static void print(int arr[]) {
int n = arr.length;
for (int i = 0; i < n; ++i) {
System.out.print(arr[i] + " ");
}
System.out.println();
}
public static void main(String args[]) {
int arr[] = { 15, 2, 6, 7, 55, 0, 28, 41, 12 };
TriMergeSort test = new TriMergeSort();
test.sort(arr, 0, arr.length);
print(arr);
}
}
上面的 while
循环可以进一步简化,但可读性稍差:
while (low < high) {
if (i < sizeA && (j >= sizeB || A[i] <= B[j])) {
arr[low++] = (k >= sizeC || A[i] <= C[k]) ? A[i++] : C[k++];
} else {
arr[low++] = (j < sizeB && (k >= sizeC || B[j] <= C[k])) ? B[j++] : C[k++];
}
}
甚至更进一步:
while (low < high) {
arr[low++] = (i < sizeA && (j >= sizeB || A[i] <= B[j])) ?
((k >= sizeC || A[i] <= C[k]) ? A[i++] : C[k++]) :
(j < sizeB && (k >= sizeC || B[j] <= C[k])) ? B[j++] : C[k++];
}
关于java - 三路归并排序问题排序不正确,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64123996/