我正在尝试确定如何计算两个 torch.distribution.Distribution
对象的 KL 散度。到目前为止,我找不到执行此操作的功能。这是我尝试过的:
import torch as t
from torch import distributions as tdist
import torch.nn.functional as F
def kl_divergence(x: t.distributions.Distribution, y: t.distributions.Distribution):
"""Compute the KL divergence between two distributions."""
return F.kl_div(x, y)
a = tdist.Normal(0, 1)
b = tdist.Normal(1, 1)
print(kl_divergence(a, b)) # TypeError: kl_div(): argument 'input' (position 1) must be Tensor, not Normal
最佳答案
torch.nn.functional.kl_div
正在计算 KL-divergence 损失。可以使用 torch.distributions.kl.kl_divergence
计算两个分布之间的 KL 散度.
关于pytorch - 两个 torch.distribution.Distribution 对象的 KL 散度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/72726304/