我尝试在pytorch中使用view(),但我无法输入3个参数。我不知道为什么它一直给出这个错误?谁能帮我这个?
def forward(self, input):
lstm_out, self.hidden = self.lstm(input.view(len(input), self.batch_size, -1))
最佳答案
看起来您的输入
是一个numpy数组,而不是torch张量。您需要先对其进行转换,例如 input = torch.Tensor(input)
。
关于lstm - 类型错误 : view() takes at most 2 arguments (3 given),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55805242/