我正在看这里的教程:https://pytorch.org/tutorials/beginner/fgsm_tutorial.html
import torch.nn.functional as F
loss = F.nll_loss(output, target)
上面两行代码中,“目标”到底是什么?他们加载目标数据集,但从不讨论它到底是什么。文档也很难理解。
最佳答案
通过运行以下代码来检查自己:
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
])),
batch_size=1, shuffle=True)
for data, target in test_loader:
print(data, target)
break
这里,data
基本上是灰度 MNIST 图像,target
是 0
和 9
之间的标签。
因此,在 loss = F.nll_loss(output, target)
中,output
是模型预测(模型在给出图像/数据时预测的内容),并且 target
是给定图像的实际标签。
此外,在上面的示例中,检查以下行:
output = model(data) # shape [1, 10]
init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
# If the initial prediction is wrong, don't bother attacking, just move on
if init_pred.item() != target.item():
continue
# Calculate the loss
loss = F.nll_loss(output, target)
在上面的代码中,只有那些 output-target
对被传递到 F.nll_loss
损失函数,其中模型预测正确。如果无法正确预测标签,则跳过此后的所有操作(包括损失计算)并继续 test_loader
中的下一个示例。
关于python - 在 PyTorch 中,nll_loss 的输入是什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57229669/