python - 如何在 Keras 中从 HDF5 文件加载模型?

标签 python machine-learning keras data-science

如何在 Keras 中从 HDF5 文件加载模型?

我尝试了什么:

model = Sequential()

model.add(Dense(64, input_dim=14, init='uniform'))
model.add(LeakyReLU(alpha=0.3))
model.add(BatchNormalization(epsilon=1e-06, mode=0, momentum=0.9, weights=None))
model.add(Dropout(0.5))

model.add(Dense(64, init='uniform'))
model.add(LeakyReLU(alpha=0.3))
model.add(BatchNormalization(epsilon=1e-06, mode=0, momentum=0.9, weights=None))
model.add(Dropout(0.5))

model.add(Dense(2, init='uniform'))
model.add(Activation('softmax'))


sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='binary_crossentropy', optimizer=sgd)

checkpointer = ModelCheckpoint(filepath="/weights.hdf5", verbose=1, save_best_only=True)
model.fit(X_train, y_train, nb_epoch=20, batch_size=16, show_accuracy=True, validation_split=0.2, verbose = 2, callbacks=[checkpointer])

以上代码成功地将最佳模型保存到名为 weights.hdf5 的文件中。然后我想做的是加载该模型。下面的代码显示了我是如何尝试这样做的:

model2 = Sequential()
model2.load_weights("/Users/Desktop/SquareSpace/weights.hdf5")

这是我得到的错误:

IndexError                                Traceback (most recent call last)
<ipython-input-101-ec968f9e95c5> in <module>()
      1 model2 = Sequential()
----> 2 model2.load_weights("/Users/Desktop/SquareSpace/weights.hdf5")

/Applications/anaconda/lib/python2.7/site-packages/keras/models.pyc in load_weights(self, filepath)
    582             g = f['layer_{}'.format(k)]
    583             weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])]
--> 584             self.layers[k].set_weights(weights)
    585         f.close()
    586 

IndexError: list index out of range

最佳答案

如果你在 HDF5 文件中存储了完整的模型,而不仅仅是权重,那么它就像

from keras.models import load_model
model = load_model('model.h5')

关于python - 如何在 Keras 中从 HDF5 文件加载模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/35074549/

相关文章:

python - django 休息细节_路线测试

machine-learning - 用于生成的神经网络?

machine-learning - 反转 dropout 如何补偿 dropout 的影响并保持期望值不变?

java - 如何在Java中创建一个包含带引号的字符串的字符串?

python - 输入层的 TensorFlow Keras 维度误差

machine-learning - 线性堆叠层等于多元线性回归吗?

python - Peewee 按需添加列

python - 将每列的最后 N 个正值提取到新的数据框中

python - 使用应用程序的 sql 文件夹内的 sql 文件将初始数据提供到 sql 表中在 django1.9 中不起作用

python - Keras `fit_generator` 验证准确度低,但 `fit` 验证准确度低