python - 有哪些方法可以使用不同维度的图像集作为预训练模型的输入?

标签 python tensorflow keras conv-neural-network

我正在处理手写数字识别问题,使用 OpenCV 进行预处理,使用 Keras/Tensorflow 进行推理。我在 MNIST 手写数字数据集上训练了一个模型,其中每个图像都是 28x28 像素。现在我正在处理一组新的数字,我计划使用原始模型架构进行进一步训练,并通过权重初始化进行迁移学习。

所以这是我的问题 :当我缩小到 28x28 像素时,我遇到了丢失某些功能的问题。这是一个例子

enter image description here

这意味着是 2,顶部循环中的微小间隙对于帮助区分 9 或 8 很重要。但我的预处理版本丢失了间隙,因此循环看起来是封闭的。

我已发布 another question关于如何在不丢失功能的情况下进行缩小。另一方面,也许我想缩小到 更大的尺寸,如 56x56 像素 我不太可能失去这些功能。 我该如何设置,使这个新大小与模型融合在一起,而不会使预先训练的权重变得无用?

这是预训练模型的定义:

def define_model(learning_rate, momentum):
    model = Sequential()
    model.add(Conv2D(32, (3,3), activation = 'relu', kernel_initializer = 'he_uniform', input_shape=(28,28,1)))
    model.add(MaxPooling2D((2,2)))
    model.add(Conv2D(64, (3,3), activation = 'relu', kernel_initializer = 'he_uniform'))
    model.add(Conv2D(64, (3,3), activation = 'relu', kernel_initializer = 'he_uniform'))
    model.add(MaxPooling2D((2,2)))
    model.add(Flatten())
    model.add(Dense(100, activation='relu', kernel_initializer='he_uniform'))
    model.add(Dense(10, activation='softmax'))
    opt = SGD(lr=learning_rate, momentum=momentum)
    model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
    return model

这是我的一个想法:在第一层之后增加最大池内核的大小,使该层的输出具有与使用 28x28 像素图像相同的形状。 (但这不会导致我失去该功能吗?)

最佳答案

为什么不升级 MNST 进行培训?你的问题是关于图像的分辨率,MNST 数据集是在很久以前 GPU 内存还很小的时候创建的。最近的模型都具有大于 200 * 200 的图像尺寸例如,resnet 使用 224*224作为输入形状。由于您的图像从一开始就已经是低分辨率的,因此您缩小比例会使模型难以相互区分。由于您的模型相当简单,我建议升级训练数据集。

是的,如果你使用你提到的池化,你可能也会丢失信息。

希望这可以帮助。

关于python - 有哪些方法可以使用不同维度的图像集作为预训练模型的输入?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59841602/

相关文章:

Python3 : How to use set_description with tqdm. contrib.concurrent process_map

python3 tkinter 网格和打包,内联打包语法和优雅

Tensorflow 输入管道

python - Keras:微调 Inception 时精度下降

python - 关于Keras中基于官方文档的Embeddings的输入维度的问题

python - 使用嵌套列表索引列表

python - 如何通过进程 ID 获取进程的标准输入?

python - Tensorflow Keras 输入形状为 : [? ,1,1,32 的 'average_pooling2d' 从 1 中减去 2 导致的负维度大小]

tensorflow 2 TextVectorization过程张量和数据集错误

python - 将 4 channel RGB-D 图像输入 LSTM