python - 如何从示例队列将数据读入 TensorFlow 批处理?

标签 python numpy classification tensorflow

如何将 TensorFlow 示例队列分成合适的批处理进行训练?

我有一些图片和标签:

IMG_6642.JPG 1
IMG_6643.JPG 2

(请随意建议另一种标签格式;我想我可能需要另一个密集到稀疏的步骤...)

我已经阅读了很多教程,但还没有完全掌握。 这就是我所拥有的,其中的注释指出了 TensorFlow 的 Reading Data 所需的步骤。页面。

  1. 文件名列表 (为简单起见,删除了可选步骤)
  2. 文件名队列
  3. 文件格式的阅读器
  4. 读取器读取记录的解码器
  5. 示例队列

在示例队列之后,我需要将该队列分批进行训练;这就是我被困的地方......

1.文件名列表

files = tf.train.match_filenames_once('*.JPG')

4.文件名队列

filename_queue = tf.train.string_input_producer(files, num_epochs=None, shuffle=True, seed=None, shared_name=None, name=None)

5.一个读者

reader = tf.TextLineReader() key, value = reader.read(filename_queue)

6.解码器

record_defaults = [[""], [1]] col1,col2 = tf.decode_csv(值,record_defaults=record_defaults) (我认为我不需要下面这一步,因为我已经将标签放在张量中,但无论如何我都包含它)

features = tf.pack([col2])

文档页面有一个示例运行一张图片,而不是批量获取图片和标签:

对于我在范围内(1200): # 获取单个实例: 例如,label = sess.run([features, col5])

然后它下面有一个批处理部分:

def read_my_file_format(filename_queue):
  reader = tf.SomeReader()
  key, record_string = reader.read(filename_queue)
  example, label = tf.some_decoder(record_string)
  processed_example = some_processing(example)
  return processed_example, label

def input_pipeline(filenames, batch_size, num_epochs=None):
  filename_queue = tf.train.string_input_producer(
  filenames, num_epochs=num_epochs, shuffle=True)
  example, label = read_my_file_format(filename_queue)
  # min_after_dequeue defines how big a buffer we will randomly sample
  #   from -- bigger means better shuffling but slower start up and more
  #   memory used.
  # capacity must be larger than min_after_dequeue and the amount larger
  #   determines the maximum we will prefetch.  Recommendation:
  #   min_after_dequeue + (num_threads + a small safety margin) *              batch_size
  min_after_dequeue = 10000
  capacity = min_after_dequeue + 3 * batch_size
  example_batch, label_batch = tf.train.shuffle_batch(
  [example, label], batch_size=batch_size, capacity=capacity,
  min_after_dequeue=min_after_dequeue)
  return example_batch, label_batch

我的问题是:如何将上面的示例代码与上面的代码一起使用?我需要 批处理 才能使用,并且大多数教程都附带mnist 批处理。

with tf.Session() as sess:
  sess.run(init)

  # Training cycle
for epoch in range(training_epochs):
    total_batch = int(mnist.train.num_examples/batch_size)
    # Loop over all batches
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)

最佳答案

如果您想让这个输入管道工作,您需要添加一个异步队列机制来生成批量示例。这是通过创建 tf.RandomShuffleQueuetf.FIFOQueue 并插入已读取、解码和预处理的 JPEG 图像来执行的。

您可以使用方便的构造来生成队列和通过 tf.train.shuffle_batch_jointf.train.batch_join 运行队列的相应线程。这是一个简化的例子。请注意,此代码未经测试:

# Let's assume there is a Queue that maintains a list of all filenames
# called 'filename_queue'
_, file_buffer = reader.read(filename_queue)

# Decode the JPEG images
images = []
image = decode_jpeg(file_buffer)

# Generate batches of images of this size.
batch_size = 32

# Depends on the number of files and the training speed.
min_queue_examples = batch_size * 100
images_batch = tf.train.shuffle_batch_join(
  image,
  batch_size=batch_size,
  capacity=min_queue_examples + 3 * batch_size,
  min_after_dequeue=min_queue_examples)

# Run your network on this batch of images.
predictions = my_inference(images_batch)

根据您需要如何扩展您的工作,您可能需要运行多个独立线程来读取/解码/预处理图像并将它们转储到您的示例队列中。 Inception/ImageNet 模型中提供了此类管道的完整示例。看看 batch_inputs:

https://github.com/tensorflow/models/blob/master/inception/inception/image_processing.py#L407

最后,如果您使用 >O(1000) JPEG 图像,请记住,单独准备 1000 个小文件效率极低。这会大大减慢您的训练速度。

将图像数据集转换为 Example 原型(prototype)的分片 TFRecord 的更强大、更快速的解决方案。这是一个完整的 script用于将 ImageNet 数据集转换为这种格式。这是一组instructions用于在包含 JPEG 图像的任意目录上运行此预处理脚本的通用版本。

关于python - 如何从示例队列将数据读入 TensorFlow 批处理?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37126108/

相关文章:

python : fit a curve to a list of integers

java - 查找哈希集中每个单词在文本文档中出现的次数

python - 如何从具有 2 个输出神经元的 softmax 二元分类器绘制 ROC 曲线?

python - 如果类似内容尚不存在,则插入记录时出现问题

python - Golang 中 UUID4 的整数表示

python - 从 numpy.timedelta64 值中提取天数

python - 在 scikit learn 中使用分类预测变量

python - 创建一个零填充的 Pandas 数据框

python - 尝试使用 python 将数据导入 MySQL 时出现 "Not all parameters were used in the SQL statement"错误

python - 在二维 numpy 数组中查找匹配的行