我正在研究一个静态多维数组收缩框架,我遇到了一个有点难以解释的问题,但我会尽力而为。假设我们有一个 N
维数组类
template<typename T, int ... dims>
class Array {}
可以实例化为
Array<double> scalar;
Array<double,4> vector_of_4s;
Array<float,2,3> matrix_of_2_by_3;
// and so on
现在我们有另一个类叫做 Indices
template<int ... Idx>
struct Indices {}
我有一个函数 contraction
现在谁的签名应该如下所示
template<T, int ... Dims, int ... Idx,
typename std::enable_if<sizeof...(Dims)==sizeof...(Idx),bool>::type=0>
Array<T,apply_to_dims<Dims...,do_contract<Idx...>>>
contraction(const Indices<Idx...> &idx, const Array<T,Dims...> &a)
我可能没有得到正确的语法,但我基本上想要返回的 Array
具有基于 Indices
条目的维度.让我举例说明什么是contraction
可以执行。请注意,在此上下文中,收缩 表示删除索引列表中参数相等的维度。
auto arr = contraction(Indices<0,0>, Array<double,3,3>)
// arr is Array<double> as both indices contract 0==0
auto arr = contraction(Indices<0,1>, Array<double,3,3>)
// arr is Array<double,3,3> as no contraction happens here, 0!=1
auto arr = contraction(Indices<0,1,0>, Array<double,3,4,3>)
// arr is Array<double,4> as 1st and 3rd indices contract 0==0
auto arr = contraction(Indices<0,1,0,7,7,2>, Array<double,3,4,3,5,5,6>)
// arr is Array<double,4,6> as (1st and 3rd, 0==0) and (4th and 5th, 7==7) indices contract
auto arr = contraction(Indices<10,10,2,3>, Array<double,5,6,4,4>
// should not compile as contraction between 1st and 2nd arguments
// requested but dimensions don't match 5!=6
// The parameters of Indices really do not matter as long as
// we can identify contractions. They are typically expressed as enums, I,J,K...
所以本质上,给定 Idx...
和 Dims...
两者的大小应该相等,请检查 Idx...
中的哪些值相等,获取它们出现的位置并删除 Dims...
中的相应条目(位置) .这本质上是一个 tensor contraction rule .
数组收缩规则:
- 索引的参数个数和数组的维度/秩应该相同,即
sizeof...(Idx)==sizeof...(Dims)
Idx
之间一对一对应和Dims
即如果我们有Indices<0,1,2>
和Array<double,4,5,6>
,0
映射到4
,1
映射到5
和2
映射到6
.- 如果
Idx
中有相同/相等的值,这意味着收缩,意味着Dims
中的相应尺寸应该消失,例如,如果我们有Indices<0,0,3>
和Array<double,4,4,6>
, 然后0==0
以及这些值映射到的相应维度是4
和4
两者都需要消失,结果数组应该是Array<double,6>
- 如果
Idx
具有相同的值,但相应的Dims
不匹配,则应触发编译时错误,例如Indices<0,0,3>
和Array<double,4,5,6>
不可能因为4!=5
, 同样Indices<0,1,0>
不可能因为4!=6
, 这导致 - 对于不同维度的数组,不可能进行收缩,例如
Array<double,4,5,6>
不能以任何方式签约。 - 允许多对、三胞胎、四胞胎等
Idx
只要对应Dims
也匹配,例如Indices<0,0,0,0,1,1,4,3,3,7,7,7>
将与Array<double,6>
签约,假设输入数组是Array<double,2,2,2,2,3,3,6,2,2,3,3,3>
.
我对元编程的了解不足以实现此功能,但我希望我已经明确说明了意图,以便有人指导我朝着正确的方向前进。
最佳答案
一堆 constexpr
函数做实际的检查:
// is ind[i] unique in ind?
template<size_t N>
constexpr bool is_uniq(const int (&ind)[N], size_t i, size_t cur = 0){
return cur == N ? true :
(cur == i || ind[cur] != ind[i]) ? is_uniq(ind, i, cur + 1) : false;
}
// For every i where ind[i] == index, is dim[i] == dimension?
template<size_t N>
constexpr bool check_all_eq(int index, int dimension,
const int (&ind)[N], const int (&dim)[N], size_t cur = 0) {
return cur == N ? true :
(ind[cur] != index || dim[cur] == dimension) ?
check_all_eq(index, dimension, ind, dim, cur + 1) : false;
}
// if position i should be contracted away, return -1, otherwise return dim[i].
// triggers a compile-time error when used in a constant expression on mismatch.
template<size_t N>
constexpr int calc(size_t i, const int (&ind)[N], const int (&dim)[N]){
return is_uniq(ind, i) ? dim[i] :
check_all_eq(ind[i], dim[i], ind, dim) ? -1 : throw "dimension mismatch";
}
现在我们需要一种方法来摆脱 -1
:
template<class Ind, class... Inds>
struct concat { using type = Ind; };
template<int... I1, int... I2, class... Inds>
struct concat<Indices<I1...>, Indices<I2...>, Inds...>
: concat<Indices<I1..., I2...>, Inds...> {};
// filter out all instances of I from Is...,
// return the rest as an Indices
template<int I, int... Is>
struct filter
: concat<typename std::conditional<Is == I, Indices<>, Indices<Is>>::type...> {};
使用它们:
template<class Ind, class Arr, class Seq>
struct contraction_impl;
template<class T, int... Ind, int... Dim, size_t... Seq>
struct contraction_impl<Indices<Ind...>, Array<T, Dim...>, std::index_sequence<Seq...>>{
static constexpr int ind[] = { Ind... };
static constexpr int dim[] = { Dim... };
static constexpr int result[] = {calc(Seq, ind, dim)...};
template<int... Dims>
static auto unpack_helper(Indices<Dims...>) -> Array<T, Dims...>;
using type = decltype(unpack_helper(typename filter<-1, result[Seq]...>::type{}));
};
template<class T, int ... Dims, int ... Idx,
typename std::enable_if<sizeof...(Dims)==sizeof...(Idx),bool>::type=0>
typename contraction_impl<Indices<Idx...>, Array<T,Dims...>,
std::make_index_sequence<sizeof...(Dims)>>::type
contraction(const Indices<Idx...> &idx, const Array<T,Dims...> &a);
除了 make_index_sequence
之外的所有内容都是 C++11。您可以在 SO 上找到它的大量实现。
关于c++ - 根据另一个可变参数包查找可变参数包的收缩,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37276904/