python-2.7 - 在 tensorflow 中保存检查点并恢复训练

标签 python-2.7 tensorflow

我正在玩保存检查点和从保存的检查点恢复训练。我正在按照 - https://www.tensorflow.org/versions/r0.8/api_docs/python/train.html#import_meta_graph 中给出的示例进行操作 为了简单起见,我没有使用任何“真实”网络训练。我只是执行了一个简单的减法操作,每个检查点一次又一次地在相同的张量上保存相同的操作。 以下 ipython 笔记本的形式提供了一个最小示例 - https://gist.github.com/dasabir/29b8f84c6e5e817a72ce06584e988f10

在第一阶段,我运行循环 100 次(通过在代码中设置变量“endIter = 100”的值)并每 10 次迭代保存一次检查点。因此,保存的检查点编号为 - 9、19、...、99。现在,当我将“enditer”值更改为 200 并恢复训练时,检查点再次从 9、19 开始保存, ...(不是 109, 119, 129, ...)。有什么技巧是我想念的吗?

最佳答案

你能打印出'latest_ckpt',看看它是否指向最新的ckpt文件吗?此外,您需要使用 tf.variable 维护 global_step:

global_step = tf.Variable(0, name='global_step', trainable=False)
...
ckpt = tf.train.get_checkpoint_state(ckpt_dir)
if ckpt and ckpt.model_checkpoint_path:
    print ckpt.model_checkpoint_path
    saver.restore(sess, ckpt.model_checkpoint_path) # restore all variables
start = global_step.eval() # get last global_step
print "Start from:", start

for i in range(start, 100):
...
    global_step.assign(i).eval() # set and update(eval) global_step with index, i
    saver.save(sess, ckpt_dir + "/model.ckpt", global_step=global_step)

你可以看看完整的例子:

https://github.com/nlintz/TensorFlow-Tutorials/pull/32/files

关于python-2.7 - 在 tensorflow 中保存检查点并恢复训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36731093/

相关文章:

python - HTTP错误 : HTTP Error 403: Forbidden

python - 在 App Engine 上轻松地从 Python 2.5 迁移到 2.7

tensorflow - 如何在jupyter笔记本中多次拟合/运行神经网络?

google-app-engine - 谷歌云平台和谷歌机器学习

使用transform_graph量化后Tensorflow SSD-Mobilenet模型精度下降

python - Ubuntu服务器上的scrapy

python - 如何在 Digital Ocean 上安装 python 2.7 mysql

javascript - 如何使用 tensorflowjs_converter 转换 TensorFlow 图形 .pb 文件?

python - 如何在 tf.keras 训练过程中获取当前 epoch 的进度?

python - Scikit Learn RandomForest 内存错误