比如我想用AVX2实现一个矩阵乘法模板函数。 (假设“矩阵”是一个实现良好的模板类)
Matrix<T> matmul(const Matrix<T>& mat1, const Matrix<T>& mat2) {
if (typeid(T).name() == typeid(float).name()) {
//using __m256 to store float
//using __m256_load_ps __m256_mul_ps __m256_add_ps
} else if (typeid(T).name() == typeid(double).name()) {
//using __m256d to store double
//using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
} else {
//...
}
}
由于没有数据类型的“变量”,程序无法确定它应该使用__m256 还是__m256d 或其他任何东西,从而使代码非常长且笨拙。还有其他方法可以避免这种情况吗?
最佳答案
在 C++17 及更高版本中,您可以使用 if constexpr
:
#include <type_traits>
Matrix<T> matmul(const Matrix<T>& mat1, const Matrix<T>& mat2) {
if constexpr (std::is_same_v<T, float>) {
//using __m256 to store float
//using __m256_load_ps __m256_mul_ps __m256_add_ps
} else if constexpr (std::is_same_v<T, double>) {
//using __m256d to store double
//using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
} else {
//...
}
}
否则,只需使用重载:
Matrix<float> matmul(const Matrix<float>& mat1, const Matrix<float>& mat2) {
//using __m256 to store float
//using __m256_load_ps __m256_mul_ps __m256_add_ps
}
Matrix<double> matmul(const Matrix<double>& mat1, const Matrix<double>& mat2) {
//using __m256d to store double
//using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
}
...
关于c++ - 如何针对程序相似的不同数据类型专门化模板函数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/74773459/