对于我的用例,我需要能够采用 pytorch 模块并解释模块中的层顺序,以便我可以以某种文件格式在层之间创建“连接”。现在假设我有一个简单的模块,如下所示
class mymodel(nn.Module):
def __init__(self, input_channels):
super(mymodel, self).__init__()
self.fc = nn.Linear(input_channels, input_channels)
def forward(self, x):
out = self.fc(x)
out += x
return out
if __name__ == "__main__":
net = mymodel(5)
for mod in net.modules():
print(mod)
这里的输出结果:
mymodel(
(fc): Linear(in_features=5, out_features=5, bias=True)
)
Linear(in_features=5, out_features=5, bias=True)
如您所见,有关加等于运算或加运算的信息未被捕获,因为它不是前向函数中的 nnmodule。我的目标是能够从 pytorch 模块对象创建图形连接,以在 json 中表达如下内容:
layers {
"fc": {
"inputTensor" : "t0",
"outputTensor": "t1"
}
"addOp" : {
"inputTensor" : "t1",
"outputTensor" : "t2"
}
}
输入张量名称是任意的,但它捕获了图的本质以及层之间的连接。
我的问题是,有没有办法从 pytorch 对象中提取信息?我本来想使用 .modules() 但后来意识到手写操作不能以这种方式捕获为模块。我想如果一切都是 nn.module 那么 .modules() 可能会给我网络层安排。在这里寻求一些帮助。我希望能够知道张量之间的连接以创建上述格式。
最佳答案
您要查找的信息并不存储在nn.Module
中,而是存储在输出张量的grad_fn
属性中:
model = mymodel(channels)
pred = model(torch.rand((1, channels))
pred.grad_fn # all the information is in the computation graph of the output tensor
提取此信息并非易事。您可能想查看torchviz从 grad_fn
信息绘制漂亮图表的软件包。
关于python - 推导 pytorch 网络的结构,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58253003/