r - 性能提升

标签 r data.table

我正在尝试获取基于滞后/转发的函数应用。我广泛使用 data.table 并且我什至有工作代码,但是知道 data.table 的强大功能我认为必须有一种更简单的方法来实现相同的目标,并可能改进性能(我在函数内创建了大量变量)。以下是函数的工作代码(可在 https://gist.github.com/tomaskrehlik/5262087#file-gistfile1-r 中找到)

# Lag-function lags the given variable by the date_variable

lag_variable <- function(data, variable, lags, date_variable = c("Date")) {
    if (lags == 0) {
      return(data)
    }
    if (lags>0) {
      name <- "lag"
    } else {
      name <- "forward"
    }
    require(data.table)
    setkeyv(data, date_variable)
    if (lags>0) {
      data[,index:=seq(1:.N)]  
    } else {
      data[,index:=rev(seq(1:.N))]
    }
    setkeyv(data, "index")
    lags <- abs(lags)
    position <- which(names(data)==variable)
    for ( j in 1:lags ) {
      lagname <- paste(variable,"_",name,j,sep="")
      lag <- paste("data[, ",lagname,":=data[list(index-",j,"), ",variable,", roll=TRUE][[",position,"L]]]", sep = "")
      eval(parse( text = lag ))
    }
    setkeyv(data, date_variable)
    data[,index:=NULL]
}

# window_func applies the function to the lagged or forwarded variables created by lag_variable
window_func <- function(data, func.name, variable, direction = "window", steps, date_variable = c("Date"), clean = TRUE) {
    require(data.table)
    require(stringr)
    transform <- match.fun(func.name)
    l <- length(names(data))
    if (direction == "forward") {
      lag_variable(data, variable, -steps, date_variable)
      cols <- which((!(is.na(str_match(names(a), paste(variable,"_forward(",paste(1:steps,collapse="|"),")",sep=""))[,1])))*1==1)
    } else {
      if (direction == "backward") {
        lag_variable(data, variable, steps, date_variable)
        cols <- which((!(is.na(str_match(names(a), paste(variable,"_lag(",paste(1:steps,collapse="|"),")",sep=""))[,1])))*1==1)
      } else {
        if (direction == "window") {
          lag_variable(data, variable, -steps, date_variable)
          lag_variable(data, variable, steps, date_variable)
          cols <- which((!(is.na(str_match(names(a), paste(variable,"_lag(",paste(1:steps,collapse="|"),")",sep=""))[,1])))*1==1)
          cols <- c(cols,which((!(is.na(str_match(names(a), paste(variable,"_forward(",paste(1:steps,collapse="|"),")",sep=""))[,1])))*1==1))
        } else {
          stop("The direction must be either backward, forward or window.")
        }
      }
    }
    data[,transf := apply(data[,cols, with=FALSE], 1, transform)]
    if (clean) {
      data[,cols:=NULL,with=FALSE]
    }
    return(data)
}

# Typical use:
# I have a data.table DT with variables Date (class IDate), value1, value2
# I want to get cumulative sum of next five days
# window_func(DT, "sum", "value1", direction = "forward", steps = 5)

编辑:示例数据可以通过以下方式创建:

a <- data.table(Date = 1:1000, value = rnorm(1000))

对于每个日期(此处为整数,仅作为示例,并不重要),我想创建接下来十个观察值的总和。要运行代码并获取输出,请执行以下操作:

window_func(data = a, func.name = "sum", variable = "value", 
      direction = "forward", steps = 10, date_variable = "Date", clean = TRUE)

该函数首先获取变量并创建十个滞后变量(使用函数lag_variable),然后按列应用函数并在自身之后进行清理。代码臃肿,因为我有时只需要在滞后观察上使用函数,有时在前向观察上使用函数,有时在两者上使用函数,这称为窗口。

有什么建议可以更好地实现这一点吗?我的代码似乎太大了。

最佳答案

我不确定你的函数的其余部分,但你可以相当有效地获得滞后总和,如下所示:

a[ , lagSum := 
       a[, list(sum=sum(value)), by=list(grp=(Date+lag-i) %/% lag)] [grp!=0, sum]
   , by=list(i=Date %% lag)]

例如:

set.seed(1)
a[ , lagSum := 
       a[, list(sum=sum(value)), by=list(grp=(Date+lag-i) %/% lag)] [grp!=0, sum]
   , by=list(i=Date %% lag)]

> a
      Date      value      lagSum
   1:    1 -0.6264538  1.32202781
   2:    2  0.1836433  3.46026279
   3:    3 -0.8356286  3.66646270
   4:    4  1.5952808  3.88085074
   5:    5  0.3295078  0.07087005
  ---                            
 996:  996 -0.3132929 -3.79332038
 997:  997 -0.8806707 -3.48002750
 998:  998 -0.4192869 -2.59935677
 999:  999 -1.4827517 -2.18006988
1000: 1000 -0.6973182 -1.88854602

确认正确的值:

# first n values
n <- 5
for (i in seq(n))
  a[seq(i, length.out=10), print(sum(value))]

#  [1] 1.322028
#  [1] 3.460263
#  [1] 3.666463
#  [1] 3.880851
#  [1] 0.07087005

基准(针对 for 循环,所以不太公平)

set.seed(1)
a <- data.table(Date = 1:1000, value = rnorm(1000))

system.time({    a[ , lagSum := 
           a[, list(sum=sum(value)), by=list(grp=(Date+lag-i) %/% lag)] [grp!=0, sum]
       , by=list(i=Date %% lag)]
})

#  user  system elapsed 
# 0.049   0.001   0.056 



set.seed(1)
a <- data.table(Date = 1:1000, value = rnorm(1000))

system.time({    for (i in seq(nrow(a)-lag+1))
      a[seq(i, length.out=10), lagSum := sum(value)]})

#  user  system elapsed 
# 1.526   0.019   2.220 

关于r - 性能提升,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/15691216/

相关文章:

r - R 中的基本 2D 颜色平滑

r - 根据分组变量将值粘贴在一起

r - 如何删除百分号并使用 r 查找列中某些值的平均值?

r - ffdfdply,R 中的分割和内存限制

R : Getting the sum of columns in a data. 按某列进行帧组

r - 错误 : Case 1 (`.` ) must be a two-sided formula, 不是 `data.frame` 对象

r - 是否有等效于 unlist() 的 S4?

r - 聚合和 DCast

r - 为 R 中的各种 ID 跨多行连接值

r - 对 2 个变量进行分组总和并从总计中减去