python - 使用 steps_per_epoch 参数时,Keras 的训练速度极慢

标签 python tensorflow keras

当我在 model.fit(..) 方法中指定 steps_per_epoch 参数时,我注意到训练模型速度大幅下降。当我将 steps_per_epoch 指定为 None(或不使用它)时,epoch 的 ETA 是连续 2 秒:

9120/60000 [===>..........................] - ETA: 2s - loss: 0.7055 - acc: 0.7535

当我添加 steps_per_epoch 参数时,预计到达时间会增加到 5 小时,训练速度变得非常慢:

5/60000 [..............................] - ETA: 5:50:00 - loss: 1.9749 - acc: 0.3437

这是可重现的脚本:

import tensorflow as tf
from tensorflow import keras
import time

print(tf.__version__)


def get_model():
    model = keras.Sequential([
        keras.layers.Flatten(input_shape=(28, 28)),
        keras.layers.Dense(128, activation='relu'),
        keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model


(train_images, train_labels), (test_images, test_labels) = keras.datasets.fashion_mnist.load_data()
train_images = train_images / 255.0

model = get_model()

# Very quick - 2 seconds
start = time.time()
model.fit(train_images, train_labels, epochs=1)
end = time.time()
print("{} seconds", end - start)

model = get_model()

# Very slow - 5 hours
start = time.time()
model.fit(train_images, train_labels, epochs=1, steps_per_epoch=len(train_images))
end = time.time()
print("{} seconds", end - start)

我也尝试过使用纯 Keras,但问题仍然存在。我使用 1.12.0 版本的 Tensorflow、python 3 和 Ubuntu 18.04.1 LTS。

为什么 steps_per_epoch 参数会导致如此显着的速度下降,我该如何避免这种情况?

谢谢!

最佳答案

请注意,您正在对一组数据使用 fit。您没有使用 fit_generator 或使用任何生成器。

除非您有非常规的想法,否则使用 steps_per_epoch 是没有意义的。

fit 中的默认批量大小为 32,这意味着您正在使用 60000//32 = 1875 每个时期的步数进行训练。

如果您使用这个数字 1875,您将训练与默认 None 相同数量的批处理。如果您使用 60000 步,您将一个纪元乘以 32。(由于速度的巨大差异,我想说在这种情况下默认批处理大小也发生了变化)


没有步骤的拟合输出中显示的总数是图像总数。请注意已完成项目的数量如何以 32 的倍数增长。

使用steps时显示的总数就是步数。请注意已完成的步数如何以 1 1 的速度增长。

关于python - 使用 steps_per_epoch 参数时,Keras 的训练速度极慢,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54634884/

相关文章:

Python将纯文本显示为html

tensorflow - 如何强制自动编码器中的瓶颈产生二进制值?

tensorflow - 合并稀疏张量中的重复索引

python - 尽管使用了allow_growth=True,为什么keras model.fit 使用了这么多内存?

python - 如何在 Tensorflow 中从 tf.keras 导入 keras?

python - 为什么即使我设置了随机种子也无法在 Keras 中获得可重现的结果?

python - 如何将 numpy 数组附加到不同大小的 numpy 数组?

Python 相当于 Perl 的 'w' 打包格式

python - Python-版本列表而不是不可变列表?

keras - 如何使用 Keras 后端收集张量?