我正在训练一个卷积自动编码器,但我无法减少损失,希望有人能指出一些可能的改进。
我有 1024x1024 的灰度图像(我也用 512x512 尝试过同样的事情),我希望压缩它们以进行无监督聚类。下面是我的完整模型,但它遵循一个非常基本的模式,即几个带最大池化的 Conv2D 层,然后是一个密集层,然后 reshape 和 Conv2D 层恢复到原始图像大小。
到目前为止我尝试过的一些事情:
1) 我发现 mse 作为损失函数比二元交叉熵效果更好,因为像素亮度值远非均匀分布(二元交叉熵被卡住,将所有值都分配给 1,误差很小但没用)。
2) 如果我只是去掉中间的致密层并稍微压缩图像,我可以轻松地实现非常低的错误和近乎完美的(至少在我看来)图像重建。这是相当明显的,但我想这表明我没有犯某种导致无意义输出的错误。
3) 我的损失并没有真正低于 0.02-0.03。尽管如此,在 0.025 左右,图像已经足够重建,很明显输出来自输入,而不是某种随机噪声(比如让每个像素都具有相同的强度或其他东西)。我认为让它低于 0.01 就足以让我聚类。我的最低值(尽管在我的数据的一个稍微简单的子集上)是 0.018,当我在热图中绘制编码值时,我可以看到样本中有明显的聚类。
4) 当我的中间密集层使用 ReLU 激活时,我得到了很多垂死的 ReLU,这使得它对最终的聚类不太有用。我改用 tanh。我还发现“he_normal”作为密集层的初始化效果更好。
5) 在中间添加更密集的层似乎根本没有帮助。
6) 反转编码器的形状(使其从每层更少的内核变为更多内核)也无济于事,即使我知道传统上卷积自动编码器的外观也是如此。
这是完整的模型(model.summary() 的输出
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_4 (InputLayer) (None, 1024, 1024, 1) 0
_________________________________________________________________
conv2d_40 (Conv2D) (None, 1024, 1024, 128) 1280
_________________________________________________________________
max_pooling2d_19 (MaxPooling (None, 512, 512, 128) 0
_________________________________________________________________
batch_normalization_4 (Batch (None, 512, 512, 128) 512
_________________________________________________________________
conv2d_41 (Conv2D) (None, 512, 512, 64) 73792
_________________________________________________________________
max_pooling2d_20 (MaxPooling (None, 256, 256, 64) 0
_________________________________________________________________
conv2d_42 (Conv2D) (None, 256, 256, 32) 18464
_________________________________________________________________
max_pooling2d_21 (MaxPooling (None, 128, 128, 32) 0
_________________________________________________________________
conv2d_43 (Conv2D) (None, 128, 128, 16) 4624
_________________________________________________________________
max_pooling2d_22 (MaxPooling (None, 64, 64, 16) 0
_________________________________________________________________
conv2d_44 (Conv2D) (None, 64, 64, 8) 1160
_________________________________________________________________
max_pooling2d_23 (MaxPooling (None, 32, 32, 8) 0
_________________________________________________________________
flatten_4 (Flatten) (None, 8192) 0
_________________________________________________________________
dense_5 (Dense) (None, 512) 4194816
_________________________________________________________________
reshape_4 (Reshape) (None, 8, 8, 8) 0
_________________________________________________________________
up_sampling2d_22 (UpSampling (None, 16, 16, 8) 0
_________________________________________________________________
conv2d_45 (Conv2D) (None, 16, 16, 16) 1168
_________________________________________________________________
up_sampling2d_23 (UpSampling (None, 32, 32, 16) 0
_________________________________________________________________
conv2d_46 (Conv2D) (None, 32, 32, 16) 2320
_________________________________________________________________
up_sampling2d_24 (UpSampling (None, 64, 64, 16) 0
_________________________________________________________________
conv2d_47 (Conv2D) (None, 64, 64, 32) 4640
_________________________________________________________________
up_sampling2d_25 (UpSampling (None, 128, 128, 32) 0
_________________________________________________________________
conv2d_48 (Conv2D) (None, 128, 128, 64) 18496
_________________________________________________________________
up_sampling2d_26 (UpSampling (None, 256, 256, 64) 0
_________________________________________________________________
conv2d_49 (Conv2D) (None, 256, 256, 128) 73856
_________________________________________________________________
up_sampling2d_27 (UpSampling (None, 512, 512, 128) 0
_________________________________________________________________
conv2d_50 (Conv2D) (None, 512, 512, 128) 147584
_________________________________________________________________
up_sampling2d_28 (UpSampling (None, 1024, 1024, 128) 0
_________________________________________________________________
conv2d_51 (Conv2D) (None, 1024, 1024, 1) 1153
=================================================================
Total params: 4,543,865
Trainable params: 4,543,609
Non-trainable params: 256
最佳答案
您的损失函数可能是问题所在。在网络的 Logit 输出上使用 BCE。应该可以解决问题。
- 使用:
tf.keras.losses.BinaryCrossentropy(from_logits=True)
- 从编码器和解码器的最后一层中删除激活函数(编码器的最后一个密集层和解码器的最后一个 Conv 层应该没有激活。)
注意:当您从嵌入重构时,向其添加一个 sigmoid 函数。
z = encoder(x)
x_hat_raw = decoder(z)
reconstruction = sigmoid(x_hat_raw)
现在应该好好训练了吧!
关于keras - 卷积自动编码器keras的高损失,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58088560/