python - 如何从管道中获取预测标签和百分比?

标签 python huggingface-transformers

我正在使用以下 Hugging Face 转换器代码。

from transformers import pipeline
classifier = pipeline("sentiment-analysis",model='bhadresh-savani/distilbert-base-uncased-emotion', return_all_scores=True)
prediction = classifier("I love using transformers. The best part is wide range of support and its easy to use", )
print(prediction)

结果是:

[[{'label': 'sadness', 'score': 0.010154823772609234}, {'label': 'joy', 'score': 0.5637667179107666}, {'label': 'love', 'score': 0.4066571295261383}, {'label': 'anger', 'score': 0.01734882965683937}, {'label': 'fear', 'score': 0.0011737244203686714}, {'label': 'surprise', 'score': 0.0008987095206975937}]]

我想知道如何与标签一起获得最高分,但我不确定如何迭代这个对象,或者是否有一个简单的表达式方法。

最佳答案

您正在使用 TextClassificationPipeline 。当您__call__管道时,如果return_all_scores=False,您将获得一个dict列表,或者一个dict列表如果 return_all_scores=True 按照documentation .

您可以将return_all_scores设置为False(默认值),然后访问您想要的值 - 在这种情况下,您只会获得分数和文本最高分标签:

from transformers import pipeline
classifier = pipeline("sentiment-analysis",model='bhadresh-savani/distilbert-base-uncased-emotion', return_all_scores=False)
prediction = classifier("I love using transformers. The best part is wide range of support and its easy to use")
print(prediction[0]["label"], prediction[0]["score"])

或者,如果您想要所有标签的分数,然后仅访问得分最高的标签 (return_all_scores=True):

from transformers import pipeline
classifier = pipeline("sentiment-analysis",model='bhadresh-savani/distilbert-base-uncased-emotion', return_all_scores=True)
prediction = classifier("I love using transformers. The best part is wide range of support and its easy to use")
print(max(prediction[0], key=lambda k: k["score"]))

关于python - 如何从管道中获取预测标签和百分比?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/72014538/

相关文章:

python - 了解 gpu 使用 huggingface 分类 - 总优化步骤

python - 对于语言模型微调(BERT 通过 Huggingface Transformers),输入文件的格式究竟应该如何设置?

nlp - 在 HuggingFace 库中基于 BERT 的模型中,merge.txt 文件是什么意思?

python - 如何在 VS Code 中安装外部 python 库?

python - 如何用Python接收Watson Speech to Text SDK的全部输出?

python - 无法为 prange 中的数组赋值

Python:将维度附加到二维数组

Python-从请求中获取十进制值并将其插入MySQL

python - 如何组合两个标记化的 bert 序列

huggingface-transformers - Rasa 与 HuggingFace 集成管道