我有一个 embedded_chars
数组,使用以下代码创建:
self.input_x = tf.placeholder(tf.int32, [None, sequence_length], name="input_x")
W = tf.Variable(
tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
name="W"
)
self.embedded_chars = tf.nn.embedding_lookup(W, self.input_x)
如果我只有 embedded_chars
和 W
,我想获取 input_x
数组。
我怎样才能得到它?
谢谢!
最佳答案
您可以使用 W
和 embedded_chars
中的嵌入向量之间的余弦距离:
# assume embedded_chars.shape == (batch_size, embedding_size)
emb_distances = tf.matmul( # shape == (vocab_size, batch_size)
tf.nn.l2_normalize(W, dim=1),
tf.nn.l2_normalize(embedded_chars, dim=1),
transpose_b=True)
token_ids = tf.argmax(emb_distances, axis=0) # shape == (batch_size)
这里的 emb_distances
是 L2 归一化矩阵 W
和 transpose(embedded_chars)
的点积,与余弦距离相同W
中的所有向量与 embedded_chars
中的所有向量之间。 argmax 简单地为我们提供了 emb_distances
每一列中最大值的索引。
@Yao Zhang:如果 W
中的所有嵌入都不同,因为它们应该不同,那么这将为您提供正确的结果:余弦距离始终在 [-1, 1] 和 cos( vector_a, vector_a) == 1.
请注意,通常您不需要进行这种从嵌入到标记索引的转换:通常您可以直接将张量的值作为第二个参数传递给 tf.nn.embedding_embedding_lookup
,这是 token 索引的张量。
关于python - 如何对tf.nn.embedding_lookup进行逆向操作?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43103265/