python-3.x - 向 Keras 中的输出层添加新节点

标签 python-3.x neural-network deep-learning keras keras-layer

我想向输出层添加新节点以稍后对其进行训练,我正在做:

def add_outputs(self, n_new_outputs):
    out = self.model.get_layer('fc8').output
    last_layer = self.model.get_layer('fc7').output
    out2 = Dense(n_new_outputs, activation='softmax', name='fc9')(last_layer)
    output = merge([out, out2], mode='concat')
    self.model = Model(input=self.model.input, output=output)

哪里'fc7'是输出层之前的全连接层'fc8' .我希望只有最后一层 out = self.model.get_layer('fc8').output但输出是所有模型。
有没有办法只从网络中取出一层?
也许还有其他更简单的方法来做到这一点......

谢谢!!!!

最佳答案

最后我找到了一个解决方案:

1) 获取最后一层的权重

2)向权重添加零并随机初始化它的连接

3)弹出输出层并新建一个

4)为新层设置新的权重

这里的代码:

 def add_outputs(self, n_new_outputs):
        #Increment the number of outputs
        self.n_outputs += n_new_outputs
        weights = self.model.get_layer('fc8').get_weights()
        #Adding new weights, weights will be 0 and the connections random
        shape = weights[0].shape[0]
        weights[1] = np.concatenate((weights[1], np.zeros(n_new_outputs)), axis=0)
        weights[0] = np.concatenate((weights[0], -0.0001 * np.random.random_sample((shape, n_new_outputs)) + 0.0001), axis=1)
        #Deleting the old output layer
        self.model.layers.pop()
        last_layer = self.model.get_layer('batchnormalization_1').output
        #New output layer
        out = Dense(self.n_outputs, activation='softmax', name='fc8')(last_layer)
        self.model = Model(input=self.model.input, output=out)
        #set weights to the layer
        self.model.get_layer('fc8').set_weights(weights)
        print(weights[0])

关于python-3.x - 向 Keras 中的输出层添加新节点,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43852058/

相关文章:

c++ - 将 OpenCV 灰度 Mat 转换为 Caffe blob

python - 机器学习: How to regularize output and force them to be away from 0?

tensorflow - 尝试在 tensorflow 中训练 mobilenet 时出现 ERROR : Config value cuda is not defined in any . rc 文件

Python 无法从文件创建变量

keras - 用于视频输入的 LSTM

Python装饰器: TypeError: function takes 1 positional argument but 2 were given

python - Keras CNN : validation accuracy stuck at 70%, 训练准确率达到 100%

python - model.fit 给出 InvalidArgumentError : Graph execution error:

python - 将枚举成员序列化为 JSON

python - chr() 和 bytes.decode 之间的区别