python - 如何从 tf.estimator.Estimator 获取最后一个 global_step

标签 python tensorflow tensorflow-estimator

train(...) 完成后,如何从 tf.estimator.Estimator 获取最后一个 global_step ?例如,典型的基于估算器的训练例程可能如下设置: n_epochs = 10 model_dir = '/path/to/model_dir'

def model_fn(features, labels, mode, params):
    # some code to build the model
    pass

def input_fn():
    ds = tf.data.Dataset()  # obviously with specifying a data source
    # manipulate the dataset
    return ds

run_config = tf.estimator.RunConfig(model_dir=model_dir)
estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)

for epoch in range(n_epochs):
    estimator.train(input_fn=input_fn)
    # Now I want to do something which requires to know the last global step, how to get it?
    my_custom_eval_method(global_step)

evaluate()方法返回一个包含 global_step 作为字段的字典。如果由于某种原因我不能或不想使用此方法,如何获取 global_step

最佳答案

只需在训练循环之前创建一个钩子(Hook):

class GlobalStepHook(tf.train.SessionRunHook):
    def __init__(self):
        self._global_step_tensor = None
        self.value = None

    def begin(self):
        self._global_step_tensor = tf.train.get_global_step()

    def after_run(self, run_context, run_values):
        self.value = run_context.session.run(self._global_step_tensor)

    def __str__(self):
        return str(self.value)

global_step = GlobalStepHook()
for epoch in range(n_epochs):
    estimator.train(input_fn=input_fn, hooks=[global_step])
    # Now the global_step hook contains the latest value of global_step
    my_custom_eval_method(global_step.value)

关于python - 如何从 tf.estimator.Estimator 获取最后一个 global_step,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51325660/

相关文章:

python - 从列表中调用一个类

python - Django序列化器不将数据保存到数据库但响应正常

python - 'obj'参数是如何传递给Django Rest Framework中的Permission类函数的?

javascript - TensorflowJS:计算多个张量之间的距离或相似度的最佳方法?

python - Pandas 系列中的特殊字符串格式

tensorflow - 什么是 TensorFlow 中的动态 RNN?

python - 为什么我的示例和标签的顺序错误?

python - 当使用 tf-tutorials 运行时,发生了 :AttributeError: module 'tensorflow.python.estimator.api.estimator' has no attribute 'SessionRunHook'

python - 使用 AdamOptimizer 继续训练自定义 tf.Estimator

python - tensorflow 估计器 : predict without loading from checkpoint everytime