我正在使用 Pytorch 的 nn.TransformerEncoder
模块。我得到了具有(正常)形状(batch-size、seq-len、emb-dim
)的输入样本。一批中的所有样本均已用零填充至该批处理中最大样本的大小。因此我希望忽略全零值的注意力。
文档说,将参数src_key_padding_mask
添加到nn.TransformerEncoder
模块的forward
函数中。该掩码应该是一个形状为 (batch-size, seq-len
) 的张量,并且每个索引的填充零为 True
或 False
其他任何内容。
我通过这样做实现了这一点:
. . .
def forward(self, x):
# x.size -> i.e.: (200, 28, 200)
mask = (x == 0).cuda().reshape(x.shape[0], x.shape[1])
# mask.size -> i.e.: (200, 20)
x = self.embed(x.type(torch.LongTensor).to(device=device))
x = self.pe(x)
x = self.transformer_encoder(x, src_key_padding_mask=mask)
. . .
当我不设置src_key_padding_mask
时,一切正常。但我这样做时得到的错误如下:
File "/home/me/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/functional.py", line 4282, in multi_head_attention_forward
assert key_padding_mask.size(0) == bsz
AssertionError
似乎它正在将掩码的第一个维度(即批量大小)与 bsz 进行比较,后者可能代表批量大小。但为什么会失败呢?非常感谢帮助!
最佳答案
我遇到了同样的问题,但这不是错误:pytorch's Transformer implementation要求输入x
为(seq-len x batch-size x emb-dim)
,而你的似乎是(batch-size x seq-len x emb-dim)
。
关于python - Pytorch 的 nn.TransformerEncoder "src_key_padding_mask"未按预期运行,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65424676/