c++ - 有效使用 enable_if 和 C++ 模板来避免类特化

标签 c++ templates c++11 template-specialization partial-specialization

我无法编译我的代码。 clang、g++ 和 icpc 都给出了不同的错误信息,

在进入问题本身之前先了解一下背景:

我现在正在研究用于处理矩阵的模板类层次结构。有数据类型(浮点型或 double 型)和“实现策略”的模板参数——目前这包括带有循环的常规 C++ 代码和英特尔 MKL 版本。以下是简要摘要(请忽略此处缺少前向引用等——这与我的问题无关):

// Matrix.h

template <typename Type, typename IP>
class Matrix : public Matrix_Base<Type, IP>;

template <typename Matrix_Type>
class Matrix_Base{
    /* ... */

    // Matrix / Scalar addition
    template <typename T>
    Matrix_Base& operator+=(const T value_) { 
      return Implementation<IP>::Plus_Equal(
          static_cast<Matrix_Type&>(*this), value_);

    /* More operators and rest of code... */
    };

struct CPP;
struct MKL;

template <typename IP>
struct Implementation{
/* This struct contains static methods that do the actual operations */

我现在遇到的麻烦与实现类的实现有关(没有双关语意)。我知道我可以使用实现模板类的特化来特化 template <> struct Implementation<MKL>{/* ... */};然而,这将导致大量代码重复,因为有许多运算符(例如矩阵标量加法、减法……),通用版本和专用版本都使用相同的代码。

因此,相反,我认为我可以摆脱模板特化,只使用 enable_if 为那些在使用 MKL(或 CUDA 等)时具有不同实现的运算符提供不同的实现。

事实证明,这比我原先预期的更具挑战性。第一个——operator += (T value_)工作正常。我添加了一个检查只是为了确保参数是合理的(如果它是我的麻烦的根源,这可以消除,我对此表示怀疑)。

template <class Matrix_Type, typename Type, typename enable_if< 
    std::is_arithmetic<Type>::value  >::type* dummy = nullptr>
static Matrix_Type& Plus_Equal(Matrix_Type& matrix_, Type value_){
    uint64_t total_elements = matrix_.actual_dims.first * matrix_.actual_dims.second;
    //y := A + b

    #pragma parallel 
    for (uint64_t i = 0; i < total_elements; ++i)
        matrix_.Data[i] += value_; 

    return matrix_;
}

但是,我真的很难弄清楚如何处理 operator *=(T value_) .这是因为 floatdouble MKL 有不同的实现,但在一般情况下并非如此。

这是声明。请注意,第三个参数是一个虚拟参数,是我强制函数重载的尝试,因为我不能使用部分模板函数特化:

template <class Matrix_Type, typename U, typename Type = 
    typename internal::Type_Traits< Matrix_Type>::type, typename  enable_if<
    std::is_arithmetic<Type>::value >::type* dummy = nullptr>

static Matrix_Type& Times_Equal(Matrix_Type& matrix_, U value_, Type dummy_ = 0.0);

一般情况下的定义。 :

template<class IP>
template <class Matrix_Type, typename U, typename Type,  typename enable_if<
    std::is_arithmetic<Type>::value >::type* dummy>
Matrix_Type& Implementation<IP>::Times_Equal(Matrix_Type& matrix_, U value_, Type){

    uint64_t total_elements = matrix_.actual_dims.first * matrix_.actual_dims.second;

    //y := A - b
    #pragma parallel
    for (uint64_t i = 0; i < total_elements; ++i)
        matrix_.Data[i] *= value_;

    return matrix_;
}

当我尝试为 MKL 实现特化时,麻烦就来了:

template<>
template <class Matrix_Type, typename U, typename Type, typename enable_if<
    std::is_arithmetic<Type>::value >::type* dummy>
Matrix_Type& Implementation<implementation::MKL>::Times_Equal(
    Matrix_Type& matrix_, 
    U value_,
    typename enable_if<std::is_same<Type,float>::value,Type>::type)
{

    float value = value_;

    MKL_INT total_elements = matrix_.actual_dims.first * matrix_.actual_dims.second;
    MKL_INT const_one = 1;

    //y := a * b
    sscal(&total_elements, &value, matrix_.Data, &const_one);
    return matrix_;
}

这给我一个 clang 错误:

_错误:“Times_Equal”的外联定义与“实现”中的任何声明不匹配_

在 g++ 中(稍微缩短)

_错误:“Matrix_Type& Implementation::Times_Equal(...)”的模板 ID ‘Times_Equal<>’与任何模板声明都不匹配。

如果我将第 3 个参数更改为 Type,而不是使用 enable_if,代码将编译得很好。但是当我这样做时,我看不到如何为 float 和 double 单独实现。

如有任何帮助,我们将不胜感激。

最佳答案

我认为使用 std::enable_if 来实现这将非常乏味,因为一般情况下必须使用 enable_if 来实现,如果它不适合其中一个特化,则将其打开。

特别针对您的代码,我认为编译器无法在您的 MKL 特化中推断出 Type,因为它隐藏在 std::enable_if 中,因此这种特化永远不会得到打电话。

而不是使用 enable_if 你也许可以做这样的事情:

#include<iostream>

struct CPP {};
struct MKL {};

namespace Implementation
{
   //
   // general Plus_Equal
   //
   template<class Type, class IP>
   struct Plus_Equal
   {
      template<class Matrix_Type>
      static Matrix_Type& apply(Matrix_Type& matrix_, Type value_)
      {
         std::cout << " Matrix Plus Equal General Implementation " << std::endl;
         // ... do general Plus_Equal ...
         return matrix_;
      }
   };

   //
   // specialized Plus_Equal for MKL with Type double
   //
   template<>
   struct Plus_Equal<double,MKL>
   {
      template<class Matrix_Type>
      static Matrix_Type& apply(Matrix_Type& matrix_, double value_)
      {
         std::cout << " Matrix Plus Equal MKL double Implementation " << std::endl;
         // ... do MKL/double specialized Plus_Equal ...
         return matrix_;
      }
   };
} // namespace Implementation

template <typename Type, typename IP, typename Matrix_Type>
class Matrix_Base
{  
   public:
   // ... matrix base implementation ...

   // Matrix / Scalar addition
   template <typename T>
   Matrix_Base& operator+=(const T value_) 
   { 
      return Implementation::Plus_Equal<Type,IP>::apply(static_cast<Matrix_Type&>(*this), value_);
   }

   // ...More operators and rest of code...
};

template <typename Type, typename IP>
class Matrix : public Matrix_Base<Type, IP, Matrix<Type,IP> >
{
   // ... Matrix implementation ...
};

int main()
{
   Matrix<float ,MKL> f_mkl_mat;
   Matrix<double,MKL> d_mkl_mat;

   f_mkl_mat+=2.0; // will use general plus equal
   d_mkl_mat+=2.0; // will use specialized MKL/double version

   return 0;
}

这里我使用了类特化而不是 std::enable_if。我发现你的示例中的IPTypeMatrix_Type 类型非常不一致,所以我希望我在这里正确使用它们。

关于 std::enable_if 的注释。我会使用表格

template<... , typename std::enable_if< some bool >::type* = nullptr> void func(...);

结束

template<... , typename = std::enable_if< some bool >::type> void func(...);

因为它使您能够执行一些其他形式无法执行的函数重载。

希望你能使用它:)

编辑 20/12-13:重新阅读我的帖子后,我发现我应该明确执行 CRTP(奇怪的重复模板模式),我在上面的代码中添加了它。我将 TypeIP 都传递给 Matrix_Base。如果您觉得这很乏味,可以提供一个矩阵特征类,Matrix_Base 可以从中取出它们。

template<class A>
struct Matrix_Traits;

// Specialization for Matrix class
template<class Type, class IP>
struct Matrix_Traits<Matrix<Type,IP> >
{
   using type = Type;
   using ip   = IP;
};

然后 Matrix_Base现在只接受一个模板参数,即矩阵类本身,并从 traits 类获取类型

template<class Matrix_Type>
class Matrix_Base
{
   // Matrix / Scalar addition
   template <typename T>
   Matrix_Base& operator+=(const T value_) 
   { 
      // We now get Type and IP from Matrix_Traits
      return Implementation::Plus_Equal<typename Matrix_Traits<Matrix_Type>::type
                                      , typename Matrix_Traits<Matrix_Type>::ip
                                      >::apply(static_cast<Matrix_Type&>(*this), value_);
   }
};

关于c++ - 有效使用 enable_if 和 C++ 模板来避免类特化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/20664189/

相关文章:

c++ - 如何从其他类访问不同类中的变量

c++ - 在 SWIG 中包装 C++ 结构模板

c++ - C++中 'this'指针的用例

security - 安全审核文件。 ISO 27001

c++ - 将迭代器传递给函数

C++11 constexpr 函数编译器错误与三元条件运算符 (? :)

c++ - XPATH在C++ Boost中使用

c++ - 为什么 OpenGL GLFW 渲染形状不起作用?

map 和 multimap 之间的 C++ 模板特化

c++ - C++ 中的 std::move 和按引用传递有什么区别?