python - 默认情况下,Keras 自定义层参数是不可训练的吗?

标签 python tensorflow keras

我在 Keras 中构建了一个简单的自定义层,惊讶地发现参数默认情况下并未设置为可训练。我可以通过显式设置可训练属性来让它工作。我无法通过查看文档或代码来解释为什么会这样。这是应该的样子还是我做错了什么导致默认情况下参数不可训练? 代码:

import tensorflow as tf


class MyDense(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(MyDense, self).__init__(kwargs)
        self.dense = tf.keras.layers.Dense(2, tf.keras.activations.relu)

    def call(self, inputs, training=None):
        return self.dense(inputs)


inputs = tf.keras.Input(shape=10)
outputs = MyDense()(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs, name='test')
model.compile(loss=tf.keras.losses.MeanSquaredError())
model.summary()

输出:

Model: "test"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 10)]              0         
_________________________________________________________________
my_dense (MyDense)           (None, 2)                 22        
=================================================================
Total params: 22
Trainable params: 0
Non-trainable params: 22
_________________________________________________________________

如果我像这样更改自定义图层创建:

outputs = MyDense(trainable=True)(inputs)

输出是我所期望的(所有参数都是可训练的):

=================================================================
Total params: 22
Trainable params: 22
Non-trainable params: 0
_________________________________________________________________

然后它按预期工作并使所有参数都可训练。我不明白为什么需要这样做。

最佳答案

毫无疑问,这是一个有趣的怪癖。

当制作自定义层时,tf.Variable 将自动包含在 trainable_variable 列表中。您没有使用 tf.Variable,而是使用 tf.keras.layers.Dense 对象,它不会被视为 tf.Variable,并且默认不设置 trainable=True。但是,您使用的 Dense 对象将被设置为可训练的。见:

MyDense().dense.trainable
True

如果您使用了 tf.Variable(它应该如此),默认情况下它将是可训练的。

import tensorflow as tf


class MyDense(tf.keras.layers.Layer):
    def __init__(self, units=2, input_dim=10):
        super(MyDense, self).__init__()
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(
            initial_value=w_init(shape=(input_dim, units), dtype="float32"),
            trainable=True,
        )
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(
            initial_value=b_init(shape=(units,), dtype="float32"), trainable=True
        )

    def call(self, inputs, **kwargs):
        return tf.matmul(inputs, self.w) + self.b


inputs = tf.keras.Input(shape=10)
outputs = MyDense()(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs, name='test')
model.compile(loss=tf.keras.losses.MeanSquaredError())
model.summary()
Model: "test"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_11 (InputLayer)        [(None, 10)]              0         
_________________________________________________________________
my_dense_18 (MyDense)        (None, 2)                 22        
=================================================================
Total params: 22
Trainable params: 22
Non-trainable params: 0
_________________________________________________________________

关于python - 默认情况下,Keras 自定义层参数是不可训练的吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65475110/

相关文章:

python - 如何使用 beautifulsoup 获取所有页面?

python - 通过从数据帧中的多列进行条件选择进行向量算术

python - cv2.remap 或 scipy.interpolate.map_coordinates 在 Tensorflow 中的等效/实现?

python - 在 Ubuntu 16.04 上使用 bazel 从源构建 tensorflow 。错误是 ---> 链接规则 '//tensorflow/contrib/lite/toco:toco' 失败(退出 1)

Python 使用 selenium 切换浏览器焦点

python - Matplotlib 在方法和日期时间之间填充

flutter - 如何在Flutter上集成自己的tflite模型?

Python|Keras : ValueError: Error when checking target: expected conv2d_3 to have 4 dimensions, 但得到了形状为 (1006, 5) 的数组

python-3.x - Tensorflow 和 Keras 中的相同(?)神经网络架构在相同数据上产生不同的准确性

c++ - tensorflow C++ 批量推理