python - Pytorch 中的图像翻译,使用 affine_grid 和 grid_sample 函数

标签 python image pytorch image-rotation affinetransform

我将把图像移动 1 或 2 个像素,因为我在仿射矩阵中指定了一个小数字 (1.25 , 1.9)。

但是,图像被移得很远,就像数百个像素:

enter image description here

(我的输入图像完全充满了黄色菠萝)

下面是一个工作示例。

import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import torch.nn.functional as F

rotation_simple = np.array([[1,0, 1.25],
                           [ 0,1, 1.9]])

#load image
transform = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor()])
dataloader = torch.utils.data.DataLoader(datasets.ImageFolder('/home/Pictures',transform=transform,), shuffle=True)
dtype =  torch.FloatTensor


i = 0
while i<3:
    img, labels = next(iter(dataloader))
    img = img#.double() # 有时候要转为double有时候不用转

    rotation_simple = torch.as_tensor(rotation_simple)[None]

    grid = F.affine_grid(rotation_simple, img.size()).type(dtype)
    x = F.grid_sample(img, grid)

    plt.imshow(x[0].permute(1, 2, 0))
    plt.show()
    i+=1

我想知道为什么该函数将图像移动这么远,而不是在 x 和 y 方向仅移动 1 个像素。

诗。设置“align_corners=True”对于这种情况没有帮助。

页。我的pytorch版本是1.4.0+cu100

最佳答案

网格和仿射变换的“度量单位”不是像素,而是标准化坐标:

grid specifies the sampling pixel locations normalized by the input spatial dimensions. Therefore, it should have most values in the range of [-1, 1]. For example, values x = -1, y = -1 is the left-top pixel of input, and values x = 1, y = 1 is the right-bottom pixel of input.

因此,按[1.25, 1.9]进行平移实际上是对几乎整个图像尺寸进行平移。您需要将平移值除以 2*img.shape 才能获得逐像素平移。

请参阅文档 grid_sample了解更多信息。

关于python - Pytorch 中的图像翻译,使用 affine_grid 和 grid_sample 函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66987451/

相关文章:

python - 如果我没有指定使用 CPU/GPU,我的脚本使用的是哪一个?

numpy - .nu​​mpy() 函数有什么作用?

python - 我收到错误 AttributeError : 'Updater' object has no attribute 'dispatcher'

linux - 创建硬盘镜像后可以再次使用

jquery - 预加载图像并在加载时显示微调器

javascript - 如何检索Cloudinary photoid?

python - 在 PyTorch 中,grad_fn 属性究竟存储了什么以及它是如何使用的?

python - 将一个元素列表交换为 int

python - 从网络缓存数据的好方法(和/或独立于平台的地方)

python - MongoDB 的纯 Python 实现?