python - 具有标准 Tensorflow 的 Tensorflow Lite 模型

标签 python tensorflow tensorflow-lite

我正在开发一个 Tensorflow 简单应用程序(想检测人们是否在捕获的图像中)。

我熟悉 Tensorflow 的 python 接口(interface)。我看到 Tensorflow Lite 有不同的简化格式。

我有兴趣在具有基于 PC 的 GPU 的传统 tensorflow Python 程序中使用 Tensorflow Lite 示例中链接的模型(因为不想花时间创建自己的模型)。

https://www.tensorflow.org/lite/models/image_classification/overview

这可能吗?

当我运行以下代码时,我收到

import tensorflow as tf
def load_pb(path_to_pb):
    with tf.gfile.GFile(path_to_pb, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
     with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph

 load_pb('detect.tflite')

main.py:5:运行时警告:意外的结束组标记:并非所有数据都已转换 graph_def.ParseFromString(f.read())

最佳答案

您可以关注example由 Tensorflow 文档提供。 tflite 模型和标签取自 here 。该代码在普通台式电脑上运行。

import tensorflow as tf
import requests
import io
from PIL import Image
import numpy as np

# load model
interpreter = tf.contrib.lite.Interpreter(model_path="mobilenet_v1_1.0_224_quant.tflite")
interpreter.allocate_tensors()

# get details of model
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# load an image
r = requests.get('https://www.tensorflow.org/lite/models/image_classification/images/dog.png')

# convert the image RGB, see input_details[0].shape
img = Image.open(io.BytesIO(r.content)).convert('RGB')

# resize the image and convert it to a Numpy array
img_data = np.array(img.resize(input_details[0]['shape'][1:3]))

# run the model on the image
interpreter.set_tensor(input_details[0]['index'], [img_data])
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])

# get the labels
with open('labels_mobilenet_quant_v1_224.txt') as f:
    labels = f.readlines()

print(labels[np.argmax(output_data[0])])

West Highland white terrier

enter image description here

关于python - 具有标准 Tensorflow 的 Tensorflow Lite 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57130462/

相关文章:

python - 值错误 : Input 0 of node incompatible with expected float_ref. **

python - 深度排序使用大量 CPU

python - 我的 numpy 数组更紧凑的 __repr__?

python - 在 Pyomo 中,是否可以根据多个表达式编写目标函数或约束?

android - 没有找到 void org.tensorflow.demo.env.ImageUtils 的实现

python - 如何在 tensorflow 中保存训练好的模型?

python - 如何在tensorflow中使用tf.data读取.csv文件?

适用于 TFLite 的 Android Camera X ImageAnalyzer 图像格式

python - 用Python编写音频插件

python - 减小 TFLite 模型大小?