我使用以下代码在训练模型的循环之外创建一个检查点管理器:
checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(object_1=object_1)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=1)
然后在训练模型时,我使用 ckpt_save_path = ckpt_manager.save() 来保存每个时期后的变量。
鉴于我想实现一种早期停止方法,我需要在特定时期后恢复所有变量并使用这些变量进行预测。如果我使用上面的代码保存变量(希望保存过程是正确的?),我怎样才能在纪元 e 之后恢复变量呢?我知道我可以首先创建相同的变量和对象,然后使用以下代码恢复最新的检查点,但不知道如何恢复特定的检查点(如纪元号 e 之后的变量)而不是最新的。
ckpt.restore(ckpt_manager.latest_checkpoint).assert_consumed()
谢谢
最佳答案
是的,您需要生成带有纪元号的文件名文本字符串。
c_manager = tf.train.CheckpointManager(checkpoint, ...)
if EPOCH == '':
if c_manager.latest_checkpoint:
tf.print("-----------Restoring from {}-----------".format(
c_manager.latest_checkpoint))
checkpoint.restore(c_manager.latest_checkpoint)
EPOCH = c_manager.latest_checkpoint.split(sep='ckpt-')[-1]
else:
tf.print("-----------Initializing from scratch-----------")
else:
checkpoint_fname = CHECKPOINT_SAVE_DIR + 'ckpt-' + str(EPOCH)
tf.print("-----------Restoring from {}-----------".format(checkpoint_fname))
checkpoint.restore(checkpoint_fname)
关于python - 如何恢复tensorflow2中的特定检查点(以实现提前停止)?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62919208/