python - 交叉熵损失与重量手动计算

标签 python pytorch loss-function cross-entropy

嗨,只是在玩代码,我得到了交叉熵损失权重实现的意外结果。

pred=torch.tensor([[8,5,3,2,6,1,6,8,4],[2,5,1,3,4,6,2,2,6],[1,1,5,8,9,2,5,2,8],[2,2,6,4,1,1,7,8,3],[2,2,2,7,1,7,3,4,9]]).float()
label=torch.tensor([[3],[7],[8],[2],[5]],dtype=torch.int64)
weights=torch.tensor([1,1,1,10,1,6,1,1,1],dtype=torch.float32)

对于这种样本变量,pytorch 的交叉熵损失为 4.7894

loss = F.cross_entropy(pred, label, weight=weights,reduction='mean')
> 4.7894

我手动实现了交叉熵损失代码,如下

one_hot = torch.zeros_like(pred).scatter(1, label.view(-1, 1), 1)
log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb).sum(dim=1).mean()

如果没有给出权重值,这种实现与 pytorch 的交叉熵函数给出相同的结果。但是有重量值

one_hot = torch.zeros_like(pred).scatter(1, label.view(-1, 1), 1)
log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb)*weights.sum(dim=1).sum()/weights.sum()
> 3.9564

它使用 pytorch 模块给出了不同的损失值(4.7894)。 我可以粗略地估计我对loss的权重的理解这里有一些问题,但是我无法找出这种差异的确切原因。 有人可以帮我解决这个问题吗?

最佳答案

我发现了问题。这很简单...... 我不应该除以总重量。 相反,用 wt.sum() (wt=one_hot*weight) 除法得到了 4.7894

>>> wt = one_hot*weights
>>> loss = -(one_hot * log_prb * weights).sum(dim=1).sum() / wt.sum()
4.7894

分母仅与“相关”权重值有关,而不是整数。

关于python - 交叉熵损失与重量手动计算,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68727252/

相关文章:

pytorch - pytorch中变换和目标变换之间的区别?

python - 将 PyTorch 计算机视觉深度学习模型部署到 Windows 桌面应用程序中

python - 类型错误 : setup() got an unexpected keyword argument 'stage'

python - 对于变分自动编码器,重建损失应该计算为图像的总和还是平均值?

python - 将简单的日期时间提升为系统时区

Python 2.7.10 Tkinter 在文本框中的文本周围使用大括号

python - 计算不同长度的 3 维数组的平均值

python - python3 类型转换是否输入了 input() 值?

python - Pytorch 自定义损失函数与 If 语句

python - ValueError : Shapes must be equal rank, 但是是 1 和 0 来自将形状 1 与其他形状合并。对于 'loss/AddN'