c++ - 根据另一个可变参数包查找可变参数包的收缩

标签 c++ c++11 multidimensional-array template-meta-programming

我正在研究一个静态多维数组收缩框架,我遇到了一个有点难以解释的问题,但我会尽力而为。假设我们有一个 N维数组类

template<typename T, int ... dims>
class Array {}

可以实例化为

Array<double> scalar;
Array<double,4> vector_of_4s;
Array<float,2,3> matrix_of_2_by_3;
// and so on

现在我们有另一个类叫做 Indices

template<int ... Idx>
struct Indices {}

我有一个函数 contraction现在谁的签名应该如下所示

template<T, int ... Dims, int ... Idx, 
typename std::enable_if<sizeof...(Dims)==sizeof...(Idx),bool>::type=0>
Array<T,apply_to_dims<Dims...,do_contract<Idx...>>> 
contraction(const Indices<Idx...> &idx, const Array<T,Dims...> &a)

我可能没有得到正确的语法,但我基本上想要返回的 Array具有基于 Indices 条目的维度.让我举例说明什么是contraction可以执行。请注意,在此上下文中,收缩 表示删除索引列表中参数相等的维度

auto arr = contraction(Indices<0,0>, Array<double,3,3>) 
// arr is Array<double> as both indices contract 0==0

auto arr = contraction(Indices<0,1>, Array<double,3,3>) 
// arr is Array<double,3,3> as no contraction happens here, 0!=1

auto arr = contraction(Indices<0,1,0>, Array<double,3,4,3>) 
// arr is Array<double,4> as 1st and 3rd indices contract 0==0  

auto arr = contraction(Indices<0,1,0,7,7,2>, Array<double,3,4,3,5,5,6>) 
// arr is Array<double,4,6> as (1st and 3rd, 0==0) and (4th and 5th, 7==7) indices contract

auto arr = contraction(Indices<10,10,2,3>, Array<double,5,6,4,4>
// should not compile as contraction between 1st and 2nd arguments 
// requested but dimensions don't match 5!=6

// The parameters of Indices really do not matter as long as 
// we can identify contractions. They are typically expressed as enums, I,J,K...

所以本质上,给定 Idx...Dims...两者的大小应该相等,请检查 Idx... 中的哪些值相等,获取它们出现的位置并删除 Dims... 中的相应条目(位置) .这本质上是一个 tensor contraction rule .

数组收缩规则:

  1. 索引的参数个数和数组的维度/秩应该相同,即sizeof...(Idx)==sizeof...(Dims)
  2. Idx之间一对一对应和 Dims即如果我们有 Indices<0,1,2>Array<double,4,5,6> , 0映射到 4 , 1映射到 52映射到 6 .
  3. 如果 Idx 中有相同/相等的值,这意味着收缩,意味着Dims中的相应尺寸应该消失,例如,如果我们有 Indices<0,0,3>Array<double,4,4,6> , 然后 0==0以及这些值映射到的相应维度是 44两者都需要消失,结果数组应该是 Array<double,6>
  4. 如果Idx具有相同的值,但相应的 Dims不匹配,则应触发编译时错误,例如 Indices<0,0,3>Array<double,4,5,6>不可能因为 4!=5 , 同样 Indices<0,1,0>不可能因为 4!=6 , 这导致
  5. 对于不同维度的数组,不可能进行收缩,例如Array<double,4,5,6>不能以任何方式签约。
  6. 允许多对、三胞胎、四胞胎等 Idx只要对应Dims也匹配,例如 Indices<0,0,0,0,1,1,4,3,3,7,7,7>将与 Array<double,6> 签约,假设输入数组是 Array<double,2,2,2,2,3,3,6,2,2,3,3,3> .

我对元编程的了解不足以实现此功能,但我希望我已经明确说明了意图,以便有人指导我朝着正确的方向前进。

最佳答案

一堆 constexpr 函数做实际的检查:

// is ind[i] unique in ind?
template<size_t N>
constexpr bool is_uniq(const int (&ind)[N], size_t i, size_t cur = 0){
    return cur == N ? true : 
           (cur == i || ind[cur] != ind[i]) ? is_uniq(ind, i, cur + 1) : false;
}

// For every i where ind[i] == index, is dim[i] == dimension?
template<size_t N>
constexpr bool check_all_eq(int index, int dimension,
                            const int (&ind)[N], const int (&dim)[N], size_t cur = 0) {
    return cur == N ? true :
           (ind[cur] != index || dim[cur] == dimension) ? 
                check_all_eq(index, dimension, ind, dim, cur + 1) : false;
}

// if position i should be contracted away, return -1, otherwise return dim[i].
// triggers a compile-time error when used in a constant expression on mismatch.
template<size_t N>
constexpr int calc(size_t i, const int (&ind)[N], const int (&dim)[N]){
    return is_uniq(ind, i) ? dim[i] :
           check_all_eq(ind[i], dim[i], ind, dim) ? -1 : throw "dimension mismatch";
}

现在我们需要一种方法来摆脱 -1:

template<class Ind, class... Inds>
struct concat { using type = Ind; };
template<int... I1, int... I2, class... Inds>
struct concat<Indices<I1...>, Indices<I2...>, Inds...>
    :  concat<Indices<I1..., I2...>, Inds...> {};

// filter out all instances of I from Is...,
// return the rest as an Indices    
template<int I, int... Is>
struct filter
    :  concat<typename std::conditional<Is == I, Indices<>, Indices<Is>>::type...> {};

使用它们:

template<class Ind, class Arr, class Seq>
struct contraction_impl;

template<class T, int... Ind, int... Dim, size_t... Seq>
struct contraction_impl<Indices<Ind...>, Array<T, Dim...>, std::index_sequence<Seq...>>{
    static constexpr int ind[] = { Ind... };
    static constexpr int dim[] = { Dim... };
    static constexpr int result[] = {calc(Seq, ind, dim)...};

    template<int... Dims>
    static auto unpack_helper(Indices<Dims...>) -> Array<T, Dims...>;

    using type = decltype(unpack_helper(typename filter<-1,  result[Seq]...>::type{}));
};


template<class T, int ... Dims, int ... Idx, 
typename std::enable_if<sizeof...(Dims)==sizeof...(Idx),bool>::type=0>
typename contraction_impl<Indices<Idx...>, Array<T,Dims...>, 
                          std::make_index_sequence<sizeof...(Dims)>>::type
contraction(const Indices<Idx...> &idx, const Array<T,Dims...> &a);

除了 make_index_sequence 之外的所有内容都是 C++11。您可以在 SO 上找到它的大量实现。

关于c++ - 根据另一个可变参数包查找可变参数包的收缩,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37276904/

相关文章:

c++ - OMP 并行还原

c++ - 有状态 lambda 问题 - Microsoft 编译器版本 19.16.27024.1

c++ - 如何创建一个可以二进制读/写大块的 std::vector-like 类?

arrays - 不同长度的多维数组

javascript - 数组的数组?

java - 如何用Java创建系统列表?

c++ - 我在尝试计算涉及数组的总和和平均值时遇到问题

c++ - 设置 SO_RCVBUF 会减小窗口缩放因子

c++ - 头文件中的重复类声明

c++ - 哪种重载组合性能最高?