python - Pytorch:在 dataloader.dataset 上使用 torch.utils.random_split() 后,数据中缺少批量大小

标签 python python-3.x deep-learning pytorch

我使用 random_split() 将数据分为训练和测试,我观察到,如果在创建数据加载器后进行随机分割,则从数据加载器获取一批数据时会丢失批量大小。

import torch
from torchvision import transforms, datasets
from torch.utils.data import random_split

# Normalize the data
transform_image = transforms.Compose([
  transforms.Resize((240, 320)),
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

data = '/data/imgs/train'

def load_dataset():
  data_path = data
  main_dataset = datasets.ImageFolder(
    root = data_path,
    transform = transform_image
  )

  loader = torch.utils.data.DataLoader(
    dataset = main_dataset,
    batch_size= 64,
    num_workers = 0,
    shuffle= True
  )

  # Dataset has 22424 data points
  trainloader, testloader = random_split(loader.dataset, [21000, 1424])

  return trainloader, testloader

trainloader, testloader = load_dataset()

现在从训练和测试加载器获取一批图像:

images, labels = next(iter(trainloader))
images.shape
# %%
len(trainloader)

# %%
images_test, labels_test = next(iter(testloader))
images_test.shape

# %%
len(testloader)

我得到的输出没有训练或测试批处理的批处理大小。输出暗淡应该是 [批处理 x channel x 高 x 宽],但我得到 [ channel x 高 x 宽]。

输出:

enter image description here

但是,如果我从数据集创建拆分,然后使用拆分创建两个数据加载器,我会在输出中获得批量大小。

def load_dataset():
    data_path = data
    main_dataset = datasets.ImageFolder(
      root = data_path,
      transform = transform_image
    )
    # Dataset has 22424 data points
    train_data, test_data = random_split(main_dataset, [21000, 1424])

    trainloader = torch.utils.data.DataLoader(
      dataset = train_data,
      batch_size= 64,
      num_workers = 0,
      shuffle= True
    )

    testloader = torch.utils.data.DataLoader(
      dataset = test_data,
      batch_size= 64,
      num_workers= 0,
      shuffle= True
    )

    return trainloader, testloader

trainloader, testloader = load_dataset()

运行相同的 4 个命令来获取单个训练和测试批处理:

enter image description here

第一种方法是错误的吗?虽然长度显示数据已经被分割了。那么为什么我看不到批量大小?

最佳答案

第一种方法是错误的。

DataLoader实例返回批量的项目。 Dataset就像实例没有一样。

当您调用make_split时你通过了loader.dataset这只是对 main_dataset 的引用(不是 DataLoader )。结果是trainloadertestloaderDataset不是DataLoader s。事实上你丢弃了 loader这是你唯一的DataLoader当您从 load_dataset 返回时.

第二个版本是您应该做的以获得两个单独的 DataLoader s。

关于python - Pytorch:在 dataloader.dataset 上使用 torch.utils.random_split() 后,数据中缺少批量大小,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59322580/

相关文章:

python - 为什么设置窗口图标时没有定义 .ico 文件?

python - 如何使用Tensorflow进行信号处理?

python - 我正在尝试使用预训练网络对花朵进行分类,但由于某种原因它无法训练

python - PyAudio 在 ubuntu 上不工作并破坏声音

python - 无法使用 anaconda 的 conda 包更新到 python 3.5

python - 在 Jupyter Python Notebook 中显示所有数据框列

python - 如何保存棋盘游戏? Python

python - 在进行逻辑回归时,如何解决Python中的值错误?

python - 我请求的页面不代表与BeautifulSoup4相同的结构

neural-network - 用于 Caffe 的 LSTM 模块