我正在尝试将我使用 Keras 定义的网络转换为 tflite。网络如下:
model = tf.keras.Sequential([
# Embedding
tf.keras.layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[BATCH_SIZE, None]),
# GRU unit
tf.keras.layers.GRU(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
# Fully connected layer
tf.keras.layers.Dense(vocab_size)
])
但是,当我尝试导出到 .tflite 时,似乎由于存在 GRU 层而出现问题。
# Save trained model in .h5 format
keras_file = 'inference.h5'
tf.keras.models.save_model(model, keras_file)
# Load .h5 model with custom loss function
model = load_model('inference.h5', custom_objects={'loss': loss})
# Converting a tf.Keras model to a TensorFlow Lite model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
有错误:
ValueError: Input 0 of node sequential_8/gru_2/AssignVariableOp was passed float from sequential_8/gru_2/68029:0 incompatible with expected resource.
有解决这个问题的方法吗?
最佳答案
现在不支持stateful,你可以试试set stateful=False
吗?
关于tensorflow - 使用 TensorFlow v2.2 将 Keras .h5 模型转换为 TFLite .tflite,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61975032/