PyTorch 数据加载器

标签 pytorch

我正在尝试使用多个torch.utils.data.DataLoader来创建应用了不同转换的数据集。目前,我的代码大致是

d_transforms = [
    transforms.RandomHorizontalFlip(),
    # Some other transforms...
]
loaders = []
for i in range(len(d_transforms)):
    dataset = datasets.MNIST('./data', 
            train=train, 
            download=True, 
            transform=d_transforms[i]
    loaders.append(
        DataLoader(dataset, 
            shuffle=True, 
            pin_memory=True, 
            num_workers=1)
        )

这可行,但速度非常慢。 kernprof显示我的代码中几乎所有时间都花在诸如

之类的行上
x, y = next(iter(train_loaders[i]))

我怀疑这是由于我正在使用 DataLoader 的多个实例,每个实例都有自己的工作程序,该工作程序尝试读取相同的数据文件。

我的问题是,有什么更好的方法来做到这一点?理想情况下,我将torch.utils.data.DataSet子类化并指定采样时要应用的转换,但这似乎不可能,因为__getitem__ 无法接受参数。

最佳答案

__getitem__ 确实接受一个参数,该参数是您要加载的内容的索引。例如。

transform = transforms.Compose(
    [transforms.ToTensor(),
     normalize])

class CountDataset(Dataset):

def __init__(self, file,transform=None):

    self.transform = transform
    #self.vocab = vocab
    with open(file,'rb') as f:
        self.data = pickle.load(f)
    self.y = self.data['answers']
    self.I = self.data['images']


def __len__(self):
    return len(self.y)

def __getitem__(self, idx):
    img_name = self.I[idx]
    label = self.y[Idx]
    fname = '/'.join(img_name.split("/")[-2:]) #/train2014/xx.jpg
    DIR = '/hdd/manoj/VQA/Images/mscoco/'
    img_full_path = os.path.join(DIR,fname)
    img = Image.open(img_full_path).convert("RGB")
    img_tensor = self.transform(img.resize((224,224)))
    return img_tensor,label


testset = CountDataset(file = 'testdat.pkl',
                        transform = transform)


testloader = DataLoader(testset, batch_size=32,
                         shuffle=False, num_workers=4)

您不会在循环中调用数据加载器。

关于PyTorch 数据加载器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47933597/

相关文章:

python - 稀疏张量以减少训练时间

python-3.x - Pytorch:ValueError:预期输入 batch_size (32) 以匹配目标 batch_size (64)

python - 如何将变量设置为属性?

machine-learning - dropout中的in-place是什么意思

python - 二维数组的按行 numpy.isin

python - PyTorch 加载 "\lib\site-packages\torch\lib\shm.dll"或其依赖项之一时出错

python - 在 PyTorch 中,nll_loss 的输入是什么?

python - 如何在 Pytorch 中展平 `nn.Sequential` 中的输入

python - 摆脱 maxpooling 层会导致运行 cuda 内存错误 pytorch

python - 给定输入大小: (128x1x1).计算出的输出大小: (128x0x0).输出大小太小