%  UPDATE_SD.M
%
%  ICA with temporally correlated sources. Update the posterior
%  approximation for the sources q(s). The posterior correlations are
%  either modelled (diag_qs=0) or not modelled (diag_qs=1).
%
function s = update_sd( s, B, vs, A, vx, x, diag_qs )

tdim = length(x);
[ xdim, sdim ] = size(A.mean);

if ~diag_qs
    % Full approximation

    % \Sigma(St)^{-1} = ...
    %  t = 1:       < A^T \Sigma_x^{-1} A > + 
    %                              < \Sigma_s1^{-1} > +
    %                                       < B^T \Sigma_s^{-1} B >
    %  t = 2..N-1:  < A^T \Sigma_x^{-1} A > + 
    %                              < \Sigma_s^{-1} > +
    %                                       < B^T \Sigma_s^{-1} B >
    %  t = N:       < A^T \Sigma_x^{-1} A > + 
    %                              < \Sigma_s^{-1} >
    L = zeros(sdim,sdim);
    
    for i = 1:xdim
        
        % <exp{v_{x,i}}> [ <A_{i,:}>^T <A_{i,:}> + diag(~A_{i,:}~) ]
        if A.clamped
            L = L + vx.mex(i) *...
                ( A.mean(i,:)' * A.mean(i,:) );
        else
            L = L + vx.mex(i) *...
                ( A.mean(i,:)' * A.mean(i,:) + diag( A.var(i,:) ) );
        end
    end
    
    for t = tdim:-1:1
        
        if t == 1
            % + < \Sigma_s1^{-1} > + < B^T \Sigma_s^{-1} B >
            s(t).covar = inv( ...
                L + eye(sdim) + ...
                diag( vs.mex .* ( B.mean.^2 + B.var ) ) );
            
        elseif t == tdim
            % + < \Sigma_s^{-1} >
            s(t).covar = inv( L + diag( vs.mex ) );
        else
            % + < \Sigma_s^{-1} > + < B^T \Sigma_s^{-1} B >
            s(t).covar = inv( ...
                L + diag( vs.mex ) + ...
                + diag( vs.mex .* ( B.mean.^2 + B.var ) ) );
        end
        
        % \mu(St) = \Sigma(St) * ( <A>^T diag( <exp Vx> ) x_t + ...
        %  t = 1:                         <B> .* <exp vs> .* <S_{t+1}> )
        %  t = 2..N-1:  <exp vs> .* <B> .* <S_{t-1}> +
        %                                 <B> .* <exp vs> .* <S_{t+1}> )
        %  t = N:       <exp vs> .* <B> .* <S_{t-1}> )
        M = A.mean' * diag( vx.mex ) * x(t).value;
        if t > 1
            M = M + vs.mex .* B.mean .* s(t-1).mean;
        end
        if t < tdim
            M = M + B.mean .* vs.mex .* s(t+1).mean;
        end
        s(t).mean = s(t).covar * M;
        
        
    end
    
else
    
    % Diagonal approximation
    
    for j = 1:sdim

        L(j) = 0;
        for i = 1:xdim
            if A.clamped
                L(j) = L(j) + vx.mex(i) *...
                    ( A.mean(i,j)^2 );
            else
                L(j) = L(j) + vx.mex(i) *...
                    ( A.mean(i,j)^2 + A.var(i,j) );
            end
        end
    end
    
    for t = tdim:-1:1
        
        for j = sdim:-1:1
            
            if 0
                gva = 0;
                gme = 0;
                k = [ 1:j-1 j+1:sdim ];
                for i = 1:xdim
                    gva = gva + ...
                          0.5 * vx.mex(i) * ( A.mean(i,j)^2 + A.var(i,j) );
                    gme = gme + vx.mex(i) * ( ...
                        ( A.mean(i,j)^2 + A.var(i,j) ) * s(t).mean(j) - ...
                        ( x(t).value(i) - A.mean(i,k) * s(t).mean(k) ) * ...
                        A.mean(i,j) );
                end
                fprintf( 's_%d(%d): ', j, t )
                fprintf( 'gva_ch = %.15f gme_ch = %.15f\n', gva, gme )
                gva = gva + 0.5 * vs.mex(j);
                gme = gme + vs.mex(j)*...
                      ( s(t).mean(j) - B.mean(j) * s(t-1).mean(j) );
                fprintf( 'gva = %.15f gme = %.15f\n', gva, gme )
                sva = 0.5 / gva;
                sme = s(t).mean(j) - gme*sva;
                fprintf( 'sva = %.15f sme = %.15f\n', sva, sme )
            end
            
            if t == 1
                % + <exp v_{s1,j}>
                s(t).covar(j,j) = inv( ...
                    L(j) + 1 + ...
                    vs.mex(j) * ( B.mean(j)^2 + B.var(j) ) );
                
            elseif t == tdim
                s(t).covar(j,j) = inv( L(j) + vs.mex(j) );
            else
                s(t).covar(j,j) = inv( ...
                    L(j) + vs.mex(j) + ...
                    vs.mex(j) * ( B.mean(j)^2 + B.var(j) ) );
            end

            M = 0;
            k = [ 1:j-1 j+1:sdim ];
            for i = 1:xdim
                M = M + vx.mex(i) * A.mean(i,j) *...
                    ( x(t).value(i) - A.mean(i,k) * s(t).mean(k) );
            end
            if t == 1
                M = M + vs.mex(j) * B.mean(j) * s(t+1).mean(j);
            elseif t == tdim
                M = M + vs.mex(j) * B.mean(j) * s(t-1).mean(j);
            else
                M = M + vs.mex(j) * B.mean(j) * ...
                    ( s(t-1).mean(j) + s(t+1).mean(j) );
            end            
            s(t).mean(j) = s(t).covar(j,j) * M;
        end

    end

end

