我正在使用 CUDA 和 Thrust。我发现输入 thrust::transform [plus/minus/divide]
乏味,所以我只想重载一些简单的运算符。
如果我能做到的话那就太棒了:
thrust::[host/device]_vector<float> host;
thrust::[host/device]_vector<float> otherHost;
thrust::[host/device]_vector<float> result = host + otherHost;
这是 +
的示例片段:
template <typename T>
__host__ __device__ T& operator+(T &lhs, const T &rhs) {
thrust::transform(rhs.begin(), rhs.end(),
lhs.begin(), lhs.end(), thrust::plus<?>());
return lhs;
}
但是,thrust::plus<?>
没有正确重载,或者我没有正确执行......其中之一。 (如果为此重载简单运算符是一个坏主意,请解释原因)。最初,我认为我可以重载 ?
占位符,类似于 typename T::iterator
,但这没有用。
我不知道如何重载 +
具有 vector 类型和 vector 迭代器类型的运算符。这有道理吗?
感谢您的帮助!
最佳答案
这似乎可行,其他人可能有更好的想法:
#include <ostream>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <thrust/transform.h>
#include <thrust/functional.h>
#include <thrust/copy.h>
#include <thrust/fill.h>
#define DSIZE 10
template <typename T>
thrust::device_vector<T> operator+(thrust::device_vector<T> &lhs, const thrust::device_vector<T> &rhs) {
thrust::transform(rhs.begin(), rhs.end(),
lhs.begin(), lhs.begin(), thrust::plus<T>());
return lhs;
}
template <typename T>
thrust::host_vector<T> operator+(thrust::host_vector<T> &lhs, const thrust::host_vector<T> &rhs) {
thrust::transform(rhs.begin(), rhs.end(),
lhs.begin(), lhs.begin(), thrust::plus<T>());
return lhs;
}
int main() {
thrust::device_vector<float> dvec(DSIZE);
thrust::device_vector<float> otherdvec(DSIZE);
thrust::fill(dvec.begin(), dvec.end(), 1.0f);
thrust::fill(otherdvec.begin(), otherdvec.end(), 2.0f);
thrust::host_vector<float> hresult1 = dvec + otherdvec;
std::cout << "result 1: ";
thrust::copy(hresult1.begin(), hresult1.end(), std::ostream_iterator<float>(std::cout, " ")); std::cout << std::endl;
thrust::host_vector<float> hvec(DSIZE);
thrust::fill(hvec.begin(), hvec.end(), 5.0f);
thrust::host_vector<float> hresult2 = hvec + hresult1;
std::cout << "result 2: ";
thrust::copy(hresult2.begin(), hresult2.end(), std::ostream_iterator<float>(std::cout, " ")); std::cout << std::endl;
// this line would produce a compile error:
// thrust::host_vector<float> hresult3 = dvec + hvec;
return 0;
}
请注意,无论哪种情况,我都可以为结果指定主机或设备 vector ,因为推力将看到差异并自动生成必要的复制操作。因此,我的模板中的结果 vector 类型(主机、设备)并不重要。
另请注意,模板定义中的 thrust::transform
函数参数不太正确。
关于c++ - 为 Thrust 重载 "+"运算符,有什么想法吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/16906255/