python - 如何在 Keras 中向 ResNet50 添加顶层密集层?

标签 python deep-learning keras

我在这里阅读了这个非常有用的关于迁移学习的 Keras 教程:

https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html

我认为这可能非常适用于这里的鱼类数据,并开始沿着这条路走下去。我尽量按照教程进行操作。代码一团糟,因为我只是想弄清楚一切是如何工作的,但可以在这里找到它:

https://github.com/MrChristophRivera/ClassifiyingFish/blob/master/notebooks/Anthony/Resnet50%2BTransfer%20Learning%20Attempt.ipynb

为简洁起见,以下是我在此处执行的步骤:

model = ResNet50(top_layer = False, weights="imagenet"
# I would resize the image to that of the standard input size of ResNet50.
datagen=ImageDataGenerator(1./255)
generator = datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_width, img_height),
    batch_size=32,
    class_mode=None,
    shuffle=False)
# predict on the training data
bottleneck_features_train = model.predict_generator(generator, 
nb_train_samples)
print(bottleneck_features_train)
file_name = join(save_directory, 'tbottleneck_features_train.npy')
np.save(open(file_name, 'wb'), bottleneck_features_train)
# Then I would use this output to feed my top layer and train it. Let's 
say I defined 
# it like so:
top_model = Sequential()
# Skipping some layers for brevity
top_model.add(Dense(8,  activation='relu')
top_model.fit(train_data, train_labels)
top_model.save_weights(top_model_weights_path).

此时,我已经保存了权重。下一步是将顶层添加到 ResNet50。本教程只是这样做的:

# VGG16 model defined via Sequential is called bottom_model.
bottom_model.add(top_model)

问题是当我尝试这样做时失败了,因为“模型没有添加属性”。我的猜测是 ResNet50 是以不同的方式定义的。无论如何,我的问题是:如何将这个带有加载重量的顶级模型添加到底部模型?任何人都可以提供有用的指示吗?

最佳答案

尝试:

input_to_model = Input(shape=shape_of_your_image)
base_model = model(input_to_model)
top_model = Flatten()(base_model)
top_model = Dense(8,  activation='relu')
...

您的问题来自于 Resn​​et50 是在所谓的 functional API 中定义的。 .我还建议您使用不同的激活函数,因为将 relu 作为输出激活可能会导致问题。此外 - 您的模型未编译。

关于python - 如何在 Keras 中向 ResNet50 添加顶层密集层?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42788342/

相关文章:

python - 人工神经网络 - 编译错误

neural-network - 为什么在 Keras 中 CNN 的训练速度比完全连接的 MLP 慢?

python - 如何使用 python 脚本修改 Vim 缓冲区?

python - 搜索 SQLite 数据库数千次的最快方法?

machine-learning - 在小型图像数据集上训练 GAN

python - 如何批量训练具有多个输入的模型?

python - 如何从嵌入层获取输出

python - 如何使用 `pandas.cut()` 根据被分箱列以外的列分箱数据?

Python正则表达式在字符串中的双引号中查找字符串

tensorflow - 具有多输入 KerasClassifier 的 Sklearn cross_val_score