PyTorch:在 random_split 之后对训练数据应用数据增强

标签 pytorch data-augmentation dataloader pytorch-dataloader

我的数据集没有用于训练和测试的单独文件夹。我想在分割后仅对训练数据应用数据增强和转换

 train_data, valid_data = D.random_split(dataset, lengths=[train_size, valid_size])

有谁知道如何实现这一目标吗?我有一个带有初始化和 getitem 的自定义数据集。训练和验证数据集进一步传递到 DataLoader。

最佳答案

您可以拥有一个仅用于转换的自定义数据集:

class TrDataset(Dataset):
  def __init__(self, base_dataset, transformations):
    super(TrDataset, self).__init__()
    self.base = base_dataset
    self.transformations = transformations

  def __len__(self):
    return len(self.base)

  def __getitem__(self, idx):
    x, y = self.base[idx]
    return self.transformations(x), y

一旦有了这个Dataset包装器,您就可以对训练集和验证集进行不同的转换:

raw_train_data, raw_valid_data = D.random_split(dataset, lengths=[train_size, valid_size])
train_data = TrDataset(raw_train_data, train_transforms)
valid_data = TrDataset(raw_valid_data, val_transforms)

关于PyTorch:在 random_split 之后对训练数据应用数据增强,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/74920920/

相关文章:

python - Hogwild 的 PyTorch 多处理错误

python - 为什么清除对象后GPU中的内存还在使用?

python - 如何将连续的行与它们之间不断增加的重叠结合起来(就像滚动窗口一样)?

numpy - 在pytorch中加载多个.npy文件(大小> 10GB)

python - 3D CNN 在图像序列上的输入形状应该是什么?

Python Dataset Class + PyTorch Dataloader : Stuck at __getitem__, 测试时如何获取索引、标签等?

python - LeakyReLU 中的 "negative"斜率在哪里?

python - 通过对最后 4 层求和来嵌入 BERT 句子