python - 谁使用 tf.estimator.train_and_evaluate 提前停止评估损失?

标签 python tensorflow deep-learning tensorflow-estimator

我正在使用 Tensorflow 估计器并明确使用方法 tf.estimator.train_and_evaluate() .
训练有一个早期停止钩子(Hook)是 tf.contrib.estimator.stop_if_no_decrease_hook ,但我确实有一个问题,即训练损失过于紧张,无法用于提前停止。
有谁知道如何使用 tf.estimator 提前停止基于评估损失?

最佳答案

您可以使用 tf.contrib.estimator.stop_if_no_decrease_hook 如下所示:

estimator = tf.estimator.Estimator(model_fn, model_dir)

os.makedirs(estimator.eval_dir())  # TODO This should not be expected IMO.

early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(
    estimator,
    metric_name='loss',
    max_steps_without_decrease=1000,
    min_steps=100)

tf.estimator.train_and_evaluate(
    estimator,
    train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]),
    eval_spec=tf.estimator.EvalSpec(eval_input_fn))

但如果它对你不起作用,最好使用 tf.estimator.experimental.stop_if_no_decrease_hook 反而。

例如:
estimator = ...
# Hook to stop training if loss does not decrease in over 100000 steps.
hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 100000)
train_spec = tf.estimator.TrainSpec(..., hooks=[hook])
tf.estimator.train_and_evaluate(estimator, train_spec, ...)

早期停止钩子(Hook)使用评估结果来决定何时停止训练,但您需要传入要监控的训练步骤数,并记住在该数量的步骤中将发生多少次评估。
如果您将钩子(Hook)设置为 hook = early_stopping.stop_if_no_decrease_hook(estimator, "loss", 10000) hook 将考虑在 10k 步范围内发生的评估。

在此处阅读有关文档的更多信息:https://www.tensorflow.org/api_docs/python/tf/estimator/experimental/stop_if_no_decrease_hook以及所有你可以使用的提前停止功能,你可以引用这个https://github.com/tensorflow/estimator/blob/master/tensorflow_estimator/python/estimator/early_stopping.py

关于python - 谁使用 tf.estimator.train_and_evaluate 提前停止评估损失?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60678769/

相关文章:

image-processing - 我想知道MS COCO数据集中有没有服装对象类?

python - 如何重复输入层直到超过一定数量的神经元?

python - 将双端队列保存在文本文件中

python - 有没有办法可以用 bash 而不是 cmd 进行编译?

python - numpy.unique 生成的列表在哪些方面是唯一的?

javascript - 在 tensorflow.js 中加载保存模型后使用自定义模型预测错误

python - 在 ubuntu 14.04 中运行 keras 进行深度学习时出错

python - 创建全卷积网络

python - 自然场景数字识别的深度学习解决方案

python - 集合比较是自反的,但不会短路。为什么?