c++ - Rcpp Armadillo : Lambda expression with each_slice

标签 c++ r lambda rcpp armadillo

我有一个具有正定矩阵的三维数组,我想获得一个具有所有矩阵的 Cholesky 因子的相同大小的数组。我正在使用 Armadillo 库和 cube 类型,其中有一个我正在尝试使用的便捷函数 each_slice 。但我没有让 lambda 表达式正常工作,所以希望有人可以帮助我并指出我的错误。

这是一个最小的例子:

// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>

// [[Rcpp::export]]
arma::cube chol_array(arma::cube Sigma) {
  arma::cube Sigma_chol = Sigma;
  Sigma_chol.each_slice([](arma::mat X) {return arma::chol(X);});
  return Sigma_chol;
}
// [[Rcpp::export]]
arma::cube chol_array2(arma::cube Sigma) {
  arma::cube Sigma_chol(size(Sigma));
  for (arma::uword i = 0; i < Sigma.n_slices; i++) {
    Sigma_chol.slice(i) = arma::chol(Sigma.slice(i));
  }
  return Sigma_chol;
}

/*** R
Sigma <- array(crossprod(matrix(rnorm(9), 3, 3)), dim = c(3, 3, 2))
chol_array(Sigma)
chol_array2(Sigma)
*/

函数chol_array2完成这项工作,但chol_array只返回原始矩阵。我错过了什么?

最佳答案

这里的问题是 .each_slice() 中缺少引用。称呼。 Armadillo 使用 lambda 表达式需要引用来更新对象,而不是 return 语句。特别是,我们有:

For form 3:

apply the given lambda_function to each slice; the function must accept a reference to a Mat object with the same element type as the underlying cube

所以,改变:

Sigma_chol.each_slice([](arma::mat X) {return arma::chol(X);});

至:

Sigma_chol.each_slice([](arma::mat& X) {X = arma::chol(X);});

固定代码

// [[Rcpp::depends(RcppArmadillo)]]
#include <RcppArmadillo.h>

// Enable lambda expressions.... 
// [[Rcpp::plugins(cpp11)]]

// [[Rcpp::export]]
arma::cube chol_array(arma::cube Sigma) {
  arma::cube Sigma_chol = Sigma;

  // NOTE: the '&' and saving _back_ into the object are crucial
  Sigma_chol.each_slice( [](arma::mat& X) { X = arma::chol(X); } ); 

  return Sigma_chol;
}

测试代码

set.seed(1113)
Sigma = array(crossprod(matrix(rnorm(9), 3, 3)), dim = c(3, 3, 2))
all.equal(chol_array(Sigma), chol_array2(Sigma))
# [1] TRUE

关于c++ - Rcpp Armadillo : Lambda expression with each_slice,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48111392/

相关文章:

ruby - 何时何地使用 Lambda?

c++ - std::map 的这种可移动类型有什么问题?

c++ - 编写 Bash 脚本来计算平均值

r - R函数环境中自动注入(inject)变量

r - 首先按组,然后对于 R data.table 中的每个其他组成员都相同

java - 有什么方法可以从 Lambda 闭包中停止 Stream.generate 吗?

c# - 根据 LINQ/lambda 中的 Group By 语句计数创建计数

c# - 从 C++ DLL 中编写的 C++ 方法返回数据到 C#

c++ - isdigit() 和 isalnum() 给出错误,因为输入是 const char 并且无法转换。其他可能查看输入是否为数字的方法?

r - 如何在绘图中手动将线型分配给线条