python - PyTorch:自定义批量采样器在第一个纪元后耗尽

标签 python pytorch pytorch-dataloader

我使用带有自定义batch_samplerDataLoader来确保每个批处理都是类平衡的。如何防止迭代器在第一个纪元耗尽自身?

import torch

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.x = torch.rand(10, 10)
        self.y = torch.Tensor([0] * 5 + [1] * 5)
        
    def __len__(self):
        len(self.y)
        
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

def custom_batch_sampler():
    batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
    return iter(batch_idx)

def train(loader):
    for epoch in range(10):
        for batch, (x, y) in enumerate(loader):
            print('epoch:', epoch, 'batch:', batch) # stops after first epoch

if __name__=='__main__':
    my_dataset = CustomDataset()
    my_loader = torch.utils.data.DataLoader(
        dataset=my_dataset,
        batch_sampler=custom_batch_sampler()
    )
    train(my_loader)

训练在第一个时期后停止,并且 next(iter(loader)) 给出 StopIteration 错误。

epoch: 0 batch: 0
epoch: 0 batch: 1
epoch: 0 batch: 2
epoch: 0 batch: 3
epoch: 0 batch: 4

最佳答案

自定义批量采样器需要是一个Sampler或一些可迭代的。在每个纪元中,都会从此可迭代生成一个新的迭代器。这意味着您实际上不需要手动创建一个迭代器(它将在第一个纪元后运行并引发 StopIteration),但您只需提供您的列表,因此如果您删除iter():

def custom_batch_sampler():
    batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
    return batch_idx

关于python - PyTorch:自定义批量采样器在第一个纪元后耗尽,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/72034010/

相关文章:

multithreading - 在多处理过程中使用队列是否使用酸洗?

python - 强制外键最多被Django中的另一个表引用

python - sql查询中的python列表作为参数

mongodb - 使用 Mongo DB 的 PyTorch DataLoader

python - 类型错误 : conv2d(): argument 'input' (position 1) must be Tensor, 不是字符串

python - 为什么单个 10x10x3 的 Conv2d 占用 850mb gpu

python - 重新连接远程 Jupyter Notebook 并获取当前单元格输出

python - 转换连接到 XBee 的温度传感器的读数

pytorch - 未找到版本 `GLIBC_2.28'

python - PyTorch:检查模型准确性导致 "TypeError: ' bool' 对象不可迭代。”