我正在编写一个 MEX
代码,其中需要使用 pinv
函数。我正在尝试找到一种方法,以最有效的方式使用 mexCallMATLAB
将 double
类型的数组传递给 pinv
。举例来说,假设该数组名为 G
,其大小为 100。
double *G = (double*) mxMalloc( 100 * sizeof(double) );
哪里
G[0] = G11; G[1] = G12;
G[2] = G21; G[3] = G22;
这意味着G的每四个连续元素都是一个2×2
矩阵。 G
存储此 2×2
矩阵的 25
个不同值。
我应该注意到这些 2×2
矩阵的条件不好,它们的元素中可能包含全零。如何使用pinv
函数计算G
元素中的伪逆?例如,如何将数组传递给 mexCallMATLAB
以计算 G
中第一个 2×2
矩阵的伪逆?
我想到了以下方法:
mxArray *G_PINV_input = mxCreateDoubleMatrix(2, 2, mxREAL);
mxArray *G_PINV_output = mxCreateDoubleMatrix(2, 2, mxREAL);
double *G_PINV_input_ptr = mxGetPr(G_PINV_input);
memcpy( G_PINV_input_ptr, &G[0], 4 * sizeof(double));
mexCallMATLAB(1, G_PINV_output, 1, G_PINV_input, "pinv");
我不确定这种方法有多好。复制这些值根本不经济,因为在我的实际应用中 G
中的元素总数很大。有没有办法跳过这个复制?
最佳答案
这是我对 MEX 函数的实现:
my_pinv.cpp
#include "mex.h"
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
// validate arguments
if (nrhs!=1 || nlhs>1)
mexErrMsgIdAndTxt("mex:error", "Wrong number of arguments");
if (!mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0]))
mexErrMsgIdAndTxt("mex:error", "Input isnt real dense double array");
if (mxGetNumberOfElements(prhs[0]) != 100)
mexErrMsgIdAndTxt("mex:error", "numel() != 100");
// create necessary arrays
mxArray *rhs[1], *lhs[1];
plhs[0] = mxCreateDoubleMatrix(100, 1, mxREAL);
rhs[0] = mxCreateDoubleMatrix(2, 2, mxREAL);
double *in = mxGetPr(prhs[0]);
double *out = mxGetPr(plhs[0]);
double *x = mxGetPr(rhs[0]), *y;
// for each 2x2 matrix
for (mwIndex i=0; i<100; i+=4) {
// copy 2x2 matrix into rhs
x[0] = in[i+0];
x[2] = in[i+1];
x[1] = in[i+2];
x[3] = in[i+3];
// lhs = pinv(rhs)
mexCallMATLAB(1, lhs, 1, rhs, "pinv");
// copy 2x2 matrix from lhs
y = mxGetPr(lhs[0]);
out[i+0] = y[0];
out[i+1] = y[1];
out[i+2] = y[2];
out[i+3] = y[3];
// free array
mxDestroyArray(lhs[0]);
}
// cleanup
mxDestroyArray(rhs[0]);
}
这是 MATLAB 中的基线实现,以便我们可以验证结果是否正确:
my_pinv0.m
function y = my_pinv0(x)
y = zeros(size(x));
for i=1:4:numel(x)
y(i:i+3) = pinv(x([0 1; 2 3]+i));
end
end
现在我们测试 MEX 函数:
% some vector
x = randn(100,1);
% MEX vs. MATLAB function
y = my_pinv0(x);
yy = my_pinv(x);
% compare
assert(isequal(y,yy))
编辑:
这是另一个实现:
my_pinv2.cpp
#include "mex.h"
inline void call_pinv(const double &a, const double &b, const double &c,
const double &d, double *out)
{
mxArray *rhs[1], *lhs[1];
// create input matrix [a b; c d]
rhs[0] = mxCreateDoubleMatrix(2, 2, mxREAL);
double *x = mxGetPr(rhs[0]);
x[0] = a;
x[1] = c;
x[2] = b;
x[3] = d;
// lhs = pinv(rhs)
mexCallMATLAB(1, lhs, 1, rhs, "pinv");
// get values from output matrix
const double *y = mxGetPr(lhs[0]);
out[0] = y[0];
out[1] = y[1];
out[2] = y[2];
out[3] = y[3];
// cleanup
mxDestroyArray(lhs[0]);
mxDestroyArray(rhs[0]);
}
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
// validate arguments
if (nrhs!=1 || nlhs>1)
mexErrMsgIdAndTxt("mex:error", "Wrong number of arguments");
if (!mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0]))
mexErrMsgIdAndTxt("mex:error", "Input isnt real dense double array");
if (mxGetNumberOfElements(prhs[0]) != 100)
mexErrMsgIdAndTxt("mex:error", "numel() != 100");
// allocate output
plhs[0] = mxCreateDoubleMatrix(100, 1, mxREAL);
double *out = mxGetPr(plhs[0]);
const double *in = mxGetPr(prhs[0]);
// for each 2x2 matrix
for (mwIndex i=0; i<100; i+=4) {
// 2x2 input matrix [a b; c d], and its determinant
const double a = in[i+0];
const double b = in[i+1];
const double c = in[i+2];
const double d = in[i+3];
const double det = (a*d - b*c);
if (det != 0) {
// inverse of 2x2 matrix [d -b; -c a]/det
out[i+0] = d/det;
out[i+1] = -c/det;
out[i+2] = -b/det;
out[i+3] = a/det;
}
else {
// singular matrix, fallback to pseudo-inverse
call_pinv(a, b, c, d, &out[i]);
}
}
}
这次我们计算2x2矩阵的行列式,如果不为零,我们自己计算逆矩阵:
否则,我们会回退到从 MATLAB 调用 PINV 来获取 pseudo-inverse .
这是快速基准测试:
% 100x1 vector
x = randn(100,1); % average case, with normal 2x2 matrices
% running time
funcs = {@my_pinv0, @my_pinv1, @my_pinv2};
t = cellfun(@(f) timeit(@() f(x)), funcs, 'Uniform',true);
% compare results
y = cellfun(@(f) f(x), funcs, 'Uniform',false);
assert(isequal(y{1},y{2}))
我得到以下时间:
>> fprintf('%.6f\n', t);
0.002111 % MATLAB function
0.001498 % first MEX-file with mexCallMATLAB
0.000010 % second MEX-file with "unrolled" matrix inverse (+ PINV as fallback)
误差是可以接受的并且在机器精度范围内:
>> norm(y{1}-y{3})
ans =
2.1198e-14
您还可以测试最坏的情况,即许多 2x2 矩阵都是奇异的:
x = randi([0 1], [100 1]);
关于matlab - 使用 mexCallMATLAB 将 Double* 转换为 mxArray* 的最有效方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/26769577/