python - 在 Tensorflow-lite 中输入具有动态尺寸的图像

标签 python tensorflow keras tensorflow-lite

我有一个 tensorflow 模型,它接受不同大小的输入图像:

inputs = layers.Input(shape=(128,None,1), name='x_input')

<tf.Tensor 'x_input:0' shape=(?, 128, ?, 1) dtype=float32>

当我将此模型转换为 tensorflow-lite 时,它​​会提示:
converter = tf.lite.TFLiteConverter.from_frozen_graph(
  graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert() 

ValueError: None is only supported in the 1st dimension.
Tensor 'x_input_1' has invalid shape '[None, 128, None, 1]'.

我无法将图像缩放到固定大小。我看到的唯一解决方案是将图像填充到某个最大尺寸并在图中使用该图像,但这似乎非常浪费。有没有其他方法可以使 tensorflow-lite 与动态图像尺寸一起工作?这种限制有什么理由吗?谢谢。

最佳答案

是的,你可以使用动态张量在 TF-精简版中。不能直接将shape设置为[None, 128, None, 1]的原因是因为这样,以后可以轻松支持更多的语言。此外,它充分利用了静态内存分配方案。对于旨在用于具有低计算能力的小型设备的框架,这是一个明智的设计选择。
以下是如何动态设置张量大小的步骤:

0. 卡住

似乎您正在从卡住的 GraphDef 转换,即 *.pb文件。假设您的卡住模型具有输入形状 [None, 128, None, 1] .

1.转换步骤。

在此步骤中,将输入大小设置为 任何有效的可以被您的模型接受。例如:

tflite_convert \
  --graph_def_file='model.pb' \
  --output_file='model.tflite' \
  --input_shapes=1,128,80,1 \     # <-- here, you set an
                                  #     arbitrary valid shape
  --input_arrays='input' \         
  --output_arrays='Softmax'

2.推理步骤

诀窍是使用函数interpreter::resize_tensor_input(...)在推理期间实时使用 TF-Lite API。我将提供它的python实现。 Java 和 C++ 实现应该相同(因为它们具有相似的 API):

from tensorflow.contrib.lite.python import interpreter

# Load the *.tflite model and get input details
model = Interpreter(model_path='model.tflite')
input_details = model.get_input_details()

# Your network currently has an input shape (1, 128, 80 , 1),
# but suppose you need the input size to be (2, 128, 200, 1).
model.resize_tensor_input(
    input_details[0]['index'], (2, 128, 200, 1))
model.allocate_tensors()

就是这样。您现在可以将该模型用于形状为 (2, 128, 200, 1) 的图像。 ,只要您的网络架构允许这样的输入形状。请注意,您将不得不做 model.allocate_tensors()每次你做这样的 reshape ,所以它会非常低效。是强烈推荐避免在您的程序中过多地使用此功能。

关于python - 在 Tensorflow-lite 中输入具有动态尺寸的图像,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55701663/

相关文章:

python - 针对住房回归示例关闭 Cloud ML 预测

python - Keras 损失一直很低,但准确性开始很高然后下降

keras - 骰子系数大于 1

java - 创建一个查找并匹配用户输入的网页爬虫

Python/Mongoengine - 保存到数据库时缺少时区?

python - 在 while 循环中使用 datetime.time()

python - 将相同的权重加载到新图中的多个变量

tensorflow - Keras image_dataset_from_directory 未找到图像

tensorflow - 在 go 中使用 tensorflow hub

python - 我怎么知道 Keras 模型中是否加载了权重?