python - Pytorch:交叉熵的维度是正确的,但对于 MSE 来说有点错误?

标签 python machine-learning image-processing computer-vision pytorch

我正在创建一个程序,它将接受 Fashion MNIST 集作为输入,并且我正在调整我的模型,看看不同的参数将如何改变准确性。

我对模型所做的调整之一是将模型的损失函数从交叉熵更改为 MSE。

# The code above is miscellaneous training data import code

trainloader = torch.utils.data.DataLoader(trainset, batch_size = 64, shuffle = True, num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size = 64, shuffle = True, num_workers=4)

dataiter = iter(trainloader)
images, labels = dataiter.next()
from torch import nn, optim
import torch.nn.functional as F

model = nn.Sequential(nn.Linear(784, 128),
                      nn.ReLU(),
                      nn.Linear(128, 10),
                      nn.LogSoftmax(dim = 1)
                     )
model.to(device)

# Define the loss
criterion = nn.MSELoss()

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr = 0.001)

# Define the epochs
epochs = 5

train_losses, test_losses = [], []

for e in range(epochs):
  running_loss = 0
  for images, labels in trainloader:
    # Flatten Fashion-MNIST images into a 784 long vector
    images = images.to(device)
    labels = labels.to(device)
    images = images.view(images.shape[0], -1)

    # Training pass
    optimizer.zero_grad()

    output = model.forward(images)

    loss = criterion(output, labels)
    loss.backward()
    optimizer.step()

我的模型在使用交叉熵损失时没有任何问题,但当我更改为 MSE 损失时,解释器提示并说我的张量大小不同,因此无法计算。

<class 'torch.Tensor'>
torch.Size([64, 1, 28, 28])
torch.Size([64])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-62-ec6942122f02> in <module>
     44     output = model.forward(images)
     45 
---> 46     loss = criterion(output, labels)
     47     loss.backward()
     48     optimizer.step()

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
    429 
    430     def forward(self, input, target):
--> 431         return F.mse_loss(input, target, reduction=self.reduction)
    432 
    433 

/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py in mse_loss(input, target, size_average, reduce, reduction)
   2213             ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
   2214     else:
-> 2215         expanded_input, expanded_target = torch.broadcast_tensors(input, target)
   2216         ret = torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
   2217     return ret

/opt/conda/lib/python3.7/site-packages/torch/functional.py in broadcast_tensors(*tensors)
     50                 [0, 1, 2]])
     51     """
---> 52     return torch._C._VariableFunctions.broadcast_tensors(tensors)
     53 
     54 

RuntimeError: The size of tensor a (10) must match the size of tensor b (64) at non-singleton dimension 1

我尝试 reshape 张量并创建新数组作为输出数组的占位符,但似乎毫无进展。

为什么交叉熵损失可以正常工作而不会出现任何错误,而 MSE 却不能?

最佳答案

nn.CrossEntropyLoss nn.MSELoss 是完全不同的损失函数,其背后的原理也根本不同。

nn.CrossEntropyLoss 离散标记任务的损失函数。因此,它期望作为输入标签概率的预测和作为地面实况离散标签的目标:x形状是n xc (其中 c 是标签数量)和 y形状为n integer 类型,每个目标采用 {0,...,c-1} 范围内的值.

相比之下, nn.MSELoss 是回归任务的损失函数。因此,它期望预测和目标具有相同的形状和数据类型。也就是说,如果您的预测是 n xc目标的形状也应该是 n xc (不仅仅是交叉熵情况下的 n)。

如果您坚持使用 MSE 损失而不是交叉熵,则需要将当前拥有的目标整数标签(形状 n )转换为 1-hot vectors形状n xc然后才计算您的预测与生成的 one-hot 目标之间的 MSE 损失。

关于python - Pytorch:交叉熵的维度是正确的,但对于 MSE 来说有点错误?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62422644/

相关文章:

python - 在列表中查找小于或等于键的值

python - Snorkel:我可以在数据集中使用不同的特征来生成标签函数 VS 训练分类器吗?

iphone - 如何为 UIImage 应用 Warping 技术?

python - 从图像中检测水平白线并使用 OpenCV Python 获取它们的坐标

python - 扩展 PIL 的 c 功能

python - 使用 urllib3 进行身份验证

python - numpy 矩阵的数组索引太多

python - 如何使用Tensorflow进行信号处理?

python - 多元时间序列的 LSTM 输入形状?

python - 使用Python在图像中创建 "spotlight"