目前能够得到正确的预测,但是打印的类别字符串是错误的。
如果只有 2 个类别,那么此代码将完美运行,但现在我使用 3 个类别。
CATEGORIES = ["RGB images score 1", "RGB images score 2", "RGB images score 3"]
prediction = model.predict([prepare('score3.png')])
print(prediction[0]) # will be a list in a list.
print(CATEGORIES[int(prediction[0][1])])
输出
[0. 0. 1.]
RGB images score 1
实际输出应为“RGB 图像得分 3”。然而我得到的是“RGB 图像得分 1”。只有 3 张图像存在此问题。
最佳答案
您正在查看预测[0][1]
,它毫无意义:网络是否预测了2
。
它确实适用于 2 个类别,但不适用于更多类别!您需要找到预测[0]等于1的索引。
您可以使用例如 print(CATEGORIES[int(np.argmax(prediction[0]))])
关于python - 如何打印出正确的预测类别?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58111456/