一年多来我一直在使用自己的 Estimator/Experiment 之类的代码,但我最终想加入 Dataset+Estimator 的行列。
我想做如下的事情:
for _ in range(N):
estimator.train(train_input_fn, steps=1000)
estimator.evaluate(validation_input_fn)
其中 train_input_fn
创建一个 tf.data.Dataset
永远循环训练集,而 validation_input_fn
创建一个 tf. data.Dataset
执行一次验证集。
train()
是否在调用期间保持 train_input_fn
的状态(即如果引用匹配则只调用一次)?这是人们使用 Estimator 进行训练循环的方式吗?
最佳答案
正如我在上面的评论中提到的,看起来它不会在调用 estimator.train()
时保存状态。
我正在使用的一个解决方案(可能也是预期的方法)是将评估监听器传递给 estimator.train()
。例如,
class EvalCheckpointSaverListener(tf.train.CheckpointSaverListener):
def __init__(self, estimator, input_fn):
self.estimator = estimator
self.input_fn = input_fn
def after_save(self, session, global_step):
self.estimator.evaluate(self.input_fn)
estimator.train(
input_fn=lambda:_train_input_fn(...),
max_steps=N,
saving_listeners=[
EvalCheckpointSaverListener(
estimator,
lambda:_eval_input_fn(...),
),
],
)
关于tensorflow - tf.estimator.Estimator.train() 是否保持 input_fn 状态,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46925196/