%  MOGICA.M
%
%  VB independent component analysis with the mixture-of-Gaussians
%  model for the sources. Miskin and MacKay's Gaussian approximation
%  for the source posterior is used. The data must be centered. By
%  default, the sources are fixed to the principal components of the
%  data.
%
%  Usage: To initialise the model and start learning:
%
%  net = mogica( data, iter, 'searchsources', 2,...
%                            'initsources', s_init,...
%                            'diag_qs', 0 )
%  or
%  net = mogica( data, iter, 'searchsources', 2,...
%                            'initA', A_init,...
%                            'diag_qs', 0 )
%
%  To continue learning of the existing model:
%  net = mogica( data, iter, net, status )
%
function [ net, status ] = mogica( data, iter, varargin )

[ xdim, tdim ] = size(data);

if isstruct(varargin{1})

    % The network is passed to the function
    [ A, s, vx, Vvx, Mvx,...
      Mmms, Vmms, Mvms, Vvms, Mmvs, Vmvs, Mvvs, Vvvs,...
      Mms, Vms, Mvs, Vvs,...
      ms, vs, mog_pi, lambda, x ] = net2mogica( varargin{1}, data );
        
    if nargin < 4 | isempty(varargin{2})
        error( 'MOGICA: Unspecifed status' )
    else
        status = varargin{2};
    end
    
    CF = status.CF;
    complete_iters = status.iter;
    diag_qs = status.diag_qs;
    
    if isfield( status, 'A_clamped' )
        A.clamped = status.A_clamped;
    else
        A.clamped = 0;
    end

    optional_flds = { 's_clamped', 'Mvx_clamped',...
                      'Vvx_clamped', 'vx_clamped', 'lambda_many' };
    for i = 1:length(optional_flds)
        if isfield( status, optional_flds{i} )
            eval( [ optional_flds{i} ...
                    '= status.' optional_flds{i} ';' ] )
        else
            eval( [ optional_flds{i} '= 0;' ] )
        end
    end

else
    
    % Default parameter values
    options = struct( 'searchsources', [], ...
                      'inita', [], ...
                      'initsources', [], ...
                      'diag_qs', 0 );
    
    if mod( length(varargin), 2 )
        error( 'MOGICA: Not enough arguments' )
    end
    for i = 1:2:length(varargin)
        if ~isfield( options, lower(varargin{i}) )
            error( [ 'MOGICA: Unknown parameter: ' varargin{i} ] )
        end
        options = setfield( options, lower(varargin{i}),...
                                     varargin{i+1} );
    end
    
    sdim = options.searchsources;
    if isempty( sdim )
        error( 'MOGICA: Unspecified number of sources' )
    end
    diag_qs = options.diag_qs;

    %
    % Create the model
    %
    
    Mmms.mean = 0;
    Mmms.var  = 1;
    
    Vmms.mean = 0;
    Vmms.var  = 1;
    Vmms.mex  = meanexp( Vmms );
    
    Mvms.mean = 0;
    Mvms.var  = 1;
    
    Vvms.mean = 0;
    Vvms.var  = 1;
    Vvms.mex  = meanexp( Vvms );
    
    Mmvs.mean = 0;
    Mmvs.var  = 1;
    
    Vmvs.mean = 0;
    Vmvs.var  = 1;
    Vmvs.mex  = meanexp( Vmvs );
    
    Mvvs.mean = 0;
    Mvvs.var  = 1;
    
    Vvvs.mean = 0;
    Vvvs.var  = 1;
    Vvvs.mex  = meanexp( Vvvs );
    
    % Mms - the means of the MoG centers ms
    % Vms - the variances of the MoG centers ms
    % Mvs - the means of the MoG variances vs
    % Vvs - the variances of the MoG variances vs
    Mms.mean = repmat( Mmms.mean, sdim, 1 );
    Mms.var  = repmat( 1/Vmms.mex, sdim, 1 );
    
    Vms.mean = repmat( Mvms.mean, sdim, 1 );
    Vms.var  = repmat( 1/Vvms.mex, sdim, 1 );
    Vms.mex  = meanexp( Vms );
    
    Mvs.mean = repmat( Mmvs.mean, sdim, 1 );
    Mvs.var  = repmat( 1/Vmvs.mex, sdim, 1 );
    
    Vvs.mean = repmat( Mvvs.mean, sdim, 1 );
    Vvs.var  = repmat( 1/Vvvs.mex, sdim, 1 );
    Vvs.mex  = meanexp( Vvs );
    
    % ms - MoG centers
    % vs - MoG variances
    % mog_pi - MoG coefficients with Dirichlet prior
    mixdim = [ 3, 3 ];
    for j = 1:sdim
        randn('state',0)
        
        ms{j}.mean = randn( mixdim(j), 1 );
        ms{j}.var = repmat( 1/Vms.mex(j), mixdim(j), 1 );
        vs{j}.mean = zeros( mixdim(j), 1 );
        vs{j}.var = repmat( 1/Vvs.mex(j), mixdim(j), 1 );
        vs{j}.mex = meanexp( vs{j} );

        % Initialization for counts does not matter
        mog_pi{j}.counts = repmat( 1, mixdim(j), 1 );
        mog_pi{j}.mean = 1/mixdim(j);
        mog_pi{j}.explog = explog_pi( mog_pi{j} );
        
    end

    s = repmat( struct( 'mean',  zeros(sdim,1),...
                        'covar', eye(sdim) ),...
                1, tdim );
    
    % Approximation coefficients for p( s_t )
    for j = 1:sdim
        lambda{j} = zeros( mixdim(j), tdim );
    end
    lambda = update_lambda( lambda, mog_pi, s, ms, vs );
    lambda_many = 0;
    
    % A
    A.mean = zeros( xdim, sdim );
    A.var = ones( xdim, sdim );
    
    % vx, Mvx, Vvx
    Mvx.mean = 0;
    Mvx.var  = 1;
    
    Vvx.mean = 0;
    Vvx.var  = 1;
    Vvx.mex  = meanexp( Vvx );
    
    vx.mean = repmat( Mvx.mean, xdim, 1 );
    vx.var  = repmat( 1/Vvx.mex, xdim, 1 );
    vx.mex  = meanexp( vx );
    
    for t = 1:tdim
        x(t).value = data(:,t);
    end

    %
    % Initialization
    %
    if isempty( options.initsources ) & isempty( options.inita )
        
        % Initialize sources with PCA
        C = data*data'/tdim;
        [ V, L ] = eig(C);
        
        [ L, id ] = sort( -diag(L) );
        V = V(:,id);
        options.initsources = V( :, 1:sdim )' * data;
        
    end
    if ~isempty( options.initsources )
        s_clamped = 1;
        for t = 1:tdim
            s(t).mean = options.initsources(:,t);
            s(t).covar = zeros(sdim,sdim);
        end
    else
        s_clamped = 0;
    end
    if ~isempty( options.inita )
        A.clamped = 1;
        A.mean = options.inita;
        A.var = zeros( xdim, sdim );
    else
        A.clamped = 0;
    end
    
    Mvx_clamped = 0;
    Vvx_clamped = 0;
    vx_clamped = 0;

    CF = [];
    
    complete_iters = 0;
    status.diag_qs = diag_qs;

end

    %
    % Learning
    %
for i = complete_iters+1:iter
    
    fprintf( '%d-', i );

    if ~vx_clamped
        vx = update_vx( vx, Mvx, Vvx, x, A, s, diag_qs );
    end
    if ~Vvx_clamped
        Vvx = update_Vvx( Vvx, Mvx, vx );
    end
    if ~Mvx_clamped
        Mvx = update_Mvx( Mvx, Vvx, vx );
    end
    
    if ~A.clamped
        A = update_A( A, s, vx, x );
        %PlotA( prA, A, iter, h_track );
    end

    lambda = update_lambda( lambda, mog_pi, s, ms, vs );

    if ~s_clamped
        s = update_sica( s, ms, vs, lambda, A, vx, x, diag_qs );
    end

    if lambda_many
        lambda = update_lambda( lambda, mog_pi, s, ms, vs );
    end
    
    ms = update_mogm( ms, Mms, Vms, s, vs, lambda );
    if lambda_many
        lambda = update_lambda( lambda, mog_pi, s, ms, vs );
    end

    vs = update_mogv( vs, Mvs, Vvs, s, ms, lambda );
    if lambda_many
        lambda = update_lambda( lambda, mog_pi, s, ms, vs );
    end
    
    mog_pi = update_mogc( mog_pi, lambda );
    %lambda = update_lambda( lambda, mog_pi, s, ms, vs );

    Mms = update_Mms( Mms, Mmms, Vmms, ms, Vms );
    Vms = update_Vms( Vms, Mvms, Vvms, ms, Mms );
    
    Mvs = update_Mms( Mvs, Mmvs, Vmvs, vs, Vvs );
    Vvs = update_Vms( Vvs, Mvvs, Vvvs, vs, Mvs );
    

    Mmms = update_Mvx( Mmms, Vmms, Mms );
    Vmms = update_Vvx( Vmms, Mmms, Mms );

    Mvms = update_Mvx( Mvms, Vvms, Vms );
    Vvms = update_Vvx( Vvms, Mvms, Vms );

    Mmvs = update_Mvx( Mmvs, Vmvs, Mvs );
    Vmvs = update_Vvx( Vmvs, Mmvs, Mvs );

    Mvvs = update_Mvx( Mvvs, Vvvs, Vvs );
    Vvvs = update_Vvx( Vvvs, Mvvs, Vvs );

end

net = mogica2net( A, s, vx, Vvx, Mvx,...
                  Mmms, Vmms, Mvms, Vvms,...
                  Mmvs, Vmvs, Mvvs, Vvvs,...
                  Mms, Vms, Mvs, Vvs, ms, vs, mog_pi, lambda );

status.iter = iter;
status.CF = CF;
status.A_clamped = A.clamped;
status.s_clamped = s_clamped;

status.Mvx_clamped = Mvx_clamped;
status.Vvx_clamped = Vvx_clamped;
status.vx_clamped = vx_clamped;
status.lambda_many = lambda_many;

return
