我偶尔会在 tensorflow 和随机梯度下降训练中遇到问题,我加载了一个小批量,这对我的优化操作造成了严重破坏,将其推向了 Nans。当然,这会在训练过程中引发错误并迫使我重新开始。即使我将优化操作包装在 try 语句中,当引发异常时,损坏已经完成,我需要重新启动。
是否有人有一种好的方法,可以在遇到错误时将优化回退到有效状态?我认为你可以为此使用检查点,但是有关保存/恢复的文档非常参差不齐,我不确定......
最佳答案
正如您所建议的,检查点是实现这一目标的方法。您的案例的关键步骤如下:
定义图表后,首先创建一个保护程序对象:
saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)
接下来,在训练期间间歇性地写出检查点:
for step in range(max_steps):
... some training steps here
# Save the model every 100 iterations
if step % 100 == 0:
saver.save(sess, checkpoint_dir, global_step=step)
最后,当您捕获错误时,重新加载最后一个好的检查点:
# this next command restores the latest checkpoint or explicitly specify the filename if you want to use some other logic
restore_fn = tf.train.latest_checkpoint(FLAGS.restore_dir)
print('Restoring from %s' % restore_fn)
saver.restore(sess, restore_fn)
关于tensorflow - "rewind" tensorflow 训练步骤,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40289739/