python - 加载带有标签的 Tensorflow 模型

标签 python tensorflow

在本教程之后,我使用 model.save('model') 存储了一个模型:
https://towardsdatascience.com/keras-transfer-learning-for-beginners-6c9b8b7143e
标签取自目录本身。
现在我想加载它并使用以下代码对图像进行预测:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras.preprocessing import image

new_model = keras.models.load_model('./model/')

# Check its architecture
new_model.summary()

with image.load_img('testpics/mypic.jpg') as img: # , target_size=(32,32)) 
    img  = image.img_to_array(img)
    img  = img.reshape((1,) + img.shape)
    # img  = img/255
    # img = img.reshape(-1,784)
    img_class=new_model.predict(img) 
    prediction = img_class[0]
    classname = img_class[0]
    print("Class: ",classname)
可悲的是,输出只是

Class: [1.3706615e-03 2.9885881e-03 1.6783881e-03 3.0293325e-03 2.9168031e-03 7.2344812e-04 2.0196944e-06 2.0119224e-02 2.2996603e-04 1.1960276e-05 3.0794670e-04 6.0808496e-05 1.4892215e-05 1.5410941e-02 1.2452166e-04 8.2580920e-09 2.4049083e-02 3.1140331e-05 7.4609083e-01 1.5793210e-01 2.4283256e-03 1.5755130e-04 2.4227127e-03 2.2325735e-07 7.2101393e-06 7.6298704e-03 2.0922457e-04 1.2269774e-03 5.5882465e-06 2.4516811e-04 8.5745640e-03]


而且我不知道如何重新加载标签......有人可以帮我吗:/?
Here is a screenshot of my saved files

最佳答案

该模型不包含标签名称。因此无法以这种方式检索它。您必须在训练时保存标签,然后才能在预测阶段加载和使用它们。
我使用 pickle 将标签作为序列化数组存储在文件中。然后您可以加载它们并使用预测的 argmax 作为数组索引。
下面是训练阶段:

CLASS_NAMES = ['ClassA', 'ClassB'] # should be dynamic
f = open('labels.pickle', "wb")
f.write(pickle.dumps(CLASS_NAMES))
f.close()
在预测中:
CLASS_NAMES = pickle.loads(open('labels.pickle', "rb").read())
predictions = model.predict(predict_image)
result = CLASS_NAMES[predictions.argmax(axis=1)[0]]

关于python - 加载带有标签的 Tensorflow 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62757103/

相关文章:

tensorflow - Keras 序列模型到 Tensorflow EstimatorSpec 的准确性下降

python - 创建一个内容大于 2GB 的张量原型(prototype)

python - 如何禁用悬停图坐标的科学记数法

python - Django 表单,ModelMultipleChoiceField 上显示错误

tensorflow - 如何在 tensorflow 2.0 中使用层列表?

c++ - 未能 Bazel 以 tensorflow 作为依赖项构建 C++ 项目

python - Tensorflow:尝试迁移学习时出错:无效的 JPEG 数据或裁剪窗口

python - django 重定向错误。似乎没有任何效果

python - 读取文件的最后 n 行(尾部)而不逐行读取?

python - 比较日期时间的异常值