python - InvalidArgumentError : Key: label. 无法解析序列化示例 : How can I find a way to parse the one-hot encoded labels from TFRecords?

标签 python tensorflow keras transfer-learning tfrecord

我有 12 个包含图像的文件夹(它们是我的数据类别)。此代码将图像及其相应标签转换为 tfrecord 数据并有效压缩:

import tensorflow as tf
from pathlib import Path
from tensorflow.keras.utils import to_categorical
import cv2
from tqdm import tqdm
from os import listdir
import numpy as np
import matplotlib.image as mpimg
from tqdm import tqdm

labels = {v:k for k, v in enumerate(listdir('train/'))}
labels

class GenerateTFRecord:
    def __init__(self, path):
        self.path = Path(path)
        self.labels = {v:k for k, v in enumerate(listdir(path))}

    def convert_image_folder(self, tfrecord_file_name):
        # Get all file names of images present in folder
        img_paths = list(self.path.rglob('*.jpg'))

        with tf.io.TFRecordWriter(tfrecord_file_name) as writer:
            for img_path in tqdm(img_paths, desc='images converted'):
                example = self._convert_image(img_path)
                writer.write(example.SerializeToString())

    def _convert_image(self, img_path):
        label = self.labels[img_path.parent.stem]
        img_shape = mpimg.imread(img_path).shape

        # Read image data in terms of bytes
        with tf.io.gfile.GFile(img_path, 'rb') as fid:
            image_data = fid.read()

        example = tf.train.Example(features = tf.train.Features(feature = {
            'rows': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[0]])),
            'cols': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[1]])),
            'channels': tf.train.Feature(int64_list = tf.train.Int64List(value = [3])),
            'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image_data])),
            'label': tf.train.Feature(int64_list = tf.train.Int64List(value = tf.one_hot(label, depth=len(labels), on_value=1, off_value=0))),
        }))
        return example

t = GenerateTFRecord(path='train/')
t.convert_image_folder('data.tfrecord')

然后我在这里使用这段代码读取 tfrecord 数据并创建我的 tf.data.Dataset:

def _parse_function(tfrecord):
    # Extract features using the keys set during creation
    features = {
        'rows': tf.io.FixedLenFeature([], tf.int64),
        'cols': tf.io.FixedLenFeature([], tf.int64),
        'channels': tf.io.FixedLenFeature([], tf.int64),
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64)
    }

    # Extract the data record
    sample = tf.io.parse_single_example(tfrecord, features)

    image = tf.image.decode_image(sample['image'])
    label = sample['label']
    # label = tf.one_hot(label, depth=len(labels), on_value=1, off_value=0)
    return image, label

def configure_for_performance(ds, buffer_size, batch_size):
    ds = ds.cache()
    ds = ds.batch(batch_size)
    ds = ds.prefetch(buffer_size=buffer_size)
    return ds


def generator(tfrecord_file, batch_size, n_data, validation_ratio, reshuffle_each_iteration=False):
    reader = tf.data.TFRecordDataset(filenames=[tfrecord_file])
    reader.shuffle(n_data, reshuffle_each_iteration=reshuffle_each_iteration)
    AUTOTUNE = tf.data.experimental.AUTOTUNE

    val_size = int(n_data * validation_ratio)
    train_ds = reader.skip(val_size)
    val_ds = reader.take(val_size)

    train_ds = train_ds.map(_parse_function, num_parallel_calls=AUTOTUNE)
    train_ds = configure_for_performance(train_ds, AUTOTUNE, batch_size)

    val_ds = val_ds.map(_parse_function, num_parallel_calls=AUTOTUNE)
    val_ds = configure_for_performance(val_ds, AUTOTUNE, batch_size)
    return train_ds, val_ds

在这里我创建了我的模型:

from os.path import isdir, dirname, abspath, join
from os import makedirs

from tensorflow.keras import Sequential
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import SGD, Adam


def create_model(optimizer, freeze_layer=False):
  densenet = DenseNet121(weights='imagenet', 
                        include_top=False)

  if freeze_layer:
    for layer in densenet_model.layers:
      if 'conv5' in layer.name:
        layer.trainable = True
      else:
        layer.trainable = False

  model = Sequential()
  model.add(densenet)
  model.add(GlobalAveragePooling2D())
  model.add(Dense(12, activation='softmax'))

  model.compile(loss='categorical_crossentropy',
                optimizer=optimizer,
                metrics=['accuracy'])

  return model

if __name__ == '__main__':
    optimizer = Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.99, epsilon=1e-6)
    densenet_model = create_model(optimizer)

    tfrecord_file = 'data.tfrecord'
    n_data = len(list(Path('train').rglob('*.jpg')))
    train, val = generator(tfrecord_file, 2, n_data, validation_ratio, True)

    validation_ratio = 0.2
    val_size = int(n_data * validation_ratio)
    train_size = n_data - val_size
    batch_size = 32
    n_epochs = 300
    n_workers = 5

    filename = '/content/drive/MyDrive/data.tfrecord'


    train_ds, val_ds = generator(filename,
                            batch_size=batch_size,
                            n_data=n_data,
                            validation_ratio=validation_ratio,
                            reshuffle_each_iteration=True)


   hist = densenet_model.fit(train_ds,
                      validation_data=val_ds,
                      epochs=n_epochs,
                      workers=n_workers,
                      steps_per_epoch=train_size//batch_size,
                      validation_steps=val_size)

这是我每次得到的错误:

InvalidArgumentError:键:标签。无法解析序列化示例。 [[{{node ParseSingleExample/ParseExample/ParseExampleV2}}]] [[IteratorGetNext]] [Op:__inference_train_function_343514]

显然我的 tfrecord 数据中的 label 有问题。

我真的需要知道,根据我的模型输出形状 (12,),我如何才能安全地将一个热编码标签存储在我的 tfrecord 中并在 tf.data.Dataset 中进行解析?

谢谢大家

最佳答案

如答案所示here数据数组应该是固定大小的,所以我认为它可以解决你的问题。

关于python - InvalidArgumentError : Key: label. 无法解析序列化示例 : How can I find a way to parse the one-hot encoded labels from TFRecords?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65620220/

相关文章:

python - 如何在 Python 中将段落添加到列表中

python - 如何在 Django 中管理多对一关系

python - 报告 pyTest 中的断言数

tensorflow - Keras 中的动态激活函数

keras - keras 的奇怪训练问题 - FC 层中的零损失突然大幅下降

machine-learning - 我可以删除预训练 Keras 模型中的层吗?

machine-learning - 二元掩模分类的最佳输出激活函数

Python:比较两个 Pandas DataFrame 并获取差异索引

python - 使用具有 LSTM 和动态 RNN 的可训练词嵌入层 : AdamOptimizer expected float_ref instead of float

python - 在 Tensorflow 2 中的每个纪元之后计算每个类的召回率