python - 如何将word2vec导入TensorFlow Seq2Seq模型?

标签 python tensorflow

我正在使用 Tensorflow 序列到序列翻译模型。我想知道我是否可以将自己的 word2vec 导入到这个模型中?而不是使用教程中提到的原始“密集表示”。

从我的角度来看,TensorFlow 似乎正在使用 One-Hot 表示来表示 seq2seq 模型。首先,对于函数 tf.nn.seq2seq.embedding_attention_seq2seq ,编码器的输入是标记化符号,例如“a”将是“4”,“dog”将是“15715”等,并且需要 num_encoder_symbols。所以我认为它让我提供单词的位置和单词的总数,然后该函数可以以 One-Hot 表示形式表示该单词。我还在学习源代码,但是很难理解。

有人可以给我关于上述问题的想法吗?

最佳答案

seq2seq embedding_* 函数确实创建了与 word2vec 非常相似的嵌入矩阵。它们是一个名为 sth 的变量,如下所示:

EMBEDDING_KEY = "embedding_attention_seq2seq/RNN/EmbeddingWrapper/embedding"

知道了这一点,你就可以修改这个变量了。我的意思是——以某种格式获取你的 word2vec 向量,比如文本文件。假设您的词汇表位于 model.vocab 中,您可以按照下面的代码片段所示的方式分配阅读向量(这只是一个代码片段,您必须更改它才能使其正常工作,但我希望它能显示出这个想法) .

   vectors_variable = [v for v in tf.trainable_variables()
                        if EMBEDDING_KEY in v.name]
    if len(vectors_variable) != 1:
      print("Word vector variable not found or too many.")
      sys.exit(1)
    vectors_variable = vectors_variable[0]
    vectors = vectors_variable.eval()
    print("Setting word vectors from %s" % FLAGS.word_vector_file)
    with gfile.GFile(FLAGS.word_vector_file, mode="r") as f:
      # Lines have format: dog 0.045123 -0.61323 0.413667 ...
      for line in f:
        line_parts = line.split()
        # The first part is the word.
        word = line_parts[0]
        if word in model.vocab:
          # Remaining parts are components of the vector.
          word_vector = np.array(map(float, line_parts[1:]))
          if len(word_vector) != vec_size:
            print("Warn: Word '%s', Expecting vector size %d, found %d"
                     % (word, vec_size, len(word_vector)))
          else:
            vectors[model.vocab[word]] = word_vector
    # Assign the modified vectors to the vectors_variable in the graph.
    session.run([vectors_variable.initializer],
                {vectors_variable.initializer.inputs[1]: vectors})

关于python - 如何将word2vec导入TensorFlow Seq2Seq模型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36072672/

相关文章:

python - Django Rest 框架中的 token 身份验证实现

python - 在 matplotlib 中仅连接一段数字数组

python - 在嵌套类中使用对象继承对象来运行迭代

Python 使用递归反转字符串

python - tensorflow 中的二进制阈值激活函数

python - 如何修复 MatMul Op 的 float64 类型与 float32 TypeError 类型不匹配?

machine-learning - Dropout 应该插入到哪里?全连接层。?卷积层。?或两者。?

python - 从 __pycache__ 恢复

windows - 在 docker 镜像中安装 Tensorflow 时出错

tensorflow - 如何在 TensorFlow 中解决 'ran out of gpu memory'