%This script can be used to compare the different algorithms for learning
%the MoG model. As such the script learns the cluster dataset using 
%the following gradient based algorithms: L-BFGS, gradient descent,
%CG, NG and NCG. Other combinations and experiment setups can be studied
%with easy modifications of the script.
%Copyright (C) 2008-2010 Mikael Kuusela.

experiment = 1;

imageData = false; %True uses image data
plot_MoG = false; %True plots the MoG model
randomMeans = true; %Initialize means of means randomly, otherwise place on top of a random data point

seed = 1974296; %Seed for random number generators
rand('state',seed);
randn('state',seed);

warning off;

if imageData == true
    %Use image data
    filename = 'swan.jpg';
    type = 'jpg';
    [X N D x_pixs y_pixs] = getImageData(filename, type);
else
    %Use artificial data
    N=500; %Number of data points
    D=2; %Dimensionality of the data
    X = generateData(3,N,D); %Generate the data
end

initK=5; %Initial number of Gaussians
nRuns = 30; %Number of runs

%Termination criteria
epsilon = 1e-7*N;
requiredTermCount = 2;

%Removal criteria
removalCriteria = 0;

%Initial values for parameters
initAlpha = ones(1,initK);
initBeta = 10*ones(1,initK);
initV = ones(1,initK)*D;
initW = 4/D*repmat(eye(D),[1 1 initK]);

%Parameters for priors
alpha_0 = 1;
beta_0 = 1;
v_0 = D;
W_0 = 4/D*eye(D);
mean_0 = zeros(D,1);

invW_0 = inv(W_0);

clear global costValues*;
clear global excess_time;

global excess_time;
excess_time = zeros(1,nRuns);

for run=1:nRuns
    if randomMeans == true
        %Initialize means of means randomly
        initMean=random('Normal',0,0.4,D,initK);
    else
        %Initialize means on top of a randomly selected data point
        randDataPoint = randint(1,initK,[1,N]);
        initMean = zeros(D,initK);
        for k=1:initK
            initMean(:,k) = X(:,randDataPoint(k));
        end
    end
    
    %save(strcat('./temp/experiment_',num2str(experiment),'_initial_',num2str(run)),'X','initMean');
    
    m_bfgs = 15; %Memory length for L-BFGS
    MoG_VB_m_gamma_bfgs;
    %save(strcat('./temp/experiment_',num2str(experiment),'_m_gamma_bfgs_end_',num2str(run)),'cost','K','mean','W','v','alpha','beta');
      
    initT3 = 0.002; %Step length for line search
    MoG_VB_m_gamma_grad;
    %save(strcat('./temp/experiment_',num2str(experiment),'_m_gamma_grad_end_',num2str(run)),'cost','K','mean','W','v','alpha','beta');   
    MoG_VB_m_gamma_conj_grad_PR;
    %save(strcat('./temp/experiment_',num2str(experiment),'_m_gamma_conj_grad_PR_end_',num2str(run)),'cost','K','mean','W','v','alpha','beta');
    
    initT3 = 2; %Step length for line search
    MoG_VB_m_gamma_nat_grad;
    %save(strcat('./temp/experiment_',num2str(experiment),'_m_gamma_nat_grad_end_',num2str(run)),'cost','K','mean','W','v','alpha','beta');    
    MoG_VB_m_gamma_nat_conj_grad_PR_Riemann;
    %save(strcat('./temp/experiment_',num2str(experiment),'_m_gamma_nat_conj_grad_PR_Riemann_end_',num2str(run)),'cost','K','mean','W','v','alpha','beta');
      
    %save(strcat('./temp/experiment_',num2str(experiment),'_',num2str(run)), 'costValues*');
end

%Save cost function values and the associated training times
save(strcat('experiment_',num2str(experiment)), 'costValues*');

%{
for run=1:nRuns
    values = reshape(nonzeros(costValues_m_gamma_bfgs(:,:,run)),[],2);
    bfgs_compensation_factor(run) = (values(end,1)-excess_time(run))/values(end,1);
end
%}