我有一个机器学习模型,可以进行多标签文本分类。我有一个预测器对象,它成功预测我用作输入的文本字符串的分类。它将其预测分配给单个预测,作为列表,如下所示:
[('unrelated', 0.9684208035469055), ('curated', 0.02895800955593586)]
我觉得这可能很简单,但本质上我只需要 为策划的比赛创建一个阈值。
因此,如果 curated 的置信度高于 0.90 或类似值,我可以打印一份声明。
但是,我不知道如何指定这个条件。
它是一个列表对象,所以我尝试指定索引。然而,每个索引都会输出['label',confidence]
。此外,索引的顺序根据置信度进行切换。它始终首先显示最高级别的置信度标签。因此指定索引号不会有太大帮助,因为它会改变。
single_prediction = predictor.predict(result)
df.at[0,'prediction'] = single_prediction
if single_prediction[0] >= .95:
print('this is a match')
print(single_prediction)
最佳答案
您可以使用列表理解来做到这一点:
results = [ [('curated', 0.6), ('unrelated', 0.4)],
[('unrelated', 0.55), ('curated', 0.45)],
[('unrelated', 0.7), ('curated', 0.3)]]
threshold = 0.4
for result in results:
if [x[1] for x in result if x[0] == 'curated'][0] > threshold:
print(result)
<小时/>
输出:
[('curated', 0.6), ('unrelated', 0.4)]
[('unrelated', 0.55), ('curated', 0.45)]
关于python - 如何为输出预测设置条件阈值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56840085/