我创建了以下带有 dropout 层的深度网络,如下所示:
class QNet_dropout(nn.Module):
"""
A MLP with 2 hidden layer and dropout
observation_dim (int): number of observation features
action_dim (int): Dimension of each action
seed (int): Random seed
"""
def __init__(self, observation_dim, action_dim, seed):
super(QNet_dropout, self).__init__()
self.seed = torch.manual_seed(seed)
self.fc1 = nn.Linear(observation_dim, 128)
self.fc2 = nn.Dropout(0.5)
self.fc3 = nn.Linear(128, 64)
self.fc4 = nn.Dropout(0.5)
self.fc5 = nn.Linear(64, action_dim)
def forward(self, observations):
"""
Forward propagation of neural network
"""
x = F.relu(self.fc1(observations))
x = F.linear(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.linear(self.fc4(x))
x = self.fc5(x)
return x
但是,当我尝试运行代码时,出现以下错误:
/home/workspace/QNetworks.py in forward(self, observations)
90
91 x = F.relu(self.fc1(observations))
---> 92 x = F.linear(self.fc2(x))
93 x = F.relu(self.fc3(x))
94 x = F.linear(self.fc4(x))
TypeError: linear() missing 1 required positional argument: 'weight'
看来我没有正确使用/转发 dropout 层。对 dropout 层进行 Forward 的正确方法应该是什么?谢谢!
最佳答案
F.linear() 函数使用不正确。您应该使用您指定的线性函数而不是 torch.nn.function。 Dropout层应该在Relu之后。您可以从 torch.nn.function 调用 Relu 函数。
import torch
import torch.nn.functional as F
class QNet_dropout(nn.Module):
"""
A MLP with 2 hidden layer and dropout
observation_dim (int): number of observation features
action_dim (int): Dimension of each action
seed (int): Random seed
"""
def __init__(self, observation_dim, action_dim, seed):
super(QNet_dropout, self).__init__()
self.seed = torch.manual_seed(seed)
self.fc1 = nn.Linear(observation_dim, 128)
self.fc2 = nn.Dropout(0.5)
self.fc3 = nn.Linear(128, 64)
self.fc4 = nn.Dropout(0.5)
self.fc5 = nn.Linear(64, action_dim)
def forward(self, observations):
"""
Forward propagation of neural network
"""
x = self.fc2(F.relu(self.fc1(observations)))
x = self.fc4(F.relu(self.fc3(x)))
x = self.fc5(x)
return x
observation_dim = 512
model = QNet_dropout(observation_dim, 10, 512)
batch_size = 8
inpt = torch.rand(batch_size, observation_dim)
output = model(inpt)
print ("output shape: ", output.shape)
关于python - 如何正确转发dropout层,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56401266/