python - 连接 Keras 模型/替换输入但保留图层

标签 python tensorflow deep-learning keras

这个问题类似于Keras replacing input layer .

我有一个分类器网络和一个自动编码器网络,我想使用自动编码器的输出(即编码 + 解码,作为预处理步骤)作为分类器的输入 - 但是在分类器已经在常规训练之后数据。

分类网络是用这样的函数式 API 构建的(基于 this example ):

clf_input = Input(shape=(28,28,1))
clf_layer = Conv2D(...)(clf_input)
clf_layer = MaxPooling2D(...)(clf_layer)
...
clf_output = Dense(num_classes, activation='softmax')(clf_layer)
model = Model(clf_input, clf_output)
model.compile(...)
model.fit(...)

像这样的自动编码器(基于 this example ):

ae_input = Input(shape=(28,28,1))
x = Conv2D(...)(ae_input)
x = MaxPooling2D(...)(x)
...
encoded = MaxPooling2D(...)(x)
x = Conv2d(...)(encoded)
x = UpSampling2D(...)(x)
...
decoded = Conv2D(...)(x)
autoencoder = Model(ae_input, decoded)
autoencoder.compile(...)
autoencoder.fit(...)

我可以像这样连接两个模型(我仍然需要原始模型,因此需要复制):

model_copy = keras.models.clone_model(model)
model_copy.set_weights(model.get_weights())
# remove original input layer
model_copy.layers.pop(0)
# set the new input
new_clf_output = model_copy(decoded)
# get the stacked model
stacked_model = Model(ae_input, new_clf_output)
stacked_model.compile(...)

当我只想将模型应用于新的测试数据时,这非常有用,但它会给出如下错误:

for layer in stacked_model.layers:
    print layer.get_config()

它到达自动编码器的末尾,但随后在分类器模型获取其输入的位置出现 KeyError 失败。此外,当使用 keras.utils.plot_model 绘制模型时,我得到了这个:

stacked_model

您可以在其中看到自动编码器层,但在最后,一个 block 中只有完整的模型,而不是分类器模型中的各个层。

有没有办法连接两个模型,这样新的堆叠模型实际上由所有单独的层组成?

最佳答案

好吧,我能想到的是真的手动遍历模型的每一层,然后像这样一个接一个地重新连接它们:

l = model.layers[1](decoded)  # layer 0 is the input layer, which we're replacing
for i in range(2, len(model.layers)):
    l = model.layers[i](l)
stacked_model = Model(ae_input, l)
stacked_model.compile(...)

虽然这有效并产生了正确的绘图并且没有错误,但这似乎不是最优雅的解决方案......

(顺便说一句,模型的复制实际上似乎是不必要的,因为我没有重新训练任何东西。)

关于python - 连接 Keras 模型/替换输入但保留图层,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50183809/

相关文章:

python - cv2.VideoCapture.read() 在 time.sleep() 之后获取旧帧

python - 如何在 Django 中动态组合 OR 查询过滤器?

python - 在 Tensorflow 2.1 中转换后无法加载 Tensor RT SavedModel

python - 为特定领域微调 Bert(无监督)

python-3.x - Pytorch autograd.grad 如何写多个输出的参数?

python - 给定一个列表,如何计算该列表中的项目?

python - Tensorflow RNN 多对多二进制标签的时间序列

tensorflow - 在反向传递中调试nans

python - Keras Predict_proba 中的神经网络始终返回等于 1 的概率

python - 如何修复 : RuntimeError: size mismatch in pyTorch