python - 在 Keras 中加载保存的模型(双向 LSTM)

标签 python machine-learning neural-network keras lstm

我在 Keras 中成功训练并保存了双向 LSTM 模型:

model = Sequential()
model.add(Bidirectional(LSTM(N_HIDDEN_NEURONS,
                        return_sequences=True,
                        activation="tanh",
                        input_shape=(SEGMENT_TIME_SIZE, N_FEATURES))))
model.add(Bidirectional(LSTM(N_HIDDEN_NEURONS)))
model.add(Dropout(0.5))
model.add(Dense(N_CLASSES, activation='sigmoid'))
model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])

model.fit(X_train, y_train,
          batch_size=BATCH_SIZE,
          epochs=N_EPOCHS,
          validation_data=[X_test, y_test])

model.save('model_keras/model.h5')

但是,当我想加载它时:

model = load_model('model_keras/model.h5')

我收到错误:

ValueError: You are trying to load a weight file containing 3 layers into a model with 0 layers.

我还尝试了不同的方法,例如分别保存和加载模型架构和权重,但它们都不适合我。另外,之前,当我使用普通(单向)LSTM 时,加载模型效果很好。

最佳答案

正如@mpariente所述和 @todayinput_shape 是双向的参数,而不是 LSTM,请参阅 Keras documentation 。我的解决方案:

# Model
model = Sequential()
model.add(Bidirectional(LSTM(N_HIDDEN_NEURONS,
                             return_sequences=True,
                             activation="tanh"), 
                        input_shape=(SEGMENT_TIME_SIZE, N_FEATURES)))
model.add(Bidirectional(LSTM(N_HIDDEN_NEURONS)))
model.add(Dropout(0.5))
model.add(Dense(N_CLASSES, activation='sigmoid'))
model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])

model.fit(X_train, y_train,
          batch_size=BATCH_SIZE,
          epochs=N_EPOCHS,
          validation_data=[X_test, y_test])

model.save('model_keras/model.h5')

然后,要加载,只需执行以下操作:

model = load_model('model_keras/model.h5')

关于python - 在 Keras 中加载保存的模型(双向 LSTM),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51236338/

相关文章:

python - 钩子(Hook)类型转换为 Python 中的字典

machine-learning - 我可以反转 ROC 曲线来绘制真阴性与假阴性率吗?

neural-network - 如何判断我的 self 对弈神经网络过度拟合

pandas - 平均特征后学习算法的准确性下降

validation - 验证集是否用于更新神经网络?

python - paramiko SFTP 挂起

Python-Seaborn : Modifying the heatmap legend

python - 如何从另一个模块更改类对象参数(python)

javascript - 是否可以在 ml5.js 中隐藏视频但保留手部姿势点?

python - 如何使用 python、sklearn 预测未知 X 值的多维时间序列