我正在尝试使用多个torch.utils.data.DataLoader
来创建应用了不同转换的数据集。目前,我的代码大致是
d_transforms = [
transforms.RandomHorizontalFlip(),
# Some other transforms...
]
loaders = []
for i in range(len(d_transforms)):
dataset = datasets.MNIST('./data',
train=train,
download=True,
transform=d_transforms[i]
loaders.append(
DataLoader(dataset,
shuffle=True,
pin_memory=True,
num_workers=1)
)
这可行,但速度非常慢。 kernprof显示我的代码中几乎所有时间都花在诸如
之类的行上x, y = next(iter(train_loaders[i]))
我怀疑这是由于我正在使用 DataLoader
的多个实例,每个实例都有自己的工作程序,该工作程序尝试读取相同的数据文件。
我的问题是,有什么更好的方法来做到这一点?理想情况下,我将torch.utils.data.DataSet
子类化并指定采样时要应用的转换,但这似乎不可能,因为__getitem__
无法接受参数。
最佳答案
__getitem__
确实接受一个参数,该参数是您要加载的内容的索引。例如。
transform = transforms.Compose(
[transforms.ToTensor(),
normalize])
class CountDataset(Dataset):
def __init__(self, file,transform=None):
self.transform = transform
#self.vocab = vocab
with open(file,'rb') as f:
self.data = pickle.load(f)
self.y = self.data['answers']
self.I = self.data['images']
def __len__(self):
return len(self.y)
def __getitem__(self, idx):
img_name = self.I[idx]
label = self.y[Idx]
fname = '/'.join(img_name.split("/")[-2:]) #/train2014/xx.jpg
DIR = '/hdd/manoj/VQA/Images/mscoco/'
img_full_path = os.path.join(DIR,fname)
img = Image.open(img_full_path).convert("RGB")
img_tensor = self.transform(img.resize((224,224)))
return img_tensor,label
testset = CountDataset(file = 'testdat.pkl',
transform = transform)
testloader = DataLoader(testset, batch_size=32,
shuffle=False, num_workers=4)
您不会在循环中调用数据加载器。
关于PyTorch 数据加载器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47933597/