python - 如何在 Unet 架构 PyTorch 中处理奇数分辨率

标签 python image-processing deep-learning pytorch hourglass

我正在 PyTorch 中实现基于 U-Net 的架构。在火车时间,我有大小 256x256 的补丁这不会造成任何问题。但是在测试时,我有全高清图像( 1920x1080 )。这会在跳过连接期间导致问题。
下采样 1920x1080 3次给240x135 .如果我再下采样一次,分辨率变为 120x68当上采样时给出 240x136 .现在,我无法连接这两个特征图。我该如何解决这个问题?
PS:我认为这是一个相当普遍的问题,但我没有得到任何解决方案,甚至在网络上的任何地方都没有提到这个问题。我错过了什么吗?

最佳答案

这是在解码过程中经常涉及跳跃连接的分割网络中非常常见的问题。网络通常(取决于实际架构)需要输入大小,其边长为最大步幅(8、16、32 等)的整数倍。
主要有两种方式:

  • 将输入调整为最接近的可行尺寸。
  • 将输入填充到下一个更大的可行尺寸。

  • 我更喜欢 (2) 因为 (1) 会导致所有像素的像素级别发生微小变化,从而导致不必要的模糊。请注意,我们通常需要在两种方法之后恢复原始形状。
    我最喜欢这个任务的代码片段(高度/宽度的对称填充):
    import torch
    import torch.nn.functional as F
    
    def pad_to(x, stride):
        h, w = x.shape[-2:]
    
        if h % stride > 0:
            new_h = h + stride - h % stride
        else:
            new_h = h
        if w % stride > 0:
            new_w = w + stride - w % stride
        else:
            new_w = w
        lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
        lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
        pads = (lw, uw, lh, uh)
    
        # zero-padding by default.
        # See others at https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.pad
        out = F.pad(x, pads, "constant", 0)
    
        return out, pads
    
    def unpad(x, pad):
        if pad[2]+pad[3] > 0:
            x = x[:,:,pad[2]:-pad[3],:]
        if pad[0]+pad[1] > 0:
            x = x[:,:,:,pad[0]:-pad[1]]
        return x
    
    一个测试片段:
    x = torch.zeros(4, 3, 1080, 1920) # Raw data
    x_pad, pads = pad_to(x, 16) # Padded data, feed this to your network 
    x_unpad = unpad(x_pad, pads) # Un-pad the network output to recover the original shape
    
    print('Original: ', x.shape)
    print('Padded: ', x_pad.shape)
    print('Recovered: ', x_unpad.shape)
    
    输出:
    Original:  torch.Size([4, 3, 1080, 1920])
    Padded:  torch.Size([4, 3, 1088, 1920])
    Recovered:  torch.Size([4, 3, 1080, 1920])
    
    引用:https://github.com/seoungwugoh/STM/blob/905f11492a6692dd0d0fa395881a8ec09b211a36/helpers.py#L33

    关于python - 如何在 Unet 架构 PyTorch 中处理奇数分辨率,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66028743/

    相关文章:

    python - 按数据类型过滤/更新 Python 数据框

    python 正则表达式,检查变量的最后两个字符

    performance - MATLAB/ Octave : cut a lot of circles from a image

    java - 使用 Play 框架将生成的图像发送到浏览器

    tensorflow - Keras : How should I prepare input data for RNN?

    python - matplotlib:在条形图上绘制多列 Pandas 数据框

    python - 使用 NLP 进行地址拆分

    python - OpenCV 图像处理以在 Python 中裁剪图像的倾斜部分

    machine-learning - 如何在 Caffe 中使用与图像不同的数据?

    machine-learning - Caffe 错误 - 数据转换器检查失败 : datum_channels > 0 (0 vs. 0)