r - 有没有更快的方法来检查列表中的列表是否相等?

标签 r list algorithm performance if-statement

这里我有四个不同分区的整数 1:7,即 {1}、{2,3,4}、{5,6} 和 {7},这些分区是写成列表,即 list(1,c(2,3,4),c(5,6),7)。我将分区视为集合,这样一个分区内元素的不同排列应被识别为相同的排列。例如,list(1,c(2,3,4),c(5,6),7)list(7,1,c(2,3,4) ,c(6,5)) 是等效的。

请注意,列表中的元素没有重复,例如,没有 list(c(1,2),c(2,1),c(1,2) ),因为这个问题正在讨论整个集合上的独占分区。

我在列表 lst 中列出了一些不同的排列,如下所示

lst <- list(list(1,c(2,3,4),c(5,6),7),
            list(c(2,3,4),1,7,c(5,6)),
            list(1,c(2,3,4),7,c(6,5)),
            list(7,1,c(3,2,4),c(5,6)))

我想做的是验证所有排列都是等价的。如果是,那么我们得到结果 TRUE

到目前为止,我所做的是对每个分区内的元素进行排序,并将 setdiff()interset()union() 一起使用> 来判断(见下面我的代码)

s <- Map(function(v) Map(sort,v),lst)
equivalent <- length(setdiff(Reduce(union,s),Reduce(intersect,s),))==0

但是,我猜当分区大小扩大时,这种方法会很慢。有没有更快的方法来制作它?提前致谢!

  • 一些测试用例(小数据)
# should return `TRUE`
lst1 <- list(list(1,c(2,3,4),c(5,6)),
            list(c(2,3,4),1,c(5,6)),
            list(1,c(2,3,4),c(6,5)))

# should return `TRUE`
lst2 <- list(list(1:2, 3:4), list(3:4, 1:2))

# should return `FALSE`
lst3 <- list(list(1,c(2,3,4),c(5,6)), list(c(2,3,4),1,c(5,6)), list(1,c(2,3,5),c(6,4)))

最佳答案

如果没有包含 的解决方案,有关 R 和任何 fast 变体的文章就不完整。

为了最大限度地提高效率,选择正确的数据结构至关重要。我们的数据结构需要存储唯一的值,并且还需要快速插入/访问。这正是 std::unordered_set 所体现的。我们只需要确定如何唯一地标识无序整数的每个向量

输入 Fundamental Theorem of Arithmetic

FTA 规定,每个数字都可以通过素数的乘积唯一地表示(最多可达因子的顺序)。

这里是一个示例,演示了我们如何使用 FTA 快速破译两个向量是否在阶上等价(注意,下面的 P 是素数列表... (2 、3、5、7、11 等):

                   Maps to                    Maps to              product
vec1 = (1, 2, 7)    -->>    P[1], P[2], P[7]   --->>   2,  3, 17     -->>   102
vec2 = (7, 3, 1)    -->>    P[7], P[3], P[1]   --->>  17,  5,  2     -->>   170
vec3 = (2, 7, 1)    -->>    P[2], P[7], P[1]   --->>   3, 17,  2     -->>   102

由此,我们看到 vec1vec3 正确映射到相同的数字,而 vec2 映射到不同的值。

由于我们的实际向量可能包含多达一百个小于 1000 的整数,因此应用 FTA 将产生非常大的数字。我们可以利用对数乘积规则来解决这个问题:

logb(xy) = logb(x) + logb(y)

有了这个,我们将能够处理更大的数字示例(这在非常大的示例上开始恶化)。

首先,我们需要一个简单的素数生成器(注意,我们实际上是生成每个素数的对数)。

#include <Rcpp.h>
using namespace Rcpp;

// [[Rcpp::plugins(cpp11)]]

void getNPrimes(std::vector<double> &logPrimes) {
    
    const int n = logPrimes.size();
    const int limit = static_cast<int>(2.0 * static_cast<double>(n) * std::log(n));
    std::vector<bool> sieve(limit + 1, true);
    
    int lastP = 3;
    const int fsqr = std::sqrt(static_cast<double>(limit));
    
    while (lastP <= fsqr) {
        for (int j = lastP * lastP; j <= limit; j += 2 * lastP)
            sieve[j] = false;
        
        int ind = 2;
        
        for (int k = lastP + 2; !sieve[k]; k += 2)
            ind += 2;
        
        lastP += ind;
    }
    
    logPrimes[0] = std::log(2.0);
    
    for (int i = 3, j = 1; i <= limit && j < n; i += 2)
        if (sieve[i])
            logPrimes[j++] = std::log(static_cast<double>(i));
}

这是主要实现:

// [[Rcpp::export]]
bool f_Rcpp_Hash(List x) {
    
    List tempLst = x[0];
    const int n = tempLst.length();
    int myMax = 0;

    // Find the max so we know how many primes to generate
    for (int i = 0; i < n; ++i) {
        IntegerVector v = tempLst[i];
        const int tempMax = *std::max_element(v.cbegin(), v.cend());
        
        if (tempMax > myMax)
            myMax = tempMax;
    }
    
    std::vector<double> logPrimes(myMax + 1, 0.0);
    getNPrimes(logPrimes);
    double sumMax = 0.0;
    
    for (int i = 0; i < n; ++i) {
        IntegerVector v = tempLst[i];
        double mySum = 0.0;
        
        for (auto j: v)
            mySum += logPrimes[j];
        
        if (mySum > sumMax)
            sumMax = mySum;
    }

    // Since all of the sums will be double values and we want to
    // ensure that they are compared with scrutiny, we multiply
    // each sum by a very large integer to bring the decimals to
    // the right of the zero and then convert them to an integer.
    // E.g. Using the example above v1 = (1, 2, 7) & v2 = (7, 3, 1)
    //              
    //    sum of log of primes for v1 = log(2) + log(3) + log(17)
    //                               ~= 4.62497281328427
    //
    //    sum of log of primes for v2 = log(17) + log(5) + log(2)
    //                               ~= 5.13579843705026
    //    
    //    multiplier = floor(.Machine$integer.max / 5.13579843705026)
    //    [1] 418140173
    //    
    // Now, we multiply each sum and convert to an integer
    //    
    //    as.integer(4.62497281328427 * 418140173)
    //    [1] 1933886932    <<--   This is the key for v1
    //
    //    as.integer(5.13579843705026 * 418140173)
    //    [1] 2147483646    <<--   This is the key for v2
    
    const uint64_t multiplier = std::numeric_limits<int>::max() / sumMax;
    std::unordered_set<uint64_t> canon;
    canon.reserve(n);
    
    for (int i = 0; i < n; ++i) {
        IntegerVector v = tempLst[i];
        double mySum = 0.0;
        
        for (auto j: v)
            mySum += logPrimes[j];
        
        canon.insert(static_cast<uint64_t>(multiplier * mySum));
    }
    
    const auto myEnd = canon.end();
    
    for (auto it = x.begin() + 1; it != x.end(); ++it) {
        List tempLst = *it;
        
        if (tempLst.length() != n)
            return false;

        for (int j = 0; j < n; ++j) {
            IntegerVector v = tempLst[j];
            double mySum = 0.0;
            
            for (auto k: v)
                mySum += logPrimes[k];
            
            const uint64_t key = static_cast<uint64_t>(multiplier * mySum);
            
            if (canon.find(key) == myEnd)
                return false;
        }
    }
    
    return true;
}

以下是应用于 @GKi 给出的 lst1、lst2、lst3 和 lst(大的) 时的结果。

f_Rcpp_Hash(lst)
[1] TRUE

f_Rcpp_Hash(lst1)
[1] TRUE

f_Rcpp_Hash(lst2)
[1] FALSE

f_Rcpp_Hash(lst3)
[1] FALSE

以下是一些将 units 参数设置为 relative 的基准测试。

microbenchmark(check = 'equal', times = 10
               , unit = "relative"
               , f_ThomsIsCoding(lst3)
               , f_chinsoon12(lst3)
               , f_GKi_6a(lst3)
               , f_GKi_6b(lst3)
               , f_Rcpp_Hash(lst3))
Unit: relative
                 expr       min        lq      mean    median        uq       max neval
f_ThomsIsCoding(lst3) 84.882393 63.541468 55.741646 57.894564 56.732118 33.142979    10
   f_chinsoon12(lst3) 31.984571 24.320220 22.148787 22.393368 23.599284 15.211029    10
       f_GKi_6a(lst3)  7.207269  5.978577  5.431342  5.761809  5.852944  3.439283    10
       f_GKi_6b(lst3)  7.399280  5.751190  6.350720  5.484894  5.893290  8.035091    10
    f_Rcpp_Hash(lst3)  1.000000  1.000000  1.000000  1.000000  1.000000  1.000000    10


microbenchmark(check = 'equal', times = 10
               , unit = "relative"
               , f_ThomsIsCoding(lst)
               , f_chinsoon12(lst)
               , f_GKi_6a(lst)
               , f_GKi_6b(lst)
               , f_Rcpp_Hash(lst))
Unit: relative
                expr        min         lq       mean     median        uq       max neval
f_ThomsIsCoding(lst) 199.776328 202.318938 142.909407 209.422530 91.753335 85.090838    10
   f_chinsoon12(lst)   9.542780   8.983248   6.755171   9.766027  4.903246  3.834358    10
       f_GKi_6a(lst)   3.169508   3.158366   2.555443   3.731292  1.902140  1.649982    10
       f_GKi_6b(lst)   2.992992   2.943981   2.019393   3.046393  1.315166  1.069585    10
    f_Rcpp_Hash(lst)   1.000000   1.000000   1.000000   1.000000  1.000000  1.000000    10

比较大示例中最快的解决方案快大约3 倍

What does this mean?

对我来说,这个结果充分说明了 @GKi、@chinsoon12、@Gregor、@ThomasIsCoding 等人所展示的 base R 的美丽和效率。我们编写了大约 100 行非常具体的 C++ 代码以获得适度的加速。公平地说,基本 R 解决方案最终会调用大部分已编译的代码,并最终使用哈希表,就像我们上面所做的那样。

关于r - 有没有更快的方法来检查列表中的列表是否相等?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59270225/

相关文章:

r - 从数据框中的点创建 shapefile

linux - 使用 Shell 脚本自动安装 R-Studio

python - 使用递归检查两个未排序列表的元素和长度是否相等

java - 在 Java 中按 y 值对坐标进行排序的最简单方法?

algorithm - booth算法的本质是什么?

R Shiny : Change tabs from within a module

r - 使用与列同名的变量对 data.table 进行子集化

python - 如何使用 Python 给出列表中每个数字元素的相反数

javascript - 我可以从嵌套数组中删除冗余数据结构包装器到嵌套 json 脚本吗?

algorithm - 计算两个整数序列之间的 Kendall Tau 距离的快速算法