tensorflow - 使用 ImageDataGenerator 进行无限循环

标签 tensorflow for-loop machine-learning keras infinite-loop

我正在处理更大的数据集,因此必须将数据批量加载到我的 RAM 中,以便在不耗尽资源的情况下更快地运行。我正在将图像数据生成器与 .flow

一起使用

使用 for 循环会导致无限循环,在循环重新开始之前不断生成相同批量大小的图像。准备器代码如下所示:

train_dataset=tf.keras.preprocessing.image.ImageDataGenerator(featurewise_center=False, samplewise_center=False,
    featurewise_std_normalization=False, samplewise_std_normalization=False,
    zca_whitening=False, rotation_range=0, width_shift_range=0.0,
    height_shift_range=0.0, brightness_range=None, shear_range=0.0, zoom_range=0.0,
    channel_shift_range=0.0, cval=0.0, horizontal_flip=False,
    vertical_flip=False, preprocessing_function=None,
    data_format=None, validation_split=0.0, dtype=None)
train_dataset.fit(X)

随后尝试循环,如下所示:

for images, y_batch in train_dataset.flow(X, y, batch_size=batch_size):
          print(np.shape(images))

代码只是不断返回维度的数组:

(batch_size,img_size,img_size,3) (我需要这些图像将数据带入我的 RAM 来执行后向支撑)。请注意,我没有使用诸如 model.fit 之类的东西,并且需要通过我的正确代码运行这些数组。

不太确定如何添加停止条件

最佳答案

这就是重点;永远继续迭代。 Keras 的 model.fit_gerentaor() 或 tf.keras 的 model.fit() 处理根据 epochs 终止训练循环steps_per_epoch 参数。

如果您想使用 ImageDataGenerator()手动训练模型,您大致可以执行以下操作:

epochs = 10
steps_per_epoch = len(x) // batch_size + 1  # we usually consider 1 epoch to be
                                            # the point where the model has seen
                                            # all the training samples at least once

generator = train_dataset.flow(X, y, batch_size=batch_size)

for e in range(epochs):
    for i, (images, y_batch) in enumerate(generator):
       model.train_on_batch(images, y_batch)  # train model for a single iteration
       if i >= steps_per_epoch:  # manually detect the end of the epoch
           break  
    generator.on_epoch_end()  # this shuffles the data at the end of each epoch

关于tensorflow - 使用 ImageDataGenerator 进行无限循环,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62836066/

相关文章:

image-processing - Tensorflow 等效于 Keras 函数 : UpSampling2D

python - 将列表列表传递到 tensorflow

machine-learning - Keras 模型为 make_moons 数据创建线性分类

python - Tensorflow:__new__() 在对象检测 API 中得到了一个意外的关键字参数 'serialized_options'

python - 使用对象循环时获取 ID 字段

javascript - 在 onchange 和通过文本框将数字存储在数组中时遇到问题

c++ append 到字符串

machine-learning - softmax 的输出不应该有零,对吗?

python - 如何保存和恢复我的sklearn模型?

machine-learning - CART算法使用的离散化方法是什么?