我将权重预训练为形状为 (3, 3, 3, 64)
的 np.array。我想使用 set_weights()
使用这些权重初始化此 Tensorflow CNN,如下所示。
但是,当我尝试这样做时,会弹出以下错误:ValueError:您在“conv2d_3”层上调用了 set_weights(weights),权重列表长度为 3,但该层需要 2 个权重。提供的权重:[[[[-0.15836713 -0.178757 0.16782044 ...
model = models.Sequential()
model.add(layers.Conv2D(64, (3, 3), activation='relu', input_shape=(224, 224, 3)))
model.layers[0].set_weights(weights)
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(256, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(256, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(512, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(512, (3, 3), activation='relu'))
model.add(layers.GlobalAveragePooling2D())
model.add(layers.Dense(4, activation='softmax'))
print(model.summary())
adam = optimizers.Adam(learning_rate=0.0001, amsgrad=False)
model.compile(loss='categorical_crossentropy',
optimizer=adam,
metrics=['accuracy'])
history = model.fit_generator(
train_generator,
steps_per_epoch=np.ceil(nb_train_samples/batch_size),
epochs=epochs,
validation_data=validation_generator,
validation_steps=np.ceil(nb_validation_samples / batch_size),
class_weight=class_weight
)
我的问题是:如何传递那些 (3, 3, 3, 64)
形状的权重来初始化 CNN?我已经检查了每一层所需的重量形状和我试图通过的形状以及所需的形状匹配。
最佳答案
您可以像这样使用 kernel_initializer
和 bias_initializer
参数:
import numpy as np
# init_kernel and init_bias are initialization weights that you have
init_kernel = np.random.normal(0, 1, (3, 3, 3, 64))
init_bias = np.zeros((64,))
kernel_initializer = tf.keras.initializers.constant(init_kernel)
bias_initializer = tf.keras.initializers.constant(init_bias)
conv_layer = tf.keras.layers.Conv2D(64, (3, 3),
activation='relu',
input_shape=(224, 224, 3),
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer)
注意我选择的内核和偏置形状。用于初始化图层的值必须具有完全相同的形状。
关于python - Tensorflow 模型中的 set_weights(),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61373229/