python - 如何打印出正确的预测类别?

标签 python python-3.x list tensorflow

目前能够得到正确的预测,但是打印的类别字符串是错误的。

如果只有 2 个类别,那么此代码将完美运行,但现在我使用 3 个类别。

CATEGORIES = ["RGB images score 1", "RGB images score 2", "RGB images score 3"]

prediction = model.predict([prepare('score3.png')])

print(prediction[0])  # will be a list in a list.

print(CATEGORIES[int(prediction[0][1])])

输出

[0. 0. 1.]

RGB images score 1

实际输出应为“RGB 图像得分 3”。然而我得到的是“RGB 图像得分 1”。只有 3 张图像存在此问题。

最佳答案

您正在查看预测[0][1],它毫无意义:网络是否预测了2

它确实适用于 2 个类别,但不适用于更多类别!您需要找到预测[0]等于1的索引。

您可以使用例如 print(CATEGORIES[int(np.argmax(prediction[0]))])

关于python - 如何打印出正确的预测类别?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58111456/

相关文章:

python - 如何从 JSON 文件中删除反斜杠

python - 如何根据另一列填充 nan 值

python - 如何在 Python 中仅对列表中的几个值进行排序

python - 对词典列表进行排序和分组

python - 使用 Python 和 Mutagen 为 MP4 文件设置封面时遇到问题

python - 如何将 multiprocessing.Pool 实例传递给 apply_async 回调函数?

python - 给定需要添加的空格数量,通过在单词之间添加空格来格式化字符串

c++ - 从 std::list 中移除对象

python - 在一列内检测行范围内的异常值

Python Beautifulsoup Find_all 除了