cuda - 在 CUDA C 项目中使用推力::max_element

标签 cuda max thrust

在 CUDA C 项目中,我想尝试使用 Thrust 库来查找 float 组内的最大元素。看来 Thrust 函数 Thrust::max_element() 就是我所需要的。我想要使​​用此函数的数组是 cuda 内核的结果(似乎工作正常),因此在调用 Throw::max_element() 时它已经存在于设备内存中。 我对 Thrust 库不是很熟悉,但是在查看了 Thrust::max_element() 的文档并阅读了本网站上类似问题的答案后,我认为我已经掌握了这个过程的工作原理。不幸的是,我得到了错误的结果,而且似乎我没有正确使用库函数。有人可以告诉我我的代码有什么问题吗?

float* deviceArray;
float* max;
int length = 1025;

*max = 0.0f;
size = (int) length*sizeof(float);     

cudaMalloc(&deviceArray, size);
cudaMemset(deviceArray, 0.0f, size);

// here I launch a cuda kernel which modifies deviceArray

thrust::device_ptr<float> d_ptr = thrust::device_pointer_cast(deviceArray);
*max = *(thrust::max_element(d_ptr, d_ptr + length));

我使用以下 header :

#include <thrust/extrema.h>
#include <thrust/device_ptr.h>

即使我确信运行内核后 deviceArray 包含非零值,我仍然得到 *max 的零值。 我使用 nvcc 作为编译器 (CUDA 7.0),并在计算能力为 3.5 的设备上运行代码。

任何帮助将不胜感激。谢谢。

最佳答案

这不是正确的 C 代码:

float* max;
int length = 1025;

*max = 0.0f;

除非您正确地为该指针提供了分配(并将指针设置为等于该分配的地址),否则不允许使用指针(max)存储数据。

除此之外,你的其余代码似乎对我有用:

$ cat t990.cu
#include <thrust/extrema.h>
#include <thrust/device_ptr.h>
#include <iostream>


int main(){

  float* deviceArray;
  float max, test;
  int length = 1025;

  max = 0.0f;
  test = 2.5f;
  int size = (int) length*sizeof(float);

  cudaMalloc(&deviceArray, size);
  cudaMemset(deviceArray, 0.0f, size);
  cudaMemcpy(deviceArray, &test, sizeof(float),cudaMemcpyHostToDevice);

  thrust::device_ptr<float> d_ptr = thrust::device_pointer_cast(deviceArray);
  max = *(thrust::max_element(d_ptr, d_ptr + length));
  std::cout << max << std::endl;
}
$ nvcc -o t990 t990.cu
$ ./t990
2.5
$

关于cuda - 在 CUDA C 项目中使用推力::max_element,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34066520/

相关文章:

c++ - 分解在 Visual C++ 中不起作用的函数

c - 将 CUDA 静态或共享库与 gcc 链接时出现 undefined reference 错误

mysql - 需要帮助更改我的查询以保持最小行完整

c++ - CUB 选择是否有返回的索引

max - PowerPivot DAX : Identify Max & Min Value per Group

c++ - 有没有办法限制 STL::map 容器的最大大小?

c++ - 元组上的推力排序非常慢

cuda - 3D 纹理内存是如何缓存的?

cuda - CUDA中实现float4运算的头文件是哪个?

c++ - 直接从数组读取时越界地址