我是 PyTorch 的初学者,正在尝试基于自定义神经网络类训练 MNIST 模型。我的学习率调度器、损失函数和优化器是:
optimizer = optim.Adam(model.parameters(), lr=0.003)
loss_fn = nn.CrossEntropyLoss()
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
我还为此目的使用学习率调度程序。最初,我的训练循环是这样的:
# this training gives high loss and it doesn't varies that much
def training(epochs):
model.train()
for batch_idx, (imgs, labels) in enumerate(train_loader):
imgs = imgs.to(device=device)
labels = labels.to(device=device)
optimizer.zero_grad()
outputs = model(imgs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
exp_lr_scheduler.step() # inside the loop and after the optimizer
if (batch_idx + 1)% 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(imgs), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.data))
但是这种训练效率并不高,我的损失在每个 epoch 几乎都是一样的。
然后,我最终将训练函数改为:
# this training works perfectly
def training(epochs):
model.train()
exp_lr_scheduler.step() # out of the loop but before optimizer step
for batch_idx, (imgs, labels) in enumerate(train_loader):
imgs = imgs.to(device=device)
labels = labels.to(device=device)
optimizer.zero_grad()
outputs = model(imgs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()
if (batch_idx + 1)% 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, (batch_idx + 1) * len(imgs), len(train_loader.dataset),
100. * (batch_idx + 1) / len(train_loader), loss.data))
现在,它可以正常工作了。我只是不明白其中的原因。 我有两个疑问:
exp_lr_scheduler.step()
不应该在 for 循环中,这样它也会随着每个时期的更新而更新吗? ;和- PyTorch 最新版本表示将
exp_lr_scheduler.step()
保留在optimizer.step()
之后,但在我的训练函数中这样做会给我带来更严重的损失。
这可能是什么原因,还是我做错了?
最佳答案
StepLR 在每step_size之后更新学习率gamma,这意味着如果step_size是7,那么学习率将在每7之后更新通过将当前学习率乘以 gamma 来计算纪元。这意味着在您的代码片段中,每 7 个时期学习率就会变小 10 倍。
您是否尝试过提高起始学习率?我会尝试0.1或0.01。我认为问题可能出在起始学习率的大小上,因为起始点已经很小了。这会导致梯度下降算法(或其衍生物,如 Adam)无法向最小值移动,因为步长太小并且结果保持相同(在函数的同一点处最小化)。
希望有帮助。
关于python - PyTorch 学习调度程序顺序极大地改变了损失,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/72935355/