python - 将输入提供给 Pytorch LSTM 网络时出现 AttributeError : 'tuple' object has no attribute 'dim' ,

标签 python tuples lstm pytorch torch

我正在尝试运行以下代码:

import matplotlib.pylab as plt
import numpy as np
import torch
import torch.nn as nn

class LSTM(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(input_shape, 12)
        self.hidden2tag = nn.Linear(12, n_actions)

    def forward(self, x):
        out = self.lstm(x)
        out = self.hidden2tag(out)
        return out


state = [(1,2,3,4,5),(2,3,4,5,6),(3,4,5,6,7),(4,5,6,7,8),(5,6,7,8,9),(6,7,8,9,0)]

device = torch.device("cuda")
net = LSTM(5, 3).to(device)

state_v = torch.FloatTensor(state).to(device)

q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
_, action = int(torch.max(q_vals_v, dim=1).item())

然后返回这个错误:

Traceback (most recent call last):
  File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/QuestionStackoverflow.py", line 26, in <module>
    q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
  File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/QuestionStackoverflow.py", line 15, in forward
    out = self.hidden2tag(out)
  File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/linear.py", line 55, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/functional.py", line 1022, in linear
    if input.dim() == 2 and bias is not None:
AttributeError: 'tuple' object has no attribute 'dim'

有人知道怎么解决吗? (摆脱张量是一个元组,以便它可以被送入 LSTM 网络)

最佳答案

pytorch LSTM 返回一个元组。
所以你会得到这个错误,因为你的线性层 self.hidden2tag 无法处理这个元组。

所以改变:

out = self.lstm(x)

out, states = self.lstm(x)

这将修复您的错误,方法是拆分元组,使 out 只是您的输出张量。

out 然后存储隐藏状态,而 states 是另一个包含最后隐藏状态和单元格状态的元组。

你也可以在这里看看:
https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM

您将在最后一行收到另一个错误,因为 max() 也返回一个元组。但这应该很容易修复并且是不同的错误:)

关于python - 将输入提供给 Pytorch LSTM 网络时出现 AttributeError : 'tuple' object has no attribute 'dim' ,,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53032586/

相关文章:

php - 从 PHP 脚本将数据传递给 Python Web 爬虫

c# - 如何序列化包含基本类型和基本类型数组的元组?

python - Python 中的 *tuple 和 **dict 是什么意思?

python - 在 LSTM 中包含分类特征和序列以进行序列预测的最佳实践?

python - pip 安装失败,出现 TypeError : identify() got an unexpected keyword argument 'requirement_or_candidate'

python - 从 html 表中获取数据并将其发送到 Pyramid 中的 View

python - 字典 Python 中 If-Condition 上的 KeyError

python-3.x - 如何使用 python 中的模型将输入的 0 值替换为预测值

machine-learning - 如何在时间步数过多的时间序列数据上训练 LSTM?

python - 如何在删除关联的 python 解释器后重建 virtualenvs