python - 如何从给定 pytorch 中的 idx 列表的数据集中获取一批样本?

标签 python pytorch dataset dataloader

我有一个 torch.utils.data.Dataset 对象,我想要一个 DataLoader 或接受 idx 列表并返回批处理的类似对象具有相应 idx 的样本。

示例,我有

list_idxs = [10, 109, 7, 12]

我想做的是:

batch = loader.getbatch(list_idxs)

其中批处理包含:

[样本10、样本109、样本7、样本12]

有没有一种简单而优雅的方式来优化这一点?

最佳答案

如果我正确理解您的问题,您可以让 DataLoader 使用自定义 batch_sampler 返回一系列手动选择的批处理(您甚至不需要通过在本例中它是一个采样器)。

给定任意数据集:

>>> from torch.utils.data import DataLoader, Dataset
>>> from torch.utils.data.sampler import Sampler
>>> class MyDataset(Dataset):
...     def __getitem__(self, idx):
...         return idx

然后您可以定义如下内容:

>>> class MyBatchSampler(Sampler):
...     def __init__(self, batches):
...         self.batches = batches
...
...     def __iter__(self):
...         for batch in self.batches:
...             yield batch
...
...     def __len__(self):
...         return len(self.batches)

它只获取包含要包含在每个批处理中的数据集索引的列表。

然后:

>>> dataset = MyDataset()
>>> batch_sampler = MyBatchSampler([[1, 2, 3], [5, 6, 7], [4, 2, 1]])
>>> dataloader = DataLoader(dataset=dataset, batch_sampler=batch_sampler)
>>> for batch in dataloader:
...     print(batch)
... 
tensor([1, 2, 3])
tensor([5, 6, 7])
tensor([4, 2, 1])

应该很容易扩展到您的实际数据集等。

关于python - 如何从给定 pytorch 中的 idx 列表的数据集中获取一批样本?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69121760/

相关文章:

python - 为什么python floor division operator的行为是这样的?

python - 如何从 gpu 内存地址创建 PyCUDA GPUArray?

c++ - 使用 PyTorch、HElib 和 OpenCV 编译 C++ 程序时出现问题

PyTorch C++ - 如何知道 cuDNN 的推荐版本?

c# - Dataset 到数据集中的每个元素

objective-c - 巨大的核心数据对象

python opencv cv2.waitkey错误

python - Rhino.Python 和 .NET 框架

python - C# 相当于 Python 的 ThreadPool

c# - 将数据表加载到数据集中的现有表