我正在使用 PyTorch 1.8 和 Python 3.8 使用以下代码从文件夹中读取图像:
print(f"PyTorch version: {torch.__version__}")
# PyTorch version: 1.8.1
# Device configuration-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"currently available device: {device}")
# currently available device: cpu
# Define transformations for training and test sets-
transform_train = transforms.Compose(
[
# transforms.RandomCrop(32, padding = 4),
# transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)
transform_test = transforms.Compose(
[
transforms.ToTensor(),
# transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)
# Define directory containing images-
data_dir = 'My_Datasets/Cat_Dog_data/'
# Define datasets-
train_data = datasets.ImageFolder(data_dir + '/train',
transform = train_transforms)
test_data = datasets.ImageFolder(data_dir + '/test',
transform = test_transforms)
print(f"number of train images = {len(train_data)} & number of validation images = {len(test_data)}")
# number of train images = 22500 & number of validation images = 2500
print(f"number of training classes = {len(train_data.classes)} & number of validation classes = {len(test_data.classes)}")
# number of training classes = 2 & number of validation classes = 2
# Define data loaders-
trainloader = torch.utils.data.DataLoader(train_data, batch_size = 32)
testloader = torch.utils.data.DataLoader(test_data, batch_size = 32)
len(trainloader), len(testloader)
# (704, 79)
# Sanity check-
len(train_data) / 32, len(test_data) / 32
您可以使用“train_loader”遍历火车数据,如下所示:
for img, lab in train_loader:
print(img.shape, lab.shape)
pass
但是,我有兴趣获取文件名以及从中读取文件的文件路径。我怎样才能做到这一点?
谢谢!
最佳答案
默认ImageFolder
Dataset
保存了self.samples
中所有图片的路径。您需要做的就是修改 __getitem__
以返回路径。
关于pytorch - 使用 PyTorch 数据加载器获取文件名和文件路径,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68112479/