所以我在 Rosetta Code 上查看合并排序的 C 示例,我对 merge() 函数的工作原理有点困惑。我认为他们使用的语法让我对冒号和 ? 感到厌烦。
void merge (int *a, int n, int m) {
int i, j, k;
int *x = malloc(n * sizeof (int));
for (i = 0, j = m, k = 0; k < n; k++) {
x[k] = j == n ? a[i++]
: i == m ? a[j++]
: a[j] < a[i] ? a[j++]
: a[i++];
}
for (i = 0; i < n; i++) {
a[i] = x[i];
}
free(x);
}
void merge_sort (int *a, int n) {
if (n < 2)
return;
int m = n / 2;
merge_sort(a, m);
merge_sort(a + m, n - m);
merge(a, n, m);
}
merge() 函数的 for 循环中究竟发生了什么?有人可以解释一下吗?
最佳答案
阅读评论:
void merge (int *a, int n, int m) {
int i, j, k;
// inefficient: allocating a temporary array with malloc
// once per merge phase!
int *x = malloc(n * sizeof (int));
// merging left and right halfs of a into temporary array x
for (i = 0, j = m, k = 0; k < n; k++) {
x[k] = j == n ? a[i++] // right half exhausted, take from left
: i == m ? a[j++] // left half exhausted, take from right
: a[j] < a[i] ? a[j++] // right element smaller, take that
: a[i++]; // otherwise take left element
}
// copy temporary array back to original array.
for (i = 0; i < n; i++) {
a[i] = x[i];
}
free(x); // free temporary array
}
void merge_sort (int *a, int n) {
if (n < 2)
return;
int m = n / 2;
// inefficient: should not recurse if n == 2
// recurse to sort left half
merge_sort(a, m);
// recurse to sort right half
merge_sort(a + m, n - m);
// merge left half and right half in place (via temp array)
merge(a, n, m);
}
merge
函数的更简单、更高效的版本,仅使用一半的临时空间:
static void merge(int *a, int n, int m) {
int i, j, k;
int *x = malloc(m * sizeof (int));
// copy left half to temporary array
for (i = 0; i < m; i++) {
x[i] = a[i];
}
// merge left and right half
for (i = 0, j = m, k = 0; i < m && j < n; k++) {
a[k] = a[j] < x[i] ? a[j++] : x[i++];
}
// finish copying left half
while (i < m) {
a[k++] = x[i++];
}
}
merge_sort
的更快版本涉及分配大小为 n * sizeof(*a)
的临时数组 x
并将其传递给递归函数函数 merge_sort1
也使用额外参数调用 merge
。 merge
中的逻辑也得到了改进,对 i
和 j
的比较次数减少了一半:
static void merge(int *a, int n, int m, int *x) {
int i, j, k;
for (i = 0; i < m; i++) {
x[i] = a[i];
}
for (i = 0, j = m, k = 0;;) {
if (a[j] < x[i]) {
a[k++] = a[j++];
if (j >= n) break;
} else {
a[k++] = x[i++];
if (i >= m) return;
}
}
while (i < m) {
a[k++] = x[i++];
}
}
static void merge_sort1(int *a, int n, int *x) {
if (n >= 2) {
int m = n / 2;
if (n > 2) {
merge_sort1(a, m, x);
merge_sort1(a + m, n - m, x);
}
merge(a, n, m, x);
}
}
void merge_sort(int *a, int n) {
if (n < 2)
return;
int *x = malloc(n / 2 * sizeof (int));
merge_sort1(a, n, x);
free(x);
}
关于c - merge_sort() 中的 merge() 如何工作?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/29244420/