python - 我怎样才能拥有 PyTorch 模块的子模块而不是模块的属性

标签 python pytorch

我想要一个 Module 的 PyTorch 子类,它将子模块保存在列表中(因为根据构造函数的参数,子模块的数量可能是可变的)。我通过以下方式设置此列表:

self.hidden_layers = [torch.nn.Linear(i, o) for i, o in pairwise(self.layer_sizes)]

根据thisthis问题是,当将 Module 对象分配给 self 的属性时,子模块仅由 __setattr__ 注册。由于 hidden_​​layers 未分配 Module 类型的对象,因此列表中的子模块不会注册为子模块,因此 self.parameters() 不会迭代子模块的参数。

我想我可以为列表中的每个元素显式调用__subattr__,但这会非常难看。是否有更正确的方法来注册不是 Module 直接属性的子模块?

最佳答案

使用nn.ModuleList

self.hidden_layers = nn.ModuleList([torch.nn.Linear(i, o) for i, o in pairwise(self.layer_sizes)])

关于python - 我怎样才能拥有 PyTorch 模块的子模块而不是模块的属性,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63681985/

相关文章:

python - 如何使用 Moviepy 和 Pygame 播放 mp4 电影

docker - 当我使用 ssh 或 exec/attach 连接到 Docker 容器时,为什么 PATH 不同?

cmake - 链接静态库 pytorch 在构建过程中找不到其内部函数

python - 如何在Python中检查视频是否有声音?

c++ - lxml._ElementTree.getpath(element) 返回 "*"而不是非默认 namespace 中元素的标签名称

yaml - 在 conda YAML 文件中为 pytorch 指定 cpu-only

pytorch - 在 Pytorch 中裁剪一小批图像——每个图像都不同

numpy - 是否有一个 torch 函数可以导出两个张量的并集?

iterator - 使用 python 连接上一句和下一句

python - pandas(水平)堆叠条形,每个条形段排序