pytorch - 如何将 Pytorch DataLoader 用于具有多个标签的数据集

标签 pytorch pytorch-dataloader

我想知道如何在 Pytorch 中创建一个支持多种类型标签的 DataLoader。我该怎么做?

最佳答案

您可以为数据集中的每个项目返回标签的 dict,DataLoader 足够智能,可以为您整理它们。即,如果您为每个项目提供一个 dict,DataLoader 将返回一个 dict,其中键是标签类型。访问该标签类型的键会返回该标签类型的整理张量。

见下文:

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class M(Dataset):
    def __init__(self):
        super().__init__()
        self.data = np.random.randn(20, 2)
        print(self.data)

    def __getitem__(self, i):
        return self.data[i], {'label_1':self.data[i], 'label_2':self.data[i]}

    def __len__(self):
        return len(self.data)

ds = M()
dl = DataLoader(ds, batch_size=6)

for x, y in dl:
    print(x, '\n', y)
    print(type(x), type(y))
[[-0.33029911  0.36632142]
 [-0.25303721 -0.11872778]
 [-0.35955625 -1.41633132]
 [ 1.28814629  0.38238357]
 [ 0.72908184 -0.09222787]
 [-0.01777293 -1.81824167]
 [-0.85346074 -1.0319562 ]
 [-0.4144832   0.12125039]
 [-1.29546792 -1.56314292]
 [ 1.22566887 -0.71523568]]
tensor([[-0.3303,  0.3663],
        [-0.2530, -0.1187],
        [-0.3596, -1.4163]], dtype=torch.float64) 
 {'item_1': tensor([[-0.3303,  0.3663],
        [-0.2530, -0.1187],
        [-0.3596, -1.4163]], dtype=torch.float64), 'item_2': tensor([[-0.3303,  0.3663],
        [-0.2530, -0.1187],
        [-0.3596, -1.4163]], dtype=torch.float64)}
<class 'torch.Tensor'> <class 'dict'>
...

关于pytorch - 如何将 Pytorch DataLoader 用于具有多个标签的数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66446881/

相关文章:

python - 如何在下面的代码中正确使用 collat​​e_fn ?

python - pytorch中DataLoader的洗牌顺序

python - pytorch Faster-RCNN 的验证损失

pytorch - ONNX 和 pytorch 的输出不同

python - Pytorch RuntimeError : [enforce fail at CPUAllocator. cpp :56] posix_memalign(&data, gAlignment, nbytes) == 0. 12 vs 0

python-3.x - 如何解决错误 : RuntimeError: received 0 items of ancdata

python - PyTorch 数据集/Dataloader 批处理

python - PyTorch:如何从张量中采样,其中张量中的每个值都有不同的被选择可能性?

pytorch - 如何在 PyTorch 中为特定张量释放 GPU 内存?

python - 使用 PyTorchVideo 加载用于训练视频分类模型的动力学数据集时出错