python - 如何在pytorch对象检测中添加转换

标签 python pytorch object-detection

我是 PyTorch 新手,正在阅读 PyTorch 对象检测文档教程 pytorch docx . 在他们的协作版本中,我进行了以下更改以添加一些转换技术。

  1. 首先修改类PennFudanDataset(torch.utils.data.Dataset)的__getitem__方法
if self.transforms is not None:
   img = self.transforms(img)     
   target = T.ToTensor()(target)
   return img, target

In actual documentation it is 
if self.transforms is not None:
   img, target = self.transforms(img, target)  

其次,在 get_transform(train) 函数处。

def get_transform(train):
  if train:
    transformed = T.Compose([             
           T.ToTensor(),
           T.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),
          T.ColorJitter(brightness=[0.1, 0.2], contrast=[0.1, 0.2], saturation=[0, 0.2], hue=[0,0.5])
    ])
    return transformed

  else:
    return T.ToTensor()

**In the documentation it is-** 
def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

在执行代码时,出现以下错误。我无法理解我做错了什么。

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataset.py", line 272, in __getitem__
    return self.dataset[self.indices[idx]]
  File "<ipython-input-41-94e93ff7a132>", line 72, in __getitem__
    target = T.ToTensor()(target)
  File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 104, in __call__
    return F.to_tensor(pic)
  File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py", line 64, in to_tensor
    raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
TypeError: pic should be PIL Image or ndarray. Got <class 'dict'>

最佳答案

我相信 Pytorch 转换仅适用于图像(在本例中为 PIL 图像或 np 数组),而不适用于标签(根据跟踪是字典)。因此,我认为您不需要像 __getitem__ 函数中的这一行 target = T.ToTensor()(target) 那样“张紧”标签。

关于python - 如何在pytorch对象检测中添加转换,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64905441/

相关文章:

python - io.StringIO 与 Python 3 中的 open()

Google Colab上的PyTorch Geometric CUDA安装问题

python - PyTorch autograd——只能为标量输出隐式创建 grad

opencv - opencv hog.cpp中的 Gamma 校正

python - 文件中的通用字符串替换

python - 在 Python 中,如何避免在从其 __new__ : 中具有 super() 的类派生的类中调用 __init__ 两次

python - 尝试连接到 Azure Sql 数据库时,用户 'user' 登录失败

python - Huggingface Trainer 抛出 AttributeError :'Namespace' 对象没有属性“get_process_log_level”

tensorflow - 重新训练 Tensorflow 对象检测 API

object-detection - 将 .csv 文件转换为 yolo darknet 格式