%  UPDATE_S.M
%
%  Static model for linear factor analysis.
%  Update the 'full' posterior approximation of the sources Q(s).
%
%  Ilin A.N.
%
function s = update_s( s, vs, A, vx, x )

tdim = length(x);
[ xdim, sdim ] = size(A.mean);

diag_qs = 0;
if ~diag_qs
    % Full approximation
    for t = 1:tdim
        
        % \Lambda(St) = < A^T \Sigma_x^{-1} A + \Sigma_s^{-1} >
        L = zeros(sdim,sdim);
        
        for i = 1:xdim
            
            % <exp{v_{x,i}}> [ <A_{i,:}>^T <A_{i,:}> + diag(~A_{i,:}~) ]
            if A.clamped
                L = L + vx.mex(i) *...
                    ( A.mean(i,:)' * A.mean(i,:) );
            else
                L = L + vx.mex(i) *...
                    ( A.mean(i,:)' * A.mean(i,:) + diag( A.var(i,:) ) );
            end
        end
        % + diag( <exp v> )
        L = L + diag( vs.mex );
        
        % \Sigma(St) = \Lambda(St)^{-1}
        s(t).covar = inv(L);
        
        % \mu(St) = \Sigma(St) <A>^T diag( <exp Vx>} ) x_t
        
        s(t).mean = s(t).covar *...
            A.mean' * diag( vx.mex ) * x(t).value;
        
    end

else

    % Diagonal approximation
    for t = 1:tdim
    
        for j = sdim:-1:1
            
            L = 0;
            for i = 1:xdim
                if A.clamped
                    L = L + vx.mex(i) *...
                        ( A.mean(i,j)^2 );
                else
                    L = L + vx.mex(i) *...
                        ( A.mean(i,j)^2 + A.var(i,j) );
                end
            end
            L = L + vs.mex(j);
            
            s(t).covar(j,j) = inv(L);
            
            M = 0;
            k = [ 1:j-1 j+1:sdim ];
            for i = 1:xdim
                M = M + vx.mex(i) * A.mean(i,j) *...
                    ( x(t).value(i) - A.mean(i,k) * s(t).mean(k) );
            end
            s(t).mean(j) = s(t).covar(j,j) * M;
        end
        
    end
    
end   
