python-3.x - 如何使用 sklearn.datasets.load_files 加载数据百分比

标签 python-3.x numpy scikit-learn deep-learning keras

我正在加载 8000 张图像 sklearn.datasets.load_files并通过来自 keras 的 resnet获得瓶颈特征。然而,这个任务在 GPU 上需要几个小时,所以我想知道是否有办法告诉 load_files 加载一定百分比的数据,比如 20%。

我这样做是为了训练我自己的顶层(最后一个密集层)并将其附加到 resnet。

def load_dataset(path):
    data = load_files(path)
    files = np.array(data['filenames'])
    targets = np_utils.to_categorical(np.array(data['target']), 100)
    return files, targets

train_files, train_targets = load_dataset('images/train')

最佳答案

这听起来更适合 Keras ImageDataGenerator 类并使用 ImageDataGenerator.flow_from_directory 方法。您不必对其使用数据扩充(这会进一步减慢速度),但您可以选择从目录中提取的批处理大小,而不是全部加载它们。

复制自https://keras.io/preprocessing/image/并用注释略作修改。

train_datagen = ImageDataGenerator(  # <- customize your transformations
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        'data/train',
        target_size=(150, 150),
        batch_size=32,  # <- control how many images are loaded each batch
        class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
        'data/validation',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')

model.fit_generator(
        train_generator,
        steps_per_epoch=2000,  # <- reduce here to lower the overall images used
        epochs=50,
        validation_data=validation_generator,
        validation_steps=800)

编辑

根据您在下面的问题... steps_per_epoch 确定每个时期加载多少批处理。

例如:

  • steps_per_epoch = 50
  • 批量大小 = 32
  • 时代 = 1

将为您提供该时期的 1,600 张图像。这恰好是您 8,000 张图片的 20%。 请注意,如果您遇到批处理大小为 32 的内存问题,您可能需要减少它并增加您的 steps_per_epoch。需要一些修补才能使其正确。

关于python-3.x - 如何使用 sklearn.datasets.load_files 加载数据百分比,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49783945/

相关文章:

python-3.x - 将 Pandas 数据框转换为包含 ID 和权重的元组列表

python-3.x - 使用 python 3.x 如何将 Tree 对象从 ete3 传递到 DendroPy 而不写入文件

python - numpy.delete 不从数组中删除列

python - 汇总分组 Pandas 数据框中的行并返回 NaN

python - LeaveOneOut 确定 knn 中的 k

Python奇数运算?

python - 使用类对象从外部文件创建 sqlite 数据库

python - 如何找到最接近网格值的点

python - 如何从 gridSearchCV 的输出中获取特征名称

python - 如何使用 pandas/sklearn 删除停止短语/停止 ngram(多词字符串)?