python - 使用队列进行训练与测试

标签 python machine-learning tensorflow

我正在使用描述的设置 here批量加载一些训练图像,基本上是这样的:

def read_my_file_format(filename_queue):
  # ... use a reader + a decoder

def input_pipeline(filenames, batch_size, num_epochs=None):
  filename_queue = tf.train.string_input_producer(...)
  example, label = read_my_file_format(filename_queue)
  example_batch, label_batch = tf.train.shuffle_batch(
      [example, label], batch_size=batch_size, ...)
  return example_batch, label_batch

def build_net():
    batch, label = input_pipeline(...)
    y = encoder(batch)  # <- build network using the batch

def train():
  with tf.Session() as sess:
    # ... init vars

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    try:
      while not coord.should_stop():
        # ... training step

    except tf.errors.OutOfRangeError:
        print('Done training -- epoch limit reached')
    finally:
        coord.request_stop()

    coord.join(threads)
    sess.close()

这对训练很有好处 - 但是,我不知道如何测试生成的网络!让我困惑的是:

  • input_pipeline 返回的张量是网络的一部分。为了测试,我必须更换它吗?
  • 我想我可以创建另一个 input_pipeline 进行测试,即使用不同的文件名队列。然后我可以使用 tf.cond 在不同的输入批处理之间切换,但是接下来:如何确保一次只有一个队列被耗尽。我不知道如何访问不同的队列以及如何指定它们的卸载方式。
<小时/>

基本上,这个问题可以归结为:测试使用 tf.train.shuffle_batch 方法构建的网络的规范方法是什么。

最佳答案

您为数据集评估创建额外输入管道的想法绝对是正确的。使用multiple input pipelines是推荐的方法之一,它由两个过程组成——一个过程是培训,另一个过程是评估。在训练过程中将使用检查点,然后每千步,代码可以尝试 eval针对训练和测试数据集的模型。

引自文档:

  • The training process reads training input data and periodically writes checkpoint files with all the trained variables.
  • The evaluation process restores the checkpoint files into an inference model that reads validation input data.

即使在训练完成/退出后也可以进行评估。 ( see this example )

另一个考虑因素是 sharing variables train 和 eval 可以在同一个进程的同一个图中运行,同时共享它们训练过的变量!

关于您所关心的队列耗尽问题,如果您使用 tf.train.shuffle_batch* 设置 num_threads 大于 1,它会同时从单个文件读取(+ 比使用 1 个线程更快) ,而不是一次 N 个文件(请参阅 batching 部分)。

关于python - 使用队列进行训练与测试,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/40802457/

相关文章:

tensorflow - 如何恢复只有一个扩展名为 ".model"的文件的 tensorflow 模型

python - 仅从 Dash Python 中的表中提取过滤后的数据

python - 如何在pyqt4中将qtableview行对齐到右侧?

machine-learning - 人工神经网络相对于支持向量机有哪些优势?

python - 使用 sklearn 和 Spark 时的轮廓分数不同

python - 在某些情况下,tf.matmul 是否等同于 Dense 层在 tensorflow 中进行的操作?

python - 在 Python 中使用文件属性时在控制台上显示打印

python - 使用 XLRD 包识别 Excel 工作表单元格颜色代码

machine-learning - 无监督学习中的训练/测试分割是否必要/有用?

tensorflow - 如何停止 tensorflow 中张量某些条目的梯度