python - 推导 pytorch 网络的结构

标签 python neural-network pytorch tensor

对于我的用例,我需要能够采用 pytorch 模块并解释模块中的层顺序,以便我可以以某种文件格式在层之间创建“连接”。现在假设我有一个简单的模块,如下所示

class mymodel(nn.Module):
    def __init__(self, input_channels):
        super(mymodel, self).__init__()
        self.fc = nn.Linear(input_channels, input_channels)
    def forward(self, x):
        out = self.fc(x)
        out += x
        return out


if __name__ == "__main__":
    net = mymodel(5)

    for mod in net.modules():
        print(mod) 

这里的输出结果:

mymodel(
  (fc): Linear(in_features=5, out_features=5, bias=True)
)
Linear(in_features=5, out_features=5, bias=True)

如您所见,有关加等于运算或加运算的信息未被捕获,因为它不是前向函数中的 nnmodule。我的目标是能够从 pytorch 模块对象创建图形连接,以在 json 中表达如下内容:

layers {
"fc": {
"inputTensor" : "t0",
"outputTensor": "t1"
}
"addOp" : {
"inputTensor" : "t1",
"outputTensor" : "t2"
}
}

输入张量名称是任意的,但它捕获了图的本质以及层之间的连接。

我的问题是,有没有办法从 pytorch 对象中提取信息?我本来想使用 .modules() 但后来意识到手写操作不能以这种方式捕获为模块。我想如果一切都是 nn.module 那么 .modules() 可能会给我网络层安排。在这里寻求一些帮助。我希望能够知道张量之间的连接以创建上述格式。

最佳答案

您要查找的信息并不存储在nn.Module中,而是存储在输出张量的grad_fn属性中:

model = mymodel(channels)
pred = model(torch.rand((1, channels))
pred.grad_fn  # all the information is in the computation graph of the output tensor

提取此信息并非易事。您可能想查看torchvizgrad_fn 信息绘制漂亮图表的软件包。

关于python - 推导 pytorch 网络的结构,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58253003/

相关文章:

python - PyQt 在(Qt?)更新后启动时崩溃

python - Keras:向 LSTM 网络添加 MDN 层

c++ - Libtorch C++ 和 pytorch 的不同输出

python - 使用 Pytorch 实现 FFT

python - 如何按照链接列表从 scrapy 的页面获取数据?

用于分类的 Python 向量化

python - 无法使用 ghostdriver 启动 phantomjs

python - 将自定义阈值与 tf.contrib.learn.DNNClassifier 一起使用?

java - Encog ImageNeuralNetwork 错误不会下降

python - torch.empty 如何计算这些值?