python - 如何设置 tf.keras.layers.RandomFlip 的可能性?

标签 python tensorflow keras data-augmentation

使用 tf.keras.layers.RandomFlip 进行随机翻转操作时是否可以设置可能性?

例如:

def augmentation():
        data_augmentation = keras.Sequential([
            keras.layers.RandomFlip("horizontal", p=0.5),
            keras.layers.RandomRotation(0.2, p=0.5)
        ])
   return data_augmentation 

最佳答案

尝试创建一个简单的 Lambda 层并在单独的函数中定义概率:

import random

def random_flip_on_probability(image, probability= 0.5):
    if random.random() < probability:
      return tf.image.random_flip_left_right(image)
    return image

def augmentation():
        data_augmentation = keras.Sequential([
            keras.layers.Lambda(random_flip_on_probability),
            keras.layers.RandomRotation(0.2, p=0.5)
        ])
   return data_augmentation 

如果您需要在训练或推理期间使用数据增强,则必须定义自己的自定义层。尝试这样的事情:

import tensorflow as tf
import pathlib

class RandomFlipOnProbability(tf.keras.layers.Layer):
  def __init__(self, probability):
    super(RandomFlipOnProbability, self).__init__()
    self.probability = probability

  def call(self, images):
    return tf.cond(tf.random.uniform(()) < self.probability, lambda: tf.image.flip_left_right(images), lambda: images)

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

batch_size = 32

train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(180, 180),
  batch_size=batch_size)


random_layer = RandomFlipOnProbability(probability = 0.9)
normalization_layer = tf.keras.layers.Rescaling(1./255)

images, _ = next(iter(train_ds.take(1)))
images = normalization_layer(random_layer(images))
image = images[0]

plt.imshow(image.numpy())

enter image description here

关于python - 如何设置 tf.keras.layers.RandomFlip 的可能性?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69689159/

相关文章:

python - Celery:组中的一个子任务总是超时

python - 从 vim 中测试单个函数

machine-learning - 简单来说什么是损失函数?

python - 二维输入的 Keras 模型

python - 模型精度停留在 0.2505。我的代码有什么问题?

python - 在 Python 中解析复杂且不断变化的 JSON 数据,深度多个级别

python - 查找数组是否为排列的最佳方法

tensorflow - 如何将 tfjs 的 body-pix 模型转换为 keras h5 或 tensorflow 卡住图

python - 无法使用 Keras fit_generator 重现结果

python - 从 Keras 中的输出层创建一个 "unpooling"掩码