我正在为包含多对图像的数据集编写一个简单的转换。作为数据增强,我想对每一对应用一些随机变换,但该对中的图像应该以相同的方式进行变换。
例如,给定一对两个图像 A
和 B
, 如果 A
水平翻转,B
必须水平翻转为 A
.然后下一对C
和 D
应该与 A
不同地转换和 B
但是 C
和 D
以同样的方式转化。我正在以下面的方式尝试
import random
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
img_a = Image.open("sample_ajpg") # note that two images have the same size
img_b = Image.open("sample_b.png")
img_c, img_d = Image.open("sample_c.jpg"), Image.open("sample_d.png")
transform = transforms.RandomChoice(
[transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip()]
)
random.seed(0)
display(transform(img_a))
display(transform(img_b))
random.seed(1)
display(transform(img_c))
display(transform(img_d))
然而,上面的代码并没有选择相同的转换,正如我测试的那样,它取决于次数transform
叫做。有什么办法可以强制
transforms.RandomChoice
在指定时使用相同的转换?
最佳答案
通常的解决方法是对第一幅图像应用变换,检索该变换的参数,然后将这些参数应用到其余图像上的确定性变换。然而,这里 RandomChoice
不提供 API 来获取应用转换的参数,因为它涉及可变数量的转换。
在这些情况下,我通常会覆盖原始函数。
看着torchvision implementation ,就这么简单:
class RandomChoice(RandomTransforms):
def __call__(self, img):
t = random.choice(self.transforms)
return t(img)
这里有两种可能的解决方案。__init__
上的转换列表中进行采样。而不是在 __call__
:import random
import torchvision.transforms as T
class RandomChoice(torch.nn.Module):
def __init__(self):
super().__init__()
self.t = random.choice(self.transforms)
def __call__(self, img):
return self.t(img)
所以你可以这样做:transform = T.RandomChoice([
T.RandomHorizontalFlip(),
T.RandomVerticalFlip()
])
display(transform(img_a)) # both img_a and img_b will
display(transform(img_b)) # have the same transform
transform = T.RandomChoice([
T.RandomHorizontalFlip(),
T.RandomVerticalFlip()
])
display(transform(img_c)) # both img_c and img_d will
display(transform(img_d)) # have the same transform
import random
import torchvision.transforms as T
class RandomChoice(torch.nn.Module):
def __init__(self, transforms):
super().__init__()
self.transforms = transforms
def __call__(self, imgs):
t = random.choice(self.transforms)
return [t(img) for img in imgs]
这允许做:transform = T.RandomChoice([
T.RandomHorizontalFlip(),
T.RandomVerticalFlip()
])
img_at, img_bt = transform([img_a, img_b])
display(img_at) # both img_a and img_b will
display(img_bt) # have the same transform
img_ct, img_dt = transform([img_c, img_d])
display(img_ct) # both img_c and img_d will
display(img_dt) # have the same transform
关于python - PyTorch : How to apply the same random transformation to multiple image?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65447992/