python - 如何在 PyTorch 中使用具有焦点损失的类权重用于多类分类的不平衡数据集

标签 python machine-learning deep-learning neural-network pytorch

我正在研究语言任务的多类分类(4 个类),我正在使用 BERT 模型进行分类任务。我正在关注 this blog as reference . 我的 BERT 微调模型返回 nn.LogSoftmax(dim=1) .
我的数据非常不平衡,所以我使用了 sklearn.utils.class_weight.compute_class_weight计算类的权重并使用损失中的权重。

class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)
weights= torch.tensor(class_weights,dtype=torch.float)
cross_entropy  = nn.NLLLoss(weight=weights) 

我的结果不太好,所以我想用 Focal Loss 做实验。并有一个 Focal Loss 代码。
class FocalLoss(nn.Module):
  def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
    super(FocalLoss, self).__init__()
    self.alpha = alpha
    self.gamma = gamma
    self.logits = logits
    self.reduce = reduce

  def forward(self, inputs, targets):
    BCE_loss = nn.CrossEntropyLoss()(inputs, targets)

    pt = torch.exp(-BCE_loss)
    F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

    if self.reduce:
      return torch.mean(F_loss)
    else:
      return F_loss
我现在有3个问题。首先也是最重要的是
  • 我应该使用带有焦点损失的类权重吗?
  • 如果我必须在此内部实现权重 Focal Loss , 我可以用 weights里面的参数 nn.CrossEntropyLoss()
  • 如果这个工具不正确,那么这个工具的正确代码应该是什么,包括权重(如果可能)
  • 最佳答案

    您可以通过以下方式找到问题的答案:

  • 焦点损失会自动处理类别不平衡,因此焦点损失不需要权重。 alpha 和 gamma 因子处理焦点损失方程中的类不平衡。
  • 不需要额外的权重,因为焦点损失使用 alpha 和 gamma 调制因子处理它们
  • 根据焦点损失公式,您提到的实现是正确的,但是我无法使我的模型与此版本收敛,因此,我使用了 the following implementation from mmdetection framework
  •     pred_sigmoid = pred.sigmoid()
        target = target.type_as(pred)
        pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
        focal_weight = (alpha * target + (1 - alpha) *
                        (1 - target)) * pt.pow(gamma)
        loss = F.binary_cross_entropy_with_logits(
            pred, target, reduction='none') * focal_weight
        loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
        return loss
    
    您也可以尝试使用 another focal loss version available

    关于python - 如何在 PyTorch 中使用具有焦点损失的类权重用于多类分类的不平衡数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64751157/

    相关文章:

    python - 我在 virtualenv 中安装的 Django 缺少管理模板文件夹

    python - SyntaxError : multiple statements found in python. 有人有这方面的经验吗?

    python - Hyperopt 参数空间 : TypeError: int() argument must be a string or a number, 不是 'Apply'

    r - 如何在R中计算决策树规则

    python - 只能使用 TensorFlow 中处理梯度的代码示例来实现类似优化器的梯度下降吗?

    Python 3 错误 : ValueError: invalid literal for int() with base 10: ''

    python - 设置 matplotlib 视频动画大小 (1920x1080)

    machine-learning - 多层网络预测简单函数

    python - pytorch中当输入参数超过两个时如何使用forward()方法

    python - `for` 循环到PyTorch中的多维数组