python - Tensorflow Keras 无法从检查点文件正确恢复初始时期的训练

标签 python tensorflow keras callback checkpoint

我正在 tensorflow 中加载 keras 模型以恢复训练。我想从我停止的纪元开始继续训练,以便纪元编号是唯一的,并且我可以跟踪纪元数量。该模型是从回调创建的检查点文件加载的,该文件保存了最高的准确性。当我在 model.fit() 中恢复训练时,我将“初始纪元”设置为 52,并将“纪元”设置为 52+5。然而,它从 epoch 1/57 而不是 53/57 开始训练,并将继续训练到 57,即使我只想要 5 个 epoch。我是否加载了错误的东西?训练恢复为“正常”,准确度是我上次中断的位置,但纪元编号不会从我想要的位置继续,而是继续从 1 重新开始。

我尝试在加载检查点文件时删除检查点回调初始化,但是由于未定义“回调列表”,因此会生成名称错误。

model = load_model('my_model.hdf5')
checkpoint = ModelCheckpoint(cp_filepath, monitor='acc', 
verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]

bs=32 #batch size
epoch count=52
cur_epochs=5
model.fit(
    training_set,
    steps_per_epoch=len(training_set)//bs,
    inital_epoch=epoch_count,
    epochs=cur_epochs+epoch_count,
    validation_data=test_set,
    validation_steps=len(test_set)//bs,
    callbacks=callbacks_list, 
    shuffle=True,
    verbose=1
    )

当从保存的文件恢复时,我期望看到 epoch 53/57 和 5 epoch 的训练。 我得到了 1/57 epoch 和 57 epoch 的训练

最佳答案

有同样的问题, 为了解决这个问题,我修改了ModelCheckpoint(回调)类。

我在 on_epoch_begin 回调函数中添加并保存了 epoch 的专用 tensorflow 检查点。

The network doesn't store its training progress with respect to training data - this is not part of its state, because at any point you could decide to change what data set to feed it.

class EpochModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):

    def __init__(self,filepath, monitor='val_loss', verbose=1, 
                 save_best_only=True, save_weights_only=True, 
                 mode='auto', ):

        super(EpochModelCheckpoint, self).__init__(filepath=filepath,monitor=monitor,
             verbose=verbose,save_best_only=save_best_only,
             save_weights_only=save_weights_only, mode=mode)

        self.ckpt = tf.train.Checkpoint(completed_epochs=tf.Variable(0,trainable=False,dtype='int32'))
        ckpt_dir = f'{os.path.dirname(filepath)}/tf_ckpts'
        self.manager = tf.train.CheckpointManager(self.ckpt, ckpt_dir, max_to_keep=3)

    def on_epoch_begin(self,epoch,logs=None):        
        self.ckpt.completed_epochs.assign(epoch)
        self.manager.save()
        print( f"Epoch checkpoint {self.ckpt.completed_epochs.numpy()}  saved to: {self.manager.latest_checkpoint}" ) 
        print(logs)

def callbacks(checkpoint_dir, model_name):

    best_model = os.path.join(checkpoint_dir, '{}_best.hdf5'.format(model_name))
    save_best = EpochModelCheckpoint( best_model  )
    return [ save_best ]

def train():

    ...

    model = get_compiled_model()
    checkpoint_dir = "./checkpoint_dir"
    model_name = "my_model"
    # Init checkpoint value
    ckpt = tf.train.Checkpoint(completed_epochs=tf.Variable(0,trainable=False,dtype='int32'))
    manager = tf.train.CheckpointManager(ckpt, f'{checkpoint_dir}/tf_ckpts', max_to_keep=3)    

    best_weights = os.path.join(checkpoint_dir, f'{model_name}_best.hdf5') 
    if os.path.exists(best_weights):
        print(f'Loading model {best_weights}')
        model.load_weights(best_weights)

        # Restore last Epoch
        ckpt.restore(manager.latest_checkpoint)
        if manager.latest_checkpoint:
            print(f"Restored epoch ckpt from {manager.latest_checkpoint}, value is ",ckpt.completed_epochs.numpy())
        else:
            print("Initializing from scratch.")

     ...
    # Set initial_epoch in the model fit to last seen Epoch
    completed_epochs=ckpt.completed_epochs.numpy()
    history = model.fit(
        x=train_ds,
        epochs=cfg.epochs,
        steps_per_epoch=cfg.steps,
        callbacks=callbacks( checkpoint_dir, model_name ),        
        validation_data=val_ds,
        validation_steps=cfg.val_steps,
        initial_epoch=completed_epochs )

关于python - Tensorflow Keras 无法从检查点文件正确恢复初始时期的训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56230510/

相关文章:

python - 如何在Python中创建嵌套的生成器结构?

python - Docker jupyter notebook 使用容器 id 作为 ip,localhost

python - 按已知子字符串拆分字符串

python - 模块 'tensorflow' 没有属性 'log'

python - 训练时调整 Tensorflow nce_loss 中的样本数

javascript - python的 Mechanize 和形式: javascript string returned

tensorflow - 如何使用复值权重进行反向传播

python - ValueError : Error when checking input: expected dense_1_input to have shape (24, )但得到形状为(1,)的数组

python - 使用双向包装器时,如何在 LSTM 层中同时获得最终隐藏状态和序列

python - Keras/Tensorflow - 类型错误 : __init__() got an unexpected keyword argument 'rescale'