machine-learning - 如何使用Keras的ModelCheckpoint继续训练模型

标签 machine-learning callback deep-learning keras checkpoint

我是 Keras 的新用户。我对使用 Keras 的训练过程有疑问。

由于我的服务器的时间限制(每个作业只能在 24 小时以内运行),我必须使用多个 10 epoch 周期来训练我的模型。

在第一个训练周期,10 个 epoch 后,使用 Keras 的 ModelCheckpoint 存储最佳模型的权重。

conf = dict()
conf['nb_epoch'] = 10
callbacks = [
             ModelCheckpoint(filepath='/1st_{epoch:d}_{val_loss:.5f}.hdf5',
             monitor='val_loss', save_best_only=True,
             save_weights_only=False, verbose=0)
            ]   

假设我获得最佳模型:“1st_10_1.00000.hdf5”。接下来,我继续使用 10 个时期训练我的模型,并存储最佳模型的权重,如下所示。

model.load_weights('1st_10_1.00000.hdf5')
model.compile(...)
callbacks = [
             ModelCheckpoint(filepath='/2nd_{epoch:d}_{val_loss:.5f}.hdf5',
             monitor='val_loss', save_best_only=True,
             save_weights_only=False, verbose=0)
            ]

但我有一个问题。第二次训练的第一个时期的 val_loss 为 1.20000,脚本生成模型“2nd_1_1.20000.hdf5”。显然,新的val_loss大于第一次训练的最佳val_loss(1.00000)。第二次训练的后续纪元似乎是基于模型“2nd_1_1.20000.hdf5”而不是“1st_10_1.00000.hdf5”进行训练。

'2nd_1_1.20000.hdf5'
'2nd_1_2.15000.hdf5'
'2nd_1_3.10000.hdf5'
'2nd_1_4.05000.hdf5'
...

我认为不使用第一个训练期的更好结果是一种浪费。任何人都可以指出修复它的方法,或者告诉程序它应该使用之前训练期间的最佳模型的方法?提前谢谢了!

最佳答案

有趣的案例,可能是一个很大的改进...除了创建自己的回调函数之外,我认为 API 目前不支持这样的解决方案。

我认为这不会那么难。您可以将其基于原始的 modelcheckpoint 回调类并进行更改。

这一行: https://github.com/fchollet/keras/blob/master/keras/callbacks.py#L390

它存储 logget 项目的当前最佳值,根据情况在 if 语句中将其初始化为 -inf/inf。

在您的情况下,您必须找到一种方法来读取文件的文件名,进行一些字符串操作,然后添加它。

我建议将其添加为单独的语句..或作为 else if

避免过多地干扰核心代码。

希望有帮助..

关于machine-learning - 如何使用Keras的ModelCheckpoint继续训练模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43546219/

相关文章:

machine-learning - 如何使用全卷积网络在 Cifar10 上获得最先进的结果?

node.js - 使用异步和请求模块限制请求

machine-learning - `caffe' : malloc(): memory corruption when snapshotting to disk

machine-learning - 为什么对于同一问题,binary_crossentropy 和 categorical_crossentropy 给出不同的性能?

machine-learning - 如何标准化数据以输入位于训练数据范围之外的神经网络?

python - 将 pandas 数据帧传递到 FastAPI 以进行 NLP ML

R 中的相对频率

javascript - 我的回调应该放在哪里? (Framework7 Swiper)

callback - 如何在 Dojo 中的另一个函数(非 AJAX)完成后调用一个函数(非 AJAX)?

tensorflow - 如何使用 estimator.export_savemodel() 保存 TensorFlow 模型