我使用 random_split() 将数据分为训练和测试,我观察到,如果在创建数据加载器后进行随机分割,则从数据加载器获取一批数据时会丢失批量大小。
import torch
from torchvision import transforms, datasets
from torch.utils.data import random_split
# Normalize the data
transform_image = transforms.Compose([
transforms.Resize((240, 320)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
data = '/data/imgs/train'
def load_dataset():
data_path = data
main_dataset = datasets.ImageFolder(
root = data_path,
transform = transform_image
)
loader = torch.utils.data.DataLoader(
dataset = main_dataset,
batch_size= 64,
num_workers = 0,
shuffle= True
)
# Dataset has 22424 data points
trainloader, testloader = random_split(loader.dataset, [21000, 1424])
return trainloader, testloader
trainloader, testloader = load_dataset()
现在从训练和测试加载器获取一批图像:
images, labels = next(iter(trainloader))
images.shape
# %%
len(trainloader)
# %%
images_test, labels_test = next(iter(testloader))
images_test.shape
# %%
len(testloader)
我得到的输出没有训练或测试批处理的批处理大小。输出暗淡应该是 [批处理 x channel x 高 x 宽],但我得到 [ channel x 高 x 宽]。
输出:
但是,如果我从数据集创建拆分,然后使用拆分创建两个数据加载器,我会在输出中获得批量大小。
def load_dataset():
data_path = data
main_dataset = datasets.ImageFolder(
root = data_path,
transform = transform_image
)
# Dataset has 22424 data points
train_data, test_data = random_split(main_dataset, [21000, 1424])
trainloader = torch.utils.data.DataLoader(
dataset = train_data,
batch_size= 64,
num_workers = 0,
shuffle= True
)
testloader = torch.utils.data.DataLoader(
dataset = test_data,
batch_size= 64,
num_workers= 0,
shuffle= True
)
return trainloader, testloader
trainloader, testloader = load_dataset()
运行相同的 4 个命令来获取单个训练和测试批处理:
第一种方法是错误的吗?虽然长度显示数据已经被分割了。那么为什么我看不到批量大小?
最佳答案
第一种方法是错误的。
仅DataLoader
实例返回批量的项目。 Dataset
就像实例没有一样。
当您调用make_split
时你通过了loader.dataset
这只是对 main_dataset
的引用(不是 DataLoader
)。结果是trainloader
和testloader
是Dataset
不是DataLoader
s。事实上你丢弃了 loader
这是你唯一的DataLoader
当您从 load_dataset
返回时.
第二个版本是您应该做的以获得两个单独的 DataLoader
s。
关于python - Pytorch:在 dataloader.dataset 上使用 torch.utils.random_split() 后,数据中缺少批量大小,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59322580/