python - 为什么我会出现 Keras 形状不匹配的情况?

标签 python tensorflow keras

我正在遵循面向初学者的 Keras mnist 示例。我尝试更改标签以适合我自己的数据,该数据有 3 个不同的文本分类。我正在使用“to_categorical”来实现这一点。形状对我来说看起来不错,但“fit”出现错误:

train_labels = keras.utils.to_categorical(train_labels, num_classes=3)

print(train_images.shape)
print(train_labels.shape)

model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation=tf.nn.relu),
    keras.layers.Dense(3, activation=tf.nn.softmax)
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_images, train_labels, epochs=5)

(7074, 28, 28)

(7074, 3)

Blockquote Blockquote Traceback (most recent call last): File "C:/Users/lawrence/PycharmProjects/tester2019/KeraTest.py", line 131, in model.fit(train_images, train_labels, epochs=5) File "C:\Users\lawrence\PycharmProjects\tester2019\venv\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1536, in fit validation_split=validation_split) File "C:\Users\lawrence\PycharmProjects\tester2019\venv\lib\site-packages\tensorflow\python\keras\engine\training.py", line 992, in _standardize_user_data class_weight, batch_size) File "C:\Users\lawrence\PycharmProjects\tester2019\venv\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1154, in _standardize_weights exception_prefix='target') File "C:\Users\lawrence\PycharmProjects\tester2019\venv\lib\site-packages\tensorflow\python\keras\engine\training_utils.py", line 332, in standardize_input_data ' but got array with shape ' + str(data_shape)) ValueError: Error when checking target: expected dense_1 to have shape (1,) but got array with shape (3,)

最佳答案

您需要使用 categorical_crossentropy 而不是 sparse_categorical_crossentropy 作为损失,因为您的标签是热编码的。

或者,如果您不对标签进行热编码,则可以使用sparse_categorical_crossentropy。在这种情况下,标签的形状应为 (batch_size, 1)

关于python - 为什么我会出现 Keras 形状不匹配的情况?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54513759/

相关文章:

python - 如何在 Django REST 中序列化层次关系

python - 标签大小与 target_names 不同 : Tensorflow Multi-Input Regression converting to Classification

python - tensorflow - tf.data.Dataset 在批处理之前随机跳过样本以获得不同的批处理

Tensorflow Hub 与 Keras 应用程序 - 性能下降

tensorflow - Keras 前 5 名预测

python - 没有得到弹性的整数类型

python - Pandas groupby 具有模式功能

python - 深度学习: How to deal with missing label values

python - 如何使用列表推导式在 python 中打印长方体所有可能尺寸的列表?

tensorflow - 使用 TF-Slim 的全卷积 ResNet 运行速度非常慢