python - Keras top_k_categorical_accuracy 指标与准确率相比极低

标签 python tensorflow keras

我使用 Keras 的 cifar100 数据集创建了一个 CNN 模型。添加 top_k_categorical_accuracy 指标时,我应该看到前 5 个预测类别之一是正确类别时的准确性。然而,在训练时,top_k_categorical_accuracy 仍然非常小,约为 4-5%,而准确度和验证准确度则一路增加到 40-50%。 Top 5 的准确度应该比正常准确度高得多,但它却给出了非常奇怪的结果。我使用不同的 k 值编写了自己的指标,但仍然存在相同的问题。即使当我使用 k=1(应该给出相同的准确度值)时,也会出现相同的问题。

型号代码:

cnn = Sequential()
cnn.add(Conv2D(filters=200, kernel_size=2, padding='same', activation='relu', input_shape=(train_images.shape[1:])))
cnn.add(Conv2D(filters=200, kernel_size=2, padding='same', activation='relu'))
cnn.add(Conv2D(filters=200, kernel_size=2, padding='same', activation='relu'))
cnn.add(MaxPooling2D(pool_size=2, padding='same'))
cnn.add(Dropout(0.4))

cnn.add(Conv2D(filters=200, kernel_size=2, padding='same', activation='relu'))
cnn.add(Conv2D(filters=200, kernel_size=2, padding='same', activation='relu'))
cnn.add(Conv2D(filters=200, kernel_size=2, padding='same', activation='relu'))
cnn.add(Conv2D(filters=200, kernel_size=2, padding='same', activation='relu'))
cnn.add(Dropout(0.4))
cnn.add(MaxPooling2D(pool_size=2, padding='same'))
cnn.add(Dropout(0.5))

cnn.add(Flatten())
cnn.add(Dense(550, activation='relu'))
cnn.add(Dropout(0.4))
cnn.add(Dense(100, activation='softmax'))

编译代码:

cnn.compile(loss='sparse_categorical_crossentropy', optimizer=opt.Adam(lr=learn_rate), metrics=['accuracy', 'top_k_categorical_accuracy'])

最佳答案

事实证明,由于我使用的是sparse_categorical_crossentropy损失函数,所以我需要使用sparse_top_k_categorical_accuracy函数。该指标还要求您的标签被展平为一维。之后,指标正确并且模型正在训练!

关于python - Keras top_k_categorical_accuracy 指标与准确率相比极低,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52507306/

相关文章:

python - 在条件 Pandas 的新列中分配值

python - Tensorflow:如何对矩阵的每个元素执行操作

python - 了解 tf.layers.conv2d 的输入/输出张量

python - 使用 Keras 和 Hyperas 进行参数调整

python - 配置 CoreOS 时出现 Ansible pip 错误

python - 以 Pythonic 方式将 1 与 0 和 0 与 1 交换

python - 如何在 TensorFlow v2.0 中正确应用梯度

python - 无法将 TensorFlow Keras LSTM 模型保存为 SavedModel 格式

python - 如何使用 Keras 功能 API 模型的输出作为另一个模型的输入

python - 压缩嵌套列表