r - 如何优化矩阵中行和列的交集?

标签 r algorithm performance intersection

在矩阵中,例如M1,行是国家,列是年份。这些国家没有同一年的观测资料。我想找到给我最多国家的“最佳”年份交叉点。最低年限和最低国家/地区的数量将预先确定。结果中包括哪些国家并不重要,年份不必是连续的。

> M1
      [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13] [,14] [,15]
 [1,]   NA   NA   NA 2004   NA 2006   NA 2008 2009    NA  2011  2012    NA    NA    NA
 [2,]   NA 2002   NA 2004   NA   NA 2007   NA   NA  2010  2011    NA  2013  2014    NA
 [3,]   NA   NA   NA 2004 2005 2006 2007 2008 2009    NA    NA  2012  2013    NA  2015
 [4,]   NA 2002   NA 2004 2005 2006 2007 2008   NA  2010  2011    NA  2013    NA    NA
 [5,] 2001   NA   NA   NA 2005 2006 2007 2008   NA  2010    NA  2012  2013  2014    NA
 [6,] 2001   NA 2003 2004 2005 2006 2007 2008 2009  2010  2011  2012    NA  2014    NA
 [7,] 2001 2002   NA   NA 2005   NA 2007   NA 2009    NA  2011    NA    NA  2014  2015
 [8,] 2001 2002   NA 2004 2005 2006   NA   NA   NA  2010    NA    NA  2013    NA  2015
 [9,]   NA 2002   NA 2004 2005   NA 2007   NA   NA  2010  2011    NA    NA    NA    NA
[10,] 2001 2002   NA 2004   NA   NA   NA   NA   NA  2010    NA  2012    NA  2014  2015

因为没有明显的相交,所以一次 Reduce(intersect...) 尝试是行不通的,我通过连续排除一个国家达到定义的阈值 来重复这样做n.行。过滤结果至少为 n.col 年。我写了这个函数,

findBestIntersect <- function(M, min.row=5, min.col=3) {
  ## min.row: minimum number of rows (countries) to analyze
  ## min.col: minimum number of complete columns (years)
  # put matrices with row combn into list (HUGE!)
  L1 <- lapply(min.row:(nrow(M) - 1), function(x)
    combn(nrow(M), x, function(i) M[i, ], simplify=FALSE))
  # select lists w/ def. number of complete columns
  slc <- sapply(L1, function(y)  # numbers of lists
    which(sapply(y, function(x)
      sum(!(apply(x, 2, function(i) any(is.na(i))))))
      >= min.col))
  # list selected lists
  L2 <- Map(function(x, i)
    x[i], L1[lengths(slc) > 0], slc[lengths(slc) > 0])
  # find intersects
  L3 <- rapply(L2, function(l)
    as.integer(na.omit(Reduce(intersect, as.list(as.data.frame(t(l)))))),
    how="list")
  return(unique(unlist(L3, recursive=FALSE)))
}

它立即为我提供了 M1 的预期结果。

> system.time(best.yrs.1 <- findBestIntersect(M1))
   user  system elapsed 
   0.06    0.00    0.07 

> best.yrs.1
[[1]]
[1] 2002 2004 2010

但是 M2 的性能只是勉强可以接受(RAM 使用量约为 1.1 GB),

> system.time(best.yrs.2 <- findBestIntersect(M2))
   user  system elapsed 
  79.90    0.39   82.76 
> head(best.yrs.2, 3)
[[1]]
[1] 2002 2009 2015

[[2]]
[1] 2002 2014 2015

[[3]]
[1] 2003 2009 2010

并且您不想使用类似于我的真实矩阵的 M3(爆炸 32 GB RAM)来尝试此操作:

# best.yrs.3 <- findBestIntersect(M3)

可能该函数最大的缺陷是 L1 变得太大非常快。

所以,我的问题是,是否有更好的方法也适用于M3? “奖金”将最大化国家和年份。如果可能的话,我想在没有额外包的情况下做到这一点。

数据

set.seed(42)
tf <- matrix(sample(c(TRUE, FALSE), 150, replace=TRUE), 10)
M1 <- t(replicate(10, 2001:2015, simplify=TRUE))
M1[tf] <- NA

tf <- matrix(sample(c(TRUE, FALSE), 300, replace=TRUE), 20)
M2 <- t(replicate(20, 2001:2015, simplify=TRUE))
M2[tf] <- NA

tf <- matrix(sample(c(TRUE, FALSE), 1488, replace=TRUE), 31)
M3 <- t(replicate(31, 1969:2016, simplify=TRUE))
M3[tf] <- NA

最佳答案

我编写了一个 coded_best_intersect 函数,该函数依赖于在 code_maker 函数中动态创建 for 循环。它在 30 秒内评估 M3。因为代码生成了一个列表,所以我依赖于 data.table 来获取 rbindlist 和 print 方法。

library(data.table)

code_maker 函数:

code_maker <- function(non_NA_M, n, k, min.col) {
  ## initializing for results
  res <- list()
  z <- 1
  ## initializing naming
  col_names <- colnames(non_NA_M)
  i_s <- paste0('i', seq_len(k))
  ## create the foor loop text. It looks like this mostly
  ## for (i1 in 1:(n - k + 1)) { for (i2 in (i1 + 1):(n-k+2)) {}}
  for_loop <- paste0('for (', i_s, ' in ', c('1:', paste0('(', i_s[-k], ' + 1):')), 
                     n - k + seq_len(k), ')', ' {\n non_na_sums', seq_len(k), 
                     '=non_NA_M[', i_s, ', ] ',
                     c('', paste0('& ', rep('non_na_sums', k - 1), seq_len(k)[-k])), '', 
                     '\n if (sum(non_na_sums', seq_len(k), ') < ', min.col, ') {next} ', 
                     collapse='\n')
  ## create the assignment back to the results which looks like
  ## res[[z]] <- data.table(M=k, N=sum(non_na_sumsk), ROWS=list(c(i1, i2, ..., ik)), 
  ##                        YEARS=list(col_names[non_na_sumsk]))
  inner_text <- paste0('\nres[[z]] <- data.table(M=k, N=sum(non_na_sums',
                       k, '), ROWS=list(c( ', paste0(i_s, collapse=', '), 
                       ')), YEARS=list(col_names[non_na_sums', k , ']))\nz <- z + 1')
  ## combines the loop parts and closes the for with }}}
  for_loop <- paste(for_loop, 
                    inner_text, 
                    paste0(rep('}', k), collapse=''))
  ## evaluate - the evaluation will assign back to res[[i]]  
  eval(parse(text=for_loop))
  res <- rbindlist(res)
  if (length(res) == 0) { #to return emtpy data.table with the correct fields
    return(data.table(M=integer(), N=integer(), ROWS=list(), YEARS=list()))
  }
  res$M <- k
  return(res)
}

coded_best_intersect 函数:

coded_best_intersect <- function(M, min.row=5, min.col=3) {
  colnames(M) <- apply(M, 2, function(x) na.omit(x)[1])
  n_row <- nrow(M)
  non_NA <- !is.na(M)
  n_combos <- min.row:(n_row - 1)
  res2 <- list()
  for (i in seq_along(n_combos)) {
    res2[[i]] <- code_maker(non_NA, n=n_row, k=n_combos[i], min.col)
    if (nrow(res2[[i]]) == 0) {
      break
    }
  }
  return(res2)
}

这是例如为 k=5 即时生成的代码:

# for (i1 in 1:5) {
#   non_na_sums1=non_NA_M[i1, ] 
#   if (sum(non_na_sums1) < 3) {next} 
#   for (i2 in (i1 + 1):6) {
#     non_na_sums2=non_NA_M[i2, ] & non_na_sums1
#     if (sum(non_na_sums2) < 3) {next} 
#     for (i3 in (i2 + 1):7) {
#       non_na_sums3=non_NA_M[i3, ] & non_na_sums2
#       if (sum(non_na_sums3) < 3) {next} 
#       for (i4 in (i3 + 1):8) {
#         non_na_sums4=non_NA_M[i4, ] & non_na_sums3
#         if (sum(non_na_sums4) < 3) {next} 
#         for (i5 in (i4 + 1):9) {
#           non_na_sums5=non_NA_M[i5, ] & non_na_sums4
#           if (sum(non_na_sums5) < 3) {next} 
#           for (i6 in (i5 + 1):10) {
#             non_na_sums6=non_NA_M[i6, ] & non_na_sums5
#             if (sum(non_na_sums6) < 3) {next}  
#             res[[z]] <- data.table(M=k, N=sum(non_na_sums6), 
#                                    ROWS=list(c( i1, i2, i3, i4, i5, i6)),
#                                    YEARS=list(col_names[non_na_sums6]))
#             z <- z + 1 }}}}}}

您可能会注意到 {next},这是一种在无法获得最少 3 列的情况下跳过组合的方法。虽然看起来都是硬编码的,但代码实际上是生成、解析然后求值的字符串。

使用和性能

矩阵M1:

system.time(final1 <- coded_best_intersect(M1))
   user  system elapsed 
      0       0       0 
data.table::rbindlist(final1)[order(-M*N)]
   M N           ROWS          YEARS
1: 5 3  2, 4, 8, 9,10 2002,2004,2010

矩阵M2:

system.time(final2 <- coded_best_intersect(M2))
   user  system elapsed 
   0.08    0.00    0.08 
data.table::rbindlist(final2)[order(-M*N)]
     M N                  ROWS               YEARS
  1: 7 3  6, 8,11,12,13,16,...      2002,2012,2013
  2: 5 4         6, 8,13,16,17 2002,2012,2013,2015
  3: 5 4         8,11,12,13,17 2002,2012,2013,2014
  4: 6 3      1, 4, 8,13,17,20      2002,2014,2015
  5: 6 3      2, 5, 6,10,14,17      2003,2006,2008
 ---                                              
126: 5 3        10,12,13,17,20      2002,2008,2014
127: 5 3        10,12,14,17,20      2003,2008,2014
128: 5 3        11,12,13,16,17      2002,2012,2013
129: 5 3        11,12,13,17,20      2002,2012,2014
130: 5 3        12,13,15,16,19      2001,2002,2013

矩阵M3:

system.time(final3 <- coded_best_intersect(M3))
   user  system elapsed 
  29.37    0.05   29.54 
data.table::rbindlist(final3)[order(-M*N)]
       M N              ROWS                             YEARS
    1: 6 7  1, 3, 8,15,20,29 1969,1973,1980,1984,1985,1992,...
    2: 5 8     1, 3, 8,14,29 1969,1973,1976,1980,1984,1987,...
    3: 5 8     1, 3, 8,20,29 1969,1973,1980,1984,1985,1992,...
    4: 5 8     2, 7, 9,13,17 1974,1993,1994,2004,2012,2013,...
    5: 5 8     3, 6, 8, 9,27 1974,1980,1984,1987,1995,1998,...
   ---                                                        
52374: 5 3    23,24,25,30,31                    1979,1997,2002
52375: 5 3    23,25,28,30,31                    1979,1992,2002
52376: 5 3    24,25,26,30,31                    1983,1997,2002
52377: 5 3    24,25,28,30,31                    1979,1983,2002
52378: 5 3    24,26,28,30,31                    1983,1986,2002

要将结果的选定部分放入字符串中,您可以执行例如以下内容:

x <- data.table::rbindlist(final3)[order(-M*N)]
el(x$YEARS[1])  # select `YEARS` of result-row `1:`
# [1] "1969" "1973" "1980" "1984" "1985" "1992" "2003"

注意:请查看其他两种截然不同的方法的编辑历史记录。第一个是 melt 和 join 技术,它会耗尽内存。第二种方法是使用 RcppAlgos::comboGeneral 来评估函数。

关于r - 如何优化矩阵中行和列的交集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57534107/

相关文章:

r - 在 R 中对数据帧的多个向量进行操作然后取平均值

寻找非重叠序列的最大覆盖范围的算法。 (即,加权间隔调度概率。)

algorithm - 求和 - 封闭式 - 从哪里开始

asp.net-mvc - MVC 框架做了什么来避免它从大量使用反射中继承的低性能

R:计算每列满足条件的次数并且行名出现在列表中

r - 包 ‘stringr’ 在 R 4.0.0 之前安装 : please re-install it BiocManager Installation path not writeable, 无法更新包

r - 如何将州和县线添加到 R 中的 geom_raster ggplot?

algorithm - 有效地找到二维网格中最大的周围正方形

java - 为什么 String.equalsIgnoreCase 这么慢

c# - 在 C# 中通过引用传递数组中的结构对性能有何影响