python - 使用 TF 2.0 将 saving_model 转换为 TFLite 模型

标签 python tensorflow tensorflow2.0 tensorflow-lite

目前我正在致力于将自定义对象检测模型(使用 SSD 和初始网络进行训练)转换为量化的 TFLite 模型。我可以使用以下代码片段(使用 Tensorflow 1.4)将自定义对象检测模型从卡住图转换为量化 TFLite 模型:

converter = tf.lite.TFLiteConverter.from_frozen_graph(args["model"],input_shapes = {'normalized_input_image_tensor':[1,300,300,3]},
input_arrays = ['normalized_input_image_tensor'],output_arrays = ['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1',
'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'])

converter.allow_custom_ops=True
converter.post_training_quantize=True 
tflite_model = converter.convert()
open(args["output"], "wb").write(tflite_model)

但是,tf.lite.TFLiteConverter.from_frozen_graph 类方法不适用于 Tensorflow 2.0 ( refer this link )。所以我尝试使用 tf.lite.TFLiteConverter.from_saved_model 类方法转换模型。代码片段如下所示:

converter = tf.lite.TFLiteConverter.from_saved_model("/content/") # Path to saved_model directory
converter.optimizations =  [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

上面的代码片段抛出以下错误:

ValueError: None is only supported in the 1st dimension. Tensor 'image_tensor' has invalid shape '[None, None, None, 3]'.

我尝试将 input_shapes 作为参数传递

converter = tf.lite.TFLiteConverter.from_saved_model("/content/",input_shapes={"image_tensor": [1,300,300,3]})

但它抛出以下错误:

TypeError: from_saved_model() got an unexpected keyword argument 'input_shapes'

我错过了什么吗?请大家不吝指正!

最佳答案

我使用tf.compat.v1.lite.TFLiteConverter.from_frozen_graph得到了解决方案。此 compat.v1TF1.x 的功能引入 TF2.x。 完整代码如下:

converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph("/content/tflite_graph.pb",input_shapes = {'normalized_input_image_tensor':[1,300,300,3]},
    input_arrays = ['normalized_input_image_tensor'],output_arrays = ['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1',
    'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'])

converter.allow_custom_ops=True

# Convert the model to quantized TFLite model.
converter.optimizations =  [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()


# Write a model using the following line
open("/content/uno_mobilenetV2.tflite", "wb").write(tflite_model)

关于python - 使用 TF 2.0 将 saving_model 转换为 TFLite 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59679645/

相关文章:

python - 查看 celery 任务是否存在

python - 将多个列表转换为单个字典

python - 为什么在tensorflow中构建resnet模型时使用fixed padding

tensorflow2.0 - 使用 tf.data 与联合核心 API 和远程执行客户端上的数据进行线性回归

python - Tensorflow - 训练期间的 Nan 成本值 - 尝试了通常的修复但没有成功

python - NotFoundError : No registered 'PyFunc' OpKernel for 'CPU' devices compatible with node {{node PyFunc}} . 已注册:<没有注册的内核>

python - 创建新表(模型)并将其与 Django 中的 auth_user 相关联

python - pymssql 与 pyodbc 与 adodbapi 与...

python - 如何根据带索引的张量过滤tensorflow的Tensor?

machine-learning - tensorflow.python.framework.errors_impl.InvalidArgumentError : Negative dimension size caused by subtracting 2 from 1 for 'max_pooling2d_1/MaxPool'