假设我有 data
,它是大小为 (B, N, D)
的数据点集合的批量张量,其中 B
是我的批量大小,N
是每个集合中的数据样本数量,D
是我的数据向量的长度。我想计算每个数据点集合的样本均值和协方差,但要批量执行。
要计算平均值,我可以执行data.mean(dim=1)
,并得到一个大小为(B, D)
的张量,表示每个集合的平均值。我以为我可以用 torch.cov
做类似的事情但它不提供批量执行此操作的能力。还有其他方法可以实现这一目标吗?我期望获得一批形状为 (B, D, D)
的协方差矩阵。
最佳答案
这可以解决问题:
def batch_cov(points):
B, N, D = points.size()
mean = points.mean(dim=1).unsqueeze(1)
diffs = (points - mean).reshape(B * N, D)
prods = torch.bmm(diffs.unsqueeze(2), diffs.unsqueeze(1)).reshape(B, N, D, D)
bcov = prods.sum(dim=1) / (N - 1) # Unbiased estimate
return bcov # (B, D, D)
这是一个脚本,用于测试它计算的内容是否与非批处理 PyTorch 版本计算的内容相同:
import time
import torch
B = 10000
N = 50
D = 2
points = torch.randn(B, N, D)
start = time.time()
my_covs = batch_cov(points)
print("My time: ", time.time() - start)
start = time.time()
torch_covs = torch.zeros_like(my_covs)
for i, batch in enumerate(points):
torch_covs[i] = batch.T.cov()
print("Torch time:", time.time() - start)
print("Same?", torch.allclose(my_covs, torch_covs, atol=1e-7))
这给了我:
My time: 0.00251793861318916016
Torch time: 0.2459864616394043
Same? True
我不能声称我的速度本质上会比迭代计算它们更快,似乎随着D
变得更大,我的速度会慢得多,所以可能有一种更好的方法来随着数据维度大小进行扩展.
关于python - 如何在 PyTorch 中计算批量样本协方差?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/71357619/