c++ - 使用元组累加器减少推力

标签 c++ c++11 templates thrust

我想使用thrust::reducethrust::host_vector上的thrust::tuple<double,double> 。因为没有预定义thrust::plus<thrust::tuple<double,double>>我自己编写并使用了 thrust::reduce 的变体有四个参数。

因为我是一个好公民,所以我放置了我的自定义版本 plus在我自己的命名空间中,我简单地未定义主模板并将其专门用于 thrust::tuple<T...> .

#include <iostream>
#include <tuple>

#include <thrust/host_vector.h>
#include <thrust/reduce.h>
#include <thrust/tuple.h>

namespace thrust_ext {

namespace detail {

// https://stackoverflow.com/a/24481400
template <size_t ...I>
struct index_sequence {};

template <size_t N, size_t ...I>
struct make_index_sequence : public make_index_sequence<N - 1, N - 1, I...> {};

template <size_t ...I>
struct make_index_sequence<0, I...> : public index_sequence<I...> {};

template < typename... T, size_t... I >
__host__ __device__ thrust::tuple<T...> plus(thrust::tuple<T...> const &lhs,
                                             thrust::tuple<T...> const &rhs,
                                             index_sequence<I...>) {
    return {thrust::get<I>(lhs) + thrust::get<I>(rhs) ...};
}

} // namespace detail

template < typename T >
struct plus;

template < typename... T >
struct plus < thrust::tuple<T...> > {
    __host__ __device__ thrust::tuple<T...> operator()(thrust::tuple<T...> const &lhs,
                                                       thrust::tuple<T...> const &rhs) const {
        return detail::plus(lhs,rhs,detail::make_index_sequence<sizeof...(T)>{});
    }
};

} //namespace thrust_ext

int main() {
    thrust::host_vector<thrust::tuple<double,double>> v(10, thrust::make_tuple(1.0,2.0));

    auto r = thrust::reduce(v.begin(), v.end(),
                            thrust::make_tuple(0.0,0.0),
                            thrust_ext::plus<thrust::tuple<double,double>>{});

    std::cout << thrust::get<0>(r) << ' ' << thrust::get<1>(r) << '\n';
}

但是,这无法编译。错误消息非常长,请参阅 this Gist 。该错误消息表明问题出在 thrust::reduce 的某些实现细节中。 。另外,如果我替换 thrust::tuplestd::tuple它按预期编译并运行。

我正在使用 Thrust 1.8.1 和 Clang 6。

最佳答案

正如您在错误消息中看到的那样,thrust::tuple<double,double>实际上是 thrust::tuple<double, double, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>

这是一个基于默认模板参数的 C++03 风格的“可变模板”,这意味着 sizeof...(T)将计算所有 null_type s 并产生错误的大小(始终为 10)。

您需要使用thrust::tuple_size检索实际大小。

关于c++ - 使用元组累加器减少推力,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48981446/

相关文章:

c++ - 比较变体内容的函数无法编译

c++ - VS 2013 模板不编译

c++ - 成员模板,来自 ISO C++ 标准的声明?

c++ - 晦涩的 C++ 语法

c++ - 以成员函数启动线程

C++11奇怪的大括号初始化行为

c++ - 为什么 C++11 中仍然需要 "using"指令来从基类中引入在派生类中重载的方法

c++ - float 类型的默认参数 = 乱码

c++ - 在 .cpp 文件中实现 operator== 的正确方法

常量数据持有者类的 C++ 模板