在Tensorflow中,我们可以使用Between-graph Replication
构建和创建多个Tensorflow session 以进行分布式培训。 MonitoredTrainingSession()
协调多个Tensorflow session ,并且checkpoint_dir
有一个参数MonitoredTrainingSession()
用于恢复Tensorflow session /图形。现在我有以下问题:
tf.train.Saver()
对象通过saver.restore(...)
恢复Tensorflow图。但是,如何使用MonitoredTrainingSession()
还原它们呢? MonitoredTrainingSession()
如何在测试(或预测)模式下工作? 我阅读了Tensorflow Doc,但没有找到这2个问题的答案。如果有人有解决方案,我将不胜感激。谢谢!
最佳答案
简短答案:
长答案:
我将更新答案,因为我自己将更好地了解tf.train.MonitoredSession(tf.train.MonitoredTrainingSession可以做什么,只是创建了tf.train.MonitoredSession的专用版本,如source code所示。 )。
以下是示例代码,显示了如何每5秒将检查点保存到'./ckpt_dir'。中断后,它将在最后保存的检查点处重新启动:
def train(inputs, labels_onehot, global_step):
out = tf.contrib.layers.fully_connected(
inputs,
num_outputs=10,
activation_fn=tf.nn.sigmoid)
loss = tf.reduce_mean(
tf.reduce_sum(
tf.nn.sigmoid_cross_entropy_with_logits(
logits=out,
labels=labels_onehot), axis=1))
train_op = opt.minimize(loss, global_step=global_step)
return train_op
with tf.Graph().as_default():
global_step = tf.train.get_or_create_global_step()
inputs = ...
labels_onehot = ...
train_op = train(inputs, labels_onehot, global_step)
with tf.train.MonitoredTrainingSession(
checkpoint_dir='./ckpt_dir',
save_checkpoint_secs=5,
hooks=[ ... ] # Choose your hooks
) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
为了实现此目的,在MonitoredTrainingSession中发生的事情实际上是三件事:
为了使其工作,必须将tf.train.CheckpointSaverHook和tf.train.ChiefSessionCreator传递给检查点目录和脚手架相同的引用。如果上面示例中的tf.train.MonitoredTrainingSession及其参数要通过上面的3个组件来实现,则它看起来像这样:
checkpoint_dir = './ckpt_dir'
scaffold = tf.train.Scaffold()
saverhook = tf.train.CheckpointSaverHook(
checkpoint_dir=checkpoint_dir,
save_secs=5
scaffold=scaffold
)
session_creator = tf.train.ChiefSessionCreator(
scaffold=scaffold,
checkpoint_dir=checkpoint_dir
)
with tf.train.MonitoredSession(
session_creator=session_creator,
hooks=[saverhook]) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
为了进行训练+交叉验证 session ,您可以将tf.train.MonitoredSession.run_step_fn()与partial一起使用,这样可以在不调用任何钩子(Hook)的情况下运行 session 调用。它的外观是先训练模型n次迭代,然后运行测试集,重新初始化迭代器,然后再训练模型,等等。当然,您必须将变量设置为复用= tf.AUTO_REUSE在执行此操作时。在代码中执行此操作的方法如下所示:
from functools import partial
# Build model
...
with tf.variable_scope(..., reuse=tf.AUTO_REUSE):
...
...
def step_fn(fetches, feed_dict, step_context):
return step_context.session.run(fetches=fetches, feed_dict=feed_dict)
with tf.train.MonitoredTrainingSession(
checkpoint_dir=...,
save_checkpoint_steps=...,
hooks=[...],
...
) as mon_sess:
# Initialize iterators (assuming tf.Databases are used)
mon_sess.run_step_fn(
partial(
step_fn,
[train_it.initializer,
test_it.initializer,
...
],
{}
)
)
while not mon_sess.should_stop():
# Train session
for i in range(n):
try:
train_results = mon_sess.run(<train_fetches>)
except Exception as e:
break
# Test session
while True:
try:
test_results = mon_sess.run(<test_fetches>)
except Exception as e:
break
# Reinitialize parameters
mon_sess.run_step_fn(
partial(
step_fn,
[train_it.initializer,
test_it.initializer,
...
],
{}
)
)
局部函数只是对mon_sess.run_step_fn()中使用的step_fn执行currying(函数编程中的经典函数)。上面的整个代码尚未经过测试,您可能必须在开始测试 session 之前重新初始化train_it,但希望现在可以清楚地知道如何在同一运行中同时运行训练集和验证集。如果要在同一图中绘制训练曲线和测试曲线,则可以将其与张量板的custom_scalar tool一起使用。
最后,这是我已经能够实现的最佳功能,我个人希望tensorflow将来使此功能的实现更加容易,因为它非常繁琐且可能没有那么高效。我知道有诸如Estimator之类的工具可以运行train_and_evaluate函数,但是由于这会在每次训练和交叉验证运行之间重建图表,因此如果仅在一台计算机上运行,效率会非常低。我在某处读到Keras + tf具有此功能,但是由于我不使用Keras + tf,因此这不是一个选择。无论如何,我希望这可以帮助其他苦苦挣扎的人!
关于python - `MonitoredTrainingSession()`如何与 "restore"和 "testing mode"一起使用?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43104992/