c++ - 使用 BLAS 和 LAPACKE 在 C++ 中使用 SVD 计算伪逆

标签 c++ lapack svd lapacke cblas

我正在尝试实现矩阵的伪逆计算 A*,以便在 C++ 中求解具有维度的方 nxn 矩阵 A 的 Ax=b。 A*的算术公式是通过SVD分解得到的。

所以我首先计算 SVD(A)=USV^T 然后 A*=VSU^T,其中S为反对角线S,其非零元素si在S*中变为1/si。最后我计算解决方案 x=A*b

但是我没有得到正确的结果。我正在使用 C++ 的 LAPACKE 接口(interface)和矩阵乘法的 cblas。这是我的代码:

double a[n * n] = {2, -1, 2,1};
double b[n]={3,4};
double u[n * n], s[n],vt[n * n];

int lda = n, ldu = n, ldvt = n;

int info = LAPACKE_dgesdd(LAPACK_COL_MAJOR, 'A', n, n, a, lda, s,
               u, ldu, vt, ldvt);




for (int i = 0; i < n; i++) {
        s[i] = 1.0 / s[i];       
}

const int a = 1;
const int c = 0;

double r1[n];
double r2[n];
double res[n];

//compute the  first multiplication s*u^T
cblas_dgemm( CblasColMajor,CblasNoTrans, CblasTrans, n, n, n, a, u, ldvt, s, ldu, c, r1, n);

//compute the second multiplication v^T^T=vs*u^T
cblas_dgemm( CblasColMajor,CblasTrans, CblasNoTrans, n, n, n, a, vt, ldvt, r1, ldu, c, r2, n);

//now that we have the pseudoinverse A* solve by multiplying with b.
cblas_dgemm( CblasColMajor,CblasNoTrans, CblasNoTrans, n, 1, n, a, r2, ldvt, b, ldu, c, res, n);

在第二个 cblas_dgemm 之后,预期在 r2 中有 A* 伪逆。然而,在与 matlab pinv 进行比较后,我没有得到相同的结果。如果我打印 r2,结果给出:

 0.25   0.50
 0.25   0.50

但应该是

0.25   -0.50
0.25   0.50

最佳答案

LAPACKE_dgesdd() 的参数 S 表示SVD decomposition中矩阵的奇异值.虽然它的长度为 n,但它并不描述 vector ,因为它表示对角矩阵。实际上,S.u^T 的结果是一个大小为 的矩阵n*n.

例程 cblas_dscal() 可以在循环中应用以计算涉及对角矩阵的矩阵乘积,尽管生成的 S.u^t 仍然是转置的。参见 what is the best way to multiply a diagonal matrix in fortran

以下代码可以通过g++ main.cpp -o main -llapacke -llapack -lgslcblas -lblas -lm -Wall(或-lcblas`...)编译

#include <iostream>
#include <string>
#include <fstream>  

#include <stdlib.h>
#include <stdio.h>
#include <math.h>



extern "C" { 
#include <lapacke.h>
#include <cblas.h>
}

int main(int argc, char *argv[])
{
const int n=2;

double a[n * n] = {2, -1, 2,1};
double b[n]={3,4};
double u[n * n], s[n],vt[n * n];

int lda = n, ldu = n, ldvt = n;

//computing the SVD
int info = LAPACKE_dgesdd(LAPACK_COL_MAJOR, 'A', n, n, a, lda, s,
               u, ldu, vt, ldvt);
if (info !=0){
std::cerr<<"Lapack error occured in dgesdd. error code :"<<info<<std::endl;
}


for (int i = 0; i < n; i++) {
        s[i] = 1.0 / s[i];       
}

const int aa = 1;
const int c = 0;

//double r1[n*n];
double r2[n*n];
double res[n];

//compute the  first multiplication s*u^T
// here : s is not a vector : it is a diagonal matrix. The ouput must be of size n*n
//cblas_dgemm( CblasColMajor,CblasNoTrans, CblasTrans, n, n, n, aa, u, ldvt, s, ldu, c, r1, n);
for (int i = 0; i < n; i++) {
cblas_dscal(n,s[i],&u[i*n],1);
}

//compute the second multiplication v^T^T=vs*u^T
cblas_dgemm( CblasColMajor,CblasTrans, CblasTrans, n, n, n, aa, vt, ldvt, u, ldu, c, r2, n);
//now, r2 is the pseudoinverse of a.
//now that we have the pseudoinverse A* solve by multiplying with b.
cblas_dgemm( CblasColMajor,CblasNoTrans, CblasNoTrans, n, 1, n, aa, r2, ldvt, b, ldu, c, res, n);


for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
    std::cout<<r2[i*n+j]<<" ";
}
}

std::cout<<std::endl;
}

它打印出预期的结果:

0.25 0.25 -0.5 0.5 

关于c++ - 使用 BLAS 和 LAPACKE 在 C++ 中使用 SVD 计算伪逆,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55599950/

相关文章:

c++ - 为什么我的 vector 中的字符串不改变值?

python - Windows 上 cvxopt 的导入问题

lapack - BLAS 和 LAPACK 的 Bazel 构建规则

c++ - MKL 库在 mex 文件和独立 C++ 中的行为不同

python - 在 Python 中解决奇异值分解 (SVD)

python - 来自 numpy.linalg.svd 的大型矩阵的 MemoryError

python - python中的稀疏矩阵svd

c++ - 购买元素后最小化硬币数量

c++ - C++哈希程序中的 undefined symbol 错误

c++ - 使用可变参数模板构建函数参数