python - keras CNN 模型预测良好,但只有一个标签无法预测

标签 python machine-learning keras conv-neural-network

在我训练模型以使用每个类别的 2800 个图像来预测 24 个类别的标签并获取 5000 个图像进行验证后,我运行了一些测试来查看标签预测的质量,我设计了一个程序来获取所有图像在文件夹测试和预测标签中,除了第 19 类之外,所有类都很好,在 1000 个用于测试的图像中,没有预测为 19

有人有解决办法吗?

这是模型架构:

model = Sequential()

model.add(Conv2D(filters=32, kernel_size=2,padding='same',activation='relu',input_shape=(32,32,1)))

model.add(MaxPooling2D(pool_size=2))

model.add(Conv2D(filters=64, kernel_size=2, padding='same', activation='relu'))

model.add(MaxPooling2D(pool_size=2))

model.add(Flatten())

model.add(Dense(1024, activation='relu'))

model.add(Dropout(0.2))

model.add(Dense(24, activation='softmax'))

model.summary()

这是优化器和训练器:

optimizer = rmsprop(learning_rate=0.0001)

model.compile(loss='categorical_crossentropy',optimizer= optimizer,metrics=['accuracy'])

checkpointer = ModelCheckpoint(filepath='CNN_newData.hdf5',verbose=1,
                               save_best_only=True)
hist = model.fit(x_train,y_train,batch_size=128,epochs=100,
                 validation_data=(x_valid,y_valid),callbacks=[checkpointer],
                 verbose=2,shuffle=True)

这是为预测准备图像的方式:

  for img in images:

            read_img = cv2.imread('test-images/' + file + '/' + img)
            read_img = cv2.cvtColor(read_img,cv2.COLOR_RGB2GRAY)
            read_img = read_img.reshape( -1,32, 32, 1)
            read_img = read_img.astype('float32')/255
            maxind = model.predict_classes(read_img)

最佳答案

  • 我认为您的数据集是均衡的?
  • 您能否上传您的损失/准确率曲线?
  • 您是否尝试过其他优化器?您的 RMSprop learning_rate 低于默认值,并且网络相当浅。
  • 您可以分享这些数据吗?您至少确定那里没有矛盾的知识吗?

read_img = cv2.imread('test-images/' + file + '/' + img)

请不要自行连接路径。一旦您将此模型推送到某个基于 Linux 的云,您就会遇到麻烦。检查pathlib .

  • 尝试使用默认参数运行 adam
  • model.add(Dense(1024,activation='relu')) - 它相当大,后面只有 24 个标签。尝试更小的值,例如 240。
  • kernel_size 看起来很奇怪,转换内核不应该很奇怪吗?尝试kernel_size=3
  • 尝试添加一些正则化

关于python - keras CNN 模型预测良好,但只有一个标签无法预测,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58910748/

相关文章:

python - 转换列表中的二进制文件(python)

python - 访问和修改 Maya 撤消队列

apache-spark - 在 pyspark 中使用 tfidf 向量的所有对相似度

machine-learning - weka 机器学习教程

caching - 不同 Vowpal Wabbit 版本缓存的数据集

python - 如何在 Keras 中从 HDF5 文件加载模型?

python - "TypeError: Using a ` tf.Tensor ` as a Python ` bool ` is not allowed."在数据集上调用map函数时

python - jEdit 上下划线字符消失

python - 重新排序 Pandas 数据框中的公共(public)列而不是其他列

neural-network - keras 损失在新纪元开始时随机跳到零