python - 计算 torch 张量数组的平均值和标准差

标签 python numpy deep-learning pytorch torch

我正在尝试计算 torch 张量数组的平均值和标准差。我的数据集有 720 张训练图像,每张图像都有 4 个地标,其中 X 和 Y 代表图像上的 2D 点。

to_tensor = transforms.ToTensor()

landmarks_arr = []

for i in range(len(train_dataset)):
    landmarks_arr.append(to_tensor(train_dataset[i]['landmarks']))
                     
mean = torch.mean(torch.stack(landmarks_arr, dim=0))#, dim=(0, 2, 3))
std = torch.std(torch.stack(landmarks_arr, dim=0)) #, dim=(0, 2, 3))



print(mean.shape)
print("mean is {} and std is {}".format(mean, std))

结果:

torch.Size([])
mean is nan and std is nan

上面有几个问题:

  1. 为什么 to_tensor 没有转换 0 到 1 之间的值?
  2. 如何正确计算平均值?
  3. 我应该除以 255 吗?

我有:

len(landmarks_arr)
    
720

landmarks_arr[0].shape

torch.Size([1, 4, 2])

landmarks_arr[0]

tensor([[[502.2869, 240.4949],
         [688.0000, 293.0000],
         [346.0000, 317.0000],
         [560.8283, 322.6830]]], dtype=torch.float64)

最佳答案

  1. 来自 ToTensor() 的 pytorch 文档:

Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8

In the other cases, tensors are returned without scaling.

由于您的 Landmark 值不是 PIL 图像,且不在 [0, 255] 范围内,因此不会应用缩放。

  • 您的计算看来是正确的。看来您的数据中可能有一些 NaN 值。
  • 你可以尝试类似的事情

    for i in range(len(train_dataset)):
        landmarks = to_tensor(train_dataset[i]['landmarks'])
        landmarks[landmarks != landmarks] = 0  # this will set all nan to zero
        landmarks_arr.append(landmarks)
    

    在你的循环中。或者在循环中断言 for nan 以找到罪魁祸首:

    for i in range(len(train_dataset)):
        landmarks = to_tensor(train_dataset[i]['landmarks'])
        assert(not torch.isnan(landmarks).any()), f'nan encountered in sample {i}'  # will trigger if a landmark contains nan
        landmarks_arr.append(landmarks)
    
  • 不,请参阅 1)。如果您愿意,您可以除以地标的最大坐标,将它们限制为 [0, 1]。
  • enter image description here

    关于python - 计算 torch 张量数组的平均值和标准差,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64366121/

    相关文章:

    python - Cygwin 新手 : How do I uninstall Python 2. 6.x 来自 Cygwin 并安装 Python 2.7.x?

    python - 在内置函数中赋值

    python - 任意重复 numpy 数组的内容

    python - 推荐系统 - 基于 Softmax 的深度神经网络模型中的用户嵌入

    tensorflow - 用图像和多标签编写 tfrecords 进行分类

    python - 如何将图像数据从存储桶加载到AWS sagemaker笔记本?

    Python/Excel - IOError : [Errno 2] No such file or directory:

    python - 改变 python 描述符

    python - NumPy 的 : read data from CSV having numerals as string

    python - Numpy:根据索引列表访问多维数组的值