我试图理解 MPI_Reduce_scatter
函数,但似乎我的推论总是错误的:(
文档说 ( link ):
MPI_Reduce_scatter first does an element-wise reduction on vector of count = S(i)recvcounts[i] elements in the send buffer defined by sendbuf, count, and datatype. Next, the resulting vector of results is split into n disjoint segments, where n is the number of processes in the group. Segment i contains recvcounts[i] elements. The ith segment is sent to process i and stored in the receive buffer defined by recvbuf, recvcounts[i], and datatype.
我有以下(非常简单的)C 程序,我希望获得第一个 recvcounts[i] 元素的最大值,但似乎我做错了什么......
#include <stdio.h>
#include <stdlib.h>
#include "mpi.h"
#define NUM_PE 5
#define NUM_ELEM 3
char *print(int arr[], int n);
int main(int argc, char *argv[]) {
int rank, size, i, n;
int sendbuf[5][3] = {
{ 1, 2, 3 },
{ 4, 5, 6 },
{ 7, 8, 9 },
{ 10, 11, 12 },
{ 13, 14, 15 }
};
int recvbuf[15] = {0};
int recvcounts[5] = {
3, 3, 3, 3, 3
};
MPI_Init(&argc, &argv);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
n = sizeof(sendbuf[rank]) / sizeof(int);
printf("sendbuf (thread %d): %s\n", rank, print(sendbuf[rank], n));
MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, MPI_INT, MPI_MAX, MPI_COMM_WORLD);
n = sizeof(recvbuf) / sizeof(int);
printf("recvbuf (thread %d): %s\n", rank, print(recvbuf, n)); // <--- I receive the same output as with sendbuf :(
MPI_Finalize();
return 0;
}
char *print(int arr[], int n) { } // it returns a string formatted as the following output
我程序的输出对于 recvbuf 和 sendbuf 是相同的。我希望 recvbuf 包含最大值:
$ mpicc 03_reduce_scatter.c
$ mpirun -n 5 ./a.out
sendbuf (thread 4): [ 13, 14, 15 ]
sendbuf (thread 3): [ 10, 11, 12 ]
sendbuf (thread 2): [ 7, 8, 9 ]
sendbuf (thread 0): [ 1, 2, 3 ]
sendbuf (thread 1): [ 4, 5, 6 ]
recvbuf (thread 1): [ 4, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
recvbuf (thread 2): [ 7, 8, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
recvbuf (thread 0): [ 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
recvbuf (thread 3): [ 10, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
recvbuf (thread 4): [ 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
最佳答案
是的,Reduce_scatter 的文档很简洁,而且没有被广泛使用,所以没有很多示例。 OCW MIT lecture 的前几张幻灯片有一个漂亮的图表,并提出一个用例。
通常情况下,关键是阅读 MPI document并特别注意对实现者的建议:
"The MPI_REDUCE_SCATTER routine is functionally equivalent to: an MPI_REDUCE collective operation with count equal to the sum of recvcounts[i] followed by MPI_SCATTERV with sendcounts equal to recvcounts."
那么让我们通过您的示例:这一行,
MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, MPI_INT, MPI_MAX, MPI_COMM_WORLD);
相当于这个:
int totcounts = 15; // = sum of {3, 3, 3, 3, 3}
MPI_Reduce({1,2,3...15}, tmpbuffer, totcounts, MPI_INT, MPI_MAX, 0, MPI_COMM_WORLD);
MPI_Scatterv(tmpbuffer, recvcounts, [displacements corresponding to recvcounts],
MPI_INT, rcvbuffer, 3, MPI_INT, 0, MPI_COMM_WORLD);
所以每个人都将发送相同的数字 {1...15},并且这些数字的每一列都会相互取最大值,从而导致 { max(1,1...1) , 最大值(2,2...2) ... 最大值(15,15...15)} = {1,2,...15}。
然后这些将被分散到处理器,一次 3 个,导致 {1,2,3}、{4,5,6}、{7,8,9}...
这就是发生的事情,我们如何让您希望发生的事情发生?我知道您希望每一行都被最大化,并且每个处理器都获得“它的”对应的行最大值。例如,假设数据如下所示:
Proc 0: 1 5 9 13
Proc 1: 2 6 10 14
Proc 2: 3 7 11 15
Proc 3: 4 8 12 16
我们想以 Proc 0(比如说)结束所有第 0 条数据的最大值,proc 1 以所有第 1 条数据的最大值,等等,所以我们最终会得到
Proc 0: 4
Proc 1: 8
Proc 2: 12
Proc 3: 16
那么让我们看看如何做到这一点。首先,每个人都会有一个值,所以所有的 recvcounts 都是 1。其次,每个进程都必须发送单独的数据。所以我们会有这样的东西:
#include <stdio.h>
#include <stdlib.h>
#include "mpi.h"
int main(int argc, char *argv[]) {
int rank, size, i, n;
MPI_Init(&argc, &argv);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
int sendbuf[size];
int recvbuf;
for (int i=0; i<size; i++)
sendbuf[i] = 1 + rank + size*i;
printf("Proc %d: ", rank);
for (int i=0; i<size; i++) printf("%d ", sendbuf[i]);
printf("\n");
int recvcounts[size];
for (int i=0; i<size; i++)
recvcounts[i] = 1;
MPI_Reduce_scatter(sendbuf, &recvbuf, recvcounts, MPI_INT, MPI_MAX, MPI_COMM_WORLD);
printf("Proc %d: %d\n", rank, recvbuf);
MPI_Finalize();
return 0;
}
运行给出(为清楚起见重新排序输出):
Proc 0: 1 5 9 13
Proc 1: 2 6 10 14
Proc 2: 3 7 11 15
Proc 3: 4 8 12 16
Proc 0: 4
Proc 1: 8
Proc 2: 12
Proc 3: 16
关于无法理解 MPI 中的 MPI_Reduce_scatter,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/25775353/