我目前正在尝试实现一个变压器,但无法理解其损耗计算。
我的编码器输入查找batch_size=1和max_sentence_length=8,如下所示:
[[Das, Wetter, ist, gut, <blank>, <blank>, <blank>, <blank>]]
我的解码器输入看起来像(德语到英语):
[[<start>, The, weather, is, good, <end>, <blank>, <blank>]]
假设我的转换器预测了这些类别概率(仅显示类别概率最高的类别的单词):
[[The, good, is, weather, <end>, <blank>, <blank>, <blank>]]
现在我使用以下方法计算损失:
loss = categorical_crossentropy(
[[The, good, is, weather, <end>, <blank>, <blank>, <blank>]],
[[The, weather, is, good, <end>, <blank>, <blank>, <blank>]]
)
这是计算损失的正确方法吗?我的变压器总是预测下一个单词的空白标记,我认为这是因为我在损失计算中出现了错误,并且必须在计算损失之前对空白标记进行一些处理。
最佳答案
您需要遮盖填充。 (您所说的 <blank>
通常称为 <pad>
。)
创建一个掩码,说明有效 token 的位置(伪代码:
mask = target != '<pad>')
计算分类交叉熵时,不要自动减少损失并保留该值。
将损失值与掩码相乘,即对应于
<blank>
的位置代币归零,并将有效头寸的损失相加。 (伪代码:loss_sum = (loss * mask).sum()
)除
loss_sum
通过有效位置的数量,即掩码之和(伪代码:loss = loss_sum / mask.sum()
)
关于machine-learning - 如何计算空白 token 预测的变压器损耗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66518375/