你好,欢迎访问远方教程PC版!
广告位招租
网站首页 >> 统计之窗 >> MATLAB专区 >> 文章内容

Matlab技巧04:EM算法分析与实现(基于高斯混合模型)(第3页)

[日期:2015-07-15]   来源:远方教程  作者:远方教程   阅读:15003次[字体: ] 访问[旧版]
 捐赠远方教程 



3.Matlab代码

   下面就是我基于matlab实现的EM算法,真心觉得理解原理不容易,在写代码的过程中也要很仔细,在实现的过程中还是参照了一位大牛的代码,理解思想,看了他的代码之后自己重新写的,思想多少有点受他影响,但是个人觉得代码还是有优化过。

%EM
M=3;          % M个高斯分布混合
N=600;        % 样本数
th=0.000001;  % 收敛阈值
K=2;          % 样本维数
% 待生成数据的参数
a_real =[2/3;1/6;1/6];%混合模型中基模型高斯密度函数的权重
mu_real=[3 4 6;5 3 7];%均值
cov_real(:,:,1)=[5 0;0 0.2];%协方差
cov_real(:,:,2)=[0.1 0;0 0.1];
cov_real(:,:,3)=[0.1 0;0 0.1];                    
%生成符合标准的样本数据(每一列为一个样本)
x=[ mvnrnd( mu_real(:,1) , cov_real(:,:,1) , round(N*a_real(1)) )' ,...
    mvnrnd( mu_real(:,2) , cov_real(:,:,2) , round(N*a_real(2)) )' ,...
    mvnrnd( mu_real(:,3) , cov_real(:,:,3) , round(N*a_real(3)) )' ];
%初始化参数
a=[1/3;1/3;1/3];
mu=[1 2 3;2 1 4];
cov(:,:,1)=[1 0;0 1];
cov(:,:,2)=[1 0;0 1];
cov(:,:,3)=[1 0;0 1];
t=inf;
while t>=th
    a_old  = a;
    mu_old = mu;
    cov_old= cov;     
    rznk_temp=zeros(M,N);
    for k=1:M
        for n=1:N
            %计算P(x|mu_cm,cov_cm)
            rznk_temp(k,n)=exp(-1/2*(x(:,n)-mu(:,k))'*inv(cov(:,:,k))*(x(:,n)-mu(:,k)));
        end
        rznk_temp(k,:)=rznk_temp(k,:)/sqrt(det(cov(:,:,k)));
    end
    rznk_temp=rznk_temp*(2*pi)^(-K/2);
%E step
    %求rznk
    rznk=zeros(M,N);
    for n=1:N
        for k=1:M
            rznk(k,n)=a(k)*rznk_temp(k,n);
        end
        rznk(:,n)=rznk(:,n)/sum(rznk(:,n));
    end
% M step
    %求Nk
    nk=zeros(1,M);
    nk=sum(rznk');
   
    % 求a
    a=nk/N;
       
    % 求MU
    for k=1:M
        mu_k_sum=0;
        for n=1:N
            mu_k_sum=mu_k_sum+rznk(k,n)*x(:,n);
        end
        mu(:,k)=mu_k_sum/nk(k);
    end
   
    % 求COV  
    for k=1:M
        cov_k_sum=0;
        for n=1:N
            cov_k_sum=cov_k_sum+rznk(k,n)*(x(:,n)-mu(:,k))*(x(:,n)-mu(:,k))';
        end
        cov(:,:,k)=cov_k_sum/nk(k);
    end
      
    t=max([norm(a_old(:)-a(:))/norm(a_old(:));norm(mu_old(:)-mu(:))/norm(mu_old(:));norm(cov_old(:)-cov(:))/norm(cov_old(:))]);
end

%输出结果并比较

a_real
a

mu_real
mu

cov_real
cov

%结果

a_real =

    0.6667
    0.1667
    0.1667


a =

    0.6657    0.1681    0.1662


mu_real =

     3     4     6
     5     3     7


mu =

    3.0366    3.9987    6.0406
    4.9941    2.9888    7.0190


cov_real(:,:,1) =

    5.0000         0
         0    0.2000


cov_real(:,:,2) =

    0.1000         0
         0    0.1000


cov_real(:,:,3) =

    0.1000         0
         0    0.1000


cov(:,:,1) =

    5.4894   -0.0389
   -0.0389    0.1939


cov(:,:,2) =

    0.0682    0.0038
    0.0038    0.0959


cov(:,:,3) =

    0.0866   -0.0033
   -0.0033    0.0761
   通过输出结果发现算法的准确性还是比较高的,算法迭代得到的值与实际值出入不是很大。

 

图片展示
 
相关评论
站长推荐