tensorflow - "rewind" tensorflow 训练步骤

标签 tensorflow

我偶尔会在 tensorflow 和随机梯度下降训练中遇到问题,我加载了一个小批量,这对我的优化操作造成了严重破坏,将其推向了 Nans。当然,这会在训练过程中引发错误并迫使我重新开始。即使我将优化操作包装在 try 语句中,当引发异常时,损坏已经完成,我需要重新启动。

是否有人有一种好的方法,可以在遇到错误时将优化回退到有效状态?我认为你可以为此使用检查点,但是有关保存/恢复的文档非常参差不齐,我不确定......

最佳答案

正如您所建议的,检查点是实现这一目标的方法。您的案例的关键步骤如下:

定义图表后,首先创建一个保护程序对象:

saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=1)

接下来,在训练期间间歇性地写出检查点:

for step in range(max_steps):

    ... some training steps here

    # Save the model every 100 iterations
    if step % 100 == 0:
        saver.save(sess, checkpoint_dir, global_step=step)

最后,当您捕获错误时,重新加载最后一个好的检查点:

# this next command restores the latest checkpoint or explicitly specify the filename if you want to use some other logic
restore_fn = tf.train.latest_checkpoint(FLAGS.restore_dir)
print('Restoring from %s' % restore_fn)
saver.restore(sess, restore_fn)

关于tensorflow - "rewind" tensorflow 训练步骤,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40289739/

相关文章:

python - 交互式 IPython 终端中连续运行的 Tensorflow Op 命名

c++ - 自写的 tensorflow c++ 代码的 Bazel 测试通过了所有测试,无论预期值如何

python - Adam在keras中是如何实现学习率衰减的

tensorflow - Ubuntu 16.04 上 TensorFlow 的 NVIDIA cuDNN 版本类型

tensorflow - tensorflow 中优化器的 `apply_gradients`和 `minimize`之间的区别

TensorFlow:Dst 张量未初始化

python - 在 tensorflow 中导入图形时使用新操作

python - Keras 符号输入/输出未实现 `__len__` 错误

c++ - Tensorflow C++ YOU_MADE_A_PROGRAMMING_MISTAKE

python - 如何在 Estimator 训练期间动态加载数据集的新部分?