r - 数据表 |组内更快的逐行递归更新

标签 r recursion data.table shift

我必须执行以下递归逐行操作才能获得z :

myfun = function (xb, a, b) {

z = NULL

for (t in 1:length(xb)) {

    if (t >= 2) { a[t] = b[t-1] + xb[t] }
    z[t] = rnorm(1, mean = a[t])
    b[t] = a[t] + z[t]

}

return(z)

}

set.seed(1)

n_smpl = 1e6 
ni = 5

id = rep(1:n_smpl, each = ni)

smpl = data.table(id)
smpl[, time := 1:.N, by = id]

a_init = 1; b_init = 1
smpl[, ':=' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]

smpl[, z := myfun(xb, a, b), by = id]

我想得到这样的结果:
      id time a b  xb            z
  1:   1    1 1 1   1    0.3735462
  2:   1    2 1 1   2    2.7470924
  3:   1    3 1 1   3    8.4941848
  4:   1    4 1 1   4   20.9883695
  5:   1    5 1 1   5   46.9767390
 ---                              
496: 100    1 1 1 100    0.3735462
497: 100    2 1 1 200  200.7470924
498: 100    3 1 1 300  701.4941848
499: 100    4 1 1 400 1802.9883695
500: 100    5 1 1 500 4105.9767390

这确实有效,但需要时间:
system.time(smpl[, z := myfun(xb, a, b), by = id])
   user  system elapsed 
 33.646   0.994  34.473

考虑到我的实际数据的大小(超过 200 万次观察),我需要让它更快。我猜do.call(myfun, .SD), .SDcols = c('xb', 'a', 'b')by = .(id, time)会更快,避免 myfun 中的 for 循环.但是,我不确定如何更新 b及其滞后(可能使用 shift )当我在 data.table 中运行此逐行操作时.有什么建议么?

最佳答案

好问题!

从一个新的 R session 开始,显示 500 万行的演示数据,这是问题中的函数和我笔记本电脑上的时间。内联一些评论。

require(data.table)   # v1.10.0
n_smpl = 1e6
ni = 5
id = rep(1:n_smpl, each = ni)
smpl = data.table(id)
smpl[, time := 1:.N, by = id]
a_init = 1; b_init = 1
smpl[, ':=' (a = a_init, b = b_init)]
smpl[, xb := (1:.N)*id, by = id]

myfun = function (xb, a, b) {

  z = NULL
  # initializes a new length-0 variable

  for (t in 1:length(xb)) {

      if (t >= 2) { a[t] = b[t-1] + xb[t] }
      # if() on every iteration. t==1 could be done before loop

      z[t] = rnorm(1, mean = a[t])
      # z vector is grown by 1 item, each time

      b[t] = a[t] + z[t]
      # assigns to all of b vector when only really b[t-1] is
      # needed on the next iteration 
  }
  return(z)
}

set.seed(1); system.time(smpl[, z := myfun(xb, a, b), by = id][])
   user  system elapsed 
 19.216   0.004  19.212

smpl
              id time a b      xb            z
      1:       1    1 1 1       1 3.735462e-01
      2:       1    2 1 1       2 3.557190e+00
      3:       1    3 1 1       3 9.095107e+00
      4:       1    4 1 1       4 2.462112e+01
      5:       1    5 1 1       5 5.297647e+01
     ---                                      
4999996: 1000000    1 1 1 1000000 1.618913e+00
4999997: 1000000    2 1 1 2000000 2.000000e+06
4999998: 1000000    3 1 1 3000000 7.000003e+06
4999999: 1000000    4 1 1 4000000 1.800001e+07
5000000: 1000000    5 1 1 5000000 4.100001e+07

所以 19.2s 是时候打败了。在所有这些时间里,我已经在本地运行了 3 次命令以确保它是一个稳定的时间。在这个任务中时间差异是微不足道的,所以我只报告一个时间来让答案更快地阅读。

处理 myfun() 中的上述内联评论:
myfun2 = function (xb, a, b) {

  z = numeric(length(xb))
  # allocate size up front rather than growing

  z[1] = rnorm(1, mean=a[1])
  prevb = a[1]+z[1]
  t = 2L
  while(t<=length(xb)) {
    at = prevb + xb[t]
    z[t] = rnorm(1, mean=at)
    prevb = at + z[t]
    t = t+1L
  }
  return(z)
}
set.seed(1); system.time(smpl[, z2 := myfun2(xb, a, b), by = id][])
   user  system elapsed 
 13.212   0.036  13.245 
smpl[,identical(z,z2)]
[1] TRUE

这相当不错(从 19.2 秒降至 13.2 秒),但它仍然是 for在 R 级循环。乍一看,它不能被矢量化,因为 rnorm() call 取决于先前的值。实际上,它可能可以通过使用 m+sd*rnorm(mean=0,sd=1) == rnorm(mean=m, sd=sd) 的属性进行矢量化。并调用矢量化 rnorm(n=5e6)一次而不是 5e6 次。但可能会有 cumsum()参与与团体打交道。所以我们不要去那里,因为这可能会使代码更难阅读,并且会针对这个精确的问题。

因此,让我们尝试 Rcpp,它看起来与您编写的样式非常相似,并且适用范围更广:
require(Rcpp)   # v0.12.8
cppFunction(
'NumericVector myfun3(IntegerVector xb, NumericVector a, NumericVector b) {
  NumericVector z = NumericVector(xb.length());
  z[0] = R::rnorm(/*mean=*/ a[0], /*sd=*/ 1);
  double prevb = a[0]+z[0];
  int t = 1;
  while (t<xb.length()) {
    double at = prevb + xb[t];
    z[t] = R::rnorm(at, 1);
    prevb = at + z[t];
    t++;
  }
  return z;
}')

set.seed(1); system.time(smpl[, z3 := myfun3(xb, a, b), by = id][])
   user  system elapsed 
  1.800   0.020   1.819 
smpl[,identical(z,z3)]
[1] TRUE

好多了: 19.2s 降至 1.8s .但是每次调用该函数都会调用第一行( NumericVector() ),它会根据组中的行数分配一个新向量。然后将其填写并返回,将其复制到该组正确位置的最后一列(由 := ),仅用于发布。所有这 100 万个小型临时向量(每组一个)的分配和管理都有些复杂。

为什么我们不一口气完成整个专栏?您已经以 for 循环样式编写了它,这没有任何问题。让我们调整 C 函数以接受 id列并添加 if当它到达一个新的组时。
cppFunction(
'NumericVector myfun4(IntegerVector id, IntegerVector xb, NumericVector a, NumericVector b) {

  // ** id must be pre-grouped, such as via setkey(DT,id) **

  NumericVector z = NumericVector(id.length());
  int previd = id[0]-1;  // initialize to anything different than id[0]
  for (int i=0; i<id.length(); i++) {
    double prevb;
    if (id[i]!=previd) {
      // first row of new group
      z[i] = R::rnorm(a[i], 1);
      prevb = a[i]+z[i];
      previd = id[i];
    } else {
      // 2nd row of group onwards
      double at = prevb + xb[i];
      z[i] = R::rnorm(at, 1);
      prevb = at + z[i];
    }
  }
  return z;
}')

system.time(setkey(smpl,id))  # ensure grouped by id
   user  system elapsed
  0.028   0.004   0.033
set.seed(1); system.time(smpl[, z4 := myfun4(id, xb, a, b)][])
   user  system elapsed 
  0.232   0.004   0.237 
smpl[,identical(z,z4)]
[1] TRUE

这样更好: 19.2s 降至 0.27s .

关于r - 数据表 |组内更快的逐行递归更新,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41886507/

相关文章:

r - 条件 data.table 与 .EACHI 合并

r - ggplot2:按颜色和线型区分线条

r - 使用 R 对具有不同行数的多个矩阵进行学生 t 检验(配对)

Java递归按值/引用传递

Haskell:递归地将十六进制字符串转换为整数?

php - 带 <ul><li> 的递归 PHP 函数

R:向 match.call 输出添加参数

r - 如何在R Shiny 中实现重置按钮?

r - 如何比浅拷贝更快地通过引用进行更新?

在 R 数据表中运行总和