python - 了解多类分类的 tf.keras.metrics.Precision 和 Recall

标签 python tensorflow machine-learning keras tf.keras

我正在为多类分类问题构建模型。所以我想使用召回率和精度来评估模型性能。 我的数据集中有 4 个类,它在 one hot 中提供。代表。

我正在阅读PrecisionRecall tf.keras文档,并有一些问题:

  1. 在计算多类分类的 Precision 和 Recall 时,我们如何取所有标签的平均值,即全局精度和 Recall?是用 macro 计算的或micro因为它没有在文档中指定,如Sikit learn .
  2. 如果我想分别计算每个标签的精度和召回率,我可以使用参数 class_id对于每个标签要做one_vs_restbinary分类。就像我在下面的代码中所做的那样?
  3. 我可以使用参数 top_k值为 top_k=2在这里会有帮助还是不适合我仅对 4 个类别进行分类?
  4. 当我测量每个类别的表现时,当我设置 top_k=1 时,可能会有什么差异?并且不设置 top_k总体而言?
model.compile(
      optimizer='sgd',
      loss=tf.keras.losses.CategoricalCrossentropy(),
      metrics=[tf.keras.metrics.CategoricalAccuracy(),
               ##class 0
               tf.keras.metrics.Precision(class_id=0,top_k=2), 
               tf.keras.metrics.Recall(class_id=0,top_k=2),
              ##class 1
               tf.keras.metrics.Precision(class_id=1,top_k=2), 
               tf.keras.metrics.Recall(class_id=1,top_k=2),
              ##class 2
               tf.keras.metrics.Precision(class_id=2,top_k=2), 
               tf.keras.metrics.Recall(class_id=2,top_k=2),
              ##class 3
               tf.keras.metrics.Precision(class_id=3,top_k=2), 
               tf.keras.metrics.Recall(class_id=3,top_k=2),
])

对此功能的任何澄清将不胜感激。 提前致谢

最佳答案

<强>3。我可以使用参数 top_k 和值 top_k=2 在这里会有帮助吗?或者它不适合我仅对 4 个类进行分类?

根据描述,如果使用此参数,它只会计算 top_k(使用 _filter_top_k 函数)预测,并将其他预测变为 False

示例来自官方文档链接:https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Precision

您可能还想阅读原始代码: https://github.com/keras-team/keras/blob/07e13740fd181fc3ddec7d9a594d8a08666645f6/keras/utils/metrics_utils.py#L487 当top_k=2时,它将计算y_true[:2]和y_pred[:2]的精度

m = tf.keras.metrics.Precision(top_k=2)
m.update_state([0, 0, 1, 1], [1, 1, 1, 1])
m.result().numpy()
0.0

正如我们在示例中看到的注释,它只会计算 y_true[:2] 和 y_pred[:2],这意味着精度将仅计算前 2 个预测(还将 y_pred 的其余部分转为 0) )。

如果你想使用4类分类,class_id参数可能就足够了。

4.当我测量每个类别的性能时,设置top_k=1和不设置top_koverall时会有什么区别? 如果您未设置 top_k 值,该函数将计算模型所做的所有预测的精度。如果你想衡量性能。

Top k 可能适用于其他模型,不适用于分类模型

关于python - 了解多类分类的 tf.keras.metrics.Precision 和 Recall,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/72720129/

相关文章:

python - 无法使用 pickle 和多个模块加载文件

string - 学习垃圾邮件发送者的名字

python-3.x - 使用自然语言处理从段落中提取特定类型的单词

python - 返回预测 wav2vec fairseq

python - B 和 C 不工作(Python3)

python - 当包含有意义的空格时,如何编写与 re.VERBOSE 一起使用的模式?

python - 在 Python 中查找图的树分解

python - tensorflow 多 GPU 训练

tensorflow - 如何在 TensorFlow 中查找第一个匹配元素的索引

python - 无法在pycharm中使用tensorflow