python - 使用 fit_generator 不匹配的形状时出错(Keras)

标签 python machine-learning keras neural-network

我正在尝试构建一个简单的分类 CNN,它将使用以下代码将一组 1233 张图像分为 4 个类别:

unclassified_datagen = keras.preprocessing.image.ImageDataGenerator(
    rescale=1. / 255,
    horizontal_flip=True
)
unclassified_generator = train_datagen.flow_from_directory(
    'data/unclassified',
    target_size=(120, 120),
    batch_size=1233,
    class_mode='input',
    shuffle=False,
)

model_unclassified = keras.Sequential()
model_unclassified.add(layers.Conv2D(1233, (3, 3), input_shape=(120, 120, 3), padding="SAME"))
model_unclassified.add(layers.Dense(64, activation='relu'))
model_unclassified.add(layers.Dense(4, activation='sigmoid'))

model_unclassified.compile(loss='sparse_categorical_crossentropy',
                           optimizer='rmsprop',
                           metrics=['accuracy'])
model_unclassified.fit_generator(unclassified_generator, epochs=1)


但我收到以下错误:ValueError: Error when checking target: expected dense_2 to have shape (120, 120, 1) but got array with shape (120, 120, 3)
我究竟做错了什么?

最佳答案

您应该添加 Flatten层,因为 Conv2D为每个样本返回 3D 数组:

model_unclassified = keras.Sequential()
model_unclassified.add(layers.Conv2D(1233, (3, 3), input_shape=(120, 120, 3), padding="SAME"))
model_unclassified.add(layers.Flatten())
model_unclassified.add(layers.Dense(64, activation='relu'))
model_unclassified.add(layers.Dense(4, activation='sigmoid'))

关于python - 使用 fit_generator 不匹配的形状时出错(Keras),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62253966/

相关文章:

python - 保持 python telegram bot 运行

python - Seaborn 类型错误 : No loop matching the specified signature and casting was found for ufunc add when using hue

python - 将 Gridsearch 中的最佳参数保存在 pandas 数据框中

machine-learning - 使用循环神经网络解决时间序列任务

machine-learning - 如何计算keras中的错误百分比

machine-learning - 使用 Keras 和 sklearn GridSearchCV 交叉验证提前停止

Python 列表顺序

java - 如何从 Java 的标准输入读取 python 二进制字符串

python - pyspark.sql.utils.IllegalArgumentException : 'requirement failed: Invalid initial capacity'

python - 将多个输入传递到 Keras 模型时出错