Pytorch DataLoader - 选择类 STL10 数据集

标签 pytorch torch torchvision

是否可以仅在 PyTorch torchvision 中的 STL10 数据集中提取 class = 0 的位置?我能够循环检查它们,但需要接收批量的 0 类图像

# STL10 dataset
train_dataset = torchvision.datasets.STL10(root='./data/',
                                           transform=transforms.Compose([
                                               transforms.Grayscale(),
                                               transforms.ToTensor()
                                           ]),
                                           split='train',
                                           download=True)


# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

for i, (images, labels) in enumerate(train_loader):
    if labels[0] == 0:...

根据 iacolippo 的答案进行编辑 - 现在可以正常工作:

# Set params
batch_size = 25
label_class = 0   # only airplane images

# Return only images of certain class (eg. airplanes = class 0)
def get_same_index(target, label):
    label_indices = []

    for i in range(len(target)):
        if target[i] == label:
            label_indices.append(i)

    return label_indices

# STL10 dataset
train_dataset = torchvision.datasets.STL10(root='./data/',
                                           transform=transforms.Compose([
                                               transforms.Grayscale(),
                                               transforms.ToTensor()
                                           ]),
                                           split='train',
                                           download=True)

# Get indices of label_class
train_indices = get_same_index(train_dataset.labels, label_class)

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices))

最佳答案

如果您只需要一个类的样本,则可以从 Dataset 实例中获取具有相同类的样本的索引,方法如下

def get_same_index(target, label):
    label_indices = []

    for i in range(len(target)):
        if target[i] == label:
            label_indices.append(i)

    return label_indices

然后您可以使用SubsetRandomSampler仅从一个类的索引列表中抽取样本

torch.utils.data.sampler.SubsetRandomSampler(indices)

关于Pytorch DataLoader - 选择类 STL10 数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51334858/

相关文章:

deep-learning - 如何修复 "RuntimeError: Function AddBackward0 returned an invalid gradient at index 1 - expected type torch.FloatTensor but got torch.LongTensor"

image-processing - 如何为 ImageNet 加速 "ImageFolder"

python - tensorflow 的实现比 torch 的慢 2 倍

linux - 错误 : Lapack library not found in compile time (zerobrane, lua,手电筒)

image-processing - 使用卷积神经网络作为二元分类器

python - PyTorch-获取 'TypeError: pic should be PIL Image or ndarray. Got <class ' numpy.ndarray'>'错误

python - Pytorch model.train() 和教程中编写的单独的 train() 函数

同一环境中的 Tensorflow 和 Torch

pytorch - 如何从 Pytorch 获取 Onnx 模型中的动态批量大小?