我用 R 编写了这个简单的“for 循环”,我必须将其转换为 C++。以下是 R 代码的可重现示例:
# Parameters required
a <- 1.8
b <- 1
time.dt <- 0.1
yp <- 40
insp.int <- 7
ph <- 2000
dt <- seq(0,ph,time.dt) # Time sequence
MD.set <- c(seq(insp.int, ph, insp.int), ph) # Decision points to check and set next inspection date
# Initialization
cum.y <- rep(0,length = length(dt))
init.y <- 0
flag <- FALSE
# At each iteration, the following loop generates a gamma distributed random number and cum.y keeps taking cumulative sum
# The objective is to return a vector cum.y with a conditional cumulative sum of previous iteration
# When dt[i] matches any values in MD.Set AND corresponding cum.y[i] is also >= yp it changes the flag to true (the last if)
# At the start of the loop it checks if dt[i] matches any values in MD.Set AND flag is also true. If yes, then cum.y is reset to 0.
for (i in 2:length(dt)){
if (dt[i] %in% MD.set && flag == TRUE){
cum.y[i] <- 0
init.y <- 0
flag <- FALSE
next
} else {
cum.y[i] <- init.y + rgamma(n = 1, shape = a*time.dt, scale = b)
init.y <- cum.y[i]
if (dt[i] %in% MD.set && cum.y[i] >= yp){
flag <- TRUE
}
}
}
res <- cbind(dt, cum.y)
由于我之前没有使用 C++ 的经验,因此在尝试这样做时遇到了许多问题。我只需要进行此转换,以便能够在 R 中的 Rcpp 包中使用它。因为代码在 R 中运行缓慢,特别是当 time.dt
变小时,我猜测 C++ 会完成这项工作快点。你能帮忙吗?
更新2:
这是我在评论和答案的帮助下进行转换的建议。但是,我不确定 C++
中的 next
相当于什么。如果我使用 continue
它会继续执行其余代码(并执行 else
之后的代码。如果我使用 break
那么它就会出来条件为真后循环。
NumericVector cumy(double a, double b, double timedt, NumericVector dt, NumericVector MDSet, double yp){
bool flag = false;
int n = dt.size();
double total = 0;
NumericVector out(n);
unordered_set<int> sampleSet(MDSet.begin(), MDSet.end());
for (int i = 0; i < n; ++i){
if (sampleSet.find(dt[i]) != sampleSet.end() && flag == true){
out[i] = 0;
total = 0;
flag = false;
continue;
} else {
out[i] = total + rgamma(1, a*timedt, b)[0];
total = out[i];
if (sampleSet.find(dt[i]) != sampleSet.end() && out[i] >= yp){
flag = true;
}
}
}
return out;
}
最佳答案
您收到的错误只是因为没有从 NumericVector
进行自动转换。至std::unordered_set<int>
。您可以通过执行以下操作来解决此问题:
std::unordered_set<int> sampleSet( MDSet.begin(), MDSet.end() )
这称为unordered_set
的构造函数,其开始和结束迭代器为 MDSet
,这将用所有值填充该集合。
您的代码中还存在另一个问题:
if (sampleSet.find(dt[i]) == sampleSet.begin())
这仅在 dt[i]
时成立。位于 sampleSet
的第一个元素处。从您的 r 代码中,我假设您只是检查值是否 dt[i]
在sampleSet
内,在这种情况下,您需要:
if (sampleSet.find(dt[i]) != sampleSet.end())
在C++中,STL find方法通常返回一个迭代器,当未找到值时,它返回结束迭代器,所以如果find
的返回值不是end
,然后在集合中找到该值。
关于c++ - 将 R 代码转换为 C++ 以进行 Rcpp 实现,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67367887/