tensorflow - 温和地批量读取 tfrecords 数据的方法

标签 tensorflow deep-learning

以下所有问题均基于tensorflow 1.0 API

我现在可以在以类名命名的目录下写入图像,这是我生成的 tfrecords 代码:

def _convert_to_example(filename, image_buffer, label, text, height, width):
    colorspace = 'RGB'
    channels = 3
    image_format = 'JPEG'

    example = tf.train.Example(features=tf.train.Features(feature={
      'image/height': _int64_feature(height),
      'image/width': _int64_feature(width),
      'image/colorspace': _bytes_feature(tf.compat.as_bytes(colorspace)),
      'image/channels': _int64_feature(channels),
      'image/class/label': _int64_feature(label),
      'image/class/text': _bytes_feature(tf.compat.as_bytes(text)),
      'image/format': _bytes_feature(tf.compat.as_bytes(image_format)),
      'image/filename': _bytes_feature(tf.compat.as_bytes(os.path.basename(filename))),
      'image/encoded': _bytes_feature(tf.compat.as_bytes(image_buffer))}))
    return example

这是主要方法,所以这里我存储了高度、宽度、 channel (这个值无法读出)等。

我可以读取 tfrecords,这是我的代码:

def read_tfrecords():
    print('reading from tfrecords file {}'.format(FLAGS.record_file))
    record_iterator = tf.python_io.tf_record_iterator(path=FLAGS.record_file)

    with tf.Session() as sess:
        for string_record in record_iterator:
            example = tf.train.Example()
            example.ParseFromString(string_record)

            height_ = int(example.features.feature['image/height'].int64_list.value[0])
            width_ = int(example.features.feature['image/width'].int64_list.value[0])
            channels_ = int(example.features.feature['image/channels'].int64_list.value[0])

            image_bytes_ = example.features.feature['image/encoded'].bytes_list.value[0]
            label_ = int(example.features.feature['image/class/label'].int64_list.value[0])
            text_bytes_ = example.features.feature['image/class/text'].bytes_list.value[0]

            # image_array_ = np.fromstring(image_bytes_, dtype=np.uint8).reshape((height_, width_, 3))
            image_ = tf.image.decode_jpeg(image_bytes_)
            image_ = sess.run(image_)
            text_ = text_bytes_.decode('utf-8')

            print('tfrecords height {0}, width {1}, channels {2}: '.format(height_, width_, channels_))
            print('decode image shape: ', image_.shape)
            print('label text: ', text_)
            print('label: ', label_)
            # io.imshow(image_)
            # plt.show()

一切都很公平,但是,当我尝试将 tfrecords 数据批量加载并将其输入网络时,问题发生了

这是我批量加载的所有代码:

tf.app.flags.DEFINE_integer('target_image_height', 150, 'train input image height')
tf.app.flags.DEFINE_integer('target_image_width', 200, 'train input image width')

tf.app.flags.DEFINE_integer('batch_size', 12, 'batch size of training.')
tf.app.flags.DEFINE_integer('num_epochs', 100, 'epochs of training.')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate of training.')

FLAGS = tf.app.flags.FLAGS


def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized=serialized_example,
        features={
            'image/height': tf.FixedLenFeature([], tf.int64),
            'image/width': tf.FixedLenFeature([], tf.int64),
            'image/channels': tf.FixedLenFeature([], tf.int64),
            'image/encoded': tf.FixedLenFeature([], tf.string),
            'image/class/label': tf.FixedLenFeature([], tf.int64),
        })

    image = tf.decode_raw(features['image/encoded'], out_type=tf.uint8)
    height = tf.cast(features['image/height'], dtype=tf.int32)
    width = tf.cast(features['image/width'], dtype=tf.int32)
    channels = tf.cast(features['image/channels'], dtype=tf.int32)
    label = tf.cast(features['image/class/label'], dtype=tf.int32)

    # cast image int64 to float32 [0, 255] -> [-0.5, 0.5]
    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
    image_shape = tf.stack([height, width, 3])
    image = tf.reshape(image, image_shape)

    return image, label


def inputs(train, batch_size, num_epochs):
    if not num_epochs:
        num_epochs = None
    filenames = ['./data/tiny_5_tfrecords/train-00000-of-00002',
                 './data/tiny_5_tfrecords/train-00001-of-00002']
    print(filenames)
    with tf.name_scope('input'):
        filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs)
    print(filename_queue)
    image, label = read_and_decode(filename_queue)
    images, sparse_labels = tf.train.shuffle_batch(
        [image, label],
        batch_size=batch_size,
        num_threads=2,
        capacity=1000 + 3 * batch_size,
        min_after_dequeue=1000)

    return images, sparse_labels


def run_training():
    images, labels = inputs(train=True, batch_size=FLAGS.batch_size,
                            num_epochs=FLAGS.num_epochs)

    images = tf.Print(images, [images], message='this is images:')
    images.eval()
    predictions = inference.lenet(images=images, num_classes=5, activation_fn='relu')
    slim.losses.softmax_cross_entropy(predictions, labels)

    total_loss = slim.losses.get_total_loss()
    tf.summary.scalar('loss', total_loss)

    optimizer = tf.train.RMSPropOptimizer(0.001, 0.9)

    train_op = slim.learning.create_train_op(total_loss=total_loss,
                                             optimizer=optimizer,
                                             summarize_gradients=True)
    slim.learning.train(train_op=train_op, save_summaries_secs=20)


def main(_):
    run_training()


if __name__ == '__main__':
    tf.app.run()

我运行这个程序,得到这个错误:

raceback (most recent call last):
  File "train_tiny5_tensorflow.py", line 111, in <module>
    tf.app.run()
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/platform/app.py", line 44, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "train_tiny5_tensorflow.py", line 107, in main
    run_training()
  File "train_tiny5_tensorflow.py", line 88, in run_training
    num_epochs=FLAGS.num_epochs)
  File "train_tiny5_tensorflow.py", line 81, in inputs
    min_after_dequeue=1000)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/training/input.py", line 1165, in shuffle_batch
    name=name)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/training/input.py", line 724, in _shuffle_batch
    dtypes=types, shapes=shapes, shared_name=shared_name)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/data_flow_ops.py", line 624, in __init__
    shapes = _as_shape_list(shapes, dtypes)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/data_flow_ops.py", line 77, in _as_shape_list
    raise ValueError("All shapes must be fully defined: %s" % shapes)
ValueError: All shapes must be fully defined: [TensorShape([Dimension(None), Dimension(None), Dimension(3)]), TensorShape([])]

显然,程序根本没有获取 tfrecords 文件。

我已经尝试过这个: 1.我觉得可能是filenames不对,我把它改成相对路径和绝对路径,都可以; 2.我把tfrecords文件放在脚本旁边,直接写tfrecords文件名不行。

所以,基本上,我遇到了这个问题:

<强>1。编写尽可能短的程序将 tfrecords 文件批量加载并输入网络的正式且合理的方法是什么

<强>2。顺便说一句,编写 tensorflow 层最简单、最优雅的方法是什么? slim是一个不错的选择,原来的方式又丑又复杂!

最佳答案

对于可能遇到同样问题的人,我在上面的代码中犯了一些错误。 只需不使用 decode_raw,而是使用 tf.image.decode_jpeg 和我的代码函数

def inputs(train, batch_size, num_epochs):
    if not num_epochs:
        num_epochs = None
    filenames = ['./data/tiny_5_tfrecords/train-00000-of-00002',
                 './data/tiny_5_tfrecords/train-00001-of-00002']
    print(filenames)
    with tf.name_scope('input'):
        filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs)
    print(filename_queue)
    image, label = read_and_decode(filename_queue)
    images, sparse_labels = tf.train.shuffle_batch(
        [image, label],
        batch_size=batch_size,
        num_threads=2,
        capacity=1000 + 3 * batch_size,
        min_after_dequeue=1000)

    return images, sparse_labels

最后两行我错过了一个制表符。

关于tensorflow - 温和地批量读取 tfrecords 数据的方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42509811/

相关文章:

python - 在 Keras 中使用 multi_gpu_model 恢复训练

python - 了解 Keras LSTM

python - 如何使用 Jupyter 笔记本解决 Python 中形状为 '(?, 4) 的 Tensor 'Placeholder_1:0' 的 ValueError : Cannot feed value of shape (230, )

python - 将 Universal Sentence Encoder 保存到 Tflite 或将其提供给 tensorflow api

python - 经过卷积步骤后,全连接层中张量的形状应该是什么?

android - tensorflow 模型的图像预处理参数

math - 硬 Sigmoid 是如何定义的

python - 哪个更有效 : tf. where 或 element-wise multiplication?

tensorflow - 尝试在 tensorflow 中训练 mobilenet 时出现 ERROR : Config value cuda is not defined in any . rc 文件

deep-learning - Fast R-CNN 中 ROI 层的目的是什么?