我想要一个 Module
的 PyTorch 子类,它将子模块保存在列表中(因为根据构造函数的参数,子模块的数量可能是可变的)。我通过以下方式设置此列表:
self.hidden_layers = [torch.nn.Linear(i, o) for i, o in pairwise(self.layer_sizes)]
根据this和 this问题是,当将 Module
对象分配给 self
的属性时,子模块仅由 __setattr__
注册。由于 hidden_layers
未分配 Module
类型的对象,因此列表中的子模块不会注册为子模块,因此 self.parameters()
不会迭代子模块的参数。
我想我可以为列表中的每个元素显式调用__subattr__
,但这会非常难看。是否有更正确的方法来注册不是 Module
直接属性的子模块?
最佳答案
使用nn.ModuleList
。
self.hidden_layers = nn.ModuleList([torch.nn.Linear(i, o) for i, o in pairwise(self.layer_sizes)])
关于python - 我怎样才能拥有 PyTorch 模块的子模块而不是模块的属性,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63681985/