python - 如何在 PyTorch 中计算批量样本协方差?

标签 python pytorch

假设我有 data,它是大小为 (B, N, D) 的数据点集合的批量张量,其中 B 是我的批量大小,N 是每个集合中的数据样本数量,D 是我的数据向量的长度。我想计算每个数据点集合的样本均值和协方差,但要批量执行。

要计算平均值,我可以执行data.mean(dim=1),并得到一个大小为(B, D)的张量,表示每个集合的平均值。我以为我可以用 torch.cov 做类似的事情但它不提供批量执行此操作的能力。还有其他方法可以实现这一目标吗?我期望获得一批形状为 (B, D, D) 的协方差矩阵。

最佳答案

这可以解决问题:

def batch_cov(points):
    B, N, D = points.size()
    mean = points.mean(dim=1).unsqueeze(1)
    diffs = (points - mean).reshape(B * N, D)
    prods = torch.bmm(diffs.unsqueeze(2), diffs.unsqueeze(1)).reshape(B, N, D, D)
    bcov = prods.sum(dim=1) / (N - 1)  # Unbiased estimate
    return bcov  # (B, D, D)

这是一个脚本,用于测试它计算的内容是否与非批处理 PyTorch 版本计算的内容相同:

import time
import torch

B = 10000
N = 50
D = 2
points = torch.randn(B, N, D)
start = time.time()
my_covs = batch_cov(points)
print("My time:   ", time.time() - start)

start = time.time()
torch_covs = torch.zeros_like(my_covs)
for i, batch in enumerate(points):
    torch_covs[i] = batch.T.cov()

print("Torch time:", time.time() - start)
print("Same?", torch.allclose(my_covs, torch_covs, atol=1e-7))

这给了我:

My time:    0.00251793861318916016
Torch time: 0.2459864616394043
Same? True

我不能声称我的速度本质上会比迭代计算它们更快,似乎随着D变得更大,我的速度会慢得多,所以可能有一种更好的方法来随着数据维度大小进行扩展.

关于python - 如何在 PyTorch 中计算批量样本协方差?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71357619/

相关文章:

python - 在 Python 中动态覆盖 __functions__

pytorch - 错误 : Some NCCL operations have failed or timed out

python - 为什么这个 tensorflow 训练需要这么长时间?

machine-learning - 如何在PyTorch中实现学习率的随机对数空间搜索?

python - 在 Windows 10 上 pip 安装 torchvision 时出错

python - 我如何通过pytorch加载预训练模型? (时尚)

python - 如何在 Pandas 数据框中转换时间 2020-09-25T00 :20:00. 000Z

python - def next() 适用于 Python pre-2.6? (而不是 object.next 方法)

python - 在Windows 98中通过网络访问py2exe程序会引发ImportErrors

python - 使用 Python 播放音频