python - 如何使用 flow_from_directory 仅选择特定文件格式?

标签 python tensorflow machine-learning keras deep-learning

我正在使用 Keras 进行一些深度学习实验。训练我的模型后,我想使用以下代码测试我的模型:

test_datagen = ImageDataGenerator(rescale=1 / 255.)

test_generator = test_datagen.flow_from_directory(directory='test/',
                                                  color_mode='grayscale',
                                                  # don't shuffle
                                                  shuffle=False,
                                                  # use same size as in training
                                                  target_size=(256, 256),
                                                  batch_size=1,
                                                  class_mode=None
                                                  )

preds = model.predict_generator(test_generator, steps=12)

问题是测试文件夹还包含另一个子目录中的子目录。 (例如 test/test2/test3/test4 ...),我也想访问 test4 文件夹内的图像,但我得到了 IsADirectoryError: [Errno 21] Is a directory: 'test/test2/test3' 错误。

我的第一个问题是:是否有可能搜索和使用而不是将所有图像复制并粘贴到一个文件夹中?

第二:我只想使用 .png 格式的图像。我可以做这样的事情吗? from_directory(directory='test/*.png') 仅适用于 .png 文件?

提前谢谢您。 更新时间:2020年2月24日

最佳答案

对于多个文件夹问题,这里有 solution :

idg1 = ImageDataGenerator(**idg1_configs)
idg2 = ImageDataGenerator(**idg2_configs)

g1 = idg1.flow_from_directory('idg1_dir',...)
g2 = idg2.flow_from_directory('idg2_dir',...)

def combine_gen(*gens):
    while True:
        for g in gens:
            yield next(g)

# ...
model.fit_generator(combine_gen(g1, g2), steps_per_epoch=len(g1)+len(g2), ...)

关于python - 如何使用 flow_from_directory 仅选择特定文件格式?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60337512/

相关文章:

python - 计算 SGDClassifier 数据损失函数值的方法?

Python - 如何获取有关 SyntaxError 的更多信息?

python - 为什么要在 Python 中隐式检查是否为空?

python - 正则表达式单词匹配

python - 属性错误: 'Node' object has no attribute 'output_masks' in Keras

Python 无法创建 NewWriteableFile(tensorflow.python.framework.errors_impl.NotFoundError : Failed to create a NewWriteableFile: )

python - TensorFlow 索引无效(越界)

python-3.x - 发现输入变量样本数量不一致: [2, 8382]

machine-learning - 如何训练 OpenNLP 模型提取多组词

python - 使用 numpy 和 scipy 在 python 中进行最小二乘估计