python - 我的 PyTorch 转发函数可以执行其他操作吗?

标签 python machine-learning pytorch autograd

通常,forward 函数将一堆层串在一起并返回最后一层的输出。在返回之前,我可以在最后一层之后进行一些额外的处理吗?例如,通过 .view 进行一些标量乘法和 reshape ?

我知道 autograd 会以某种方式计算出渐变。所以我不知道我的额外处理是否会以某种方式搞砸。谢谢。

最佳答案

通过 computational graph 跟踪梯度张量,而不是通过函数。只要你的张量有requires_grad=True属性(property)及其grad不是 None 你可以(几乎)做任何你喜欢的事情,并且仍然能够反向传播。
只要您使用 pytorch 的操作(例如 herehere 中列出的操作)就应该没问题。

有关更多信息,请参阅 this .

例如(取自 torchvision's VGG implementation ):

class VGG(nn.Module):

    def __init__(self, features, num_classes=1000, init_weights=True):
        super(VGG, self).__init__()
        #  ...

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)  # <-- what you were asking about
        x = self.classifier(x)
        return x

更复杂的示例可以在 torchvision's implementation of ResNet 中看到。 :

class Bottleneck(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        # ...

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:    # <-- conditional execution!
            identity = self.downsample(x)

        out += identity  # <-- inplace operations
        out = self.relu(out)

        return out

关于python - 我的 PyTorch 转发函数可以执行其他操作吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60523638/

相关文章:

python - 如何在 Django 中下载 CSV 文件?

python - 获取 CSV 的长度以显示进度

python - 如何将 PyTorch 张量的每一行中的重复值清零?

python - xgboost.cv 给出 TypeError : 'StratifiedKFold' object is not iterable

python - Pytorch 的数据加载器 shuffle 何时发生?

pytorch - 激活梯度惩罚

python - 比较python中的大量字典列表

python - 如何将 html 表转换为 pandas 数据框

r - r 中的最佳簇数

image-processing - convert_imageset.cpp 指南