algorithm - EM 算法代码不起作用

标签 algorithm matlab statistics

我正在尝试使用 EM 算法估计高斯混合的平均值、权重和协方差,但我得到的是所有向量值的“NaN”(不是数字)。这是我的代码:

function [nw,ns,nu]=est(wo,so,uo,w,s,u,o,J)
% o is the data points vector (size=nx2)
% J is the number of Gaussian curve
% maximum number of Gaussians=5
% wo is the old weight vector (size=1x5) and "w" is the actual weight vector (size=1x5)
% so is the old covariance vector (size=1x5) and "s" is the actual covariance value (size=1x5)
% uo is the old mean vector (size=5x2) and "u" is the actual mean vector (size=5x2)

o=unique(o,'rows');
t=length(o);

for i=1:5
for j=1:J
    un=0;ud=0;s2n=0;
    for T=1:t
        ud=ud+probab(wo,o(T,:),uo,so,j,J);
    end
    for T=1:t
        un=un+probab(wo,o(T,:),uo,so,j,J)*(o(T,:));
    end
    for T=1:t
        s2n=s2n+probab(wo,o(T,:),uo,so,j,J)*(o(T,:)-uo(j,:))*(((o(T,:)-uo(j,:))'));
    end
    wtemp(j)=ud/t;
    utemp(j,:)=un/ud;
    stemp(j)=sqrt(s2n/(ud));
end
wo=wtemp;
uo=utemp;
so=stemp;
end
    nw=wo;
    ns=so;
    nu=uo;

函数“probab”是:

function [pro]=probab(w,o,u,s,j,J)
pro=w(j).*(1/(2*pi*s(j))).*exp((-1/(2.*(s(j).^2)))*((o(1)-u(j,1)).^2+(o(2)-u(j,2)).^2))/(gaussiandistribution(w,u,s,o(1),o(2),J));

函数“高斯分布”是

function [z]=gaussiandistribution(w,u,s,X,Y,J)
z1=w(1).*(1/(2*pi*s(1))).*exp(-(((X-u(1,1)).^2)/(2.*s(1).^2)+((Y-u(1,2)).^2)/(2.*s(1).^2)));
z2=w(2).*(1/(2*pi*s(2))).*exp(-(((X-u(2,1)).^2)/(2.*s(2).^2)+((Y-u(2,2)).^2)/(2.*s(2).^2)));
z3=w(3).*(1/(2*pi*s(3))).*exp(-(((X-u(3,1)).^2)/(2.*s(3).^2)+((Y-u(3,2)).^2)/(2.*s(3).^2)));
z4=w(4).*(1/(2*pi*s(4))).*exp(-(((X-u(4,1)).^2)/(2.*s(4).^2)+((Y-u(4,2)).^2)/(2.*s(4).^2)));
z5=w(5).*(1/(2*pi*s(5))).*exp(-(((X-u(5,1)).^2)/(2.*s(5).^2)+((Y-u(5,2)).^2)/(2.*s(5).^2)));
if J==5
z=z1+z2+z3+z4+z5;
elseif J==4
z=z1+z2+z3+z4;
elseif J==3
z=z1+z2+z3;
elseif J==2
z=z1+z2;
elseif J==1
z=z1;
end

最佳答案

我已经运行了您的代码(在 Octave 中),但没有看到任何 NaN。以下是我的值(value)观:

octave:13> wo
wo =

   0.20000   0.20000   0.20000   0.20000   0.20000

octave:14> so
so =

   1   1   1   1   1

octave:15> uo
uo =

   0   0
   0   0
   0   0
   0   0
   0   0

octave:16> w
w =

   0.20000   0.20000   0.20000   0.20000   0.20000

octave:17> s
s =

   1   1   1   1   1

octave:18> u
u =

   0   0
   0   0
   0   0
   0   0
   0   0

octave:19> 
octave:19> o
o =

   1.00000   2.00000
   3.00000   1.00000
  -1.00000   2.00000
   2.00000  -2.00000
  -2.00000  -1.00000
   3.00000  -2.00000
   2.00000   3.00000
   1.50000  -0.25000

这是我得到的输出:

octave:20> [nw, ns, nu] = est(wo, so, uo, w, s, u, o, 5)
nw =

   0.20000   0.20000   0.20000   0.20000   0.20000

ns =

   2.4770   2.4770   2.4770   2.4770   2.4770

nu =

   1.18750   0.34375
   1.18750   0.34375
   1.18750   0.34375
   1.18750   0.34375
   1.18750   0.34375

我没有检查这些是否正确,但无论如何它们都不是 NaN。

您的输入值是什么?

关于algorithm - EM 算法代码不起作用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/29979808/

相关文章:

algorithm - 从 Excel 导入中查找包含的边界区域

javascript - 我怎样才能将相机的位置合并到这个算法中来计算 Sprite 在屏幕上的位置?

performance - ppval的高效替代品

matlab - 如何实现面部 Action 编码系统(FACS)?

c++ - 在相机图像中的监视器上查找白色像素

java - 在散点图中查找位于 X 轴和 Y 轴附近的点

r - 如何让 R 以指数表示法读取一列数字?

algorithm - 遍历从 (0,0,0) 到 (5,5,5) 的所有路径

performance - MergeSort 中的两个递归调用是什么?

python - 通过 Python 但通过 Linux 命令终端调用 MATLAB