python - `MonitoredTrainingSession()`如何与 "restore"和 "testing mode"一起使用?

标签 python session tensorflow distributed restore

在Tensorflow中,我们可以使用Between-graph Replication构建和创建多个Tensorflow session 以进行分布式培训。 MonitoredTrainingSession()协调多个Tensorflow session ,并且checkpoint_dir有一个参数MonitoredTrainingSession()用于恢复Tensorflow session /图形。现在我有以下问题:

  • 我们通常使用tf.train.Saver()对象通过saver.restore(...)恢复Tensorflow图。但是,如何使用MonitoredTrainingSession()还原它们呢?
  • 因为我们运行多个流程,并且每个流程都会构建并创建一个Tensorflow session 进行培训,所以我想知道在培训之后我们是否还必须运行多个流程进行测试(或预测)。换句话说,MonitoredTrainingSession()如何在测试(或预测)模式下工作?

  • 我阅读了Tensorflow Doc,但没有找到这2个问题的答案。如果有人有解决方案,我将不胜感激。谢谢!

    最佳答案

    简短答案:

  • 您需要将全局步骤传递给传递给mon_sess.run的优化器。这样就可以保存和检索保存的检查点。
  • 可以通过一个MonitoredTrainingSession同时运行培训和交叉验证 session 。首先,您需要通过训练图和交叉验证图通过同一张图的不同流(我建议您查询this guide以获取有关如何执行此操作的信息)。其次,您必须-对mon_sess.run()-传递针对训练流的优化器,以及传递交叉验证流的损失(/您要跟踪的参数)的参数。如果要与培训分开运行测试 session ,只需在图中仅运行测试集,并在图中仅运行test_loss(/要跟踪的其他参数)即可。有关如何完成此操作的更多详细信息,请参见下面。

  • 长答案:

    我将更新答案,因为我自己将更好地了解tf.train.MonitoredSession(tf.train.MonitoredTrainingSession可以做什么,只是创建了tf.train.MonitoredSession的专用版本,如source code所示。 )。

    以下是示例代码,显示了如何每5秒将检查点保存到'./ckpt_dir'。中断后,它将在最后保存的检查点处重新启动:
    def train(inputs, labels_onehot, global_step):
        out = tf.contrib.layers.fully_connected(
                                inputs,
                                num_outputs=10,
                                activation_fn=tf.nn.sigmoid)
        loss = tf.reduce_mean(
                 tf.reduce_sum(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                                logits=out,
                                labels=labels_onehot), axis=1))
        train_op = opt.minimize(loss, global_step=global_step)
        return train_op
    
    with tf.Graph().as_default():
        global_step = tf.train.get_or_create_global_step()
        inputs = ...
        labels_onehot = ...
        train_op = train(inputs, labels_onehot, global_step)
    
        with tf.train.MonitoredTrainingSession(
            checkpoint_dir='./ckpt_dir',
            save_checkpoint_secs=5,
            hooks=[ ... ] # Choose your hooks
        ) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
    

    为了实现此目的,在MonitoredTrainingSession中发生的事情实际上是三件事:
  • tf.train.MonitoredTrainingSession创建一个tf.train.Scaffold对象,该对象的作用类似于网络中的蜘蛛。它收集了您需要训练,保存和加载模型的零件。
  • 它创建一个tf.train.ChiefSessionCreator对象。我对此的知识是有限的,但据我了解,它用于将tf算法分布在多个服务器上的情况。我的理解是,它告诉运行该文件的计算机是主计算机,并且应该在此处保存检查点目录,并且记录程序应该在此处记录其数据,等等。
  • 它创建一个tf.train.CheckpointSaverHook,用于保存检查点。

  • 为了使其工作,必须将tf.train.CheckpointSaverHook和tf.train.ChiefSessionCreator传递给检查点目录和脚手架相同的引用。如果上面示例中的tf.train.MonitoredTrainingSession及其参数要通过上面的3个组件来实现,则它看起来像这样:
    checkpoint_dir = './ckpt_dir'
    
    scaffold = tf.train.Scaffold()
    saverhook = tf.train.CheckpointSaverHook(
        checkpoint_dir=checkpoint_dir,
        save_secs=5
        scaffold=scaffold
    )
    session_creator = tf.train.ChiefSessionCreator(
        scaffold=scaffold,
        checkpoint_dir=checkpoint_dir
    )
    
    with tf.train.MonitoredSession(
        session_creator=session_creator,
        hooks=[saverhook]) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
    

    为了进行训练+交叉验证 session ,您可以将tf.train.MonitoredSession.run_step_fn()与partial一起使用,这样可以在不调用任何钩子(Hook)的情况下运行 session 调用。它的外观是先训练模型n次迭代,然后运行测试集,重新初始化迭代器,然后再训练模型,等等。当然,您必须将变量设置为复用= tf.AUTO_REUSE在执行此操作时。在代码中执行此操作的方法如下所示:
    from functools import partial
    
    # Build model
    ...
    
    with tf.variable_scope(..., reuse=tf.AUTO_REUSE):
        ...
    
    ...
    
    def step_fn(fetches, feed_dict, step_context):
        return step_context.session.run(fetches=fetches, feed_dict=feed_dict)
    
    with tf.train.MonitoredTrainingSession(
                    checkpoint_dir=...,
                    save_checkpoint_steps=...,
                    hooks=[...],
                    ...
                    ) as mon_sess:
    
                    # Initialize iterators (assuming tf.Databases are used)
                    mon_sess.run_step_fn(
                               partial(
                                   step_fn, 
                                   [train_it.initializer, 
                                    test_it.initializer, 
                                    ...
                                   ], 
                                   {}
                                )
                    )
    
                    while not mon_sess.should_stop():
                        # Train session
                        for i in range(n):
                            try:
                                train_results = mon_sess.run(<train_fetches>)
                            except Exception as e:
                                break
    
                        # Test session
                        while True:
                            try:
                                test_results = mon_sess.run(<test_fetches>)
                            except Exception as e:
                                break
    
                        # Reinitialize parameters
                        mon_sess.run_step_fn(
                                   partial(
                                      step_fn, 
                                      [train_it.initializer, 
                                       test_it.initializer, 
                                       ...
                                      ], 
                                      {}
                                   )
                        )
    

    局部函数只是对mon_sess.run_step_fn()中使用的step_fn执行currying(函数编程中的经典函数)。上面的整个代码尚未经过测试,您可能必须在开始测试 session 之前重新初始化train_it,但希望现在可以清楚地知道如何在同一运行中同时运行训练集和验证集。如果要在同一图中绘制训练曲线和测试曲线,则可以将其与张量板的custom_scalar tool一起使用。

    最后,这是我已经能够实现的最佳功能,我个人希望tensorflow将来使此功能的实现更加容易,因为它非常繁琐且可能没有那么高效。我知道有诸如Estimator之类的工具可以运行train_and_evaluate函数,但是由于这会在每次训练和交叉验证运行之间重建图表,因此如果仅在一台计算机上运行,​​效率会非常低。我在某处读到Keras + tf具有此功能,但是由于我不使用Keras + tf,因此这不是一个选择。无论如何,我希望这可以帮助其他苦苦挣扎的人!

    关于python - `MonitoredTrainingSession()`如何与 "restore"和 "testing mode"一起使用?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43104992/

    相关文章:

    python - 声音不会一直播放! - Pygame,Python

    python - 通过 SQLAlchemy 使用 Postgresql 执行多个语句不会保留更改

    python - 在 docker 和 Jenkins 中处理大型二进制文件 (3 GB)

    python - 在 Matplotlib 中加载文本文件时出错

    java - 如何使用NiFi进程 session 迁移功能?

    tensorflow - Inception-v3 使用 RMSProp epsilon=1

    php - 比较数据库中的日期是否大于当前日期

    session - 如何在另一个类中重用在 AsyncTask 中创建的 SSH (Jsch) session

    python - 当你添加一个带有 numpy 数组的张量时会发生什么?

    python - 在每次迭代时为动态大小的张量创建占位符