%  UPDATE_SV.M
%
%  ICA with super-Gaussian sources. Update the posterior approximation
%  for the sources q(s). The posterior correlations are either
%  modelled (diag_qs=0) or not modelled (diag_qs=1).
%
function s = update_sv( s, u, A, vx, x, diag_qs )

tdim = length(x);
[ xdim, sdim ] = size(A.mean);

%const0.mean = 0;
%const0.var = 0;
%const0.mex = meanexp( const0 );
%disp( 'Before update_sv' )
%fprintf( 'cost_sv: %f\n', cost_sv( s, const0, u, diag_qs ) )
%fprintf( 'cost_xv: %f\n', cost_xv( x, A, s, vx, diag_qs ) )

if ~diag_qs
    % Full approximation

    % \Lambda(St) = < A^T \Sigma_x^{-1} A + \Sigma_s^{-1} >
    L = zeros(sdim,sdim);
    
    for i = 1:xdim
        
        % <exp{v_{x,i}}> [ <A_{i,:}>^T <A_{i,:}> + diag(~A_{i,:}~) ]
        % + diag( <exp v> )
        if A.clamped
            L = L + vx.mex(i) *...
                ( A.mean(i,:)' * A.mean(i,:) );
        else
            L = L + vx.mex(i) *...
                ( A.mean(i,:)' * A.mean(i,:) + diag( A.var(i,:) ) );
        end
    end
    
    for t = 1:tdim
        
        % \Sigma(St) = \Lambda(St)^{-1}
        s(t).covar = inv( L  + diag( u(t).mex ) );
        
        % \mu(St) = \Sigma(St) <A>^T diag( <exp Vx>} ) x_t
        
        s(t).mean = s(t).covar *...
            A.mean' * diag( vx.mex ) * x(t).value;
        
    end

else
    
    % Diagonal approximation
    for j = 1:sdim

        L(j) = 0;
        for i = 1:xdim
            if A.clamped
                L(j) = L(j) + vx.mex(i) *...
                    ( A.mean(i,j)^2 );
            else
                L(j) = L(j) + vx.mex(i) *...
                    ( A.mean(i,j)^2 + A.var(i,j) );
            end
        end
    end

    for t = 1:tdim
        
        for j = sdim:-1:1
            
            s(t).covar(j,j) = inv( L(j) + u(t).mex(j) );
            
            M = 0;
            k = [ 1:j-1 j+1:sdim ];
            for i = 1:xdim
                M = M + vx.mex(i) * A.mean(i,j) *...
                    ( x(t).value(i) - A.mean(i,k) * s(t).mean(k) );
            end
            s(t).mean(j) = s(t).covar(j,j) * M;
        end
        
    end
    
end

%disp( 'After update_sv' )
%fprintf( 'cost_sv: %f\n', cost_sv( s, const0, u, diag_qs ) )
%fprintf( 'cost_xv: %f\n', cost_xv( x, A, s, vx, diag_qs ) )
