function [dc_dsm, dc_dsv, dc_dnetm, dc_dnetv, dx] =...
    feedback(x, net, sources, data, noiseparam, status)
% FEEDBACK Do feedback phase calculations

% Copyright (C) 1999-2004 Antti Honkela, Harri Valpola,
% and Xavier Giannakopoulos.
%
% This package comes with ABSOLUTELY NO WARRANTY; for details
% see License.txt in the program package.  This is free software,
% and you are welcome to redistribute it under certain conditions;
% see License.txt for details.

noisevar = normalvar(noiseparam);

nsampl = size(data, 2);
nsources = size(sources, 1);

datavars = noisevar * ones(1, nsampl);

dx{4}.var = .5 ./ datavars;
dx{4}.e = (x{4}.e - data(:,1:nsampl)) ./ datavars;
dx{4}.extra = dx{4}.var;
dx{4}.multi = repmat(shiftdim(sources.var, -1), ...
		     [size(x{4}.multi, 1) 1 1])...
    .* repmat(reshape(dx{4}.var, [size(data, 1) 1 nsampl]), [1 nsources 1])...
    .* (2 * x{4}.multi);

multivar = zeros(size(sources));

% The first layer (linear)
temp = (x{4}.multi).^2;

% Somewhat more efficient way to calculate
%   multivar(:,i) = temp(:,:,i)' * multivar(:,i);
for i=1:nsources
  multivar(i,:) = sum(reshape(temp(:,i,:), size(dx{4}.var)) .* dx{4}.var, 1);
end

dx{3}.var = net.w2.var' * dx{4}.extra;
dx{3}.e = net.w2.e' * dx{4}.e + (2*net.w2.var' * dx{4}.extra) .* x{3}.e;
dx{3}.extra = net.w2.e' .^2 * dx{4}.extra;

%dx{3}.multi = zeros(size(x{3}.multi));
%  dx{3}.multi(:,:,i) = net.w2.e' * dx{4}.multi(:,:,i);

d0 = size(net.w2, 2);
[d1 d2 d3] = size(dx{4}.multi);
dx{3}.multi = ...
    reshape(net.w2.e' * reshape(dx{4}.multi, [d1 d2*d3]), [d0 d2 d3]);


[dc_dnetm.w2, dc_dnetv.w2, dc_dnetm.b2, dc_dnetv.b2] = ...
    netgrads(x{3}, dx{4}, net.w2, net.b2);

% The second layer (nonlinear)

if strcmp(status.approximation, 'hermite'),
  [dx{2}.e, dx{2}.var, dx{2}.multi, dx{2}.extra] = ...
      feedback_hermite(dx{3}.e, dx{3}.var, dx{3}.multi, dx{3}.extra, ...
		       x{2}.e, x{2}.var, x{2}.multi, x{2}.extra, ...
		       x{3}.e, x{3}.var, net.nonlin, x{5}, status);
elseif strcmp(status.approximation, 'taylor'),
  [dx{2}.e, dx{2}.var, dx{2}.multi, dx{2}.extra] = ...
      feedback_taylor(dx{3}.e, dx{3}.var, dx{3}.multi, dx{3}.extra, ...
		      x{2}.e, x{2}.var, x{2}.multi, x{2}.extra, ...
		      net.nonlin);
else
  error('Unsupported approximation')
end

dx{1}.e = net.w1.e' * dx{2}.e + ...
    (2 * net.w1.var' * (dx{2}.var + dx{2}.extra)) .* x{1}.e;
dx{1}.var = (net.w1.e'.^2 + net.w1.var') * dx{2}.var + ...
    net.w1.var' * dx{2}.extra;


[dc_dnetm.w1, dc_dnetv.w1, dc_dnetm.b1, dc_dnetv.b1] = ...
    netgradstop(x{1}, dx{2}, net.w1, net.b1);

dc_dsm = dx{1}.e;
dc_dsv = dx{1}.var + multivar;


function [dm, dv, dmv, dev] = ...
    feedback_hermite(dgm0, dgv0, dgmv, dgev, m_in, v_in, mv_in, ev_in, ...
		     m_out, v_out, nonlin, aux, status)
% FEEDBACK_HERMITE  Evaluate the gradients of Gauss-Hermite quadrature
%   approximation of nonlinearity

% The order of approximation and related abscissas and weights
n = 3;
xi = [0, sqrt(6)/2, -sqrt(6)/2];
wi = [2/3, 1/6, 1/6];

% Basis points with extravar as variance
%ev_args = zeros([size(m_in), n]);
%for k=1:length(xi),
%  ev_args(:, :, k) = m_in + xi(k) * sqrt(2 * ev_in);
%end

% Components of the sum to evaluate output mean with extravar as
% input variance
%ev_sum = repmat(reshape(wi, [1, 1, n]), [size(m_in), 1]) .* ...
%	 feval(nonlin, ev_args);

% Basis points with input var as variance
%v_args = zeros([size(m_in), n]);
%for k=1:length(xi),
%  v_args(:, :, k) = m_in + xi(k) * sqrt(2 * v_in);
%end

% Components of the sum to evaluate output mean with var as
% input variance
%v_sum = repmat(reshape(wi, [1, 1, n]), [size(m_in), 1]) .* ...
%	feval(nonlin, v_args);

% Output mean (now given as input)
% m_out = sum(v_sum, 3);

% Normalised sum components
v_sum0 = (repmat(reshape(wi, [1, 1, n]), [size(m_out), 1]) .* ...
	  aux.f_vardevs);
ev_sum0 = (repmat(reshape(wi, [1, 1, n]), [size(m_out), 1]) .* ...
	   aux.f_evdevs);

% Compensate the use of output mean in evaluation of var and extravar
% dgm = dgm0 + 2 * (m_out - sum(ev_sum, 3)) .* dgev;
dgm = dgm0 - 2 * sum(ev_sum0, 3) .* dgev;

% Compensate the use of output variance in evaluation of multivar
% This value cannot be used in gradients with respect to variance as
% it breaks the fixed point update rule used for the variances
dgv_mv = dgv0 + 1 ./ (2 * sqrt(v_out .* v_in) + 1e-20) .* ...
	 reshape(sum(dgmv .* mv_in, 2), size(dgv0));
dgv = dgv0;

% Easy case first: the multivars
d = sqrt(v_out ./ v_in);
dmv = repmat(reshape(d, [size(m_in, 1), 1, size(dgmv, 3)]),...
	     [1, size(dgmv, 2), 1]) .* dgmv;

% Evaluate the derivative of the nonlinearity at basis points
der_vargs = feval(['d3' nonlin], aux.v_args, aux.f_varvals);
der_evargs = feval(['d3' nonlin], aux.ev_args, aux.f_evvals);

% Partial derivative with respect to input extravar
dev = sum(ev_sum0 .* ...
	  der_evargs .* ...
	  repmat(reshape(xi, [1, 1, n]), [size(m_out), 1]), 3) .* ...
      (ev_in .^ -0.5) .* dgev;

% Partial derivative with respect to input var
sqrtvi = v_in .^ -0.5;
temp = .5 * sum(repmat(reshape(wi .* xi, [1, 1, n]), [size(m_out), 1]) .* ...
		der_vargs, 3) .* sqrtvi .* dgm;

if strcmp(status.updatealg, 'old'),
  dv = sum(v_sum0 .* ...
	   der_vargs .* ...
	   repmat(reshape(xi, [1, 1, n]), [size(m_out), 1]), 3) .* ...
       sqrtvi .* dgv + ...
       temp .* (temp > 0);
else
  dv = sum(v_sum0 .* ...
	   der_vargs .* ...
	   repmat(reshape(xi, [1, 1, n]), [size(m_out), 1]), 3) .* ...
       sqrtvi .* dgv + temp;
end

% Partial derivative with respect to input mean
dm = sum(repmat(reshape(wi, [1, 1, n]), [size(m_out), 1]) .* ...
	 der_vargs, 3) .* dgm + ...
     2 * sum(v_sum0 .* ...
	     der_vargs, 3) .* dgv_mv + ...
     2 * sum(ev_sum0 .* ...
	     der_evargs, 3) .* dgev;


function [dm, dv, dmv, dev] = ...
    feedback_taylor(dgm0, dgv0, dgmv, dgev, m_in, v_in, mv_in, ev_in, nonlin)
% FEEDBACK_TAYLOR  Evaluate the gradients of Taylor
%   approximation of nonlinearity

[der1, der2, der3] = feval(['d3' nonlin], m_in);

temp = .5 * der2 .* dgm0;

dv = temp .* (temp > 0) + (der1 .^ 2) .* dgv0;

dm = dgm0 .* (der1 + .5*v_in .* der3 .* (temp > 0)) + ...
     2 * dgv0 .* v_in .* der2 .* der1 + ...
     2 * dgev .* ev_in .* der2 .* der1 + ...
     reshape(sum(dgmv .* mv_in, 2), size(der2)) .* der2;
dev = (der1 .^ 2) .* dgev;

dmv = repmat(reshape(der1, [size(m_in, 1) 1 size(dgmv, 3)]),...
	     [1 size(dgmv, 2) 1]) .* dgmv;


function [dcp_dwm, dcp_dwv, dcp_dbm, dcp_dbv] = netgrads(x, dx, w, b)
% NETGRADS Calculate partial derivatives of kldiv with respect to
%   network weights

% A more efficient way to calculate
%temp = x.multivar;
%for i=1:nsampl
%  bonus = bonus + dx.multi(:,:,i) * temp(:,:,i)';
%end
d0 = size(x.multi, 1);
[d1 d2 d3] = size(dx.multi);
bonus = reshape(dx.multi, [d1 d2*d3]) * reshape(x.multi, [d0 d2*d3])';

dcp_dwm = dx.e * x.e' + ...
	  2 * (dx.extra * x.extra') .* w.e ...
	  + bonus;
dcp_dwv = dx.extra * (x.var + x.e .^ 2)';

dcp_dbm = sum(dx.e, 2);
dcp_dbv = sum(dx.extra, 2);


function [dcp_dwm, dcp_dwv, dcp_dbm, dcp_dbv] = netgradstop(x, dx, w, b)
% NETGRADSTOP Calculate partial derivatives of kldiv with respect to
%   network weights

dcp_dwm = dx.e * x.e' + ...
	  2 * (dx.var * x.var') .* w.e + ...
	  sum(dx.multi, 3);
dcp_dwv = (dx.extra + dx.var) * (x.var + x.e .^ 2)';

dcp_dbm = sum(dx.e, 2);
dcp_dbv = sum(dx.var + dx.extra, 2);
