pytorch - 如何在 Pytorch 中选择数据集的子集?

标签 pytorch dataset subset

我正在尝试运行 https://github.com/menardai/FashionGenAttnGAN在 Google Colab 上的 GPU 上,磁盘大小为 30 GB。代码文件及其数据集文件约为 15 GB。提取此代码后,磁盘剩余空间约为 14 GB。当我尝试运行 Pretrain.py 时,我可以看到字幕正在加载,但突然出现“断言错误”。由于我没有得到任何正确的答案来解释这个错误的原因,我认为这是因为我的 Colab 环境空间不足。我想到的解决方案是编写一些代码来告诉模型仅选择 30% 的训练和测试数据集进行加载。但我不知道该怎么做。有人可以帮我吗?

最佳答案

data是您的总数据,您可以将其划分为您想要的数量,只需编辑valid_size即可。

valid_size=0.3
num_train = len(data)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]

# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

# prepare data loaders (combine dataset and sampler)
train_loader = torch.utils.data.DataLoader(data, batch_size=4,
    sampler=train_sampler, num_workers=2)
valid_loader = torch.utils.data.DataLoader(data, batch_size=4, 
    sampler=valid_sampler, num_workers=2)

如果出现内存问题,只需减少batch_size即可。

关于pytorch - 如何在 Pytorch 中选择数据集的子集?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67572438/

相关文章:

python - 修复 torchvision 变换的随机种子

python - PyTorch 中 tensor.permute 和 tensor.view 的区别?

algorithm - 定义重叠元素或不包含在子集中的元素

r - 使用两个不同的状态值子集重复值

numpy - TypeError : can’t convert CUDA tensor to numpy. 首先使用 Tensor.cpu() 将张量复制到主机内存

python - HTTP 错误 : HTTP Error 403: Forbidden on Google Colab

tensorflow - 使用 tensorflow_datasets API 访问已下载的数据集

python - 在字典列表中搜索名称 python

mysql - 如何改进尝试在大型数据库中查找重复条目的 MySql 查询?

python - 如何从n个列表中形成所有可能的组合