c++ - 直接在 RcppArmadillo 中调用 LAPACK 例程

标签 c++ r rcpp armadillo

由于 Armadillo (据我所知)没有三角求解器,我想使用 dtrtrs 中提供的 LAPACK 三角求解器。我查看了以下两个( firstsecond )SO 线程并将一些内容拼凑在一起,但它不起作用。

我使用 RStudio 创建了一个新的包,同时还启用了 RcppArmadillo。我有一个头文件 header.h:

#include <RcppArmadillo.h>

#ifdef ARMA_USE_LAPACK
#if !defined(ARMA_BLAS_CAPITALS)
#define arma_dtrtrs dtrtrs
#else
#define arma_dtrtrs DTRTRS
#endif
#endif

extern "C" {
  void arma_fortran(arma_dtrtrs)(char* UPLO, char* TRANS, char* DIAG, int* N, int* NRHS,
                    double* A, int* LDA, double* B, int* LDB, int* INFO);
}

int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb);

static int trisolve(const arma::mat &in_A, const arma::mat &in_b, arma::mat &out_x);

这本质上是第一个链接问题的答案,还有一个包装函数和主函数。函数的主要内容位于 trisolve.cpp 中,如下所示:

#include "header.h"

int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb) {
  int info = 0;
  wrapper_dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, A, &lda, B, &ldb, &info);
  return info;
}


static int trisolve(const arma::mat &in_A, const arma::mat &in_b, arma::mat &out_x) {
  size_t  rows = in_A.n_rows;
  size_t  cols = in_A.n_cols;

  double *A = new double[rows*cols];
  double *b = new double[in_b.size()];

  //Lapack has column-major order
  for(size_t col=0, D1_idx=0; col<cols; ++col)
  {
    for(size_t row = 0; row<rows; ++row)
    {
      // Lapack uses column major format
      A[D1_idx++] = in_A(row, col);
    }
    b[col] = in_b(col);
  }

  for(size_t row = 0; row<rows; ++row)
  {
    b[row] = in_b(row);
  }

  int info = trtrs('U', 'N', 'N', cols, 1, A, rows, b, rows);

  for(size_t col=0; col<cols; col++) {
    out_x(col)=b[col];
  }

  delete[] A;
  delete[] b;

  return 0;
}


// [[Rcpp::export]]

arma::mat RtoRcpp(arma::mat A, arma::mat b) {
  arma::uword n = A.n_rows;
  arma::mat x = arma::mat(n, 1, arma::fill::zeros);

  int info = trisolve(A, b, x);
  return x;
}

我(至少)有两个问题:

  1. 尝试编译时,我从头文件中得到:conflicting types for 'dtrtrs_'。但是,我看不出输入有什么问题(这实际上是从第二个链接线程复制的)。
  2. 毫不奇怪,wrapper_dtrtrts_ 不正确。但从 Armadillo 的compiler_setup.hpp中我可以看出, arma_fortran 应该为我创建一个名为 wrapper_dtrtrs_ 的函数。我应该在主 cpp 文件中使用什么名称?

最佳答案

Armadillo 已经使用dtrtrs用于解决三对角问题。部分代码引用:

因此,如果我们可以触发此调试语句,我们就可以确定 dtrtrs确实使用过:

#define ARMA_EXTRA_DEBUG
// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>

// [[Rcpp::export]]
void testTrisolve() {
  arma::mat A = arma::randu<arma::mat>(5,5);
  arma::mat B = arma::randu<arma::mat>(5,5);

  arma::mat X1 = arma::solve(A, B);
  arma::mat X3 = arma::solve(arma::trimatu(A), B);
}

/*** R
testTrisolve()
*/

这会产生大量调试消息,其中:

lapack::gesvx()
[...]
lapack::trtrs()

所以我们清楚地看到dtrtrs用于三对角的情况。

至于你原来的问题:

  1. 冲突类型错误是 Aramdillo 已经使用 dtrtrs 的结果。 ,但签名略有不同( Aconst )。
  2. Fortran 函数的 C 级名称取决于 ARMA_BLAS_UNDERSCORE 的值和ARMA_USE_WRAPPER 。我不确定情况是否总是如此,但对我来说,前者已定义,后者未定义(参见 config.hpp ),导致 dtrtrs_作为名称。

确实,如果我添加 const Armadillo 使用它并将该函数调用为 dtrtrs_ ,您的代码编译时不会出现错误或警告(未使用的变量除外......):

// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>

extern "C" {
  void arma_fortran(dtrtrs)(char* UPLO, char* TRANS, char* DIAG, int* N, int* NRHS,
                    const double* A, int* LDA, double* B, int* LDB, int* INFO);
}

int trtrs(char uplo, char trans, char diag, int n, int nrhs, double* A, int lda, double* B, int ldb) {
  int info = 0;
  dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, A, &lda, B, &ldb, &info);
  return info;
}

[...]

关于c++ - 直接在 RcppArmadillo 中调用 LAPACK 例程,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52908185/

相关文章:

rcpp 一个文件中有多个函数且没有匹配的函数

c++ - 这是不确定的吗?

multithreading - 使用 Rcpp 时抛出错误的线程安全方法是什么

c++ - openGL 2D 移动和对象 1 个单位

r - 如何通过将索引保留在第一列和第二列中将矩阵转换为 R 中的数据帧?

r - 异步进程阻塞 R Shiny 应用程序

r - 基于另一个矩阵对一个矩阵进行子集化

带条件的 R 累积和

c++ - 内存中有什么值(value)?

c++ - C++ 中的结构*