python - 合并多个 CNN

标签 python machine-learning neural-network keras conv-neural-network

我正在尝试对模型中的多个输入执行 Conv1D。所以我有 15 个输入,每个输入大小为 1x1500,其中每个输入都是一系列层的输入。所以我有 15 个卷积模型,我想在全连接层之前合并它们。我已经在函数中定义了卷积模型,但我不明白如何调用该函数然后合并它们。

def defineModel(nkernels, nstrides, dropout, input_shape):
    model = Sequential()
    model.add(Conv1D(nkernels, nstrides, activation='relu', input_shape=input_shape))
    model.add(Conv1D(nkernels*2, nstrides, activation='relu'))
    model.add(BatchNormalization())
    model.add(MaxPooling1D(nstrides))
    model.add(Dropout(dropout))
    return model


models = {}
for i in range(15):
    models[i] = defineModel(64,2,0.75,(64,1))

我已成功连接 4 个模型,如下所示:

merged = Concatenate()([ model1.output, model2.output, model3.output, model4.output])

merged = Dense(512, activation='relu')(merged)
merged = Dropout(0.75)(merged)
merged = Dense(1024, activation='relu')(merged)
merged = Dropout(0.75)(merged)
merged = Dense(40, activation='softmax')(merged)
model = Model(inputs=[model1.input, model2.input, model3.input, model4.input], outputs=merged)

由于单独编写 15 层效率不高,因此如何在 for 循环中实现 15 层?

最佳答案

当然,正如 @GabrielM 所建议的,使用函数式 API 是最好的方法,但是如果您不想修改 define_model 函数,您也可以这样做:

models = []
inputs = []
outputs = []
for i in range(15):
    model = defineModel(64,2,0.75,(64,1))
    models.append(model)
    inputs.append(model.input)
    outputs.append(model.output)


merged = Concatenate()(outputs) # this should be output tensors and not models

# the rest is the same ...

model = Model(inputs=inputs, outputs=merged)

关于python - 合并多个 CNN,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52848427/

相关文章:

python - 在 Python Anywhere 托管服务器中触发 OpenCV 脚本时找不到相机

machine-learning - 使用带有概率的 Vowpal wabbit 作为标签来预测概率

c++ - 关于multi-probe Local Sensitive Hashing的问题

Python scikit学习MLPClassifier "hidden_layer_sizes"

python - 为什么 [False] 中的 False==False 返回 True?

Python 字典 : Needed a better output

c++ - windows平台下visual studio C++环境下如何使用xgboost?

python - MLP分类拟合

neural-network - 什么是 NEAT(增强拓扑的神经进化)?

Python-BeautifulSoup : find td width