%  UPDATE_SICA.M
%
%  ICA with the mixture-of-Gaussians model for the sources. Update the
%  posterior approximation for the sources q(s). The posterior
%  correlations are either modelled (diag_qs=0) or not modelled
%  (diag_qs=1). Miskin and MacKay's Gaussian approximation for the
%  source posterior is used.
%
function s = update_sica( s, ms, vs, lambda, A, vx, x, diag_qs )

tdim = length(x);
[ xdim, sdim ] = size(A.mean);

if ~diag_qs
    % Full approximation

    % \Sigma(St)^{-1} = < A^T \Sigma_x^{-1} A > + 
    %
    %                   diag( sum lambda_{t,j,k} < exp v_{j,k} >
    %                          k
    L = zeros(sdim,sdim);
    
    for i = 1:xdim
        
        % <exp{v_{x,i}}> [ <A_{i,:}>^T <A_{i,:}> + diag(~A_{i,:}~) ]
        if A.clamped
            L = L + vx.mex(i) *...
                ( A.mean(i,:)' * A.mean(i,:) );
        else
            L = L + vx.mex(i) *...
                ( A.mean(i,:)' * A.mean(i,:) + diag( A.var(i,:) ) );
        end
    end
    
    for t = 1:tdim
        
        for j = 1:sdim
            % + diag( sum lambda_{t,j,k} < exp v_{j,k} >
            %          k
            L_t(j) = sum( lambda{j}(:,t) .* vs{j}.mex );
        end

        s(t).covar = inv( L + diag(L_t) );
        
        
        % \mu(St) = \Sigma(St) * ( <A>^T diag( <exp Vx> ) x_t + ...
        %           { sum lambda_{t,j,k} < exp v_{j,k} m_{j,k} > }
        %              k
        %
        for j = 1:sdim
            M(j,1) = sum( lambda{j}(:,t) .* vs{j}.mex .* ms{j}.mean );
        end
        M = A.mean' * diag( vx.mex ) * x(t).value + M;
        
        s(t).mean = s(t).covar * M;
        
    end
    
else
    
    % Diagonal approximation
    %                   n
    % ~s_{t,j}~^{-1} = sum <exp v_{x,i}> ( <a_{ij}>^2 + ~a_{ij}~ )
    %                  i=1
    %
    %                  + sum lambda_{t,j,k} < exp v_{j,k} >
    %                     k
    %
    % <s_{t,j)> = ~s_{t,j}~ ( 
    %     n
    %    sum <exp v_{x,i}> <a_{ij}> ( x_{t,i} - sum <a_{ik}><s_{t,k}> )
    %    i=1
    %
    %    + sum lambda_{t,j,k} < exp v_{j,k} ><m_{j,k}> )
    %       k 
    %
    for j = 1:sdim

        L(j) = 0;
        for i = 1:xdim
            if A.clamped
                L(j) = L(j) + vx.mex(i) *...
                    ( A.mean(i,j)^2 );
            else
                L(j) = L(j) + vx.mex(i) *...
                    ( A.mean(i,j)^2 + A.var(i,j) );
            end
        end
    end

    for t = 1:tdim
        
        for j = 1:sdim
            
            s(t).covar(j,j) = inv( ...
                L(j) + sum( lambda{j}(:,t) .* vs{j}.mex ) );
            
            M = sum( lambda{j}(:,t) .* vs{j}.mex .* ms{j}.mean );
            k = [ 1:j-1 j+1:sdim ];
            for i = 1:xdim
                M = M + vx.mex(i) * A.mean(i,j) *...
                    ( x(t).value(i) - A.mean(i,k) * s(t).mean(k) );
            end
            s(t).mean(j) = s(t).covar(j,j) * M;
        end
        
    end
    
    

end

