pytorch - 多类多目标分类问题的最佳损失函数

标签 pytorch classification multilabel-classification multiclass-classification

我有一个分类问题,我不知道如何对这个分类问题进行分类。据我了解,

A Multiclass classification problem is where you have multiple mutually exclusive classes and each data point in the dataset can only be labelled by one class. For example, in an Image Classification task for fruits, a fruit data point labelled as an apple cannot be an orange and an orange cannot be a banana and so on. Each data point, in this case can only be any one of the fruits of the fruits class and so is labelled accordingly.

哪里...

A Multilabel classification is a problem where you have multiple sets of mutually exclusive classes of which the data point can be labelled simultaneously. For example, in an Image Classification task for Cars, a car data point labelled as a sedan cannot be a hatchback and a hatchback cannot be a SUV and so on for the type of car. At the same time, the same car data point can be labelled one from VW, Ford, Mercedes, etc. as the car manufacturer. So in this case, the car data point is labeled from two different sets of mutually exclusive classes.

如果我的理解有误,请指正。

现在是我的问题,我的多类分类问题,假设是 A、B、C、D 和 E。这里每个数据点可以有一个或多个来自集合的类,如下左所示:

|-------------|----------|              |-------------|-----------------|
|      X      |     y    |              |      X      |    One-Hot-Y    |
|-------------|----------|              |-------------|-----------------|
|     DP1     |   A, B   |              |     DP1     | [1, 1, 0, 0, 0] |
|-------------|----------|              |-------------|-----------------|
|     DP2     |   C      |              |     DP2     | [0, 0, 1, 0, 0] |
|-------------|----------|              |-------------|-----------------|
|     DP3     |   B, E   |              |     DP3     | [0, 1, 0, 0, 1] |
|-------------|----------|              |-------------|-----------------|
|     DP4     |   A, C   |              |     DP4     | [1, 0, 1, 0, 0] |
|-------------|----------|              |-------------|-----------------|
|     DP5     |   D      |              |     DP5     | [0, 0, 0, 1, 0] |
|-------------|----------|              |-------------|-----------------|

I One-Hot 编码训练标签,如右上图所示。我的问题是:

  1. 我可以使用什么损失函数(最好在 PyTorch 中)来训练模型以针对 One-Hot 编码输出进行优化
  2. 我们怎么称呼这样的分类问题?多标签还是多类?

感谢您的回答!

最佳答案

What Loss function (preferably in PyTorch) can I use for training the model to optimize for the One-Hot encoded output

您可以使用 torch.nn.BCEWithLogitsLoss (或 MultiLabelSoftMarginLoss,因为它们是等价的),看看这个结果如何。这是标准方法,其他可能性可能是 MultilabelMarginLoss .

What do we call such a classification problem? Multi-label or Multi-class?

它是多标签的(因为可以同时存在多个标签)。在单热编码中:

[1, 1, 0, 0, 0], [0, 1, 0, 0, 1] - multilabel
[0, 0, 1, 0, 0] - multiclass
[1], [0] - binary (special case of multiclass)

多类不能有多个 1,因为所有其他标签都是互斥的。

关于pytorch - 多类多目标分类问题的最佳损失函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64634902/

相关文章:

python - 如何禁用 tqdm 的进度条并仅保留 Pytorch Lightning(或一般的 tqdm)中的文本信息

android - android中U2Net模型的使用

neural-network - 带有多标签图像的咖啡

python - 使用 Out of Core 进行 Scikit Learn 多标签分类

deep-learning - 在 PyTorch 的 "MaxPool2D"中,是否根据 "ceil_mode"添加了填充?

java - OpenCV - 测试分类器的准确性?

python - 财经新闻的机器学习

python - Theano:如何将所需输出(1d)和标签之间的距离实现为成本函数

machine-learning - sklearn - 预测每个类别的概率

python - 如何在 nn.LSTM pytorch 中取得 R2 分数