python - 如何使用 Tensorflow 数据集进行 CNN 模型训练

标签 python tensorflow keras

我想使用 tf.data.Dataset 类提供数据


from tensorflow_core.python.keras.datasets import cifar10
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))

  • 我这样做是为了在管道中使用数据集

  • 进一步利用Dataset的其他功能。

我像这样定义我的模型

    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
    model.add(layers.MaxPool2D((2, 2)))
    # more layers

但是当我打电话训练模型时

model.fit(train_dataset, epochs=10, validation_data=test_dataset, callbacks=[ cp_callback])

我收到错误

ValueError: Error when checking input: expected conv2d_input to have 4 dimensions, but got array with shape (32, 32, 3)

  • 到底发生了什么?如何在 Conv2D 层中使用 DataSet 和 input_shape=(32, 32, 3) ?

Tensorflow 教程 ( https://www.tensorflow.org/tutorials/load_data/numpy ) 没有涵盖这种情况,我找不到可以帮助我解决问题的解释。

最佳答案

应将批量生成器添加到具有任意批量大小的数据集中。基于Tensorflow的文档here批处理功能:

Combines consecutive elements of this dataset into batches. The components of the resulting element will have an additional outer dimension, which will be batch_size (or N % batch_size for the last element if batch_size does not divide the number of input elements N evenly and drop_remainder is False). If your program depends on the batches having the same outer dimension, you should set the drop_remainder argument to True to prevent the smaller batch from being produced.

假设您的批量大小为 16。然后:

my_batch_size =16
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
# Shapes of data are (32,32,3) here

train_dataset.batch(my_batch_size)
test_dataset.batch(my_batch_size)
# Shapes of data are (None,32,32,3) or (16,32,32,3) here

然后你就可以训练你的模型了。

关于python - 如何使用 Tensorflow 数据集进行 CNN 模型训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59109662/

相关文章:

python - python中字符串到元组的转换

python - Keras 的 predict_generator 没有返回正确数量的样本

python - 非常基本的 Keras CNN,有 2 个类给出了莫名其妙的答案

python - 在 keras 中使用额外输入加载自定义损失

python - 回调函数在实例中看不到正确的值

Python 多处理 apply_async "assert left > 0"AssertionError

python - 相同的 CSS,浏览器和 bs4 .select() 方法中的不同结果

machine-learning - 实现超分辨率 CNN 时遇到问题

python - 在 TensorFlow 中保存或导出权重和偏差以进行非 Python 复制

python - Keras weighted_metrics 在计算中不包括样本权重