python - 如何恢复tensorflow2中的特定检查点(以实现提前停止)?

标签 python tensorflow tensorflow2.0 checkpointing

我使用以下代码在训练模型的循环之外创建一个检查点管理器:

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(object_1=object_1)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=1)

然后在训练模型时,我使用 ckpt_save_path = ckpt_manager.save() 来保存每个时期后的变量。

鉴于我想实现一种早期停止方法,我需要在特定时期后恢复所有变量并使用这些变量进行预测。如果我使用上面的代码保存变量(希望保存过程是正确的?),我怎样才能在纪元 e 之后恢复变量呢?我知道我可以首先创建相同的变量和对象,然后使用以下代码恢复最新的检查点,但不知道如何恢复特定的检查点(如纪元号 e 之后的变量)而不是最新的。

ckpt.restore(ckpt_manager.latest_checkpoint).assert_consumed()

谢谢

最佳答案

是的,您需要生成带有纪元号的文件名文本字符串。

c_manager = tf.train.CheckpointManager(checkpoint, ...)

if EPOCH == '':
    if c_manager.latest_checkpoint:
        tf.print("-----------Restoring from {}-----------".format(
            c_manager.latest_checkpoint))
        checkpoint.restore(c_manager.latest_checkpoint)
        EPOCH = c_manager.latest_checkpoint.split(sep='ckpt-')[-1]
    else:
        tf.print("-----------Initializing from scratch-----------")
else:    
    checkpoint_fname = CHECKPOINT_SAVE_DIR + 'ckpt-' + str(EPOCH)
    tf.print("-----------Restoring from {}-----------".format(checkpoint_fname))
    checkpoint.restore(checkpoint_fname)

关于python - 如何恢复tensorflow2中的特定检查点(以实现提前停止)?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62919208/

相关文章:

python - 如何在python3中对字典元素进行分组

python - TFRecordReader 似乎非常慢,多线程读取不工作

tensorflow - 如何评估 tensorflow 估计器每个时期的测试数据集

python-3.x - 断言失败 : predictions must be >= 0, 条件 x >= y 不保持元素

python - 只能在安装目录中导入 python 模块

python - python pandas 数据框中的定数值积分

python - 使用多个键从字典中获取元素

docker - 我需要在没有官方tensorflow镜像的Docker上使用tensorflow吗?

python - TensorFlow - 返回多维张量的不同子张量

pycharm - PyCharm中使用tensorflow2.0时出现错误 'cannot find reference'