pytorch - pytorch nn.EmbeddingBag 中的偏移量是什么意思?

标签 pytorch

我知道偏移量在有两个数字时是什么意思,但是当有两个以上数字时是什么意思,例如:

weight = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
embedding_sum = nn.EmbeddingBag.from_pretrained(weight, mode='sum')
print(list(embedding_sum.parameters()))
input = torch.LongTensor([0,1])
offsets = torch.LongTensor([0,1,2,1])

print(embedding_sum(input, offsets))
结果是:
[Parameter containing:
tensor([[1., 2., 3.],
        [4., 5., 6.]])]
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [0., 0., 0.],
        [0., 0., 0.]])
谁能帮我?

最佳答案

source code所示,

return F.embedding(
    input, self.weight, self.padding_idx, self.max_norm,
    self.norm_type, self.scale_grad_by_freq, self.sparse) 
它使用 functional embedding bag ,这解释了 offsets参数为

offsets (LongTensor, optional) – Only used when input is 1D. offsets determines the starting index position of each bag (sequence) in input.


EmbeddingBag docs :

If input is 1D of shape (N), it will be treated as a concatenation of multiple bags (sequences). offsets is required to be a 1D tensor containing the starting index positions of each bag in input. Therefore, for offsets of shape (B), input will be viewed as having B bags. Empty bags (i.e., having 0-length) will have returned vectors filled by zeros.


最后一条语句(“空袋(即长度为 0)将返回由零填充的向量。”)解释了结果张量中的零向量。

关于pytorch - pytorch nn.EmbeddingBag 中的偏移量是什么意思?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65547335/

相关文章:

python - 如何将 PyTorch 张量的每一行中的重复值清零?

python - 将 PyTorch 的张量形状从 (C, B, H) 修改为 (B, C*H)

python - .backward() 之后 pytorch grad 为 None

python - Pytorch BCELoss 对相同输入使用不同的输出

tensorflow - Pytorch Autograd : what does runtime error "grad can be implicitly created only for scalar outputs" mean

python - 使用 detectorron2 进行语义分割

python - 在 PyTorch 中使用 Conv2D 时,首先发生填充还是膨胀?

pytorch - 当批量大小不是 train_size 的一个因素时,将 loss().item 乘以 batch_size 以获得批量损失是否是个好主意?

python - IndexError : Dimension out of range - PyTorch dimension expected to be in range of [-1, 0],但得到 1

python - 所有 cuda、pytorch、cuda 工具包都匹配,但 `torch.cuda.is_available()` 仍然为 False