有人可以告诉我 forward()
方法中多个参数背后的概念吗?
一般来说,forward()
方法的实现有两个参数
- 自己
- 输入
如果前向方法的参数多于这些参数,PyTorch 如何使用前向方法。
让我们考虑一下这个代码库: https://github.com/bamps53/kaggle-autonomous-driving2019/blob/master/models/centernet.py 网上有 236 位作者使用了带有两个以上参数的前向方法:
- 中心
- return_embeddings
我找不到一篇文章可以回答我对第 254 行(return_embeddings:
)和第 257 行(if center is not None:
)条件的查询将执行。据我所知,该方法由 nn 模块内部调用。有人可以帮我点亮一下吗?
最佳答案
转发功能由您设置。这意味着您可以根据需要添加更多参数。例如,您可以添加如下所示的输入
def forward(self, input1, input2, input3):
x = self.layer1(input1)
y = self.layer2(input2)
z = self.layer3(input3)
net = torch.cat((x,y,z),1)
return net
关键点是您必须在为网络提供数据时控制参数。图层只能输入一个参数。因此,您需要从输入中一一提取特征,并将其与 torch.cat((x,y),1)
(1 表示维度)连接起来。
关于python - pytorch中当输入参数超过两个时如何使用forward()方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60463821/