python - 为什么我的简单前馈神经网络发散(pytorch)?

标签 python machine-learning neural-network pytorch

我正在使用 pytorch 试验一个简单的 2 层神经网络,仅输入三个大小为 10 的输入,并以单个值作为输出。我对输入进行了标准化并降低了学习率。据我了解,两层全连接神经网络应该能够简单地适应这些数据

Features:

0.8138  1.2342  0.4419  0.8273  0.0728  2.4576  0.3800  0.0512  0.6872  0.5201
1.5666  1.3955  1.0436  0.1602  0.1688  0.2074  0.8810  0.9155  0.9641  1.3668
1.7091  0.9091  0.5058  0.6149  0.3669  0.1365  0.3442  0.9482  1.2550  1.6950
[torch.FloatTensor of size 3x10]


Targets
[124, 125, 122]
[torch.FloatTensor of size 3]

代码改编自一个简单示例,我使用 MSELoss 作为损失函数。几次迭代后,损失就会发散到无穷大:

features = torch.from_numpy(np.array(features))

x_data = Variable(torch.Tensor(features))
y_data = Variable(torch.Tensor(targets))

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = torch.nn.Linear(10,5)
        self.linear2 = torch.nn.Linear(5,1)

    def forward(self, x):
        l_out1 = self.linear(x)
        y_pred = self.linear2(l_out1)
        return y_pred

model = Model()

criterion = torch.nn.MSELoss(size_average = False)
optim = torch.optim.SGD(model.parameters(), lr = 0.001)

def main():
    for iteration in range(1000):
        y_pred = model(x_data)
        loss = criterion(y_pred, y_data)

        print(iteration, loss.data[0])
        optim.zero_grad()

        loss.backward()
        optim.step()

如有任何帮助,我们将不胜感激。谢谢

编辑:

事实上,这似乎只是因为学习率太高了。设置为 0.00001 可修复收敛问题,尽管收敛速度非常慢。

最佳答案

这是因为您没有在层之间使用非线性,并且您的网络仍然是线性的。

您可以使用 Relu 使其非线性。您可以像这样更改转发方法:

...
y_pred = torch.nn.functional.F.relu(self.linear2(l_out1))
...

关于python - 为什么我的简单前馈神经网络发散(pytorch)?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47500817/

相关文章:

python - 改善简单的1层神经网络

python - Django/Python : generate pdf with the proper language

machine-learning - 感知器为什么将它用于线性可分的事物

python-3.x - Tensorflow 和 Keras 中的相同(?)神经网络架构在相同数据上产生不同的准确性

machine-learning - 使用最少的图像数据设计分类器

tensorflow - 让 LSTM 从 3 个变量的相关性中学习

neural-network - 为什么缩放数据在神经网络(LSTM)中非常重要

python - 使用 Python SDK 创建 Azure 容器时出现“HTTP header 格式不正确”错误

python - 如何在Python中检查key是否存在于values中以及values是否存在于key中

python - 从自定义包中导入大量模块