python - TensorFlow - 一次读取 TFRecords 中的所有示例?

标签 python tensorflow tfrecord

如何一次读取 TFRecords 中的所有示例?

我一直在使用 tf.parse_single_example 来读取单个示例,使用的代码类似于 example of the fully_connected_reader 中的方法 read_and_decode 中给出的代码。 .但是,我想一次针对我的整个验证数据集运行网络,因此想全部加载它们。

我不完全确定,但是 the documentation似乎建议我可以使用 tf.parse_example 而不是 tf.parse_single_example 一次加载整个 TFRecords 文件。我似乎无法让它工作。我猜这与我如何指定功能有关,但我不确定在功能规范中如何说明有多个示例。

换句话说,我尝试使用类似于:

reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_example(serialized_example, features={
    'image_raw': tf.FixedLenFeature([], tf.string),
    'label': tf.FixedLenFeature([], tf.int64),
})

不起作用,我认为这是因为这些功能不会同时出现多个示例(但同样,我不确定)。 [这会导致错误 ValueError: Shape () must have rank 1]

这是一次读取所有记录的正确方法吗?如果是这样,我需要改变什么才能真正阅读记录?非常感谢!

最佳答案

为了清楚起见,我在一个 .tfrecords 文件中有几千张图片,它们是 720 x 720 rgb png 文件。标签是 0,1,2,3 之一。

我也尝试使用 parse_example,但无法使其正常工作,但此解决方案适用于 parse_single_example。

缺点是现在我必须知道每个 .tf 记录中有多少项目,这有点令人沮丧。如果我找到更好的方法,我会更新答案。另外,要小心超出 .tfrecords 文件中记录数的范围,如果您遍历最后一条记录,它将从第一条记录重新开始

诀窍是让队列运行者使用协调器。

我在这里留下了一些代码来保存正在读取的图像,以便您可以验证图像是否正确。

from PIL import Image
import numpy as np
import tensorflow as tf

def read_and_decode(filename_queue):
 reader = tf.TFRecordReader()
 _, serialized_example = reader.read(filename_queue)
 features = tf.parse_single_example(
  serialized_example,
  # Defaults are not specified since both keys are required.
  features={
      'image_raw': tf.FixedLenFeature([], tf.string),
      'label': tf.FixedLenFeature([], tf.int64),
      'height': tf.FixedLenFeature([], tf.int64),
      'width': tf.FixedLenFeature([], tf.int64),
      'depth': tf.FixedLenFeature([], tf.int64)
  })
 image = tf.decode_raw(features['image_raw'], tf.uint8)
 label = tf.cast(features['label'], tf.int32)
 height = tf.cast(features['height'], tf.int32)
 width = tf.cast(features['width'], tf.int32)
 depth = tf.cast(features['depth'], tf.int32)
 return image, label, height, width, depth


def get_all_records(FILE):
 with tf.Session() as sess:
   filename_queue = tf.train.string_input_producer([ FILE ])
   image, label, height, width, depth = read_and_decode(filename_queue)
   image = tf.reshape(image, tf.pack([height, width, 3]))
   image.set_shape([720,720,3])
   init_op = tf.initialize_all_variables()
   sess.run(init_op)
   coord = tf.train.Coordinator()
   threads = tf.train.start_queue_runners(coord=coord)
   for i in range(2053):
     example, l = sess.run([image, label])
     img = Image.fromarray(example, 'RGB')
     img.save( "output/" + str(i) + '-train.png')

     print (example,l)
   coord.request_stop()
   coord.join(threads)

get_all_records('/path/to/train-0.tfrecords')

关于python - TensorFlow - 一次读取 TFRecords 中的所有示例?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37151895/

相关文章:

python - 在 PyDev 中,如何为方法返回的实例获取自动完成功能?

json - 使用keras在json中转储cnn的权重

tensorflow - 训练模型以实现 DLib 的面部标志,例如手部特征点及其标志

python - 从 Python 将 TFRecord 输出到 Google Cloud Storage

java - 如何优雅地以各种顺序调用方法?

python - 如何访问和可视化预训练的 TensorFlow 2 模型中的权重?

具有混合数据类型的 TensorFlow 数据集生成器

Tensorflow:计算 TFRecord 文件中示例的数量——不使用已弃用的 `tf.python_io.tf_record_iterator`

tensorflow - 使用 keras 在 gcloud ml-engine 上处理 TB 数据的最佳方法

python - 如何在 Celery 任务执行期间强制记录器格式?