python - 如何将 model.state_dict() 存储在临时变量中以供以后使用?

标签 python pytorch

我尝试将模型的状态字典临时存储在一个变量中,并希望稍后将其恢复到我的模型中,但该变量的内容会随着模型更新而自动更改。

有一个最小的例子:

import torch as t
import torch.nn as nn
from torch.optim import Adam


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(3, 2)

    def forward(self, x):
        return self.fc(x)


net = Net()
loss_fc = nn.MSELoss()
optimizer = Adam(net.parameters())

weights = net.state_dict()
print(weights)

x = t.rand((5, 3))
y = t.rand((5, 2))
loss = loss_fc(net(x), y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(weights)

我认为两个输出是相同的,但我得到了(输出可能由于随机初始化而改变)

OrderedDict([('fc.weight', tensor([[-0.5557,  0.0544, -0.2277],
        [-0.0793,  0.4334, -0.1548]])), ('fc.bias', tensor([-0.2204,  0.2846]))])
OrderedDict([('fc.weight', tensor([[-0.5547,  0.0554, -0.2267],
        [-0.0783,  0.4344, -0.1538]])), ('fc.bias', tensor([-0.2194,  0.2856]))])

weights的内容变了,太奇怪了。

我还尝试了 .copy()t.no_grad() 如下,但它们没有帮助。

with t.no_grad():
    weights = net.state_dict().copy()

是的,我知道我可以使用 t.save() 保存状态字典,但我只想弄清楚前面的示例中发生了什么。

我使用的是 Python 3.8.5Pytorch 1.8.1

感谢您的帮助。

最佳答案

这就是OrderedDict 的工作原理。这是一个更简单的示例:

from collections import OrderedDict

# a mutable variable
l = [1,2,3]

# an OrderedDict with an entry pointing to that mutable variable
x = OrderedDict([("a", l)])

# if you change the list
l[1] = 20

# the change is reflected in the OrderedDict
print(x)
# >> OrderedDict([('a', [1, 20, 3])])

如果您想避免这种情况,则必须进行深层复制,而不是浅层复制:

from copy import deepcopy
x2 = deepcopy(x)

print(x2)
# >> OrderedDict([('a', [1, 20, 3])])

# now, if you change the list
l[2] = 30

# you do not change your copy
print(x2)
# >> OrderedDict([('a', [1, 20, 3])])

# but you keep changing the original dict
print(x)
# >> OrderedDict([('a', [1, 20, 30])])

由于 Tensor 也是可变的,因此在您的情况下预计会有相同的行为。因此,您可以使用:

from copy import deepcopy

weights = deepcopy(net.state_dict())

关于python - 如何将 model.state_dict() 存储在临时变量中以供以后使用?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67161171/

相关文章:

python - 值错误 : cannot copy sequence with size 821 to array axis with dimension 7

python - 我应该如何键入提示一个也可以是无限的整数变量?

python - 占星查询 SIMBAD : Obtaining coordinates for all frames

python - 如何在 PyTorch 中查找和理解 autograd 源代码

python - 在 matplotlib 中显示张量图像

pytorch - PyTorch 中的可微分图像压缩操作

python - 如何向 KivyMD MDDialog 中的按钮添加操作?

Python - 模拟导入字典

pytorch - TorchServe MAR 每个模型有多个 Python 文件

python - 如何从检查点文件加载微调的 pytorch Huggingface bert 模型?