我使用带有自定义batch_sampler
的DataLoader
来确保每个批处理都是类平衡的。如何防止迭代器在第一个纪元耗尽自身?
import torch
class CustomDataset(torch.utils.data.Dataset):
def __init__(self):
self.x = torch.rand(10, 10)
self.y = torch.Tensor([0] * 5 + [1] * 5)
def __len__(self):
len(self.y)
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
def custom_batch_sampler():
batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
return iter(batch_idx)
def train(loader):
for epoch in range(10):
for batch, (x, y) in enumerate(loader):
print('epoch:', epoch, 'batch:', batch) # stops after first epoch
if __name__=='__main__':
my_dataset = CustomDataset()
my_loader = torch.utils.data.DataLoader(
dataset=my_dataset,
batch_sampler=custom_batch_sampler()
)
train(my_loader)
训练在第一个时期后停止,并且 next(iter(loader))
给出 StopIteration
错误。
epoch: 0 batch: 0
epoch: 0 batch: 1
epoch: 0 batch: 2
epoch: 0 batch: 3
epoch: 0 batch: 4
最佳答案
自定义批量采样器需要是一个Sampler
或一些可迭代的。在每个纪元中,都会从此可迭代生成一个新的迭代器。这意味着您实际上不需要手动创建一个迭代器(它将在第一个纪元后运行并引发 StopIteration
),但您只需提供您的列表,因此如果您删除iter()
:
def custom_batch_sampler():
batch_idx = [[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]]
return batch_idx
关于python - PyTorch:自定义批量采样器在第一个纪元后耗尽,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/72034010/