python - PyTorch 梯度下降

标签 python pytorch

我正在尝试在 PyTorch 中手动实现梯度下降作为学习练习。我有以下内容来创建我的合成数据集:

import torch
torch.manual_seed(0)
N = 100
x = torch.rand(N,1)*5
# Let the following command be the true function
y = 2.3 + 5.1*x
# Get some noisy observations
y_obs = y + 2*torch.randn(N,1)

然后我创建我的预测函数( y_pred ),如下所示。
w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)
y_pred = w*x+b
mse = torch.mean((y_pred-y_obs)**2)

它使用 MSE 来推断权重 w,b .我使用下面的块根据梯度更新值。
gamma = 1e-2
for i in range(100):
  w = w - gamma *w.grad
  b = b - gamma *b.grad
  mse.backward()

但是,循环仅在第一次迭代中起作用。 第二次迭代以后 w.grad设置为 None . 我很确定发生这种情况的原因是因为我将 w 设置为它自己的函数(我可能是错的)。

问题是如何使用梯度信息正确更新权重?

最佳答案

  • 在应用梯度下降之前,您应该调用后向方法。
  • 您需要使用新权重来计算每次迭代的损失。
  • 每次迭代都创建没有梯度带的新张量。

  • 以下代码在我的计算机上运行良好,经过 500 次迭代训练后给出 w=5.1 & b=2.2。

    代码:
    import torch
    torch.manual_seed(0)
    N = 100
    x = torch.rand(N,1)*5
    # Let the following command be the true function
    y = 2.3 + 5.1*x
    # Get some noisy observations
    y_obs = y + 0.2*torch.randn(N,1)
    
    w = torch.randn(1, requires_grad=True)
    b = torch.randn(1, requires_grad=True)
    
    
    gamma = 0.01
    for i in range(500):
        print(i)
        # use new weight to calculate loss
        y_pred = w * x + b
        mse = torch.mean((y_pred - y_obs) ** 2)
    
        # backward
        mse.backward()
        print('w:', w)
        print('b:', b)
        print('w.grad:', w.grad)
        print('b.grad:', b.grad)
    
        # gradient descent, don't track
        with torch.no_grad():
            w = w - gamma * w.grad
            b = b - gamma * b.grad
        w.requires_grad = True
        b.requires_grad = True
    

    输出:
    499
    w: tensor([5.1095], requires_grad=True)
    b: tensor([2.2474], requires_grad=True)
    w.grad: tensor([0.0179])
    b.grad: tensor([-0.0576])
    

    关于python - PyTorch 梯度下降,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52213282/

    相关文章:

    python - 很难找到熟悉的 C for 循环的 Python 3.x 实现

    python - 从命令行运行 PyCharm 项目

    Python 插值错误

    python - Pytorch:更新 numpy 数组而不更新相应的张量

    python - 将 PyTorch 张量转换为 python 列表

    Python Dataset Class + PyTorch Dataloader : Stuck at __getitem__, 测试时如何获取索引、标签等?

    python - 使用 paramiko 防止 SFTP/SSH session 超时

    python - 在 Spark 中广播用户定义的类

    python - 模块未找到错误: No module named ‘tools.nnwrap’ (windows)

    python - 类型错误 : __array__() takes 1 positional argument but 2 were given