python - 获取 pytorch 数据集的子集

标签 python machine-learning neural-network torch pytorch

我有一个网络,我想在某些数据集上进行训练(例如,CIFAR10)。我可以通过

创建数据加载器对象
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

我的问题如下:假设我想进行几次不同的训练迭代。假设我想首先在奇数位置的所有图像上训练网络,然后在偶数位置的所有图像上训练网络等等。为此,我需要能够访问这些图像。不幸的是,trainset 似乎不允许这样的访问。也就是说,尝试执行 trainset[:1000] 或更一般的 trainset[mask] 将引发错误。

我可以代替

trainset.train_data=trainset.train_data[mask]
trainset.train_labels=trainset.train_labels[mask]

然后

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              shuffle=True, num_workers=2)

但是,这将迫使我在每次迭代中创建完整数据集的新副本(因为我已经更改了 trainset.train_data 所以我需要重新定义 trainset ).有什么办法可以避免吗?

理想情况下,我希望有一些“等同于”的东西

trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4,
                                              shuffle=True, num_workers=2)

最佳答案

torch.utils.data.Subset更简单,支持 shuffle,并且不需要编写自己的采样器:

import torchvision
import torch

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=None)

evens = list(range(0, len(trainset), 2))
odds = list(range(1, len(trainset), 2))
trainset_1 = torch.utils.data.Subset(trainset, evens)
trainset_2 = torch.utils.data.Subset(trainset, odds)

trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=4,
                                            shuffle=True, num_workers=2)
trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=4,
                                            shuffle=True, num_workers=2)

关于python - 获取 pytorch 数据集的子集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47432168/

相关文章:

python - 日期时间 - 10 小时

python - 如何在 Zenoss 映射中导入 DMD?

tensorflow - keras.model.predict 引发 ValueError : Error when checking input

matlab - Netlab的函数mlperr计算均方误差吗?

python - 如何使用自动编码器可视化降维? ( python | tensorflow )

python - geopy.distance.vincenty 给出了不同的弧度和度数结果

python - 匹配直到可选的未转义字符/序列之一或换行符

python - 根据文档,PyGAD 未接收整数参数

machine-learning - 时间序列的 LSTM

python - 如何连接 "Jagged"张量