validation - 使用估算器时,将验证监视器替换为 tf.train.SessionRunHook

标签 validation tensorflow

我正在运行 DNNClassifier,我在训练时监控其准确性。来自 contrib/learn 的 drivers.ValidationMonitor 一直工作得很好,在我的实现中我定义了它:

validation_monitor = skflow.monitors.ValidationMonitor(input_fn=lambda: input_fn(A_test, Cl2_test), eval_steps=1, every_n_steps=50)

然后使用以下调用:

clf.fit(input_fn=lambda: input_fn(A, Cl2),
            steps=1000, monitors=[validation_monitor])

哪里:

clf = tensorflow.contrib.learn.DNNClassifier(...

这很好用。也就是说,验证监视器似乎已被弃用,并且类似的功能将被 tf.train.SessionRunHook 取代。

我是 TensorFlow 的新手,对我来说,这样的替换实现看起来并不简单。任何建议都将受到高度赞赏。同样,我需要在特定数量的步骤后验证训练。 预先非常感谢。

最佳答案

有一个未公开的实用程序,名为 monitors.replace_monitors_with_hooks()它将监视器转换为钩子(Hook)。该方法接受 (i) 可能包含监视器和钩子(Hook)的列表,以及 (ii) 将使用钩子(Hook)的 Estimator,然后通过在每个监视器周围包装 SessionRunHook 来返回钩子(Hook)列表。

from tensorflow.contrib.learn.python.learn import monitors as monitor_lib

clf = tf.estimator.Estimator(...)

list_of_monitors_and_hooks = [tf.contrib.learn.monitors.ValidationMonitor(...)]
hooks = monitor_lib.replace_monitors_with_hooks(list_of_monitors_and_hooks, clf)

这并不是完全替换 ValidationMonitor 问题的真正解决方案 - 我们只是用一个未弃用的函数将其包装起来。不过,我可以说到目前为止这对我来说很有效,因为它保留了我需要的 ValidationMonitor 的所有功能(即评估每个 n 个步骤、提前停止使用指标等)

还有一件事 - 要使用此 Hook ,您需要从 tf.contrib.learn.Estimator (仅接受监视器)更新到更成熟和官方的 tf.estimator.Estimator (只接受钩子(Hook))。因此,您应该将分类器实例化为 tf.estimator.DNNClassifier,并使用其方法 train() 进行训练(这只是 的重命名)适合()):

clf = tf.estimator.Estimator(...)

...

clf.train(
    input_fn=...
    ...
    hooks=hooks)

关于validation - 使用估算器时,将验证监视器替换为 tf.train.SessionRunHook,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44793703/

相关文章:

ios - swift 3 : best way to validate the text entered by the user in a UITextField

tensorflow - 什么是 tf.bfloat16 "truncated 16-bit floating point"?

javascript - 使用 Tensorflow.js 和 tf.Tensor 处理大数据的最佳方式是什么?

PHP/MySQL - 验证用户名(附加字符除外)

javascript - 注册或编辑前验证用户密码

java - 输入验证检查输入是否为 double

javascript - AJAX/Javascript 复选框验证

Tensorflow 摘要标量未显示在 tensorboard 中

python - 在 TensorFlow 中分配 op : what is the return value?

tensorflow - 如何在 Fortran 中使用神经网络(基于 Pytorch 或 Tensorflow)?