PyTorch:如何将 torchvision.transforms.AugMIx 与 torch.float32 一起使用?
我尝试使用 torchvision.transforms.AugMIx 在图像数据集中应用数据增强,但出现以下错误:TypeError:仅支持 torch.uint8 图像张量,但发现 torch.float32。 我尝试将其转换为 int,但出现另一个错误。
我尝试使用 AugMix 函数的代码:
transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((224, 224)), # resize to 224*224
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # normalization
torchvision.transforms.AugMix()
]
)
to_tensor = torchvision.transforms.ToTensor()
Image.MAX_IMAGE_PIXELS = None
class BreastDataset(torch.utils.data.Dataset):
def __init__(self, json_path, data_dir_path='./dataset', clinical_data_path=None, is_preloading=True):
self.data_dir_path = data_dir_path
self.is_preloading = is_preloading
with open(json_path) as f:
print(f"load data from {json_path}")
self.json_data = json.load(f)
def __len__(self):
return len(self.json_data)
def __getitem__(self, index):
label = int(self.json_data[index]["label"])
patient_id = self.json_data[index]["id"]
patch_paths = self.json_data[index]["patch_paths"]
data = {}
if self.is_preloading:
data["bag_tensor"] = self.bag_tensor_list[index]
else:
data["bag_tensor"] = self.load_bag_tensor([os.path.join(self.data_dir_path, p_path) for p_path in patch_paths])
data["label"] = label
data["patient_id"] = patient_id
data["patch_paths"] = patch_paths
return data
def load_bag_tensor(self, patch_paths):
"""Load a bag data as tensor with shape [N, C, H, W]"""
patch_tensor_list = []
for p_path in patch_paths:
patch = Image.open(p_path).convert("RGB")
patch_tensor = transform(patch) # [C, H, W]
patch_tensor = torch.unsqueeze(patch_tensor, dim=0) # [1, C, H, W]
patch_tensor_list.append(patch_tensor)
bag_tensor = torch.cat(patch_tensor_list, dim=0) # [N, C, H, W]
return bag_tensor
感谢任何帮助!预先感谢您!
最佳答案
对我来说,首先应用 AugMix
,然后使用 ToTensor()
工作
transformation = transforms.Compose([
transforms.AugMix(severity= 6,mixture_width=2),
transforms.ToTensor(),
transforms.RandomErasing(),
transforms.RandomGrayscale(p = 0.35)
])
关于PyTorch:如何将 torchvision.transforms.AugMIx 与 torch.float32 一起使用?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/73754867/