python - 有没有办法检索在随机 Torchvision 变换中使用的特定参数?

标签 python pytorch affinetransform data-augmentation torchvision

我可以在训练期间通过应用随机变换(旋转/平移/重新缩放)来增加我的数据,但我不知道选择的值。
我需要知道应用了哪些值。我可以手动设置这些值,但是我失去了 Torch Vision 转换提供的很多好处。
是否有一种简单的方法可以让这些值以一种明智的方式在培训期间应用?
这是一个例子。我希望能够打印出旋转角度,在每个图像上应用平移/重新缩放:

import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms


RandAffine = transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.8, 1.2))

rotate = transforms.RandomRotation(degrees=45)
shift = RandAffine
composed = transforms.Compose([rotate,
                               shift])

# Apply each of the above transforms on sample.
fig = plt.figure()
sample = np.zeros((28,28))
sample[5:15,7:20] = 255
sample = transforms.ToPILImage()(sample.astype(np.uint8))
title = ['None', 'Rot','Aff','Comp']
for i, tsfrm in enumerate([None,rotate, shift, composed]):
    if tsfrm:
        t_sample = tsfrm(sample)
    else:
        t_sample = sample
    ax = plt.subplot(1, 5, i + 2)
    plt.tight_layout()
    ax.set_title(title[i])
    ax.imshow(np.reshape(np.array(list(t_sample.getdata())), (-1,28)), cmap='gray')    

plt.show()

最佳答案

恐怕没有简单的方法可以解决这个问题:Torchvision 的随机变换实用程序的构建方式是在调用时对变换参数进行采样。它们是唯一的随机变换,因为 (1) 用户无法访问使用的参数和 (2) 相同的随机变换是 不是 可重复。
从 Torchvision 0.8.0 开始,随机变换通常使用两个主要功能构建:

  • get_params :这将基于变换的超参数进行采样(您在初始化变换运算符时提供的内容,即参数的值范围)
  • forward :应用转换时执行的函数。重要的部分是它从 get_params 获取参数。然后使用关联的确定性函数将其应用于输入。对于 RandomRotation , F.rotate 会被调用。同样, RandomAffine 将使用 F.affine .

  • 您的问题的一种解决方案是从 get_params 中采样参数。自己并调用功能性 - 确定性 - API。所以你不会使用 RandomRotation , RandomAffine ,也没有任何其他 Random*转型。

    例如,让我们看看 T.RandomRotation (为了简洁起见,我删除了评论)。
    class RandomRotation(torch.nn.Module):
        def __init__(
            self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, 
            center=None, fill=None, resample=None):
            # ...
    
        @staticmethod
        def get_params(degrees: List[float]) -> float:
            angle = float(torch.empty(1).uniform_(float(degrees[0]), \
                float(degrees[1])).item())
            return angle
    
        def forward(self, img):
            fill = self.fill
            if isinstance(img, Tensor):
                if isinstance(fill, (int, float)):
                    fill = [float(fill)] * F._get_image_num_channels(img)
                else:
                    fill = [float(f) for f in fill]
            angle = self.get_params(self.degrees)
    
            return F.rotate(img, angle, self.resample, self.expand, self.center, fill)
    
        def __repr__(self):
            # ...
    
    考虑到这一点,这里有一个可能的覆盖来修改 T.RandomRotation :
    class RandomRotation(T.RandomRotation):
        def __init__(*args, **kwargs):
            super(RandomRotation, self).__init__(*args, **kwargs) # let super do all the work
    
            self.angle = self.get_params(self.degrees) # initialize your random parameters
    
        def forward(self): # override T.RandomRotation's forward
            fill = self.fill
            if isinstance(img, Tensor):
                if isinstance(fill, (int, float)):
                    fill = [float(fill)] * F._get_image_num_channels(img)
                else:
                    fill = [float(f) for f in fill]
    
            return F.rotate(img, self.angle, self.resample, self.expand, self.center, fill)
    
    我基本上复制了 T.RandomRotationforward函数,唯一不同的是参数在__init__中采样(即一次)而不是在 forward 内(即每次通话)。 Torchvision 的实现涵盖了所有情况,您通常不需要复制完整的 forward .在某些情况下,您可以直接调用功能版本。例如,如果您不需要设置 fill参数,您可以丢弃该部分并仅使用:
    class RandomRotation(T.RandomRotation):
        def __init__(*args, **kwargs):
            super(RandomRotation, self).__init__(*args, **kwargs) # let super do all the work
    
            self.angle = self.get_params(self.degrees) # initialize your random parameters
    
        def forward(self): # override T.RandomRotation's forward
            return F.rotate(img, self.angle, self.resample, self.expand, self.center)
    

    如果您想覆盖其他随机变换,您可以查看 the source code . API 是不言自明的,您应该不会在为每个转换实现覆盖时遇到太多问题。

    关于python - 有没有办法检索在随机 Torchvision 变换中使用的特定参数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65906171/

    相关文章:

    python - 将数组映射到更高维度 - Python

    java - Google Wave 沙盒

    python - 如何从图像列表 Pytorch 加载数据集

    python-3.x - 类型错误 : backward() got an unexpected keyword argument 'grad_tensors' in pytorch

    OpenCV 从一个三角形变形到另一个三角形

    python - 当参数是 Pandas DataFrame 等包对象时的 Docstrings Python 函数

    python - 检查窗口是否在后台 Tkinter

    c++ - 将找到的变换矩阵推广到整个图像

    python - _C.cpython-38-x86_64-linux-gnu.so : undefined symbol: _ZN6caffe28TypeMeta21_typeMetaDataInstanceIdEEPKNS_6detail12TypeMetaDataEv

    c++ - OpenCV:estimateAffine3D 抛出一个难以破译的异常