python - 使用 tf.train.MonitoredTrainingSession 时如何获取全局步骤

标签 python tensorflow machine-learning save

当我们在Saver.save中指定global_step时,它会将global_step存储为checkpoint后缀。

# save the checkpoint
saver = tf.train.Saver()
saver.save(session, checkpoints_path, global_step)

我们可以像这样恢复检查点并获取存储在检查点中的最后一个全局步骤:

# restore the checkpoint and obtain the global step
saver.restore(session, ckpt.model_checkpoint_path)
...
_, gstep = session.run([optimizer, global_step], feed_dict=feed_dict_train)

如果我们使用tf.train.MonitoredTrainingSession,将全局步骤保存到检查点并获取gstep的等效方法是什么?

编辑1

按照Maxim的建议,我在tf.train.MonitoredTrainingSession之前创建了global_step变量,并添加了一个CheckpointSaverHook,如下所示:

global_step = tf.train.get_or_create_global_step()
save_checkpoint_hook = tf.train.CheckpointSaverHook(checkpoint_dir=checkpoints_abs_path,
                                                    save_steps=5,
                                                    checkpoint_basename=(checkpoints_prefix + ".ckpt"))

with tf.train.MonitoredTrainingSession(master=server.target,
                                       is_chief=is_chief,                     
                                       hooks=[sync_replicas_hook, save_checkpoint_hook],
                                       config=config) as session:

    _, gstep = session.run([optimizer, global_step], feed_dict=feed_dict_train)
    print("current global step=" + str(gstep))

我可以看到它生成的检查点文件类似于 Saver.saver 所做的。但是,它无法从检查点检索全局步骤。请告知我该如何解决这个问题?

最佳答案

可以通过tf.train.get_global_step()获取当前全局步长或通过 tf.train.get_or_create_global_step()功能。后者应在训练开始前调用。

对于受监控的 session ,添加 tf.train.CheckpointSaverHookhooks,它在内部使用定义的全局步张量在每 N 步之后保存模型。

关于python - 使用 tf.train.MonitoredTrainingSession 时如何获取全局步骤,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48338492/

相关文章:

python - matplotlib 弹出窗口中出现错误(AttributeError : 'NoneType' object has no attribute 'set_canvas' )

python - 如何将curl命令行转换为pycurl代码

python - 使用 App Engine Google Cloud 部署 TensorFlow 应用程序时出错

python - Keras 总是预测相同的输出

python - 从 L1 正则化逻辑回归中恢复命名特征

python - Django 在 url 中传递一个字符串

python - Python 中的 scipy kmeans 和 kmeans2 聚类问题

tensorflow - 从 tensorflow.js 神经网络获取权重

python - 如何在 tensorflow 中找到张量和/或运算的输出? (或 "tensorflow op.outputs only pointing to itself")

python - 如何在机器学习中对数值和分类特征使用统一管道?