python - Keras提示错误: (Error when checking input: expected conv2d_4_input to have 4 dimensions)

标签 python keras deep-learning conv-neural-network

这是我在 MNIST 数据集上使用卷积神经网络的代码。不幸的是,Keras 在通过网络时提示错误。感谢您的帮助。我想知道出现此类错误的原因。

这是错误:检查输入时出错:预期 conv2d_4_input 有 4 个维度,但得到形状为 (45000, 28, 28) 的数组

model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28,28, 1), padding= 'same'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu', padding= 'same'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3), activation='relu', padding= 'same'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten())
model.add(layers.Dropout(0.4))
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
k = 4
num_val_samples = len(train_images) // k
num_epochs = 20
all_scores = []
for i in range(k):
    print('processing fold #', i)
    valid_data = train_images[i * num_val_samples: (i + 1) *
                          num_val_samples] 
    valid_labels = train_labels[i * num_val_samples: (i + 1) *
                                num_val_samples]
partial_train_images = np.concatenate(
    [train_images[:i * num_val_samples], train_images[(i + 1) * num_val_samples:]], axis=0)
partial_train_labels = np.concatenate([train_labels[:i * num_val_samples], train_labels[(i + 1) * num_val_samples:]],axis=0)

model.fit(partial_train_images, partial_train_labels,epochs=20, 
batch_size=1, verbose=0)
val_mse, val_mae = model.evaluate(val_data, val_targets, verbose=0)
all_scores.append(val_mae)

我看过其他页面,但那里的解决方案都没有帮助。

最佳答案

您没有在数组中包含 channel 维度,对于灰度图像,它应该是具有一个元素的维度,因此每个样本都是(28, 28, 1):

partial_train_images = partial_train_images.reshape((-1, 28, 28, 1))
val_data = val_data.reshape((-1, 28, 28, 1))

关于python - Keras提示错误: (Error when checking input: expected conv2d_4_input to have 4 dimensions),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56588281/

相关文章:

python - a[0],a[1] = a[1],a[0]的过程中发生了什么?

用于处理复杂对象的方法的 Python 单元测试

python - 如何为keras提供多个数据特征的输入?

python - 将 Minibatch 标准偏差应用于 Keras GAN 层的正确方法

深度学习环境下Python包安装错误

python - 比较 Python 中的两个列表

python - 如何在 Windows cmd 上从 pip 安装 Pandas ?

python - 导入错误:无法导入名称 normalize_data_format

python - 由于保存模型导致训练崩溃 : "tensorflow.GraphDef was modified concurrently during serialization"

python - Keras 函数式 API 中的 ResNet50 网络(python)