c++ - pytorch C++与alexnet和cv::imread图像

标签 c++ opencv pytorch

我正在尝试使用C++应用程序推断使用alexnet预训练网络的图像分类任务。我已经成功推断出了用python加载网络的狗图像:

alexnet = torchvision.models.alexnet(pretrained=True)
img = Image.open("dog.jpg")
transform = transforms.Compose([
 transforms.Resize(256),                
 transforms.CenterCrop(224),        
 transforms.ToTensor(),                  
 transforms.Normalize(                   
 mean=[0.485, 0.456, 0.406],         
 std=[0.229, 0.224, 0.225]              
 )])
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)
alexnet.forward(batch_t)
_, index = torch.max(out, 1)

结果index是208,Labrador_retriever,看起来不错。
然后,我保存要从C++应用程序加载的网络
example = torch.rand(1, 3, 224, 224)
traced_script_module_alex = torch.jit.trace(alexnet, example)
traced_script_module.save("alexnet.pt")

当我加载到C++时,会得到错误的结果:
cv::Mat img = cv::imread("dog.jpg");
cv::resize(img, img, cv::Size(224, 224), cv::INTER_CUBIC);

// Convert the image and label to a tensor.
torch::Tensor img_tensor = torch::from_blob(img.data, { 1, img.rows, img.cols, 3 }, torch::kByte);
img_tensor = img_tensor.permute({ 0, 3, 1, 2 }); // convert to CxHxW
img_tensor = img_tensor.to(torch::kFloat);
std::vector<torch::jit::IValue> input;
input.push_back(img_tensor);
torch::jit::script::Module  module = torch::jit::load("alexnet.pt");
at::Tensor output = module.forward(input).toTensor();
std::cout << output.argmax(1) << '\n';
argmax是463,存储桶。
我想我看的不是同一张图片;我想念什么...?

最佳答案

您的C++代码缺少Python代码的这一部分:

transform = transforms.Compose([
 transforms.Resize(256),                
 transforms.CenterCrop(224),        
 transforms.ToTensor(),                  
 transforms.Normalize(                   
 mean=[0.485, 0.456, 0.406],         
 std=[0.229, 0.224, 0.225]              
 )])
img_t = transform(img)

关于c++ - pytorch C++与alexnet和cv::imread图像,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59783791/

相关文章:

c++ - 如何在C++中输入整数而不使用 ">>"

c++ - 检查成员是否存在,可能在基类中,VS2005/08版本

c++ - 使用 C++ 中的自定义元素进行 Const 结构初始化

python - 录制电脑屏幕时 MSS 重复帧

c++ - 在 C++ 中调用函数(或虚函数)是一个 coSTLy 操作

opencv - 是否有访问 OpenCV 元素的模板方法?

python - 错误导入 cv2 : ImportError: numpy. core.multiarray 导入失败

tensorflow - tf.data.experimental.sample_from_datasets 的 PyTorch 替代品

tensorflow - 有与 Tensorflow 等效的 PyTorch 闪电吗?

python - 多个 PyTorch 网络在不同 CPU 上并行运行