python - 如何从 DataLoader 获取样本的文件名?

标签 python machine-learning pytorch torchvision

我需要用我训练的卷积神经网络的数据测试结果编写一个文件。数据包括语音数据采集。文件格式需要是“文件名,预测”,但我很难提取文件名。我这样加载数据:

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

TEST_DATA_PATH = ...

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_dataset = torchvision.datasets.MNIST(
    root=TEST_DATA_PATH,
    train=False,
    transform=trans,
    download=True
)

test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

我正尝试按如下方式写入文件:

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        file = os.listdir(TEST_DATA_PATH + "/all")[i]
        format = file + ", " + str(predicted.item()) + '\n'
        f.write(format)
f.close()

os.listdir(TESTH_DATA_PATH + "/all")[i] 的问题在于它与test_loader 加载文件的顺序不同步。我能做什么?

最佳答案

嗯,这取决于你的 Dataset实现。例如,在 torchvision.datasets.MNIST(...)在这种情况下,您无法检索文件名只是因为没有单个样本的文件名(MNIST 样本是 loaded in a different way )。

因为你没有显示你的 Dataset实现,我将告诉您如何使用 torchvision.datasets.ImageFolder(...) 完成此操作(或任何 torchvision.datasets.DatasetFolder(...) ):

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        sample_fname, _ = test_loader.dataset.samples[i]
        f.write("{}, {}\n".format(sample_fname, predicted.item()))
f.close()

您可以看到在 __getitem__(self, index) 期间检索了文件的路径。 , 特别是 here .

如果您实现了自己的 Dataset (也许想支持 shufflebatch_size > 1 ),然后我会返回 sample_fname__getitem__(...) 上打电话并做这样的事情:

for i, (images, labels, sample_fname) in enumerate(test_loader, 0):
    # [...]

这样你就不需要关心 shuffle .如果batch_size大于 1,您需要更改循环的内容以获得更通用的内容,例如:

f = open("test_y", "w")
for i, (images, labels, samples_fname) in enumerate(test_loader, 0):
    outputs = model(images)
    pred = torch.max(outputs, 1)[1]
    f.write("\n".join([
        ", ".join(x)
        for x in zip(map(str, pred.cpu().tolist()), samples_fname)
    ]) + "\n")
f.close()

关于python - 如何从 DataLoader 获取样本的文件名?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56699048/

相关文章:

python - 映射多个数据框的值并填充列

python - 在线性回归中比较 StandardScaler 与 Normalizer 的结果

python - 运行时错误 ("grad can be implicitly created only for scalar outputs")

python - 如何用 Pandas 创建高低开收盘图表

python - 我需要 Apache 吗?

machine-learning - 如果数据集具有多个特征,您如何知道它是否适合线性回归?

python - 如何部署机器学习模型以使用多个特征数据进行预测

pytorch - 在 Pytorch 中裁剪一小批图像——每个图像都不同

neural-network - 如何仅在一个维度(按行或按列)将线性层应用于二维层 - 部分连接层

python - matplotlib 中的关闭步骤