python-3.x - “torchmetrics”不适用于 PyTorchLightning

标签 python-3.x metrics pytorch-lightning

我正在尝试了解如何使用 torchmetrics与 PyTorch 闪电。
但是,我得到了与准确性、F1 分数、精度等相同的输出。

这是代码。

metric_acc = torchmetrics.Accuracy()
metric_f1 = torchmetrics.F1()
metric_pre = torchmetrics.Precision()
metric_rec = torchmetrics.Recall()

n_batches = 3
for i in range(n_batches):
    # simulate a classification problem
    preds = torch.randn(10, 5).softmax(dim=-1)
    target = torch.randint(5, (10,))

    acc = metric_acc(preds, target)
    f1 = metric_f1(preds, target)
    pre = metric_pre(preds, target)
    rec = metric_rec(preds, target)
    print(f"Accuracy on batch {i}: {acc}")
    print(f"F1 score on batch {i}: {f1}")
    print(f"pre score on batch {i}: {pre}")
    print(f"rec score on batch {i}: {rec}")
    print('-' * 20)


acc = metric_acc.compute()
f1 = metric_f1.compute()
pre = metric_pre.compute()
rec = metric_rec.compute()
print(f"Accuracy on all data: {acc}")
print(f"f1 score on all data: {f1}")
print(f"pre score on all data: {pre}")
print(f"rec score on all data: {rec}")

结果在这里。

Accuracy on batch 0: 0.10000000149011612
F1 score on batch 0: 0.10000000894069672
pre score on batch 0: 0.10000000149011612
rec score on batch 0: 0.10000000149011612
--------------------
Accuracy on batch 1: 0.30000001192092896
F1 score on batch 1: 0.30000001192092896
pre score on batch 1: 0.30000001192092896
rec score on batch 1: 0.30000001192092896
--------------------
Accuracy on batch 2: 0.4000000059604645
F1 score on batch 2: 0.40000003576278687
pre score on batch 2: 0.4000000059604645
rec score on batch 2: 0.4000000059604645
--------------------
Accuracy on all data: 0.2666666805744171
f1 score on all data: 0.2666666805744171
pre score on all data: 0.2666666805744171
rec score on all data: 0.2666666805744171

Process finished with exit code 0

当我将它与 PyTorchLightning 一起使用时,我得到了相同的结果,所以我用简单的代码尝试并得到了相同的结果。
如果您知道问题或解决方案,请告诉我。
非常感谢。

最佳答案

这样做的原因是,对于多类分类,如果您使用 F1、Precision、ACC 和 Recall with micro(默认),这些是 equivalent metrics并建议你应该使用宏

metric_acc = torchmetrics.Accuracy(average='macro')
metric_f1 = torchmetrics.F1(average='macro')
metric_pre = torchmetrics.Precision(average='macro')
metric_rec = torchmetrics.Recall(average='macro')

关于python-3.x - “torchmetrics”不适用于 PyTorchLightning,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69139618/

相关文章:

python - "plt.figure"有什么意义?

scikit-learn - 如何计算F1分数进行多标签分类?

pytorch-lightning - validation_epoch_end 与 DDP Pytorch Lightning

pytorch - PyTorch Lightning 是否是整个 epoch 的平均指标?

python - PyTorch Lightning 在validation_epoch_end 中将张量移动到正确的设备

python - 如何让程序只有在按下回车键时才继续运行?

python - dict.keys() 返回的键 k 在执行 dict[k] : KeyError on existing key 时导致 KeyError

Python:追加到一个字典值,如果已经存在,该值是一个列表,会添加到所有键的值而不是

c++ - 具有余弦距离的 mlpack 最近邻?

metrics - LOC 计数应该包括测试和评论吗?