python - 为什么我的 pytorch NN 返回一个 nan 的张量?

标签 python deep-learning neural-network pytorch

我有一个非常简单的神经网络,它采用扁平化的 6x6 网格作为输入,并应输出要对该网格采取的四个 Action 的值,即 1x4 的值张量。

有时由于某种原因运行几次后我得到一个 1x4 的 nan 张量

tensor([[nan, nan, nan, nan]], grad_fn=<ReluBackward0>)

我的模型看起来像这样,输入暗淡为 36,输出暗淡为 4:

class Model(nn.Module):
    def __init__(self, input_dim, output_dim):
        # super relates to nn.Module so this initializes nn.Module
        super(Model, self).__init__()
        # Gridsize as input,
        # last layer needs 4 outputs because of 4 possible actions: left, right, up, down
        # output values are Q Values need activation function for those like argmax
        self.lin1 = nn.Linear(input_dim, 24)
        self.lin2 = nn.Linear(24, 24)
        self.lin3 = nn.Linear(24, output_dim)

    # function to feed the input through the net
    def forward(self, x):
        # rectified linear as activation function for the first two layers
        if isinstance(x, np.ndarray):
            x = torch.tensor(x, dtype=torch.float)

        activation1 = F.relu(self.lin1(x))
        activation2 = F.relu(self.lin2(activation1))
        output = F.relu(self.lin3(activation2))

        return output

输入是:

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3333,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3333, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6667]])

nan 输出的可能原因是什么?我该如何解决这些问题?

最佳答案

作为输出的 nan 值仅意味着训练不稳定,这可能有各种可能的原因,包括代码中的各种错误。如果您认为您的代码是正确的,您可以尝试通过降低学习率或使用 gradient clipping 来解决不稳定性问题。 .

关于python - 为什么我的 pytorch NN 返回一个 nan 的张量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66625645/

相关文章:

java - 使用 Neuroph 神经网络进行图像识别?

Python-如何从对象列表中获取日期字段具有最新值的对象?

python - 无法在openERP中导入自定义模块

python - 有效地重新分配列表

python - 如何解压pkl文件?

deep-learning - 在对自定义数据集进行 yolo 训练期间重写框是什么意思?

python - 如何在 pytorch 中实现 Conv2d 的棋盘步幅?

python - 使用 python 的多处理池和映射函数测量进度

machine-learning - 检查目标 : expected dense_8 to have shape (2, 时出错,但得到形状为 (1,) 的数组

machine-learning - 使用学习到的参数来学习其他参数