pytorch - 使用 PyTorch 数据加载器获取文件名和文件路径

标签 pytorch python-3.8

我正在使用 PyTorch 1.8 和 Python 3.8 使用以下代码从文件夹中读取图像:

print(f"PyTorch version: {torch.__version__}")
# PyTorch version: 1.8.1

# Device configuration-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"currently available device: {device}")
# currently available device: cpu


# Define transformations for training and test sets-
transform_train = transforms.Compose(
    [
      # transforms.RandomCrop(32, padding = 4),
      # transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
     ]
     )

transform_test = transforms.Compose(
    [
      transforms.ToTensor(),
      # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
     ]
     )

# Define directory containing images-
data_dir = 'My_Datasets/Cat_Dog_data/'

# Define datasets-
train_data = datasets.ImageFolder(data_dir + '/train', 
                                  transform = train_transforms)
test_data = datasets.ImageFolder(data_dir + '/test', 
                                 transform = test_transforms)

print(f"number of train images = {len(train_data)} & number of validation images = {len(test_data)}")
# number of train images = 22500 & number of validation images = 2500

print(f"number of training classes = {len(train_data.classes)} & number of validation classes = {len(test_data.classes)}")
# number of training classes = 2 & number of validation classes = 2

# Define data loaders-
trainloader = torch.utils.data.DataLoader(train_data, batch_size = 32)
testloader = torch.utils.data.DataLoader(test_data, batch_size = 32)

len(trainloader), len(testloader)
# (704, 79)

# Sanity check-
len(train_data) / 32, len(test_data) / 32

您可以使用“train_loader”遍历火车数据,如下所示:

for img, lab in train_loader:
   print(img.shape, lab.shape)
   pass

但是,我有兴趣获取文件名以及从中读取文件的文件路径。我怎样才能做到这一点?

谢谢!

最佳答案

默认ImageFolder Dataset保存了self.samples中所有图片的路径。您需要做的就是修改 __getitem__ 以返回路径。

关于pytorch - 使用 PyTorch 数据加载器获取文件名和文件路径,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68112479/

相关文章:

python - 在 Python 中 bool(datetime) 可以评估为 False 吗?

python - 如何更改 Raspberry Pi 中的默认 Python 版本

python - 用于 ubuntu 14.04(cuda 8.0、python2.7.6、pip 19.0.1)上 gpu 安装错误的 pytorch - 不支持轮子

android - 如何构建Pytorch Mobile示例HelloWorld应用程序?

python - "ImportError: no suitable image found"使用 BeautifulSoup 和 Python3.8

python - Pathlib 'normalizes' 带有 "$"的 UNC 路径

python - 无法导入模块 'lambda_function' : No module named 'psycopg2. _psycopg aws lambda 函数

python - 在 numpy 中获取 3D 张量的所有 2D 对角线

python - 如何对 pandas、torch 和 numpy 输入使用打字覆盖

python - 在 PyTorch 中,如何通过损失列表中的平均梯度更新神经网络?