python - tensorflow 2.0,模型.fit(): Your input ran out of data

标签 python tensorflow keras neural-network

我对 TensorFlow 和 Keras 完全陌生,我正在尝试尝试一些我在网上找到的代码。

特别是我正在使用 Fashion-MNIST - 由 60000 个示例和 10000 个示例的测试集组成。每个都是 28x28 灰度图像。

我正在按照本教程“https://towardsdatascience.com/building-your-first-neural-network-in-tensorflow-2-tensorflow-for-hackers-part-i-e1e2f1dfe7a0”进行操作,直到定义为止都没有问题

history = model.fit(
train_dataset.repeat(), 
epochs=10, 
steps_per_epoch=500,
validation_data=val_dataset.repeat(), 
validation_steps=2)

只要我理解,我就需要使用train_dataset.repeat()作为输入数据集,否则我将没有足够的训练示例使用这些超参数值(epochs、steps_per_epochs)。

我的问题是:如何避免使用.repeat()? 我需要如何更改超参数?

为了简单起见,我在这里处理代码:

def preprocess(x,y):

    x = tf.cast(x,tf.float32) / 255.0
    y = tf.cast(y, tf.float32)

    return x,y 

def create_dataset(xs, ys, n_classes=10):

    ys = tf.one_hot(ys, depth=n_classes)

    return tf.data.Dataset.from_tensor_slices((xs, ys)).map(preprocess).shuffle(len(ys)).batch(128)


model.compile(optimizer = 'adam', loss =tf.losses.CategoricalCrossentropy(from_logits= True), metrics =['accuracy'])

history1 = model.fit(train_dataset.repeat(), 
                    epochs=10, 
                    steps_per_epoch=500,
                    validation_data=val_dataset.repeat(), 
                    validation_steps=2)

谢谢!

最佳答案

如果您不想使用 .repeat(),则需要让模型在每个时期仅传递一次整个数据。

为了做到这一点,您需要计算模型遍历整个数据集需要多少步,计算很简单:

steps_per_epoch = len(train_dataset) // batch_size

因此,如果 train_dataset 包含 60 000 个样本,batch_size 为 128,则每个周期需要 468 个步骤。

通过这样设置此参数,您可以确保不超过数据集的大小。

关于python - tensorflow 2.0,模型.fit(): Your input ran out of data,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60228989/

相关文章:

python - 从python中的函数返回不同的数据类型

tensorflow - 使用保存的模型在 tensorflow 中进行预测

python - 应该定义密集层输入的最后一个维度。没有发现。收到完整的输入形状 : <unknown>

python - Python中的 double 浮点值?

python - GSpread 导入错误 : No module named oauth2client. service_account

python - pyparsing 优先级分割

python - Tensorflow:使用不同的 "call"函数创建 LSTM 单元的自定义子类

python - tensorflow 中的线性模型

python-3.x - keras LSTM模型输入和输出维度不匹配

python - CNN with keras,精度保持不变,没有提高