tensorflow - 微调 BERT 的最后 x 层

标签 tensorflow module google-colaboratory embedding

我正在尝试仅在特定的最后一层(假设最后 3 层)上微调 BERT。我想使用 Google Colab 进行 TPU 训练。我正在使用hub.Module加载 BERT 并对其进行微调,然后将微调后的输出用于我的分类任务。

bert_module = hub.Module(BERT_MODEL_HUB,tags=tags,trainable=True)

hub.Module 可以选择将模型设置为可训练或不可训练,但不能部分训练(仅特定层)

有人知道如何使用 hub.Module 训练 BERT 的最后 1,2 或 3 层吗?

谢谢

最佳答案

您可以在可训练变量列表中手动设置它。以下是我在tensorflow-keras中对Bert层的实现-

class BertLayer(tf.layers.Layer):
 def __init__(
    self,
    n_fine_tune_layers=10,
    pooling="first",
    bert_path="https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1",
    **kwargs,
):
    self.n_fine_tune_layers = n_fine_tune_layers
    self.trainable = True
    self.output_size = 768
    self.pooling = pooling
    self.bert_path = bert_path
    if self.pooling not in ["first", "mean"]:
        raise NameError(
            f"Undefined pooling type (must be either first or mean, but is {self.pooling}"
        )

    super(BertLayer, self).__init__(**kwargs)

 def build(self, input_shape):
    self.bert = hub.Module(
        self.bert_path, trainable=self.trainable, name=f"{self.name}_module"
    )

    # Remove unused layers
    trainable_vars = self.bert.variables
    if self.pooling == "first":
        trainable_vars = [var for var in trainable_vars if not "/cls/" in var.name]
        trainable_layers = ["pooler/dense"]

    elif self.pooling == "mean":
        trainable_vars = [
            var
            for var in trainable_vars
            if not "/cls/" in var.name and not "/pooler/" in var.name
        ]
        trainable_layers = []
    else:
        raise NameError(
            f"Undefined pooling type (must be either first or mean, but is {self.pooling}"
        )

    # Select how many layers to fine tune
    for i in range(self.n_fine_tune_layers):
        trainable_layers.append(f"encoder/layer_{str(11 - i)}")

    # Update trainable vars to contain only the specified layers
    trainable_vars = [
        var
        for var in trainable_vars
        if any([l in var.name for l in trainable_layers])
    ]

    # Add to trainable weights
    for var in trainable_vars:
        self._trainable_weights.append(var)

    for var in self.bert.variables:
        if var not in self._trainable_weights:
            self._non_trainable_weights.append(var)

    super(BertLayer, self).build(input_shape)

  def call(self, inputs):
    inputs = [K.cast(x, dtype="int32") for x in inputs]
    input_ids, input_mask, segment_ids = inputs
    bert_inputs = dict(
        input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids
    )
    if self.pooling == "first":
        pooled = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True)[
            "pooled_output"
        ]
    elif self.pooling == "mean":
        result = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True)[
            "sequence_output"
        ]

        mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1)
        masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / (
                tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10)
        input_mask = tf.cast(input_mask, tf.float32)
        pooled = masked_reduce_mean(result, input_mask)
    else:
        raise NameError(f"Undefined pooling type (must be either first or mean, but is {self.pooling}")

    return pooled

def compute_output_shape(self, input_shape):
    return (input_shape[0], self.output_size)

重点关注上面代码中的以下行 -

trainable_layers.append(f"encoder/layer_{str(11 - i)}")

您可以将 n_fine_tune_layers 参数默认设置为 1/2/3,或者在声明图层时传递它 -

def __init__(self, n_fine_tune_layers=2, **kwargs):

关于tensorflow - 微调 BERT 的最后 x 层,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56028464/

相关文章:

python - 关于 tf.nn.softmax_cross_entropy_with_logits_v2

python - 从 python 中删除导入的模块

javascript - node.js 中不同的 module.exports 模式

JavaScript 模块

r - 如何在colab中使用R从谷歌驱动器读取数据?

python - TypeError 使用 GradientTape.gradient 计算梯度

tensorflow - 如何使用多个model_spec创建单个PredictRequest()?

python - 我如何知道哪个预测针对哪个数据?那么,如何评估预测呢?

audio - 如何在google colab中显示多个音频文件?

python - 同时枚举2个Python字典的值?