machine-learning - Keras - 无法获得正确的类别预测

标签 machine-learning tensorflow neural-network keras conv-neural-network

我使用 Keras 构建了一个 CNN,对 2 个不同类别的图像进行分类。我遇到的问题是训练后我似乎无法得到正确的预测。

一些背景... 数据集有 78750 个示例(大约 95% Cat.1 和 5% Cat.2),这可能是罪魁祸首,因为我假设 Cat.1 发生了过度拟合。 1(我认为这是问题所在,但由于许多其他原因,更改数据集大小很困难)

为了解决这个问题,我在每个卷积层上添加了正则化,但没有效果。

我的问题是......我是否绝对需要更改我的类别大小,或者我可以采取其他措施来对抗 Cat 的过度拟合。 1?

这是 CNN 的代码:

model = Sequential()
model.add(Conv2D(filters=25,
                 kernel_size=(10, 10),
                 strides=(1, 1),
                 activation='relu',
                 input_shape=input_shape,
                 padding="VALID",
                 kernel_initializer=random_normal(mean=0, stddev=.1),
                 kernel_regularizer=l2(.001)))
model.add(MaxPooling2D(pool_size=(2, 2),
                       strides=(2, 2)))

model.add(Conv2D(filters=25,
                 kernel_size=(7, 7),
                 strides=(1, 1),
                 activation='relu',
                 padding="VALID",
                 kernel_initializer=random_normal(mean=0, stddev=.1),
                 kernel_regularizer=l2(.001)))
model.add(MaxPooling2D(pool_size=(2, 2),
                       strides=(2, 2)))

model.add(Conv2D(filters=25,
                 kernel_size=(5, 5),
                 strides=(2, 2),
                 activation='relu',
                 padding="VALID",
                 kernel_initializer=random_normal(mean=0, stddev=.1),
                 kernel_regularizer=l2(.001)))
model.add(MaxPooling2D(pool_size=(2, 2),
                       strides=(1, 1)))

model.add(Conv2D(filters=25,
                 kernel_size=(5, 5),
                 strides=(2, 2),
                 activation='relu',
                 padding="VALID",
                 kernel_initializer=random_normal(mean=0, stddev=.1),
                 kernel_regularizer=l2(.001)))

model.add(Flatten())
model.add(Dense(2, activation='relu', kernel_initializer=random_normal(mean=0, stddev=.1), kernel_regularizer=l2(.001)))
model.add(Dense(2, activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.sgd(lr=.001, momentum=0.9),
              metrics=['accuracy'])

编辑 1

这是运行 1 epoch 训练的输出...

Epoch 1/2
  500/78750 [..............................] - ETA: 664s - loss: 1.3999 - acc: 0.9460
 1000/78750 [..............................] - ETA: 652s - loss: 1.3713 - acc: 0.9500
 1500/78750 [..............................] - ETA: 648s - loss: 1.3897 - acc: 0.9460
 2000/78750 [..............................] - ETA: 648s - loss: 1.3970 - acc: 0.9420
 2500/78750 [..............................] - ETA: 646s - loss: 1.3965 - acc: 0.9376
 3000/78750 [>.............................] - ETA: 640s - loss: 1.3972 - acc: 0.9373
 3500/78750 [>.............................] - ETA: 636s - loss: 1.3886 - acc: 0.9377
 4000/78750 [>.............................] - ETA: 628s - loss: 1.3886 - acc: 0.9403
 4500/78750 [>.............................] - ETA: 625s - loss: 1.3857 - acc: 0.9400
 5000/78750 [>.............................] - ETA: 619s - loss: 1.3813 - acc: 0.9416
 5500/78750 [=>............................] - ETA: 612s - loss: 1.3773 - acc: 0.9436
 6000/78750 [=>............................] - ETA: 608s - loss: 1.3756 - acc: 0.9447
 6500/78750 [=>............................] - ETA: 606s - loss: 1.3735 - acc: 0.9454
 7000/78750 [=>............................] - ETA: 602s - loss: 1.3733 - acc: 0.9466
 7500/78750 [=>............................] - ETA: 597s - loss: 1.3709 - acc: 0.9481
 8000/78750 [==>...........................] - ETA: 594s - loss: 1.3688 - acc: 0.9480
 8500/78750 [==>...........................] - ETA: 589s - loss: 1.3672 - acc: 0.9485
 9000/78750 [==>...........................] - ETA: 584s - loss: 1.3656 - acc: 0.9491
 9500/78750 [==>...........................] - ETA: 580s - loss: 1.3642 - acc: 0.9491
10000/78750 [==>...........................] - ETA: 576s - loss: 1.3629 - acc: 0.9497
10500/78750 [===>..........................] - ETA: 571s - loss: 1.3625 - acc: 0.9494
11000/78750 [===>..........................] - ETA: 567s - loss: 1.3615 - acc: 0.9495
11500/78750 [===>..........................] - ETA: 562s - loss: 1.3604 - acc: 0.9496
12000/78750 [===>..........................] - ETA: 558s - loss: 1.3596 - acc: 0.9496
12500/78750 [===>..........................] - ETA: 554s - loss: 1.3599 - acc: 0.9496
13000/78750 [===>..........................] - ETA: 549s - loss: 1.3591 - acc: 0.9494
13500/78750 [====>.........................] - ETA: 545s - loss: 1.3588 - acc: 0.9496
14000/78750 [====>.........................] - ETA: 541s - loss: 1.3588 - acc: 0.9496
14500/78750 [====>.........................] - ETA: 537s - loss: 1.3581 - acc: 0.9497
15000/78750 [====>.........................] - ETA: 533s - loss: 1.3577 - acc: 0.9497
15500/78750 [====>.........................] - ETA: 529s - loss: 1.3571 - acc: 0.9503
16000/78750 [=====>........................] - ETA: 525s - loss: 1.3568 - acc: 0.9502
16500/78750 [=====>........................] - ETA: 520s - loss: 1.3563 - acc: 0.9498
17000/78750 [=====>........................] - ETA: 515s - loss: 1.3557 - acc: 0.9500
17500/78750 [=====>........................] - ETA: 510s - loss: 1.3552 - acc: 0.9501
18000/78750 [=====>........................] - ETA: 506s - loss: 1.3547 - acc: 0.9504
18500/78750 [======>.......................] - ETA: 502s - loss: 1.3544 - acc: 0.9504
19000/78750 [======>.......................] - ETA: 497s - loss: 1.3540 - acc: 0.9502
19500/78750 [======>.......................] - ETA: 492s - loss: 1.3537 - acc: 0.9502
20000/78750 [======>.......................] - ETA: 488s - loss: 1.3533 - acc: 0.9501
20500/78750 [======>.......................] - ETA: 483s - loss: 1.3529 - acc: 0.9497
21000/78750 [=======>......................] - ETA: 479s - loss: 1.3525 - acc: 0.9496
21500/78750 [=======>......................] - ETA: 475s - loss: 1.3522 - acc: 0.9500
22000/78750 [=======>......................] - ETA: 471s - loss: 1.3518 - acc: 0.9498
22500/78750 [=======>......................] - ETA: 466s - loss: 1.3515 - acc: 0.9497
23000/78750 [=======>......................] - ETA: 462s - loss: 1.3512 - acc: 0.9499
23500/78750 [=======>......................] - ETA: 458s - loss: 1.3509 - acc: 0.9496
24000/78750 [========>.....................] - ETA: 454s - loss: 1.3506 - acc: 0.9495
24500/78750 [========>.....................] - ETA: 450s - loss: 1.3503 - acc: 0.9499
25000/78750 [========>.....................] - ETA: 445s - loss: 1.3501 - acc: 0.9501
25500/78750 [========>.....................] - ETA: 441s - loss: 1.3498 - acc: 0.9500
26000/78750 [========>.....................] - ETA: 437s - loss: 1.3496 - acc: 0.9501
26500/78750 [=========>....................] - ETA: 433s - loss: 1.3494 - acc: 0.9503
27000/78750 [=========>....................] - ETA: 428s - loss: 1.3491 - acc: 0.9501
27500/78750 [=========>....................] - ETA: 424s - loss: 1.3489 - acc: 0.9501
28000/78750 [=========>....................] - ETA: 419s - loss: 1.3487 - acc: 0.9501
28500/78750 [=========>....................] - ETA: 415s - loss: 1.3484 - acc: 0.9503
29000/78750 [==========>...................] - ETA: 411s - loss: 1.3482 - acc: 0.9503
29500/78750 [==========>...................] - ETA: 407s - loss: 1.3480 - acc: 0.9501
30000/78750 [==========>...................] - ETA: 403s - loss: 1.3478 - acc: 0.9503
30500/78750 [==========>...................] - ETA: 399s - loss: 1.3476 - acc: 0.9501
31000/78750 [==========>...................] - ETA: 395s - loss: 1.3474 - acc: 0.9502
31500/78750 [===========>..................] - ETA: 391s - loss: 1.3472 - acc: 0.9501
32000/78750 [===========>..................] - ETA: 387s - loss: 1.3470 - acc: 0.9501
32500/78750 [===========>..................] - ETA: 383s - loss: 1.3468 - acc: 0.9502
33000/78750 [===========>..................] - ETA: 379s - loss: 1.3467 - acc: 0.9501
33500/78750 [===========>..................] - ETA: 375s - loss: 1.3465 - acc: 0.9501
34000/78750 [===========>..................] - ETA: 371s - loss: 1.3464 - acc: 0.9503
34500/78750 [============>.................] - ETA: 367s - loss: 1.3462 - acc: 0.9502
35000/78750 [============>.................] - ETA: 363s - loss: 1.3461 - acc: 0.9503
35500/78750 [============>.................] - ETA: 358s - loss: 1.3459 - acc: 0.9503
36000/78750 [============>.................] - ETA: 354s - loss: 1.3458 - acc: 0.9502
36500/78750 [============>.................] - ETA: 350s - loss: 1.3456 - acc: 0.9504
37000/78750 [=============>................] - ETA: 346s - loss: 1.3455 - acc: 0.9504
37500/78750 [=============>................] - ETA: 341s - loss: 1.3454 - acc: 0.9505
38000/78750 [=============>................] - ETA: 337s - loss: 1.3452 - acc: 0.9506
38500/78750 [=============>................] - ETA: 333s - loss: 1.3451 - acc: 0.9506
39000/78750 [=============>................] - ETA: 329s - loss: 1.3450 - acc: 0.9506
39500/78750 [==============>...............] - ETA: 325s - loss: 1.3449 - acc: 0.9506
40000/78750 [==============>...............] - ETA: 321s - loss: 1.3448 - acc: 0.9508
40500/78750 [==============>...............] - ETA: 317s - loss: 1.3447 - acc: 0.9509
41000/78750 [==============>...............] - ETA: 313s - loss: 1.3445 - acc: 0.9507
41500/78750 [==============>...............] - ETA: 309s - loss: 1.3444 - acc: 0.9506
42000/78750 [===============>..............] - ETA: 304s - loss: 1.3443 - acc: 0.9507
42500/78750 [===============>..............] - ETA: 300s - loss: 1.3442 - acc: 0.9508
43000/78750 [===============>..............] - ETA: 296s - loss: 1.3441 - acc: 0.9508
43500/78750 [===============>..............] - ETA: 292s - loss: 1.3440 - acc: 0.9508
44000/78750 [===============>..............] - ETA: 287s - loss: 1.3439 - acc: 0.9508
44500/78750 [===============>..............] - ETA: 283s - loss: 1.3438 - acc: 0.9509
45000/78750 [================>.............] - ETA: 279s - loss: 1.3438 - acc: 0.9509
45500/78750 [================>.............] - ETA: 275s - loss: 1.3437 - acc: 0.9511
46000/78750 [================>.............] - ETA: 271s - loss: 1.3436 - acc: 0.9510
46500/78750 [================>.............] - ETA: 267s - loss: 1.3435 - acc: 0.9512
47000/78750 [================>.............] - ETA: 263s - loss: 1.3434 - acc: 0.9513
47500/78750 [=================>............] - ETA: 259s - loss: 1.3433 - acc: 0.9512
48000/78750 [=================>............] - ETA: 255s - loss: 1.3432 - acc: 0.9513
48500/78750 [=================>............] - ETA: 250s - loss: 1.3431 - acc: 0.9512
49000/78750 [=================>............] - ETA: 246s - loss: 1.3430 - acc: 0.9511
49500/78750 [=================>............] - ETA: 242s - loss: 1.3429 - acc: 0.9511
50000/78750 [==================>...........] - ETA: 238s - loss: 1.3428 - acc: 0.9513
50500/78750 [==================>...........] - ETA: 233s - loss: 1.3428 - acc: 0.9514
51000/78750 [==================>...........] - ETA: 229s - loss: 1.3427 - acc: 0.9514
51500/78750 [==================>...........] - ETA: 225s - loss: 1.3426 - acc: 0.9514
52000/78750 [==================>...........] - ETA: 221s - loss: 1.3427 - acc: 0.9515
52500/78750 [===================>..........] - ETA: 217s - loss: 1.3426 - acc: 0.9515
53000/78750 [===================>..........] - ETA: 213s - loss: 1.3425 - acc: 0.9515
53500/78750 [===================>..........] - ETA: 209s - loss: 1.3425 - acc: 0.9516
54000/78750 [===================>..........] - ETA: 204s - loss: 1.3424 - acc: 0.9515
54500/78750 [===================>..........] - ETA: 200s - loss: 1.3423 - acc: 0.9513
55000/78750 [===================>..........] - ETA: 196s - loss: 1.3423 - acc: 0.9515
55500/78750 [====================>.........] - ETA: 192s - loss: 1.3422 - acc: 0.9514
56000/78750 [====================>.........] - ETA: 188s - loss: 1.3421 - acc: 0.9513
56500/78750 [====================>.........] - ETA: 184s - loss: 1.3420 - acc: 0.9513
57000/78750 [====================>.........] - ETA: 179s - loss: 1.3420 - acc: 0.9513
57500/78750 [====================>.........] - ETA: 175s - loss: 1.3419 - acc: 0.9513
58000/78750 [=====================>........] - ETA: 171s - loss: 1.3419 - acc: 0.9513
58500/78750 [=====================>........] - ETA: 167s - loss: 1.3418 - acc: 0.9512
59000/78750 [=====================>........] - ETA: 163s - loss: 1.3417 - acc: 0.9510
59500/78750 [=====================>........] - ETA: 159s - loss: 1.3417 - acc: 0.9511
60000/78750 [=====================>........] - ETA: 155s - loss: 1.3416 - acc: 0.9511
60500/78750 [======================>.......] - ETA: 150s - loss: 1.3415 - acc: 0.9512
61000/78750 [======================>.......] - ETA: 146s - loss: 1.3414 - acc: 0.9512
61500/78750 [======================>.......] - ETA: 142s - loss: 1.3414 - acc: 0.9512
62000/78750 [======================>.......] - ETA: 138s - loss: 1.3413 - acc: 0.9512
62500/78750 [======================>.......] - ETA: 134s - loss: 1.3412 - acc: 0.9513
63000/78750 [=======================>......] - ETA: 130s - loss: 1.3412 - acc: 0.9514
63500/78750 [=======================>......] - ETA: 126s - loss: 1.3411 - acc: 0.9514
64000/78750 [=======================>......] - ETA: 121s - loss: 1.3411 - acc: 0.9515
64500/78750 [=======================>......] - ETA: 117s - loss: 1.3411 - acc: 0.9516
65000/78750 [=======================>......] - ETA: 113s - loss: 1.3410 - acc: 0.9516
65500/78750 [=======================>......] - ETA: 109s - loss: 1.3412 - acc: 0.9516
66000/78750 [========================>.....] - ETA: 105s - loss: 1.3411 - acc: 0.9517
66500/78750 [========================>.....] - ETA: 101s - loss: 1.3410 - acc: 0.9516
67000/78750 [========================>.....] - ETA: 97s - loss: 1.3410 - acc: 0.9516 
67500/78750 [========================>.....] - ETA: 92s - loss: 1.3409 - acc: 0.9516
68000/78750 [========================>.....] - ETA: 88s - loss: 1.3408 - acc: 0.9515
68500/78750 [=========================>....] - ETA: 84s - loss: 1.3408 - acc: 0.9515
69000/78750 [=========================>....] - ETA: 80s - loss: 1.3407 - acc: 0.9515
69500/78750 [=========================>....] - ETA: 76s - loss: 1.3407 - acc: 0.9515
70000/78750 [=========================>....] - ETA: 72s - loss: 1.3406 - acc: 0.9515
70500/78750 [=========================>....] - ETA: 68s - loss: 1.3405 - acc: 0.9516
71000/78750 [==========================>...] - ETA: 64s - loss: 1.3405 - acc: 0.9516
71500/78750 [==========================>...] - ETA: 59s - loss: 1.3404 - acc: 0.9516
72000/78750 [==========================>...] - ETA: 55s - loss: 1.3404 - acc: 0.9517
72500/78750 [==========================>...] - ETA: 51s - loss: 1.3403 - acc: 0.9518
73000/78750 [==========================>...] - ETA: 47s - loss: 1.3403 - acc: 0.9517
73500/78750 [===========================>..] - ETA: 43s - loss: 1.3402 - acc: 0.9518
74000/78750 [===========================>..] - ETA: 39s - loss: 1.3401 - acc: 0.9517
74500/78750 [===========================>..] - ETA: 35s - loss: 1.3401 - acc: 0.9518
75000/78750 [===========================>..] - ETA: 31s - loss: 1.3400 - acc: 0.9518
75500/78750 [===========================>..] - ETA: 26s - loss: 1.3401 - acc: 0.9519
76000/78750 [===========================>..] - ETA: 22s - loss: 1.3400 - acc: 0.9519
76500/78750 [============================>.] - ETA: 18s - loss: 1.3400 - acc: 0.9519
77000/78750 [============================>.] - ETA: 14s - loss: 1.3399 - acc: 0.9519
77500/78750 [============================>.] - ETA: 10s - loss: 1.3399 - acc: 0.9519
78000/78750 [============================>.] - ETA: 6s - loss: 1.3398 - acc: 0.9518 
78500/78750 [============================>.] - ETA: 2s - loss: 1.3398 - acc: 0.9518
78750/78750 [==============================] - 855s - loss: 1.3397 - acc: 0.9518 - val_loss: 1.3321 - val_acc: 0.9523

这是 model.summary()...

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 72, 72, 25)        2525      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 36, 36, 25)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 30, 30, 25)        30650     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 15, 15, 25)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 6, 6, 25)          15650     
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 5, 5, 25)          0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 1, 1, 25)          15650     
_________________________________________________________________
flatten_1 (Flatten)          (None, 25)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 2)                 52        
_________________________________________________________________
dense_2 (Dense)              (None, 2)                 6         
=================================================================
Total params: 64,533
Trainable params: 64,533
Non-trainable params: 0

最佳答案

您的数据集高度不平衡,因此模型将第二类视为噪声,并将所有内容分类为类别 1。平衡数据集的最简单方法是对第二类的示例进行过采样,使模型更多地看到类别 2经常。

这可能会解决类输出的问题,但这样的模型泛化性较差。为了提高泛化能力你可以尝试data augmentation ,对图像应用随机变换。

关于machine-learning - Keras - 无法获得正确的类别预测,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46085413/

相关文章:

machine-learning - 学习型人工智能作为游戏中的对手有意义吗?

r - 使用 nnet 包中的 multinom 函数时,如何控制神经网络的架构?

用于评估模型性能的 R 表——观察类别与预测类别

Tensorflow 保存/恢复批量归一化

python - 部分无标签的 Tensorflow 多标签分类

tensorflow - Keras 数据生成器预测相同数量的值

machine-learning - 切点处的解是最优解吗?

python - 决策树未捕获因变量的方差

tensorflow - 语音识别(使用ML?),而不是语音识别

python - 我的卷积神经网络过度拟合