function [sources, net, oldgrads, status] = update_everything(...
    sources, net, dcp_dsm, dcp_dsv, params, dcp_dnetm, dcp_dnetv, ...
    data, status, oldgrads)
% UPDATE_EVERYTHING  Perform a line search to find optimal step length
%     and update all parameters

% Copyright (C) 1999-2004 Antti Honkela, Harri Valpola,
% and Xavier Giannakopoulos.
%
% This package comes with ABSOLUTELY NO WARRANTY; for details
% see License.txt in the program package.  This is free software,
% and you are welcome to redistribute it under certain conditions;
% see License.txt for details.

switch status.updatealg,
 case 'grad',
  alg.conjugate = 0;
  alg.natgrad = 0;
 case 'conjgrad',
  alg.conjugate = 1;
  alg.natgrad = 0;
 case 'ncg',
  alg.conjugate = 1;
  alg.natgrad = 1;
 case 'ng',
  alg.conjugate = 0;
  alg.natgrad = 1;
 otherwise,
  error('Unknown updatealg');
end

if ~isfield(status, 'varalpha'),
  status.varalpha = 1;
end

x0 = struct('s', sources, 'net', net);
newvars = struct('s', .5 ./ max(dcp_dsv, .45 ./ sources.v), ...
		 'w1', .5 ./ max(dcp_dnetv.w1, .45 ./ net.w1.v), ...
		 'w2', .5 ./ max(dcp_dnetv.w2, .45 ./ net.w2.v), ...
		 'b1', .5 ./ max(dcp_dnetv.b1, .45 ./ net.b1.v), ...
		 'b2', .5 ./ max(dcp_dnetv.b2, .45 ./ net.b2.v));
if status.updatesrcvars < 0,
  newvars.s = sources.v;
end
if status.updatenet < 0,
  newvars.w1 = net.w1.v;
  newvars.w2 = net.w2.v;
  newvars.b1 = net.b1.v;
  newvars.b2 = net.b2.v;
end

fargs = struct('data', data, 'params', params, 'status', status);

[x, alpha, c] = bisection_search(x0, newvars, ...
				 min([1, 2*sqrt(status.varalpha)]), ...
				 @search_helper_var, fargs);
status.varalpha = 2*alpha;

if status.updatesrcs < 0,
  dc_dsm = zeros(size(dcp_dsm));
else
  dc_dsm = -dcp_dsm;
end

if status.updatenet < 0,
  dc_dnetm.w1 = zeros(size(dcp_dnetm.w1));
  dc_dnetm.b1 = zeros(size(dcp_dnetm.b1));
  dc_dnetm.w2 = zeros(size(dcp_dnetm.w2));
  dc_dnetm.b2 = zeros(size(dcp_dnetm.b2));
else
  dc_dnetm.w1 = -dcp_dnetm.w1;
  dc_dnetm.b1 = -dcp_dnetm.b1;
  dc_dnetm.w2 = -dcp_dnetm.w2;
  dc_dnetm.b2 = -dcp_dnetm.b2;
end

grad = [dc_dsm(:); ...
	dc_dnetm.w1(:); dc_dnetm.w2(:); ...
	dc_dnetm.b1(:); dc_dnetm.b2(:)];

if alg.natgrad,
  vars = [x.s.v(:); ...
	  x.net.w1.v(:); x.net.w2.v(:); ...
	  x.net.b1.v(:); x.net.b2.v(:)];
  
  ngrad = vars .* grad;
  fn = sum((1 ./ vars) .* (ngrad .^ 2));

  [ngrad_rest dc_dsm] = pop(ngrad, dc_dsm);
  [ngrad_rest dc_dnetm.w1] = pop(ngrad_rest, dc_dnetm.w1);
  [ngrad_rest dc_dnetm.w2] = pop(ngrad_rest, dc_dnetm.w2);
  [ngrad_rest dc_dnetm.b1] = pop(ngrad_rest, dc_dnetm.b1);
  [ngrad_rest dc_dnetm.b2] = pop(ngrad_rest, dc_dnetm.b2);
end

if alg.conjugate,
  gn = sum(grad.^2);

  if (oldgrads.norm ~= 0) & all(size(oldgrads.s) == size(dc_dsm)) ...
	& all(size(oldgrads.net.w1) == size(dc_dnetm.w1)),
    % Powell-Beale restarts
    if alg.natgrad,
      do_restart = abs(riemann_prod(oldgrads.ngrad, ngrad, vars)) >= .2*fn;
    else,
      do_restart = oldgrads.grad' * grad >= .2*gn;
    end
    if do_restart,
      fprintf('Resetting CG (Powell-Beale)\n');
      oldgrads.s = zeros(size(sources));
      oldgrads.net = netgrad_zeros(net);
      oldgrads.norm = 0;
      oldgrads.grad = zeros(size(grad));
      beta = 0;
    else
      % Fletcher-Reeves formula
      % beta = gn / oldgrads.norm;

      if alg.natgrad,
	% Polak-Ribiere formula
	beta = riemann_prod(ngrad, ngrad - oldgrads.ngrad, ...
                        vars) / oldgrads.fnorm;
      else
	% Hestenes-Stiefel formula
	beta = (gn - grad'*oldgrads.grad) / (oldgrads.p' * (grad - oldgrads.grad));
      end
    end

    dc_dsm = dc_dsm + beta * oldgrads.s;
    dc_dnetm.w1 = dc_dnetm.w1 + beta * oldgrads.net.w1;
    dc_dnetm.w2 = dc_dnetm.w2 + beta * oldgrads.net.w2;
    dc_dnetm.b1 = dc_dnetm.b1 + beta * oldgrads.net.b1;
    dc_dnetm.b2 = dc_dnetm.b2 + beta * oldgrads.net.b2;
  end
  oldgrads.norm = gn;
  oldgrads.s = dc_dsm;
  oldgrads.net = dc_dnetm;
  oldgrads.grad = grad;
  if alg.natgrad,
    oldgrads.ngrad = ngrad;
    oldgrads.fnorm = fn;
  else
    oldgrads.p = -[dc_dsm(:); ...
		   dc_dnetm.w1(:); dc_dnetm.w2(:); ...
		   dc_dnetm.b1(:); dc_dnetm.b2(:)];
  end
end

step = struct('s', dc_dsm, 'net', dc_dnetm);
fargs = struct('data', data, 'params', params, 'status', status);

[x, status.t0] = linesearch(x, step, status.t0, @search_helper_mean, ...
			    fargs, c);
sources = x.s;
net = x.net;



function [x, c] = search_helper_mean(x, step, lambda, d),

x.s.m = x.s.m + lambda * step.s;
x.net.w1.m = x.net.w1.m + lambda * step.net.w1;
x.net.b1.m = x.net.b1.m + lambda * step.net.b1;
x.net.w2.m = x.net.w2.m + lambda * step.net.w2;
x.net.b2.m = x.net.b2.m + lambda * step.net.b2;
fs_tmp = feedfw(x.s, x.net, d.status.approximation);
c = kl_batch(fs_tmp{4}, x.s, d.data, d.params) + ...
    kl_static_split(x.net, d.params, 0);



function [x, c] = search_helper_var(x, newvar, lambda, d),

x.s.v = exp(lambda * log(newvar.s) + (1-lambda) * log(x.s.v));
x.net.w1.v = exp(lambda * log(newvar.w1) + (1-lambda) * log(x.net.w1.v));
x.net.w2.v = exp(lambda * log(newvar.w2) + (1-lambda) * log(x.net.w2.v));
x.net.b1.v = exp(lambda * log(newvar.b1) + (1-lambda) * log(x.net.b1.v));
x.net.b2.v = exp(lambda * log(newvar.b2) + (1-lambda) * log(x.net.b2.v));
fs_tmp = feedfw(x.s, x.net, d.status.approximation);
c = kl_batch(fs_tmp{4}, x.s, d.data, d.params) + ...
    kl_static_split(x.net, d.params, 0);


function [A C] = pop(A, B),
% POP  Implements a stack style pop operation with reshaping
%
%  Returns matrix C with shape and size of matrix B from the vector
%  stack A.
%
%  [A, C] = pop(A, B)
%
%  Pop elements from vector A to create a size B matrix C
n = prod(size(B));
C = reshape(A(1:n), size(B));
A = A(n+1:end);


function n = riemann_prod(x, y, vars),
% RIEMANN_PROD Inner product in a Riemannian space

n = x' * (vars.^(-1) .* y);
