python - 如何在 TF2 中将 ImageDataGenerator 与 TensorFlow 数据集结合起来?

标签 python tensorflow keras tensorflow2.0

我有一个 TF 数据集 classify猫和狗:

import tensorflow_datasets as tfds
SPLIT_WEIGHTS = (8, 1, 1)
splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)

(raw_train, raw_validation, raw_test), metadata = tfds.load(
    'cats_vs_dogs', split=list(splits),
    with_info=True, as_supervised=True)

在示例中,他们使用 map 函数进行一些图像增强。我想知道是否也可以使用漂亮的 ImageDataGenerator 类来完成,例如描述的 here :

from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_image_generator = ImageDataGenerator(rescale=1./255) # Generator for our training data
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
                                                           directory=train_dir,
                                                           shuffle=True,
                                                           target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                           class_mode='binary')

我面临的问题是我只能看到 3 ways使用 ImageDataGenerator:pandas 数据帧、numpy 数组和图像目录。 有没有办法也使用 Tensorflow 数据集并结合这些方法?

最佳答案

是的,但是有点棘手。
Keras ImageDataGenerator 适用于 numpy.array 而不是 tf.Tensor,因此我们必须使用 Tensorflow 的 numpy_function 。这将使我们能够对 tf.data.Dataset 内容执行操作,就像 numpy 数组一样。

首先,让我们声明我们将在数据集上 .map 的函数(假设您的数据集由图像、标签对组成):

# We will take 1 original image and create 5 augmented images:
HOW_MANY_TO_AUGMENT = 5

def augment(image, label):

  # Create generator and fit it to an image
  img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
  img_gen.fit(image)

  # We want to keep original image and label
  img_results = [(image/255.).astype(np.float32)] 
  label_results = [label]

  # Perform augmentation and keep the labels
  augmented_images = [next(img_gen.flow(image)) for _ in range(HOW_MANY_TO_AUGMENT)]
  labels = [label for _ in range(HOW_MANY_TO_AUGMENT)]

  # Append augmented data and labels to original data
  img_results.extend(augmented_images)
  label_results.extend(labels)

  return img_results, label_results

现在,为了在 tf.data.Dataset 中使用此函数,我们必须声明一个 numpy_function:

def py_augment(image, label):
  func = tf.numpy_function(augment, [image, label], [tf.float32, tf.int32])
  return func

py_augment 可以安全地使用,例如:

augmented_dataset_ds = image_label_dataset.map(py_augment)

数据集中的image部分现已成型 (HOW_MANY_TO_AUGMENT、image_height、image_width、 channel )。 要将其转换为简单的(1,image_height,image_width,channels),您可以简单地使用unbatch:

unbatched_augmented_dataset_ds = Augmented_dataset_ds.unbatch()

所以整个部分看起来像这样:

HOW_MANY_TO_AUGMENT = 5

def augment(image, label):

  # Create generator and fit it to an image
  img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
  img_gen.fit(image)

  # We want to keep original image and label
  img_results = [(image/255.).astype(np.float32)] 
  label_results = [label]

  # Perform augmentation and keep the labels
  augmented_images = [next(img_gen.flow(image)) for _ in range(HOW_MANY_TO_AUGMENT)]
  labels = [label for _ in range(HOW_MANY_TO_AUGMENT)]

  # Append augmented data and labels to original data
  img_results.extend(augmented_images)
  label_results.extend(labels)

  return img_results, label_results

def py_augment(image, label):
  func = tf.numpy_function(augment, [image, label], [tf.float32, tf.int32])
  return func

unbatched_augmented_dataset_ds = augmented_dataset_ds.map(py_augment).unbatch()

# Iterate over the dataset for preview:
for image, label in unbatched_augmented_dataset_ds:
    ...

关于python - 如何在 TF2 中将 ImageDataGenerator 与 TensorFlow 数据集结合起来?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59648804/

相关文章:

Python:将正则表达式应用于以日期时间作为列的数据框

python - pytorch .stack .squeeze后的最终形状

tensorflow - 如何从张量板文件中读取评估损失?

python - Tensorflow:属性错误:模块 'tensorflow.python.ops.nn' 没有属性 'softmax_cross_entropy_with_logits_v2'

python - 如何使用 boto3 在 Python 中获取 S3 目录作为 os.path?

tensorflow - "object has no attribute ' _name_scope '" tensorflow /keras 中的错误

用于从行中删除双重条目的 Python 程序

python - 查看单词是否为等值图的函数

python - Tensorflow 在训练时两次打印相同的信息

python - LSTM RNN 同时预测多个时间步长和多个特征