matlab - 使用 mexCallMATLAB 将 Double* 转换为 mxArray* 的最有效方法

标签 matlab mex matrix-inverse

我正在编写一个 MEX 代码,其中需要使用 pinv 函数。我正在尝试找到一种方法,以最有效的方式使用 mexCallMATLABdouble 类型的数组传递给 pinv。举例来说,假设该数组名为 G,其大小为 100。

double *G = (double*) mxMalloc( 100 * sizeof(double) );

哪里

G[0] = G11; G[1] = G12;
G[2] = G21; G[3] = G22;

这意味着G的每四个连续元素都是一个2×2矩阵。 G 存储此 2×2 矩阵的 25 个不同值。

enter image description here

我应该注意到这些 2×2 矩阵的条件不好,它们的元素中可能包含全零。如何使用pinv函数计算G元素中的伪逆?例如,如何将数组传递给 mexCallMATLAB 以计算 G 中第一个 2×2 矩阵的伪逆?

我想到了以下方法:

mxArray *G_PINV_input     = mxCreateDoubleMatrix(2, 2, mxREAL);
mxArray *G_PINV_output    = mxCreateDoubleMatrix(2, 2, mxREAL);
double  *G_PINV_input_ptr = mxGetPr(G_PINV_input);

memcpy( G_PINV_input_ptr, &G[0], 4 * sizeof(double));
mexCallMATLAB(1, G_PINV_output, 1, G_PINV_input, "pinv");

我不确定这种方法有多好。复制这些值根本不经济,因为在我的实际应用中 G 中的元素总数很大。有没有办法跳过这个复制?

最佳答案

这是我对 MEX 函数的实现:

my_pinv.cpp

#include "mex.h"

void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
    // validate arguments
    if (nrhs!=1 || nlhs>1)
        mexErrMsgIdAndTxt("mex:error", "Wrong number of arguments");
    if (!mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0]))
        mexErrMsgIdAndTxt("mex:error", "Input isnt real dense double array");
    if (mxGetNumberOfElements(prhs[0]) != 100)
        mexErrMsgIdAndTxt("mex:error", "numel() != 100");

    // create necessary arrays
    mxArray *rhs[1], *lhs[1];
    plhs[0] = mxCreateDoubleMatrix(100, 1, mxREAL);
    rhs[0] = mxCreateDoubleMatrix(2, 2, mxREAL);
    double *in = mxGetPr(prhs[0]);
    double *out = mxGetPr(plhs[0]);
    double *x = mxGetPr(rhs[0]), *y;

    // for each 2x2 matrix
    for (mwIndex i=0; i<100; i+=4) {
        // copy 2x2 matrix into rhs
        x[0] = in[i+0];
        x[2] = in[i+1];
        x[1] = in[i+2];
        x[3] = in[i+3];

        // lhs = pinv(rhs)
        mexCallMATLAB(1, lhs, 1, rhs, "pinv");

        // copy 2x2 matrix from lhs
        y = mxGetPr(lhs[0]);
        out[i+0] = y[0];
        out[i+1] = y[1];
        out[i+2] = y[2];
        out[i+3] = y[3];

        // free array
        mxDestroyArray(lhs[0]);
    }

    // cleanup
    mxDestroyArray(rhs[0]);
}

这是 MATLAB 中的基线实现,以便我们可以验证结果是否正确:

my_pinv0.m

function y = my_pinv0(x)
    y = zeros(size(x));
    for i=1:4:numel(x)
        y(i:i+3) = pinv(x([0 1; 2 3]+i));
    end
end

现在我们测试 MEX 函数:

% some vector
x = randn(100,1);

% MEX vs. MATLAB function
y = my_pinv0(x);
yy = my_pinv(x);

% compare
assert(isequal(y,yy))

编辑:

这是另一个实现:

my_pinv2.cpp

#include "mex.h"

inline void call_pinv(const double &a, const double &b, const double &c,
                      const double &d, double *out)
{
    mxArray *rhs[1], *lhs[1];

    // create input matrix [a b; c d]
    rhs[0] = mxCreateDoubleMatrix(2, 2, mxREAL);
    double *x = mxGetPr(rhs[0]);
    x[0] = a;
    x[1] = c;
    x[2] = b;
    x[3] = d;

    // lhs = pinv(rhs)
    mexCallMATLAB(1, lhs, 1, rhs, "pinv");

    // get values from output matrix
    const double *y = mxGetPr(lhs[0]);
    out[0] = y[0];
    out[1] = y[1];
    out[2] = y[2];
    out[3] = y[3];

    // cleanup
    mxDestroyArray(lhs[0]);
    mxDestroyArray(rhs[0]);
}

void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
    // validate arguments
    if (nrhs!=1 || nlhs>1)
        mexErrMsgIdAndTxt("mex:error", "Wrong number of arguments");
    if (!mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0]))
        mexErrMsgIdAndTxt("mex:error", "Input isnt real dense double array");
    if (mxGetNumberOfElements(prhs[0]) != 100)
        mexErrMsgIdAndTxt("mex:error", "numel() != 100");

    // allocate output
    plhs[0] = mxCreateDoubleMatrix(100, 1, mxREAL);
    double *out = mxGetPr(plhs[0]);
    const double *in = mxGetPr(prhs[0]);

    // for each 2x2 matrix
    for (mwIndex i=0; i<100; i+=4) {
        // 2x2 input matrix [a b; c d], and its determinant
        const double a = in[i+0];
        const double b = in[i+1];
        const double c = in[i+2];
        const double d = in[i+3];
        const double det = (a*d - b*c);

        if (det != 0) {
            // inverse of 2x2 matrix [d -b; -c a]/det
            out[i+0] =  d/det;
            out[i+1] = -c/det;
            out[i+2] = -b/det;
            out[i+3] =  a/det;
        }
        else {
            // singular matrix, fallback to pseudo-inverse
            call_pinv(a, b, c, d, &out[i]);
        }
    }
}

这次我们计算2x2矩阵的行列式,如果不为零,我们自己计算逆矩阵:

2x2_matrix_inverse

否则,我们会回退到从 MATLAB 调用 PINV 来获取 pseudo-inverse .

这是快速基准测试:

% 100x1 vector
x = randn(100,1);           % average case, with normal 2x2 matrices

% running time
funcs = {@my_pinv0, @my_pinv1, @my_pinv2};
t = cellfun(@(f) timeit(@() f(x)), funcs, 'Uniform',true);

% compare results
y = cellfun(@(f) f(x), funcs, 'Uniform',false);
assert(isequal(y{1},y{2}))

我得到以下时间:

>> fprintf('%.6f\n', t);
0.002111   % MATLAB function
0.001498   % first MEX-file with mexCallMATLAB
0.000010   % second MEX-file with "unrolled" matrix inverse (+ PINV as fallback)

误差是可以接受的并且在机器精度范围内:

>> norm(y{1}-y{3})
ans =
   2.1198e-14

您还可以测试最坏的情况,即许多 2x2 矩阵都是奇异的:

x = randi([0 1], [100 1]);

关于matlab - 使用 mexCallMATLAB 将 Double* 转换为 mxArray* 的最有效方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/26769577/

相关文章:

python - 使用 python 求解 7000x7000 线性系统时的最佳性能方法

c++ - 如何让 MatLab 找到 Visual C++ 编译器?

c++ - MATLAB任意代码执行

python - python中使用statsmodels错误的逻辑回归

matlab - 如何根据表格中的数据修改合适的单元格颜色(在Matlab中)?

performance - 在Matlab中获取矩阵的对角线

matlab - 具有包装方法的集成分类器

matlab - 如何在matlab中保存二进制数据并进行十六进制转换

c++ - 如何使用 Microsoft Visual C++ 2010 通过 MATLAB 7.10 (R2010a) 创建 MEX 文件?

c++ - LAPACKE C++实矩阵求逆