pytorch - Pytorch 中软标签的交叉熵

标签 pytorch cross-entropy

我正在尝试定义二分类问题的损失函数。但是,目标标签不是硬标签0,1,而是0~1之间的一个 float 。

Pytorch 中的 torch.nn.CrossEntropy 不支持软标签,所以我想自己写一个交叉熵函数。

我的函数是这样的

def cross_entropy(self, pred, target):
    loss = -torch.mean(torch.sum(target.flatten() * torch.log(pred.flatten())))
    return loss

def step(self, batch: Any):
    x, y = batch
    logits = self.forward(x)
    loss = self.criterion(logits, y)
    preds = logits
    # torch.argmax(logits, dim=1)
    return loss, preds, y

然而它根本不起作用。

谁能给我一个建议,我的损失函数有没有错误?

最佳答案

好像BCELoss和健壮的版本 BCEWithLogitsLoss正在“开箱即用”地处理模糊目标。他们不希望 target 是二进制的“0 到 1 之间的任何数字都可以。
请阅读文档。

关于pytorch - Pytorch 中软标签的交叉熵,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70429846/

相关文章:

python - pytorch中引入nn.Parameter的目的

python - 如何使用 PyTorch 多处理?

deep-learning - VGG16 预测的随机性

python - 在 PyTorch 中使用 module.to() 移动成员张量

pytorch - 有谁知道为什么 SHAP 的 Deep Explainer 在 ResNet-50 预训练模型上失败?

torch - pytorch中的交叉熵损失 nn.CrossEntropyLoss()

keras - 如何在 keras 中计算非 0 或 1 的目标值的交叉熵

neural-network - Tensorflow:具有交叉熵的缩放 logits

python - Pytorch 闪电指标 : ValueError: preds and target must have same number of dimensions, 或 preds 的一个附加维度

python - Pytorch-运行时错误 : Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward