python - 查找图像 channel PyTorch 的平均值和标准差

标签 python deep-learning pytorch mean standard-deviation

假设我有一批尺寸为 (B x C x W x H) 张量形式的图像,其中 B 是批量大小,C 是图像中的 channel 数,W 和 H 是宽度和图像的高度。我希望使用 transforms.Normalize() 函数根据 C 图像 channel 中数据集的平均值和标准差对图像进行标准化,这意味着我想要一个 1 x C 形式的结果张量。有没有一种简单的方法可以做到这一点?

我尝试了 torch.view(C, -1).mean(1)torch.view(C, -1).std(1) 但我得到错误:

view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

编辑

在研究了 view() 在 PyTorch 中的工作原理之后,我知道为什么我的方法不起作用;但是,我仍然不知道如何获得每个 channel 的平均值和标准差。

最佳答案

请注意,相加的是方差,而不是标准差。请参阅此处的详细说明:https://apcentral.collegeboard.org/courses/ap-statistics/classroom-resources/why-variances-add-and-why-it-matters

修改后的代码如下:

nimages = 0
mean = 0.0
var = 0.0
for i_batch, batch_target in enumerate(trainloader):
    batch = batch_target[0]
    # Rearrange batch to be the shape of [B, C, W * H]
    batch = batch.view(batch.size(0), batch.size(1), -1)
    # Update total number of images
    nimages += batch.size(0)
    # Compute mean and std here
    mean += batch.mean(2).sum(0) 
    var += batch.var(2).sum(0)

mean /= nimages
var /= nimages
std = torch.sqrt(var)

print(mean)
print(std)

关于python - 查找图像 channel PyTorch 的平均值和标准差,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60101240/

相关文章:

deep-learning - Keras 输入数组(x)不能有不同数量的样本吗?

python - Pytorch模型的超参数优化

python - 使用 Tensorflow/PyTorch 加速自定义函数的最小化

python - 修改数组不会影响其在外循环中的长度

python - 通过mlxtend包中的StackingCVClassifier将参数传递给底层分类器的fit方法

python - Keras 图像数据生成器 : how to use data augmentation with images paths

python - 如何在 Unet 架构 PyTorch 中处理奇数分辨率

python - 尝试/除了抓取 URL 末尾带有 3 个随机数字的网站

python - 当模式存在时匹配字符串(以模式开头除外)

r - 有没有办法在不同版本的 H2O 之间使用保存的模型?