在 PyTorch 中,我们可以通过多种方式定义架构。在这里,我想使用 Sequential
创建一个简单的 LSTM 网络。模块。
在 Lua 的火炬中,我通常会选择:
model = nn.Sequential()
model:add(nn.SplitTable(1,2))
model:add(nn.Sequencer(nn.LSTM(inputSize, hiddenSize)))
model:add(nn.SelectTable(-1)) -- last step of output sequence
model:add(nn.Linear(hiddenSize, classes_n))
但是,在 PyTorch 中,我找不到
SelectTable
的等价物。获得最后的输出。nn.Sequential(
nn.LSTM(inputSize, hiddenSize, 1, batch_first=True),
# what to put here to retrieve last output of LSTM ?,
nn.Linear(hiddenSize, classe_n))
最佳答案
首先,我让 i 类提取最后一个单元格输出,如下所示
class extractlastcell(nn.Module):
def forward(self,x):
out , _ = x
return out[:, -1, :]
当我想在你的例子中使用它时,它会是这样的nn.Sequential(
nn.LSTM(inputSize, hiddenSize, 1, batch_first=True),
extractlastcell(),
nn.Linear(hiddenSize, classe_n))
关于deep-learning - PyTorch 中带有 Sequential 模块的简单 LSTM,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44130851/