我有一个过滤器,它应该适用于任意张量。对于我的情况,当过滤器在等级 1,2 和 3 张量上工作时就足够了,这些张量分别是列表、矩阵和 3d 矩阵或体积。此外,过滤器可以应用于每个可能的方向。对于列表,这只是一个,对于矩阵存在 2 个可能的方向(即 X 方向和 y 方向),对于体积存在 3 个可能的方向。
在详细介绍之前,让我先问我的问题:我的过滤器布局是否正确,或者我是否忘记了一些重要的事情,这可能会让我以后遇到困难?我对C++模板并不陌生,但并不是如鱼得水。是否可以进一步压缩此布局(也许有办法绕过虚拟 XDirection
类或更短的 Type2Type
)?
过滤器的基本过程是对每个张量秩和每个方向都相同。只有几行代码,函数callKernel
,这是不同的。使重载 operator()
右拨callKernel
函数是下面代码中唯一有趣的部分。由于模板的部分特化不适用于类方法,您可以将模板参数转换为真正的类类型并将其作为伪参数提供给 callKernel
。 .
下面的代码是它们布局到 2 级。它可以用 g++
编译。并且可以试用。
template <class DataType, int Rank>
class Tensor { };
class XDirection;
class YDirection;
template <class TensorType, class Direction>
struct Type2Type {
typedef TensorType TT;
typedef Direction D;
};
template <class TensorType, class Direction>
struct Filter {
Filter(const TensorType &t){}
TensorType operator()(){
/* much code here */
callKernel(Type2Type<TensorType,Direction>());
/* more code */
TensorType result;
return result;
}
void callKernel(Type2Type<Tensor<double,1>, XDirection>) {}
void callKernel(Type2Type<Tensor<double,2>, XDirection>) {}
void callKernel(Type2Type<Tensor<double,2>, YDirection>) {}
};
int main(void) {
Tensor<double, 2> rank_two_tensor;
Filter<Tensor<double,2>,XDirection> f(rank_two_tensor);
f();
}
让我补充一些重要的事情:过滤器逻辑必须在 operator()
中因为您在这里看到的将与需要此结构的英特尔线程构建 block 一起使用。非常重要的是,callKernel
是内联的。据我所知,情况应该是这样。
在此先感谢您提供任何有帮助的和批评性的评论。
最佳答案
首先,对于第一次尝试模板来说还不错。
如果你有最新版本的 GCC,你可以像这样简化,有一个更好的方法可以使用 std::is_same<>
在类型上有条件地执行代码.它将返回 true
如果类型相同。它还使您的意图更加明确。
#include <type_traits>
template <class TensorType, class Direction>
struct Filter {
Filter(const TensorType &t) { }
TensorType operator()(){
/* much code here */
callKernel();
/* more code */
TensorType result;
return result;
}
void callKernel() {
// Check which type is called at compile time (the if will be simplified by the compiler)
if (std::is_same<TensorType, Tensor<double, 2> >::value) {
if (std::is_same<Direction, XDirection>::value) {
// do stuff
} else {
...
}
} else if (...) {
...
}
}
};
编辑:如果您愿意,您甚至可以将其移动到 op()
以确保代码是内联的。
关于c++ - 使用模板实现通用过滤器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/9322514/