r - 寻找跨多维数组实现 logSumExp 的更快方法

标签 r multidimensional-array

我正在编写的一些 R 代码中有一行非常慢。它使用 apply 命令在 4 维数组中应用 logSumExp。我想知道有什么方法可以加快速度!

Reprex:(这可能需要 10 秒或更长时间才能运行)

library(microbenchmark)
library(matrixStats)

array4d <- array( runif(5*500*50*5 ,-1,0),
                  dim = c(5, 500, 50, 5) )
microbenchmark(
    result <- apply(array4d, c(1,2,3), logSumExp)
)

任何建议表示赞赏!

最佳答案

rowSums 是 apply 的一个不太通用的版本,它针对相加时的速度进行了优化,因此可以用来加快计算速度。如果在计算中保持 NANaN 之间的差异很重要,请注意帮助文件 ?rowSums 中的警告。

library(microbenchmark)
library(matrixStats)

array4d <- array( runif(5*500*50*5 ,-1,0),
                  dim = c(5, 500, 50, 5) )
microbenchmark(
  result <- apply(array4d, c(1,2,3), logSumExp),
  result2 <- log(rowSums(exp(array4d), dims=3))
)


# Unit: milliseconds
#                                            expr      min       lq      mean    median        uq      max neval
# result <- apply(array4d, c(1, 2, 3), logSumExp) 249.4757 274.8227 305.24680 297.30245 328.90610 405.5038   100
# result2 <- log(rowSums(exp(array4d), dims = 3))  31.8783  32.7493  35.20605  33.01965  33.45205 133.3257   100

all.equal(result, result2)

#TRUE

这会使我的计算机速度提高 9 倍

关于r - 寻找跨多维数组实现 logSumExp 的更快方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62891781/

相关文章:

c - 尝试使用两个初始化列表初始化二维结构数组时发出警告

r - 如何结合R语言制作交互式图表

R dplyr pivot wider with duplicates 并生成变量名

PHP - 遍历多个表并填充多维数组

c - 下面的代码有什么问题?如何纠正这个问题?

php - 为mysql准备多维数组

Python - 在 n 维列表中查找所有项目 x 的位置

r - 从tar捕获错误并继续处理

R highcharts多堆积条形图

使用knitr编译.rnw文件后自动删除.tex文件