pytorch - 两个 torch.distribution.Distribution 对象的 KL 散度

标签 pytorch pytorch-distributions

我正在尝试确定如何计算两个 torch.distribution.Distribution 对象的 KL 散度。到目前为止,我找不到执行此操作的功能。这是我尝试过的:

import torch as t
from torch import distributions as tdist
import torch.nn.functional as F

def kl_divergence(x: t.distributions.Distribution, y: t.distributions.Distribution):
    """Compute the KL divergence between two distributions."""
    return F.kl_div(x, y)  

a = tdist.Normal(0, 1)
b = tdist.Normal(1, 1)

print(kl_divergence(a, b))  # TypeError: kl_div(): argument 'input' (position 1) must be Tensor, not Normal

最佳答案

torch.nn.functional.kl_div 正在计算 KL-divergence 损失。可以使用 torch.distributions.kl.kl_divergence 计算两个分布之间的 KL 散度.

关于pytorch - 两个 torch.distribution.Distribution 对象的 KL 散度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/72726304/

相关文章:

python - BERT 使用拥抱面部模型的分数将 'SpanAnnotation' 转换为答案

python - 哪种模型/技术用于特定句子提取?

python - 如何将 HuggingFace 的 Seq2seq 模型转换为 onnx 格式

python - 为什么我的完全卷积自动编码器不对称?

machine-learning - 对CNN中跳跃层的实现感到困惑