python - img 应该是 PIL 图片。得到了 <class 'torch.Tensor' >

标签 python pytorch

我正在尝试遍历加载程序以检查它是否正常工作,但出现以下错误:

TypeError: img should be PIL Image. Got <class 'torch.Tensor'>

我已经尝试添加 transforms.ToTensor()transforms.ToPILImage(),但它给了我一个错误,要求相反。即,对于 ToPILImage() ,它将要求张量,反之亦然。

# Imports here
%matplotlib inline
import matplotlib.pyplot as plt
from torch import nn, optim
import torch.nn.functional as F
import torch
from torchvision import transforms, datasets, models
import seaborn as sns
import pandas as pd
import numpy as np

data_dir = 'flowers'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
test_dir = data_dir + '/test'

#Creating transform for training set
train_transforms = transforms.Compose(
[transforms.Resize(255), 
transforms.CenterCrop(224), 
transforms.ToTensor(), 
transforms.RandomHorizontalFlip(), 
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

#Creating transform for test set
test_transforms = transforms.Compose(
[transforms.Resize(255),
transforms.CenterCrop(224), 
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])

#transforming for all data
train_data = datasets.ImageFolder(train_dir, transform=train_transforms)
test_data = datasets.ImageFolder(test_dir, transform = test_transforms)
valid_data = datasets.ImageFolder(valid_dir, transform = test_transforms)

#Creating data loaders for test and training sets
trainloader = torch.utils.data.DataLoader(train_data, batch_size = 32, 
shuffle = True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32)
images, labels = next(iter(trainloader))

一旦我运行 plt.imshow(images[0]) ,它应该允许我简单地看到图像,如果它工作正常的话。

最佳答案

transforms.RandomHorizo​​ntalFlip() 适用于 PIL.Images,不适用于 torch.Tensor。在上面的代码中,您在 transforms.RandomHorizo​​ntalFlip() 之前应用 transforms.ToTensor(),这会产生张量。

但是,根据 pytorch 官方文档 here

transforms.RandomHorizontalFlip() horizontally flip the given PIL Image randomly with a given probability.

所以,只需更改上面代码中的转换顺序,如下所示:

train_transforms = transforms.Compose([transforms.Resize(255), 
                                       transforms.CenterCrop(224),  
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(), 
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 

关于python - img 应该是 PIL 图片。得到了 <class 'torch.Tensor' >,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57079219/

相关文章:

python-3.x - 如何在当前纪元期间加载/获取下一个纪元的下一批数据?

python - 使用 torch 视觉创建火车变换时出错

python - 在整个天空的网格上绘制 FITS 图像

python - 根据共存规则将列表拆分为多个组

python - 创建副本时如何避免改变原始全局变量

python - 每 100 个元素切片 Python 列表的最 pythonic 方法

python - pytorch的nn.Module如何注册子模块?

python - 梯度计算所需的变量之一已通过就地操作修改

tensorflow - softmax_cross_entropy_with_logits 的 PyTorch 等价

python - 如何在Hadoop流中使用opt解析器作为映射器指定python脚本