tensorflow - tf.estimator.Estimator.train() 是否保持 input_fn 状态

标签 tensorflow

一年多来我一直在使用自己的 Estimator/Experiment 之类的代码,但我最终想加入 Dataset+Estimator 的行列。

我想做如下的事情:

for _ in range(N):
  estimator.train(train_input_fn, steps=1000)
  estimator.evaluate(validation_input_fn)

其中 train_input_fn 创建一个 tf.data.Dataset 永远循环训练集,而 validation_input_fn 创建一个 tf. data.Dataset 执行一次验证集。

train() 是否在调用期间保持 train_input_fn 的状态(即如果引用匹配则只调用一次)?这是人们使用 Estimator 进行训练循环的方式吗?

最佳答案

正如我在上面的评论中提到的,看起来它不会在调用 estimator.train() 时保存状态。

我正在使用的一个解决方案(可能也是预期的方法)是将评估监听器传递给 estimator.train()。例如,

class EvalCheckpointSaverListener(tf.train.CheckpointSaverListener):
  def __init__(self, estimator, input_fn):
    self.estimator = estimator
    self.input_fn = input_fn

  def after_save(self, session, global_step):
    self.estimator.evaluate(self.input_fn)

estimator.train(
  input_fn=lambda:_train_input_fn(...),
  max_steps=N,
  saving_listeners=[
    EvalCheckpointSaverListener(
      estimator,
      lambda:_eval_input_fn(...), 
    ),
  ],
)

关于tensorflow - tf.estimator.Estimator.train() 是否保持 input_fn 状态,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46925196/

相关文章:

python - 使用 RNN 和 Layer 类在 Keras 中实现最小的 LSTMCell

python - 神经网络 - 输入标准化

python - 如何在 Google App Engine 柔性环境中运行 TensorFlow?

python - python tensorflow信号处理MFCC功能

python - 在tensorflow docker-compose中没有贴切?

python - 缺少标签的深度多任务学习

tensorflow - 如何保存在TPU上训练的Keras模型?

python - 保存广度和深度的 tensorflow 模型

python - Keras TimeDistributed 层实际上是做什么的?

python - 修改tensorflow中的预训练模型