tensorflow - 训练后量化后加载模型时出现问题

标签 tensorflow tensorflow-lite

我已经训练了一个模型并将其转换为 .tflite 模型。我已经使用以下方法完成了训练后量化:

import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
tflite_quant_model = converter.convert()

但是当我尝试在 RaspberryPi 上使用模型进行推理时,出现以下错误

Traceback (most recent call last):
File "tf_lite_test.py", line 8, in <module>
interpreter = tf.lite.Interpreter(model_path="converted_from_h5_model_with_quants.tflite")
File "/home/pi/.local/lib/python3.5/site-packages/tensorflow/lite/python/interpreter.py", line 46, in __init__
model_path))
ValueError: Didn't find op for builtin opcode 'CONV_2D' version '2'
Registration failed.

当我将模型转换为 tflite 而不应用任何训练后量化时,我没有收到任何错误。这是我用来隐藏模型而不应用训练后量化的代码。

import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_quant_model = converter.convert()

这是我的模型:

model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(IMG_SHAPE, IMG_SHAPE, 3)),
tf.keras.layers.MaxPooling2D(2, 2),

tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),

tf.keras.layers.Conv2D(128, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),

tf.keras.layers.Dropout(0.5),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(3, activation='softmax')
])

如何应用训练后量化并加载模型而不会出现此错误?

最佳答案

也许您需要重建 tflite 运行时。这个型号可能太旧了,无法使用。请参阅此处的说明:https://www.tensorflow.org/lite/guide/build_rpi

关于tensorflow - 训练后量化后加载模型时出现问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56722720/

相关文章:

python - 将 Universal Sentence Encoder 保存到 Tflite 或将其提供给 tensorflow api

python - Tensorflow - 训练期间的 Nan 成本值 - 尝试了通常的修复但没有成功

python - Tensorflow:您必须使用 dtype float 为占位符张量 'Placeholder' 提供一个值 [但该值是一个 float ]

python - 等级> 2的Tensorflow matmul操作不起作用

c++ - 如何在没有 ARM 处理器的 Linux 上安装 TensorflowLite C++?

python - 在 Tensorflow-lite 中输入具有动态尺寸的图像

python - 在 Tensorflow 2 中的每个纪元之后计算每个类的召回率

python - 在 CentOS 7 上构建 Tensorflow 时出错

python - 如何在脚本中加载 tflite 模型?

java - 如何在没有 labelmaps.txt 文件的情况下使用 Tensorflow lite 检测对象地标?