我正在尝试使用 Tensorflow 微调来自 Huggingface 的预训练 BERT 模型。一切运行顺利,模型构建和训练没有错误。但是,当我尝试保存模型时,它会因错误“IndexError:列表索引超出范围”而停止。我正在使用带有 TPU 的 Google Colab。
如有任何帮助,我们将不胜感激!
代码:
import tensorflow as tf
from tensorflow.keras import activations, optimizers, losses
from transformers import TFBertModel
def create_model(max_sequence, model_name, num_labels):
bert_model = TFBertModel.from_pretrained(model_name)
input_ids = tf.keras.layers.Input(shape=(max_sequence,), dtype=tf.int32, name='input_ids')
attention_mask = tf.keras.layers.Input((max_sequence,), dtype=tf.int32, name='attention_mask')
output = bert_model([input_ids, attention_mask])[0]
output = output[:, 0, :]
output = tf.keras.layers.Dense(num_labels, activation='sigmoid')(output)
model = tf.keras.models.Model(inputs=[input_ids, attention_mask], outputs=output)
return model
with strategy.scope():
model = create_model(20, 'bert-base-uncased', 1)
opt = optimizers.Adam(learning_rate=3e-5)
loss = 'binary_crossentropy'
model.compile(optimizer=opt, loss=loss, metrics=['accuracy'])
model.fit(tfdataset_train, batch_size=32, epochs=2)
SAVE_PATH = 'path/to/save/location'
model.save(SAVE_PATH)
错误:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-22-255116b49022> in <module>()
1 SAVE_PATH = 'path/to/save/location'
----> 2 model.save(SAVE_PATH,save_format='tf')
50 frames
/usr/local/lib/python3.7/dist-packages/transformers/modeling_tf_utils.py in input_processing(func, config, input_ids, **kwargs)
372 output[tensor_name] = input
373 else:
--> 374 output[parameter_names[i]] = input
375 elif isinstance(input, allowed_types) or input is None:
376 output[parameter_names[i]] = input
IndexError: list index out of range
用形状绘制的模型: Tensorflow Model
最佳答案
解决办法是改变:
output = bert_model([input_ids, attention_mask])[0]
到
output = bert_model.bert([input_ids, attention_mask])[0]
引用:https://github.com/huggingface/transformers/issues/3627
我对您发布的解决方案投了赞成票,但后来我在训练时发现模型存在问题。它收敛不好。
关于python - 保存微调的 Tensorflow 模型时列出超出范围的索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66555455/