python - Pytorch:了解 nn.Module 类如何在内部工作

标签 python deep-learning pytorch

一般情况下,nn.Module可以由子类继承,如下所示。

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)  # 

class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.fc1 = nn.Linear(20, 1)
        self.apply(init_weights)

    def forward(self, x):
        x = self.fc1(x)
        return x

我的第一个问题是,为什么我甚至可以简单地运行下面的代码 __init__ training_signals 没有任何定位参数它看起来像 training_signals传递给 forward()方法。它是如何工作的?
model = LinearRegression()
training_signals = torch.rand(1000,20)
model(training_signals)

第二个问题是self.apply(init_weights)内部工作?是否在调用 forward 之前执行方法?

最佳答案

Q1: Why I can simply run the code below even my __init__ doesn't have any positional arguments for training_signals and it looks like that training_signals is passed to forward() method. How does it work?



一、__init__运行此行时调用:

model = LinearRegression()

如您所见,您不传递任何参数,也不应该传递。您的签名__init__与基类之一相同(运行时调用 super(LinearRegression, self).__init__() )。如您所见here , nn.Module的初始化签名只是 def __init__(self) (就像你的一样)。

二、model现在是一个对象。当您运行以下行时:

model(training_signals)

您实际上是在调用 __call__方法和传递training_signals作为位置参数。如您所见here ,除此之外,__call__方法调用 forward方法:

result = self.forward(*input, **kwargs)

传递 __call__ 的所有参数(位置和命名)到forward .

Q2: How does self.apply(init_weights) internally work? Is it executed before calling forward method?



PyTorch 是开源的,所以你可以简单地去源代码并检查它。如您所见here ,实现非常简单:

def apply(self, fn):
    for module in self.children():
        module.apply(fn)
    fn(self)
    return self

引用该函数的文档:它“将 fn 递归地应用于每个子模块(由 .children() 返回)以及 self ”。基于实现,还可以理解需求:
  • fn必须是可调用的;
  • fn仅接收 Module 作为输入目的;
  • 关于python - Pytorch:了解 nn.Module 类如何在内部工作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58795601/

    相关文章:

    python - 什么是适用于 App Engine 上 python 的良好 SOAP 客户端库?

    python - 计算日期列表中每月错误 bool 值的数量?

    python - DataFrame 创建 - 重新索引

    python - 如何使用 tf.train.Checkpoint 保存大量变量

    tensorflow - 为什么要在 Keras 上使用纯 TensorFlow?

    python - 无法使用 2 层多层感知器 (MLP) 学习 XOR 表示

    python - 生成数字表

    tensorflow - 第一个训练时期很慢

    python - Torch 数据集循环太远

    theano - pytorch 中 theano.tensor.clip 的等价物是什么?