python - 在 PyTorch 中,nll_loss 的输入是什么?

标签 python pytorch

我正在看这里的教程:https://pytorch.org/tutorials/beginner/fgsm_tutorial.html

import torch.nn.functional as F
loss = F.nll_loss(output, target)

上面两行代码中,“目标”到底是什么?他们加载目标数据集,但从不讨论它到底是什么。文档也很难理解。

最佳答案

通过运行以下代码来检查自己:

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([
            transforms.ToTensor(),
            ])),
        batch_size=1, shuffle=True)
for data, target in test_loader:
    print(data, target)
    break

这里,data 基本上是灰度 MNIST 图像,target09 之间的标签。

因此,在 loss = F.nll_loss(output, target) 中,output 是模型预测(模型在给出图像/数据时预测的内容),并且 target 是给定图像的实际标签。

此外,在上面的示例中,检查以下行:

output = model(data) # shape [1, 10]
init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability

# If the initial prediction is wrong, don't bother attacking, just move on
if init_pred.item() != target.item():
   continue

# Calculate the loss
loss = F.nll_loss(output, target)

在上面的代码中,只有那些 output-target 对被传递到 F.nll_loss 损失函数,其中模型预测正确。如果无法正确预测标签,则跳过此后的所有操作(包括损失计算)并继续 test_loader 中的下一个示例。

关于python - 在 PyTorch 中,nll_loss 的输入是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57229669/

相关文章:

python - 使用 Cloud Storage 在一个 Google App Engine 应用程序中备份并在另一个应用程序中恢复?

python - Pytorch 几何稀疏邻接矩阵到边索引张量

machine-learning - PyTorch 或 TensorFlow 中可以使用位置相关的卷积滤波器吗?

python - 使用 pytorch 进行多类句子分类(使用 nn.LSTM)

python - 如何实现 PyTorch 数据集以用于 AWS SageMaker?

python - 如何正确处理python中的可选特性

python - PyQt 线程。获取动态输出

python - 语音识别未知值错误

python - pytorch:获取给定 ImageFolder 数据集的类数

python - 将 QuantLib 导入为 ql 错误