r - 更快地查找侧翼非 NA 值的索引

标签 r data.table

这是一个速度优化问题。

这是我的示例数据。真实数据有超过 100k 行和 >300 列。

library(data.table)
dt <- data.table(ref=1:20, tgt1=11:30, tgt2=21:40)
dt[c(3,8,9,15,16,17), "tgt1"] = NA
dt[c(4,5,15,17), "tgt2"] = NA
dt

#>     ref tgt1 tgt2
#>  1:   1   11   21
#>  2:   2   12   22
#>  3:   3   NA   23
#>  4:   4   14   NA
#>  5:   5   15   NA
#>  6:   6   16   26
#>  7:   7   17   27
#>  8:   8   NA   28
#>  9:   9   NA   29
#> 10:  10   20   30
#> 11:  11   21   31
#> 12:  12   22   32
#> 13:  13   23   33
#> 14:  14   24   34
#> 15:  15   NA   NA
#> 16:  16   NA   36
#> 17:  17   NA   NA
#> 18:  18   28   38
#> 19:  19   29   39
#> 20:  20   30   40

有些列在某些位置有 NA,我的目标是获取最接近的非 NA 侧翼值的位置。例如,对于第二列 tgt1,我使用以下代码

tgt = dt[, tgt1]
tgt.na = which(is.na(tgt))
tgt.non.na = which(!is.na(tgt))
start = sapply(tgt.na, function(x) max(tgt.non.na[tgt.non.na < x]))
stop = sapply(tgt.na, function(x) min(tgt.non.na[tgt.non.na > x]))
data.frame(start, stop)

#>   start stop
#> 1     2    4
#> 2     7   10
#> 3     7   10
#> 4    14   18
#> 5    14   18
#> 6    14   18

在这里,对于tgt1列,我得到了我想要的。例如,对于第 3 行的 NA,最接近的侧翼非 NA 值位于 2 和 4,对于其他行依此类推。我的问题是 sapply 非常慢。想象一下运行超过 300 列和 100k 行。以目前的形式,需要几个小时才能完成。最终,当找到这些位置时,它们将用于索引 ref 列中的值,以计算 tgt1 等列中的缺失值。但这是另一个话题了。

有什么办法可以让它更快吗?任何 data.table 方式的解决方案。

编辑:所有出色的解决方案,这是我的基准测试,您可以看到与我原来的 sapply 相比,所有建议的方法都运行得快如闪电。我选择 lapply,不仅因为它是最快的,而且还因为它与我当前的代码语法非常吻合。

Unit: milliseconds
           expr         min          lq        mean      median          uq         max neval
         sapply 3755.118949 3787.288609 3850.322669 3819.458269 3897.924530 3976.390790     3
 dt.thelatemail    9.145551    9.920238   10.242885   10.694925   10.791552   10.888180     3
  lapply.andrew    2.626525    3.038480    3.446682    3.450434    3.856760    4.263086     3
   zoo.chinsoon    6.457849    6.578099    6.629839    6.698349    6.715834    6.733318     3

最佳答案

这是使用 rle 的基本 R 替代方案。我使用了 lapply ,因为我不确定你想如何保存所有输出数据帧。希望这可以帮助!

dt <- data.table(ref=1:20, tgt1=11:30, tgt2=21:40)
dt[c(3,8,9,15,16,17), "tgt1"] = NA
dt[c(4,5,15,17), "tgt2"] = NA


lapply(dt[,-1], function(x) {
  na_loc <- which(is.na(x))
  rle_x <- rle(is.na(x))
  reps <- rle_x$lengths[rle_x$values == T]

  start <- na_loc - 1
  start <- start[!start %in% na_loc]
  end <- na_loc + 1
  end <- end[!end %in% na_loc]

  data.frame(start = rep(start, reps),
             end = rep(end, reps))
})

$tgt1
   start end
1:     2   4
2:     7  10
3:     7  10
4:    14  18
5:    14  18
6:    14  18

$tgt2
   start end
1:     3   6
2:     3   6
3:    14  16
4:    16  18

对于包含 300 列的示例数据帧,它在我的笔记本电脑上也能很好地扩展:

df1 <- data.frame(ref = 1:1e5)
df1[paste0("tgt", 1:300)] <- replicate(300, sample(c(1:50, rep(NA, 5)), 1e5, replace = T))

microbenchmark::microbenchmark(
  base = {
    lapply(df1[,-1], function(x) {
      na_loc <- which(is.na(x))
      rle <- rle(is.na(x))
      reps <- rle$lengths[rle$values == T]

      start <- na_loc - 1
      start <- start[!start %in% na_loc]
      end <- na_loc + 1
      end <- end[!end %in% na_loc]

      data.frame(start = rep(start, reps),
                 end = rep(end, reps))
    }
  )},
  times = 5
)

Unit: seconds
 expr      min       lq     mean   median       uq      max neval
 base 1.863319 1.888617 1.897651 1.892166 1.898196 1.945954     5

关于r - 更快地查找侧翼非 NA 值的索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55736576/

相关文章:

r - 列槽不足

r - Data.table 和 get() 命令 (R)

r - 通过 R 中的 IV 值对连续变量进行分箱

mysql - 如何使用sparklyr spark_write_jdbc连接MySql

r - 选择表示两个位置之间范围的行,以便仅包括至少包含另一个表的一个位置的间隔

r - data.table/data.frame rbind 无法正常工作

r - 基准测试 data.frame (base)、data.frame(package dataframe) 和 data.table

r - 在 R 中使用过滤器功能。需要为赛马数据库分配 NA 并保持数据集的长度相同

r - 使用代码范围转换 data.table

r - 使用 Reduce/do.call 和 ifelse