在 VAE 教程中,两个正态分布的 kl-divergence 定义为:
而且在很多代码中,比如here , here和 here ,代码实现为:
KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
或
def latent_loss(z_mean, z_stddev):
mean_sq = z_mean * z_mean
stddev_sq = z_stddev * z_stddev
return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)
它们有什么关系?为什么代码中没有任何“tr”或“.transpose()”?
最佳答案
您发布的代码中的表达式假定 X 是一个不相关 多变量高斯随机变量。这在协方差矩阵的行列式中缺少交叉项是显而易见的。因此均值向量和协方差矩阵的形式为
使用它我们可以快速推导出以下原始表达式组件的等价表示
将这些代回原始表达式得到
关于pytorch - pytorch 代码中的 KL 散度与公式有何关系?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61597340/