假设我有两个矩阵:
A=50;
B=50;
C=1000;
X = rand(A,B);
Y = rand(A,B,C);
我想从 Y
的每个切片 C
中减去 X
。这是一个相当常见的问题,我找到了三种替代解决方案:
% Approach 1: for-loop
tic
Z1 = zeros(size(Y));
for i=1:C
Z1(:,:,i) = Y(:,:,i) - X;
end
toc
% Approach 2: repmat
tic
Z2 = Y - repmat(X,[1 1 C]);
toc
% Approach 3: bsxfun
tic
Z3=bsxfun(@minus,Y,X);
toc
我正在构建一个经常(即数千次)解决此类问题的程序,因此我正在寻找最有效的解决方案。以下是常见的结果模式:
Elapsed time is 0.013527 seconds.
Elapsed time is 0.004080 seconds.
Elapsed time is 0.006310 seconds.
循环明显慢一些,bsxfun 比repmat 慢一点。当我对 Y
的切片进行逐元素乘法(而不是减法)X
时,我发现了相同的模式,尽管repmat 和 bsxfun 在乘法方面更接近一些。
增加数据大小...
A=500;
B=500;
C=1000;
Elapsed time is 2.049753 seconds.
Elapsed time is 0.570809 seconds.
Elapsed time is 1.016121 seconds.
在这里,repmat 是明显的赢家。我想知道 SO 社区中是否有人有一个很酷的技巧来加速这个操作。
最佳答案
根据您的实际情况,bsxfun
和 repmat
有时会比另一个更有优势,就像@rayryeng 建议的那样。您还可以考虑另一种选择:mex 。我在这里硬编码了一些参数以获得更好的性能。
#include "mex.h"
#include "matrix.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
double *A, *B, *C;
int ind_1, ind_2, ind_3, ind_21, ind_32, ind_321, Dims[3] = {500,500,5000};
plhs[0] = mxCreateNumericArray(3, Dims, mxDOUBLE_CLASS, mxREAL);
A = mxGetPr(prhs[0]);
B = mxGetPr(prhs[1]);
C = mxGetPr(plhs[0]);
for ( int ind_3 = 0; ind_3 < 5000; ind_3++)
{
ind_32 = ind_3*250000;
for ( int ind_2 = 0; ind_2 < 500; ind_2++)
{
ind_21 = ind_2*500; // taken out of the innermost loop to save some calculation
ind_321 = ind_32 + ind_21;
for ( int ind_1 = 0 ; ind_1 < 500; ind_1++)
{
C[ind_1 + ind_321] = A[ind_1 + ind_321] - B[ind_1 + ind_21];
}
}
}
}
要使用它,请在命令窗口中输入(假设您将上面的 c 文件命名为 mexsubtract.c )
mex -WIN64 mexsubtract.c
然后你可以像这样使用它:
Z4 = mexsubtract(Y,X);
以下是我的计算机上使用 A=500、B=500、C=5000 的测试结果:
(repmat) Elapsed time is 3.441695 seconds.
(bsxfun) Elapsed time is 3.357830 seconds.
(cmex) Elapsed time is 3.391378 seconds.
这是一个势均力敌的竞争者,在某些更极端的情况下,它会有优势。例如,这是我使用 A = 10、B = 500、C = 200000 得到的结果:
(repmat) Elapsed time is 2.769177 seconds.
(bsxfun) Elapsed time is 3.178385 seconds.
(cmex) Elapsed time is 2.552115 seconds.
关于matlab - MATLAB 中的高效 3D 逐元素运算,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/33272191/