%This script learns the VB-MoG model using the VB-EM algorithm. Refer to
%the demonstration scripts named MoG_comparison_experiment_# for usage of
%this script. The parameters of the learned model are stored in variables
%mean, W, v, alpha and beta. The cost function values and corresponding
%CPU times are stored in the variable costValues_VB_EM.
%Copyright (C) 2008-2010 Mikael Kuusela.

disp('-----------------');
disp('VB EM');

%Initialize number of components
K = initK;

%Initialize parameters
W = initW;
v = initV;
beta = initBeta;
alpha = initAlpha;
mean = initMean;

%Initial values for pi_eff and lambda_eff   
pi_eff = exp(psi(alpha)-psi(sum(alpha)));
lambda_eff = zeros(1,K);
for k=1:K
    temp = 0;
    for i=1:D
        temp = temp + psi((v(k) + 1 - i)/2);
    end
    lambda_eff(k) = exp(temp + D*log(2) + log(det(W(:,:,k))));
end

cost = inf;
totTime = 0;
iter = 0;

while 1
    iter = iter + 1;
    
    %{
    if plot_MoG == true
        plot_final_MoG(alpha, v, W, mean, beta, K, D, X, cost);
    end
    %}
    
    %Terminate if termination criteria are met      
    if iter ~= 1 && abs(prevCost - cost) < epsilon 
        termCount = termCount + 1;
        if termCount == requiredTermCount            
            break;
        end
    else
        termCount = 0;
    end

    t = cputime;
    
    %Calculate responsibilities
    res = zeros(N,K);
    for n=1:N
        for k=1:K
            res(n,k) = 	pi_eff(k)*sqrt(lambda_eff(k))*exp(-D/(2*beta(k))-v(k)/2*(X(:,n)-mean(:,k))'*W(:,:,k)*(X(:,n)-mean(:,k)));
        end
        res(n,:) = res(n,:) / sum(res(n,:)); %Normalization
        
        normalization = false;
        for k=1:K
            if res(n,k) < 1e-10
                res(n,k) = 1e-10; %Avoid domain errors due to limited accuracy of floating point numbers
                normalization = true;
            end
        end
        if normalization == true
            res(n,:) = res(n,:) / sum(res(n,:)); %Normalization
        end
    end

    %Calculate helper values  
    N_eff = zeros(1,K);
    for k=1:K
        N_eff(k) = sum(res(:,k));
    end
    
    %Remove components for which N_eff < removalCriteria
    remove = find(N_eff < removalCriteria);
    removeN = size(remove,2);
    if removeN > 0
        K = K - removeN;
        
        res(:,remove) = [];
        
        for n=1:N
            res(n,:) = res(n,:) / sum(res(n,:)); %Normalization
        end
        
        N_eff = zeros(1,K);
        for k=1:K
            N_eff(k) = sum(res(:,k));
        end
    end
    
    x_avg = zeros(D,K);
    S = zeros(D,D,K);
    for k=1:K
        for n=1:N
            x_avg(:,k) = x_avg(:,k) + res(n,k)*X(:,n);
        end    
        x_avg(:,k) = x_avg(:,k) / N_eff(k);

        for n=1:N
            S(:,:,k) = S(:,:,k) + res(n,k)*(X(:,n) - x_avg(:,k))*(X(:,n) - x_avg(:,k))';
        end
        S(:,:,k) = S(:,:,k) / N_eff(k);
    end
    
    %Update distribution parameters for each Gaussian using the VB-EM
    %update equations
    alpha = zeros(1,K);
    beta = zeros(1,K);
    v = zeros(1,K);
    mean = zeros(D,K);
    W = zeros(D,D,K);
    
    for k=1:K
        alpha(k) = alpha_0 + N_eff(k);  
    end
    
    for k=1:K    
        beta(k) = beta_0 + N_eff(k);
    end
    
    for k=1:K
        mean(:,k) = 1/(beta_0 + N_eff(k)) * (beta_0*mean_0 + N_eff(k)*x_avg(:,k));
    end
    
    for k=1:K    
        v(k) = v_0 + N_eff(k);
    end
    
    for k=1:K    
        W(:,:,k) = inv(invW_0 + N_eff(k) * S(:,:,k) + (beta_0 * N_eff(k)) / (beta_0 + N_eff(k)) * (x_avg(:,k) - mean_0) * (x_avg(:,k) - mean_0)');
    end

    %Calculate pi_eff and lambda_eff
    pi_eff = exp(psi(alpha)-psi(sum(alpha)));
    lambda_eff = zeros(1,K);
    for k=1:K
        temp = 0;
        for i=1:D
            temp = temp + psi((v(k) + 1 - i)/2);
        end
        lambda_eff(k) = exp(temp + D*log(2) + log(det(W(:,:,k))));
    end
    
    %Calculate cost function
    prevCost = cost;
    cost = costFunction(alpha_0, beta_0, v_0, W_0, invW_0, mean_0, res, N_eff, x_avg, S, alpha, beta, mean, v, W, pi_eff, lambda_eff, K, N, D);
    
    time = cputime - t;
    
    disp(strcat('Elapsed cpu time:', num2str(time)));
    disp(strcat('Cost function value:', num2str(cost)));
    disp(strcat('Previous cost function value:', num2str(prevCost)));
    disp(strcat('Run:', num2str(run)));    
    if prevCost >= cost
        disp('OK');
    else
        disp('ERROR');
    end
    
    totTime = totTime + time;
    costValues_VB_EM(iter,:,run) = [totTime,cost];
end

disp(strcat('Total time: ',num2str(totTime)));

if plot_MoG == true
    plot_final_MoG(alpha, v, W, mean, beta, K, D, X, cost);
end

%Save image segments
if imageData == true
    for k=1:K
        writeImageData(res,k,x_pixs,N,strcat(filename,'_segmented_VB_EM_',num2str(run),'_',num2str(k),'.png'),'png');
    end
end