pytorch - pytorch 代码中的 KL 散度与公式有何关系?

标签 pytorch autoencoder loss-function

在 VAE 教程中,两个正态分布的 kl-divergence 定义为: enter image description here

而且在很多代码中,比如here , herehere ,代码实现为:

 KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())

def latent_loss(z_mean, z_stddev):
    mean_sq = z_mean * z_mean
    stddev_sq = z_stddev * z_stddev
    return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)

它们有什么关系?为什么代码中没有任何“tr”或“.transpose()”?

最佳答案

您发布的代码中的表达式假定 X 是一个不相关 多变量高斯随机变量。这在协方差矩阵的行列式中缺少交叉项是显而易见的。因此均值向量和协方差矩阵的形式为

enter image description here

使用它我们可以快速推导出以下原始表达式组件的等价表示

enter image description here

将这些代回原始表达式得到

enter image description here

关于pytorch - pytorch 代码中的 KL 散度与公式有何关系?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61597340/

相关文章:

python - 在 Python 中实现梯度下降并收到溢出错误

使用 mini-batch 时累积的 pytorch 损失

python - 根据纪元更改多输出损失权重

deep-learning - 如何通过一次操作合并两个 torch.utils.data 数据加载器

python - Keras 自动编码器

python - 梯度计算所需的变量之一已通过就地操作修改

machine-learning - mlpack稀疏编码解决方案未找到

python - 变分自动编码器 : InvalidArgumentError: Incompatible shapes: [100, 5] 与 [100]

pytorch - ONNX 和 pytorch 的输出不同

pytorch - 激活梯度惩罚