python - 累积梯度

标签 python pytorch gradient-descent

我想在向后传递之前累积梯度。所以想知道正确的做法是什么。根据 this article
它的:

model.zero_grad()                                   # Reset gradients tensors
for i, (inputs, labels) in enumerate(training_set):
    predictions = model(inputs)                     # Forward pass
    loss = loss_function(predictions, labels)       # Compute loss function
    loss = loss / accumulation_steps                # Normalize our loss (if averaged)
    loss.backward()                                 # Backward pass
    if (i+1) % accumulation_steps == 0:             # Wait for several backward steps
        optimizer.step()                            # Now we can do an optimizer step
        model.zero_grad()

而我预计它是:
model.zero_grad()                                   # Reset gradients tensors
loss = 0
for i, (inputs, labels) in enumerate(training_set):
    predictions = model(inputs)                     # Forward pass
    loss += loss_function(predictions, labels)       # Compute loss function                              
    if (i+1) % accumulation_steps == 0:             # Wait for several backward steps
        loss = loss / accumulation_steps            # Normalize our loss (if averaged)
        loss.backward()                             # Backward pass
        optimizer.step()                            # Now we can do an optimizer step
        model.zero_grad()     
        loss = 0  

我累积损失,然后除以累积步骤以求平均。

第二个问题,如果我是对的,考虑到我只在每个累积步骤中进行反向传递,您是否希望我的方法更快?

最佳答案

所以根据答案here ,第一种方法是内存高效的。两种方法所需的工作量或多或少相同。

第二种方法不断累积图形,因此需要 accumulation_steps多倍的内存。第一种方法直接计算梯度(并简单地添加梯度),因此需要较少的内存。

关于python - 累积梯度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53331540/

相关文章:

Pytorch,无法获得 <class 'torch.Tensor' > 的 repr

machine-learning - 随机梯度下降增加成本函数

python - Kivy:将数据传递给另一个类

python - 按正确的顺序将字典写入 csv 文件

python - pytorch自定义数据集: DataLoader returns a list of tensors rather than tensor of a list

使用 .detach() 的 Pytorch DQN、DDQN 导致非常大的损失(呈指数增长)并且根本不学习

machine-learning - 为什么用caffe训练时 `Train net output`损失和 `iteration loss`是相同的?

algorithm - 直观理解 Adam 优化器

python - 使用 bs4 进行网页抓取验证

python - Matplotlib,避免 plot_trisurf() 中不需要的三角形