我想在我的神经网络架构中使用预训练的嵌入。预训练的嵌入由 gensim 训练。我找到了this informative answer这表明我们可以像这样加载 pre_trained 模型:
import gensim
from torch import nn
model = gensim.models.KeyedVectors.load_word2vec_format('path/to/file')
weights = torch.FloatTensor(model.vectors)
emb = nn.Embedding.from_pretrained(torch.FloatTensor(weights.vectors))
这似乎工作正常,也在 1.0.1 上。我的问题是,我不太明白我必须向这样的层提供什么才能使用它。我可以只喂 token (分段句子)吗?我是否需要一个映射,例如 token 到索引?
我发现您可以简单地通过类似的方式访问 token 的向量
print(weights['the'])
# [-1.1206588e+00 1.1578362e+00 2.8765252e-01 -1.1759659e+00 ... ]
这对 RNN 架构意味着什么?我们可以简单地加载批处理序列的标记吗?例如:
for seq_batch, y in batch_loader():
# seq_batch is a batch of sequences (tokenized sentences)
# e.g. [['i', 'like', 'cookies'],['it', 'is', 'raining'],['who', 'are', 'you']]
output, hidden = model(seq_batch, hidden)
这似乎不起作用,所以我假设您需要在最终的 word2vec 模型中将标记转换为其索引。真的吗?我发现你可以使用 word2vec 模型的
vocab
来获取单词的索引。 :weights.vocab['world'].index
# 147
因此,作为嵌入层的输入,我应该提供
int
的张量吗?对于由单词序列组成的句子序列?使用虚拟数据加载器(参见上面的示例)和虚拟 RNN 欢迎使用示例。
最佳答案
This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding word embeddings.
所以如果你想喂一个句子,你给一个
LongTensor of
索引,每个对应于词汇表中的一个单词,其中nn.Embedding
层将映射到 future 的词向量。这是一个插图
test_voc = ["ok", "great", "test"]
# The word vectors for "ok", "great" and "test"
# are at indices, 0, 1 and 2, respectively.
my_embedding = torch.rand(3, 50)
e = nn.Embedding.from_pretrained(my_embedding)
# LongTensor of indicies corresponds to a sentence,
# reshaped to (1, 3) because batch size is 1
my_sentence = torch.tensor([0, 2, 1]).view(1, -1)
res = e(my_sentence)
print(res.shape)
# => torch.Size([1, 3, 50])
# 1 is the batch dimension, and there's three vectors of length 50 each
就 RNN 而言,接下来您可以将该张量输入您的 RNN 模块,例如
lstm = nn.LSTM(input_size=50, hidden_size=5, batch_first=True)
output, h = lstm(res)
print(output.shape)
# => torch.Size([1, 3, 5])
我还建议您查看 torchtext .它可以自动化一些你必须手动完成的事情。
关于vector - 使用来自 gensim 的预训练向量的 torch 嵌入层的预期输入,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54655604/