python - 无法更改现有 Keras 模型中的激活

标签 python keras keras-layer

我有一个带有 relu 激活的普通 VGG16 模型,即

def VGG_16(weights_path=None):
    model = Sequential()
    model.add(ZeroPadding2D((1, 1),input_shape=(3, 224, 224)))
    model.add(Convolution2D(64, 3, 3, activation='relu'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(64, 3, 3, activation='relu'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
[...]
    model.add(Flatten())
    model.add(Dense(4096, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(4096, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(1000, activation='softmax'))

    if weights_path:
        model.load_weights(weights_path)

    return model

我用现有的权重实例化它,现在想将所有 relu 激活更改为 softmax(我知道没有用)

model = VGG_16('vgg16_weights.h5')
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)

softmax_act = keras.activations.softmax
for (n, layer) in enumerate(model.layers):
    if 'activation' in layer.get_config() and layer.get_config()['activation'] == 'relu':
        print('replacing #{}: {}, {}'.format(n, layer, layer.activation))
        layer.activation = softmax_act
        print('-> {}'.format(layer.activation))

model.compile(optimizer=sgd, loss='categorical_crossentropy')

注意:model.compile 在更改后 被调用,所以我猜模型应该仍然可以修改。

然而,即使调试打印正确地说

replacing #1: <keras.layers.convolutional.Convolution2D object at 0x7f7d7c497f50>, <function relu at 0x7f7dbe699a28>
-> <function softmax at 0x7f7d7c4972d0>
[...]

实际结果与使用 relu 激活的模型相同。
为什么 Keras 不使用更改后的激活函数?

最佳答案

你可能想使用 apply_modifications

idx_of_layer_to_change = -1
model.layers[idx_of_layer_to_change].activation = activations.softmax
model = utils.apply_modifications(model)

关于python - 无法更改现有 Keras 模型中的激活,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43030721/

相关文章:

python - 喀拉斯 2 : Using lambda function in "Merge" layers

python-3.x - 当训练样本增加时,准确率降低

java - 是否有相当于Python的Python的itertools?

python - 即使 Anaconda 说它已安装,导入 Numpy 也会导致错误?

python - Keras - 整个训练过程中损失 Nan 和 0.333 准确度

tensorflow - 为什么我的混淆矩阵 "shifted"在右边?

python - 维数错误 : expected 3, 得到 2 形状 (119, 80)

keras - Keras 中 LSTM 的数学公式?

python - 是否可以组合两个范围来创建字典?

python - 如何过滤数据使用等于或大于 url 中的条件?