我有一个 torch.utils.data.Dataset
对象,我想要一个 DataLoader
或接受 idx 列表并返回批处理的类似对象具有相应 idx 的样本。
示例,我有
list_idxs = [10, 109, 7, 12]
我想做的是:
batch = loader.getbatch(list_idxs)
其中批处理包含:
[样本10、样本109、样本7、样本12]
有没有一种简单而优雅的方式来优化这一点?
最佳答案
如果我正确理解您的问题,您可以让 DataLoader
使用自定义 batch_sampler
返回一系列手动选择的批处理(您甚至不需要通过在本例中它是一个采样器
)。
给定任意数据集
:
>>> from torch.utils.data import DataLoader, Dataset
>>> from torch.utils.data.sampler import Sampler
>>> class MyDataset(Dataset):
... def __getitem__(self, idx):
... return idx
然后您可以定义如下内容:
>>> class MyBatchSampler(Sampler):
... def __init__(self, batches):
... self.batches = batches
...
... def __iter__(self):
... for batch in self.batches:
... yield batch
...
... def __len__(self):
... return len(self.batches)
它只获取包含要包含在每个批处理中的数据集索引的列表。
然后:
>>> dataset = MyDataset()
>>> batch_sampler = MyBatchSampler([[1, 2, 3], [5, 6, 7], [4, 2, 1]])
>>> dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler)
>>> for batch in dataloader:
... print(batch)
...
tensor([1, 2, 3])
tensor([5, 6, 7])
tensor([4, 2, 1])
应该很容易扩展到您的实际数据集等。
关于python - 如何从给定 pytorch 中的 idx 列表的数据集中获取一批样本?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69121760/