pytorch - 如何正确提示Transformer模型的解码器?

标签 pytorch nlp artificial-intelligence huggingface-transformers summarization

我正在使用拥抱脸部变形金刚。我有一个预训练的编码器 + 解码器模型 (Pegasus),并且想要按照 this article 中所述对其进行微调。 .

具体来说,他们使用以下过程:

Summary generation using entity prompts

换句话说,他们在生成模型本身之前添加了手动提示。

我的问题与解码器输入有关。具体来说,我想微调模型,以便它接受提示(实体链),并从该点开始生成摘要。

例如:

<s> [ENTITYCHAIN] Frozen | Disney [SUMMARY] $tok_1 $tok_2 $tok_3 ...
=========================================== ^^^^^^ ^^^^^^ ^^^^^^
This is not generated                       Generate from here

但是,正如您所期望的,该模型正在为实体链中的每个 token 生成预测,这是我不需要的。但最重要的是,损失的计算还考虑了与实体链相关的预测。这显然破坏了训练的目的,因为它混淆了模型,因为它应该学习仅生成摘要,而不是实体链(已经作为提示给出)。

正如我所说,我想要的是给解码器一个提示(实体链),并使其生成摘要,同时能够关注提示中的额外信息。当然,损失应该只在生成的 token 中计算,不包括提示 token 。

通过调查model documentation ,我似乎没有找到执行此操作的选项。有任何想法吗? :)

最佳答案

pytorch 损失函数使用的约定是,如果在训练期间将标签设置为 -100,损失函数将忽略该标记。请参阅Documentation为了心情轻松。

这是一个最小的代码示例:

# Libraries
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from copy import deepcopy

# Get the tokenizer and the model
checkpoint = 't5-small'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

# Sample text
inp = 'Here is my input'
outp = 'Here is my output'

# Get token IDs
inp_ids = tokenizer(inp, return_tensors = 'pt').input_ids
outp_ids = tokenizer(outp, return_tensors = 'pt').input_ids

# Calculate loss
loss = model(input_ids = inp_ids, labels = outp_ids).loss.item()

print(loss)

# Let's set the first token to -100 and recalculate loss
modified_outp_ids = deepcopy(outp_ids)
modified_outp_ids[0][0] = -100 # the first [0] is because we only have one sequence in our batch

model_output = model(input_ids = inp_ids, labels = modified_outp_ids)

print(model_output.loss.item())

关于pytorch - 如何正确提示Transformer模型的解码器?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/76299091/

相关文章:

linux - 通过 SSH 连接到 Docker?还是 SSH 上的 docker?我需要指挥

python - 如何在训练多类 CNN 模型时存储和加载包含 5000 万个 25x25 numpy 数组的训练数据?

machine-learning - 检测 Apple 硅 GPU 核心数

点和框游戏的 Java minimax

numpy - Pytorch 内存模型 : how does "torch.from_numpy()" work?

linguistics - 理解 semcor 语料库结构 h

java - 自然语言处理——将文本特征转化为特征向量

VB.NET 以最少的步骤自定义复杂排序

artificial-intelligence - Wolfram Alpha 是如何工作的?

为给定文本获取合适图片的算法