python - 从 PyTorch 中的 BiLSTM (BiGRU) 获取最后一个状态

标签 python lstm pytorch

在阅读了几篇文章后,我仍然对我从 BiLSTM 获取最后隐藏状态的实现的正确性感到困惑。

  • Understanding Bidirectional RNN in PyTorch (TowardsDataScience)
  • PackedSequence for seq2seq model (PyTorch forums)
  • What's the difference between “hidden” and “output” in PyTorch LSTM? (StackOverflow)
  • Select tensor in a batch of sequences (Pytorch formums)

  • 最后一个来源(4)的方法对我来说似乎是最干净的,但我仍然不确定我是否正确理解了该线程。我是否使用了来自 LSTM 和反向 LSTM 的正确最终隐藏状态?这是我的实现
    # pos contains indices of words in embedding matrix
    # seqlengths contains info about sequence lengths
    # so for instance, if batch_size is 2 and pos=[4,6,9,3,1] and 
    # seqlengths contains [3,2], we have batch with samples
    # of variable length [4,6,9] and [3,1]
    
    all_in_embs = self.in_embeddings(pos)
    in_emb_seqs = pack_sequence(torch.split(all_in_embs, seqlengths, dim=0))
    output,lasthidden = self.rnn(in_emb_seqs)
    if not self.data_processor.use_gru:
        lasthidden = lasthidden[0]
    # u_emb_batch has shape batch_size x embedding_dimension
    # sum last state from forward and backward  direction
    u_emb_batch = lasthidden[-1,:,:] + lasthidden[-2,:,:]
    

    这是正确的吗?

    最佳答案

    在一般情况下,如果您想创建自己的 BiLSTM 网络,您需要创建两个常规 LSTM,并使用常规输入序列馈送一个,另一个使用反向输入序列馈送。在完成两个序列的输入后,您只需从两个网络中获取最后一个状态并以某种方式将它们联系在一起(求和或连接)。

    据我了解,您正在使用 this example 中的内置 BiLSTM(在 nn.LSTM 构造函数中设置 bidirectional=True)。然后您在输入批次后获得连接的输出,因为 PyTorch 会为您处理所有麻烦。

    如果是这种情况,并且您想对隐藏状态求和,那么您必须

    u_emb_batch = (lasthidden[0, :, :] + lasthidden[1, :, :])
    

    假设你只有一层。如果你有更多层,你的变体看起来更好。

    这是因为结果是结构化的(参见 documentation ):

    h_n of shape (num_layers * num_directions, batch, hidden_size): tensor containing the hidden state for t = seq_len



    顺便一提,
    u_emb_batch_2 = output[-1, :, :HIDDEN_DIM] + output[-1, :, HIDDEN_DIM:]
    

    应该提供相同的结果。

    关于python - 从 PyTorch 中的 BiLSTM (BiGRU) 获取最后一个状态,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50856936/

    相关文章:

    pytorch - PyTorch 中复数的矩阵乘法

    pytorch - 带填充掩码的 TransformerEncoder

    python - 我无法导入 CSV 文件?

    python - 如何在没有 Setter 的情况下使用 Getter

    python - 使用 COBYLA 方法进行盆地跳跃似乎可以忽略约束

    python - Tensorflow:尝试使用未初始化的值 beta1_power

    python - 使用 Keras io 进行最简单的 Lstm 训练

    python - 带 x 值标签的 seaborn 条形图(无色调)

    machine-learning - 在 Keras 中,什么时候应该使用 input_shape 而不是 input_dim?

    python - 如何沿轴指数衰减值?