我正在尝试定义二分类问题的损失函数。但是,目标标签不是硬标签0,1,而是0~1之间的一个 float 。
Pytorch 中的 torch.nn.CrossEntropy 不支持软标签,所以我想自己写一个交叉熵函数。
我的函数是这样的
def cross_entropy(self, pred, target):
loss = -torch.mean(torch.sum(target.flatten() * torch.log(pred.flatten())))
return loss
def step(self, batch: Any):
x, y = batch
logits = self.forward(x)
loss = self.criterion(logits, y)
preds = logits
# torch.argmax(logits, dim=1)
return loss, preds, y
然而它根本不起作用。
谁能给我一个建议,我的损失函数有没有错误?
最佳答案
好像BCELoss
和健壮的版本 BCEWithLogitsLoss
正在“开箱即用”地处理模糊目标。他们不希望 target
是二进制的“0 到 1 之间的任何数字都可以。
请阅读文档。
关于pytorch - Pytorch 中软标签的交叉熵,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70429846/