tensorflow - 如何在tensorflow keras中访问自定义层的递归层

标签 tensorflow keras keras-layer

来自 tensorflow keras 示例。我可以递归地创建一个包含线性层的自定义层

class MLPBlock(layers.Layer):

  def __init__(self):
    super(MLPBlock, self).__init__()
    self.linear_1 = Linear(32)
    self.linear_2 = Linear(32)
    self.linear_3 = Linear(1)

  def call(self, inputs):
    x = self.linear_1(inputs)
    x = tf.nn.relu(x)
    x = self.linear_2(x)
    x = tf.nn.relu(x)
    return self.linear_3(x)

如何访问自定义层的所有组件层,我想访问所有组件层的权重和偏差。

例如:

MLPBlock(Parent Layer):
    linear_1
    linear_2
    linear_3

我查看了 tensorflow keras api 版本 r 1.14 https://www.tensorflow.org/guide/keras 但找不到任何方法来做到这一点。

最佳答案

我假设您正在关注 this tutorial .基于此,您可以通过以下方式访问权重:

class MLPBlock(tf.keras.Model):

    def __init__(self):
        super(MLPBlock, self).__init__()
        self.linear_1 = tf.keras.layers.Dense(32)
        self.linear_2 = tf.keras.layers.Dense(32)
        self.linear_3 = tf.keras.layers.Dense(1)

    def call(self, inputs):
        x = self.linear_1(inputs)
        x = tf.nn.relu(x)
        x = self.linear_2(x)
        x = tf.nn.relu(x)
        return self.linear_3(x)

mlp_block = MLPBlock()
y = mlp_block(tf.ones(shape=(3, 64)))
for layer in mlp_block.layers:
    weights, biases = layer.get_weights()

请注意,我稍微修改了示例,以便您可以访问图层的权重和偏差。也就是说,我所做的不是使用 tf.keras.layers.Layer 对类进行子类化,而是使用 tf.keras.Model 进行子类化,这样层的堆栈就可以被视为模型,然后您可以访问该模型的层。然后,为了简单起见,我没有使用自定义 Linear 层,而是使用了 tf.keras.layers.Dense,但是,使用自定义层应该没有什么不同。

关于tensorflow - 如何在tensorflow keras中访问自定义层的递归层,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57479680/

相关文章:

python - 在 Keras 中,如何使用无维度的 Reshape 层?

tensorflow - Tensorflow 中是否有与 PyTorch 的 RandomResizedCrop 等效的功能?

python - 在 TensorFlow 中平铺变量张量是否会创建新变量?

python - 训练新的 Yolo 模型是否需要调整图像大小?

python - Keras 模型预测同一类别

tensorflow - 如何为 keras 层编写 lambda 函数,如下所示:layer1 * Layer2 = Product(layer1 * Layer2)

python - Keras LSTM TensorFlow错误: 'Shapes must be equal rank, but are 1 and 0'

python - 从 Tensorflow PrefetchDataset 中提取目标

python - 并行 LSTM 分别处理输入的不同部分

tensorflow - Keras 中的动态激活函数