pytorch - 如何理解mbart中的decoder_start_token_id和forced_bos_token_id?

标签 pytorch multilingual huggingface-transformers

当我想使用huggingface的预训练模型(例如mbart)进行多语言实验时,参数decoder_start_token_idforced_bos_token_id的含义让我感到困惑。我发现这样的代码:

# While generating the target text set the decoder_start_token_id to the target language id. 
# The following example shows how to translate English to Romanian 
# using the facebook/mbart-large-en-ro model.
from transformers import MBartForConditionalGeneration, MBartTokenizer

tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-en-ro", src_lang="en_XX")
article = "UN Chief Says There Is No Military Solution in Syria"
inputs = tokenizer(article, return_tensors="pt")
translated_tokens = model.generate(**inputs, decoder_start_token_id=tokenizer.lang_code_to_id["ro_RO"])
tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]

和:

# To generate using the mBART-50 multilingual translation models, 
# eos_token_id is used as the decoder_start_token_id and the target language id is forced as the first generated token. 
# To force the target language id as the first generated token, 
# pass the forced_bos_token_id parameter to the generate method. 
# The following example shows how to translate between Hindi to French and Arabic to English 
# using the facebook/mbart-50-large-many-to-many checkpoint.
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

article_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है"
article_ar = "الأمين العام للأمم المتحدة يقول إنه لا يوجد حل عسكري في سوريا."

model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

# translate Hindi to French
tokenizer.src_lang = "hi_IN"
encoded_hi = tokenizer(article_hi, return_tensors="pt")
generated_tokens = model.generate(**encoded_hi, forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"])
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
# => "Le chef de l 'ONU affirme qu 'il n 'y a pas de solution militaire en Syria."

# translate Arabic to English
tokenizer.src_lang = "ar_AR"
encoded_ar = tokenizer(article_ar, return_tensors="pt")
generated_tokens = model.generate(**encoded_ar, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
# => "The Secretary-General of the United Nations says there is no military solution in Syria."

这两个参数的注释是:

decoder_start_token_id (:obj:`int`, `optional`): 
If an encoder-decoder model starts decoding with a different token than `bos`, 
the id of that token.

forced_bos_token_id (:obj:`int`, `optional`): 
The id of the token to force as the first generated token after the :obj:`decoder_start_token_id`.
Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where 
the first generated token needs to be the target language token.

对于 mbart 的不同变体,例如 facebook/mbart-large-cc25facebook/mbart-large-50,我们应该指定哪一个来生成响应特定语言?

最佳答案

在标准的序列到序列模型中,解码从向解码器提供[bos]符号开始,它生成单词w1,它被提供作为下一步解码器的输入。解码器生成单词w2。这一直持续到生成 [eos](句子结束)标记。

[bos] w_1  w_2  w_3
  ↓    ↓    ↓    ↓
┌──────────────────┐
│     DECODER      │
└──────────────────┘
  ↓    ↓    ↓    ↓
 w_1  w_2  w_2 [eos]

对于 mBART,这更加棘手,因为您需要告诉它目标语言和源语言是什么。对于编码器和训练数据,分词器负责处理这一问题,并在源句子的末尾和目标句子的开头添加特定于语言的标签。然后,句子的格式如下(假设源有 4 个单词,目标有 3 个单词):

  • 来源:v1 v2 v3 v4 [src_lng]
  • 目标:[tgt_lng] w1 w2 w3 [eos]

与训练不同,在推理时,目标句子是未知的,而你想要生成它。但您仍然需要告诉解码器应该使用什么而不是通用的 [bos] token 。这就是 forced_bos_token_id 发挥作用的地方。仍然是标记器知道特定标记的 ID。不同的 mBART 有不同的分词器,您应该始终使用与模型匹配的分词器的语言 ID。

您提到的属性似乎做同样的事情,但我会坚持 mBART documentation 中提到的 forced_bos_token_id 。 HuggingFace Transformers 中的方法 API 非常宽松,某些属性仅适用于某些模型,而会被其他模型忽略。我会避免使用特定模型的文档中未明确提及的内容。

关于pytorch - 如何理解mbart中的decoder_start_token_id和forced_bos_token_id?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68313263/

相关文章:

ios - 核心数据的多语言方法

php - 用 PHP 构建的多语言(多语言)网站的自定义 404 页面

python - 来自 "ImportError: cannot import name ' 的变压器 : Error importing packages. 'torch.optim.lr_scheduler' SAVE_STATE_WARNING'

python - pytorch反向传播中的两种操作有什么区别?

python - Pytorch 在 __init__() 中定义层和直接在 forward() 中使用有什么区别?

python - 二元交叉熵与 2 个类别的分类交叉熵

android - 在不重新发布应用程序的情况下在 android 中添加新语言

bert-language-model - Huggingface BERT Tokenizer 添加新 token

python - 使用没有 IPyWidgets 的拥抱脸转换器

google-cloud-platform - 将 TPU 与 PyTorch 结合使用