tensorflow - 用于多个输入和基于图像的目标输出的 Keras ImageDataGenerator

标签 tensorflow keras tensorflow-datasets tf.keras

我有一个模型,它将两个图像作为输入并生成一个图像作为目标输出。

我所有的训练图像数据都在以下子文件夹中:

  • 输入1
  • 输入2
  • 目标

我可以使用 keras 中的 ImageDataGenerator 类和方法,如 flow_from_directorymodel.fit_generator 方法来训练网络吗?

我该怎么做?因为我遇到的大多数例子都处理单一输入和基于标签的目标输出。

在我的例子中,我有一个非分类目标输出数据和多个输入。

请帮忙,因为你的建议真的很有帮助。

最佳答案

一种可能性是将三个 ImageDataGenerator 合并为一个,使用 class_mode=None(因此它们不返回任何目标),并使用 shuffle=False (重要)。确保您对每个使用相同的 batch_size 并确保每个输入都在不同的目录中,并且目标也在不同的目录中,并且每个输入中的图像数量完全相同目录。

idg1 = ImageDataGenerator(...choose params...)
idg2 = ImageDataGenerator(...choose params...)
idg3 = ImageDataGenerator(...choose params...)

gen1 = idg1.flow_from_directory('input1_dir',
                                shuffle=False,
                                class_mode=None)
gen2 = idg2.flow_from_directory('input2_dir',
                                shuffle=False,
                                class_mode=None)
gen3 = idg3.flow_from_directory('target_dir',
                                shuffle=False,
                                class_mode=None)

创建自定义生成器:

class JoinedGen(tf.keras.utils.Sequence):
    def __init__(self, input_gen1, input_gen2, target_gen):
        self.gen1 = input_gen1
        self.gen2 = input_gen2
        self.gen3 = target_gen

        assert len(input_gen1) == len(input_gen2) == len(target_gen)

    def __len__(self):
        return len(self.gen1)

    def __getitem__(self, i):
        x1 = self.gen1[i]
        x2 = self.gen2[i]
        y = self.gen3[i]

        return [x1, x2], y

    def on_epoch_end(self):
        self.gen1.on_epoch_end()
        self.gen2.on_epoch_end()
        self.gen3.on_epoch_end()

用这个生成器训练:

my_gen = JoinedGen(gen1, gen2, gen3)
model.fit_generator(my_gen, ...)

或者创建自定义生成器。它的所有结构如上所示。

关于tensorflow - 用于多个输入和基于图像的目标输出的 Keras ImageDataGenerator,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59492866/

相关文章:

python - FailedPreconditionError(请参阅上面的回溯): GetNext() failed because the iterator has not been initialized

python - 在 Tensorflow 中设置交集

python - 值错误: Cannot create a tensor proto whose content is larger than 2GB

python - 编码目标列以在 Tensorflow 中进行分类

python - Keras 训练 warm_start

python - 恢复使用迭代器的 Tensorflow 模型

python - 导入 tensorflow 抛出导入错误: DLL load failed

python - Keras - 恢复特定时间戳的 LSTM 隐藏状态

python - Keras:binary_crossentropy 和 categorical_crossentropy 混淆

tensorflow - 如何将大 float 保存为 TFRecord 格式? float_list/float32 似乎截断了值