tensorflow - 使用 TensorFlow v2.2 将 Keras .h5 模型转换为 TFLite .tflite

标签 tensorflow keras neural-network tensorflow-lite

我正在尝试将我使用 Keras 定义的网络转换为 tflite。网络如下:

model = tf.keras.Sequential([
        # Embedding
        tf.keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[BATCH_SIZE, None]),
        # GRU unit
        tf.keras.layers.GRU(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
        # Fully connected layer
        tf.keras.layers.Dense(vocab_size)
        ])

但是,当我尝试导出到 .tflite 时,似乎由于存在 GRU 层而出现问题。

# Save trained model in .h5 format
keras_file = 'inference.h5'
tf.keras.models.save_model(model, keras_file)

# Load .h5 model with custom loss function
model = load_model('inference.h5', custom_objects={'loss': loss})

# Converting a tf.Keras model to a TensorFlow Lite model.
converter    = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

有错误:

ValueError: Input 0 of node sequential_8/gru_2/AssignVariableOp was passed float from sequential_8/gru_2/68029:0 incompatible with expected resource.

有解决这个问题的方法吗?

最佳答案

现在不支持stateful,你可以试试set stateful=False吗?

关于tensorflow - 使用 TensorFlow v2.2 将 Keras .h5 模型转换为 TFLite .tflite,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61975032/

相关文章:

python - Keras 自定义损失函数内的 TensorFlow session

python - 在多个 h5 文件上训练 ANN Keras(基于 Tensorflow)模型

tensorflow - 网格的一部分作为 cnn 的输入

python - CNN OCR 机器可读区

tensorflow - 跳过 TFRecordDataset.map() 中的数据集条目

python - 将列表输入到 TensorFlow 中的 feed_dict 的问题

python - reshape np 数组以进行深度学习

neural-network - Keras 中 add_loss 函数的目的是什么?

python - 如何解释 keras "predict_generator "输出?

python - 从目录中读取 pandas 中的多个文件 csv 并将它们存储在列表数组中,每个文件作为一个观察