%  UPDATE_VX.M
%
%  Linear factor analysis model. Update the posterior approximation
%  for the observation noise parameter vx.
%
function vx = update_vx( vx, Mvx, Vvx, x, A, s, diag_qs )

tdim = length(x);
[ xdim, sdim ] = size(A.mean);

if ~diag_qs

    % Full approximation
    for i = xdim:-1:1
        
        % From itself:
        %   gme = < exp Vvx > ( <vx_i> - <Mvx> )
        %   gva = 0.5 <exp Vvx>
        %   gex = 0
        gme_1 = Vvx.mex * ( vx.mean(i) - Mvx.mean );
        gva_1 = 0.5 * Vvx.mex;
        gex_1 = 0;
        
        % From children x:
        %          N
        %   gme = sum -0.5
        %         t=1
        %
        %   gva = 0
        %
        %   f_{t,i}   = A_{i,:} St
        %   <f_{t,i}> = <A_{i,:}> <St>
        %   ~f_{t,i}~ = ~A_{i,:} St~ = <A_{i,:}> \Sigma(St) <A_{i,:}>^T +
        %                              + ~A_{i,:}~ ( <St>^2 + ~St~ )
        %
        %          N
        %   gex = sum 0.5 [ ( x_{t,i} - <f_{t,i}> )^2 + ~f_{t,i}~ ]
        %         t=1
        %
        gme = - tdim/2;
        gva = 0;
        gex = 0;
        
        for t = 1:tdim
            if A.clamped
                f_var = A.mean(i,:) * s(t).covar * A.mean(i,:)';
            else
                f_var = A.mean(i,:) * s(t).covar * A.mean(i,:)' + ...
                        A.var(i,:) * ( s(t).mean.^2 + diag( s(t).covar ) );
            end
            gex = gex + ...
                  0.5 * ( ( x(t).value(i) - A.mean(i,:)*s(t).mean )^2 +...
                          f_var );
        end
        
        [vx.mean(i), vx.var(i)] = ...
            var_newton( vx.mean(i), vx.var(i), ...
                        gme_1+gme, gva_1+gva, gex_1+gex );
        
    end

else
    
    % Diagonal approximation
    for i = xdim:-1:1
        
        % From itself:
        %   gme = < exp Vvx > ( <vx_i> - <Mvx> )
        %   gva = 0.5 <exp Vvx>
        %   gex = 0
        gme_1 = Vvx.mex * ( vx.mean(i) - Mvx.mean );
        gva_1 = 0.5 * Vvx.mex;
        gex_1 = 0;
        
        % From children x:
        %          N
        %   gme = sum -0.5
        %         t=1
        %
        %   gva = 0
        %
        %   f_{t,i}   = A_{i,:} St
        %   <f_{t,i}> = <A_{i,:}> <St>
        %   ~f_{t,i}~ = ~A_{i,:} St~ = <A_{i,:}> \Sigma(St) <A_{i,:}>^T +
        %                              + ~A_{i,:}~ ( <St>^2 + ~St~ )
        %
        %          N
        %   gex = sum 0.5 [ ( x_{t,i} - <f_{t,i}> )^2 + ~f_{t,i}~ ]
        %         t=1
        %
        gme = - tdim/2;
        gva = 0;
        gex = 0;
        
        for t = 1:tdim
            if A.clamped
                f_var = A.mean(i,:).^2 * diag( s(t).covar );
            else
                f_var = A.mean(i,:).^2 * diag( s(t).covar ) + ...
                        A.var(i,:) * ( s(t).mean.^2 + diag( s(t).covar ) );
            end
            gex = gex + ...
                  0.5 * ( ( x(t).value(i) - A.mean(i,:)*s(t).mean )^2 +...
                          f_var );
        end
        
        [vx.mean(i), vx.var(i)] = ...
            var_newton( vx.mean(i), vx.var(i), ...
                        gme_1+gme, gva_1+gva, gex_1+gex );
        
    end


end

vx.mex  = meanexp( vx );

