我正在通过回归训练神经网络,但它在测试期间预测恒定值。这就是为什么我想要可视化训练过程中神经网络权重的变化,并在 jupyter Notebook
中查看权重的动态变化。
目前,我的模型如下所示:
import torch
from torch import nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.inp = nn.Linear(2, 40)
self.act1 = nn.Tanh()
self.h1 = nn.Linear(40, 40)
self.act2 = nn.Tanh()
self.h2 = nn.Linear(40, 2)
self.act3 = nn.Tanh()
#self.h3 = nn.Linear(20, 20)
#self.act4=nn.Tanh()
self.h4 = nn.Linear(2, 1)
def forward_one_pt(self, x):
out = self.inp(x)
out = self.act1(out)
out = self.h1(out)
out = self.act2(out)
out = self.h2(out)
out = self.act3(out)
#out = self.h3(out)
#out = self.act4(out)
out = self.h4(out)
return out
def forward(self, config):
E = torch.zeros([config.shape[0], 1])
for i in range(config.shape[0]):
E[i] = self.forward_one_pt(config[i])
# print("config[",i,"] = ",config[i],"E[",i,"] = ",E[i])
return torch.sum(E, 0)
我的主要功能如下所示:
def main() :
learning_rate = 0.5
n_pts = 1000
t_pts = 100
epochs = 15
coords,E = load_data(n_pts,t_pts)
#generating my data to NN
G = get_symm(coords,save,load_symmetry,symmtery_pickle_file,eeta1,eeta2,Rs,ex,lambdaa,zeta,boxl,Rc,pi,E,scale)
net = Net()
if(cuda_flag):
net.cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
net_trained = train(save,text_output,epochs,n_pts,G,E,net,t_pts,optimizer,criterion,out,cuda_flag)
test(save,n_pts,t_pts,G,E,net_trained,out,criterion,cuda_flag)
torch.save(net,save_model)
任何教程或答案都会有帮助
最佳答案
您可以使用model.state_dict()
来查看您的权重是否跨时期更新:
old_state_dict = {}
for key in model.state_dict():
old_state_dict[key] = model.state_dict()[key].clone()
output = model(input)
new_state_dict = {}
for key in model.state_dict():
new_state_dict[key] = model.state_dict()[key].clone()
for key in old_state_dict:
if not (old_state_dict[key] == new_state_dict[key]).all():
print('Diff in {}'.format(key))
else:
print('NO Diff in {}'.format(key))
顺便说一句,您可以对前向函数进行矢量化,而不是对其进行循环。以下将完成与原始转发功能相同的工作,但速度更快:
def forward(self, config):
out= self.forward_one_pt(config)
return torch.sum(out, 0)
关于python - Pytorch:训练时可视化模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57494217/