function [sources, net, tnet, oldgrads, status] = update_everything(...
    sources, net, tnet, dcp_dsm, dcp_dsvn, fs, tfs, params, ...
    dcp_dnetm, dcp_dnetv, dcp_dtnetm, dcp_dtnetv, newac, ...
    data, oldc, status, oldgrads, clamped, missing, notimedep)
% UPDATE_EVERYTHING  Perform a line search to find optimal step length
%     and update all parameters

% Copyright (C) 1999-2005 Antti Honkela, Harri Valpola,
% Xavier Giannakopoulos and Matti Tornio
%
% This package comes with ABSOLUTELY NO WARRANTY; for details
% see License.txt in the program package.  This is free software,
% and you are welcome to redistribute it under certain conditions;
% see License.txt for details.

[sources, net, tnet, newc, restc, status] = update_variances(...
    sources, net, tnet, dcp_dsvn, fs, tfs, params, dcp_dnetv, dcp_dtnetv, ...
    newac, data, oldc, status, clamped, missing, notimedep);

if status.updatesrcs < 0,
  dc_dsm = zeros(size(dcp_dsm));
else
  dc_dsm = -dcp_dsm;
  % Set gradient for the clamped sources to zeros
  dc_dsm(find(clamped)) = 0;
end

if status.updatenet < 0,
  dc_dnetm.w1 = zeros(size(dcp_dnetm.w1));
  dc_dnetm.b1 = zeros(size(dcp_dnetm.b1));
  dc_dnetm.w2 = zeros(size(dcp_dnetm.w2));
  dc_dnetm.b2 = zeros(size(dcp_dnetm.b2));
else
  dc_dnetm.w1 = -dcp_dnetm.w1;
  dc_dnetm.b1 = -dcp_dnetm.b1;
  dc_dnetm.w2 = -dcp_dnetm.w2;
  dc_dnetm.b2 = -dcp_dnetm.b2;
end

if status.updatetnet < 0,
  dc_dtnetm.w1 = zeros(size(dcp_dtnetm.w1));
  dc_dtnetm.b1 = zeros(size(dcp_dtnetm.b1));
  dc_dtnetm.w2 = zeros(size(dcp_dtnetm.w2));
  dc_dtnetm.b2 = zeros(size(dcp_dtnetm.b2));
else
  dc_dtnetm.w1 = -dcp_dtnetm.w1;
  dc_dtnetm.b1 = -dcp_dtnetm.b1;
  dc_dtnetm.w2 = -dcp_dtnetm.w2;
  dc_dtnetm.b2 = -dcp_dtnetm.b2;
end


% CG for networks and CG or Kalman for sources
if strcmp(status.updatealg, 'conjgrad'),
  grad = [dc_dsm(:); ...
	  dc_dnetm.w1(:); dc_dnetm.w2(:); ...
	  dc_dnetm.b1(:); dc_dnetm.b2(:); ...
	  dc_dtnetm.w1(:); dc_dtnetm.w2(:); ...
	  dc_dtnetm.b1(:); dc_dtnetm.b2(:)];

  gn = sum(grad.^2);

  if (oldgrads.norm ~= 0) & isfield(oldgrads, 'grad') & ...
	all(size(oldgrads.s) == size(dc_dsm)) & ...
	all(size(oldgrads.net.w1) == size(dc_dnetm.w1)) & ...
	all(size(oldgrads.tnet.w1) == size(dc_dtnetm.w1)) & ...
	all(size(oldgrads.grad) == size(grad)),
    
    % Powell-Beale restarts
    if oldgrads.grad'*grad >= .2*gn,
      fprintf('Resetting CG (Powell-Beale)\n');
      oldgrads.s = zeros(size(sources));
      oldgrads.net = netgrad_zeros(net);
      oldgrads.tnet = netgrad_zeros(tnet);
      oldgrads.norm = 0;
      oldgrads.grad = zeros(size(grad));
      beta = 0;
    else
      % Fletcher-Reeves formula (from Antti's version)
      %beta = gn / oldgrads.norm;
      % Polak-Ribire formula
      %beta = grad' * (grad - oldgrads.grad) / oldgrads.norm;
      %beta = max([0 beta]);
      % Hestenes-Stiefel formula
      beta = (gn - grad'*oldgrads.grad) / (oldgrads.p' * (grad - oldgrads.grad));
    end

    if status.debug >= 2,
      Nh = size(status.kls, 2);
      status.history.beta(Nh) = beta;
    end
    
    dc_dsm       = dc_dsm       + beta * oldgrads.s;
    dc_dnetm.w1  = dc_dnetm.w1  + beta * oldgrads.net.w1;
    dc_dnetm.w2  = dc_dnetm.w2  + beta * oldgrads.net.w2;
    dc_dnetm.b1  = dc_dnetm.b1  + beta * oldgrads.net.b1;
    dc_dnetm.b2  = dc_dnetm.b2  + beta * oldgrads.net.b2;
    dc_dtnetm.w1 = dc_dtnetm.w1 + beta * oldgrads.tnet.w1;
    dc_dtnetm.w2 = dc_dtnetm.w2 + beta * oldgrads.tnet.w2;
    dc_dtnetm.b1 = dc_dtnetm.b1 + beta * oldgrads.tnet.b1;
    dc_dtnetm.b2 = dc_dtnetm.b2 + beta * oldgrads.tnet.b2;
  end
  oldgrads.norm = gn;
  oldgrads.s = dc_dsm;
  oldgrads.net = dc_dnetm;
  oldgrads.tnet = dc_dtnetm;
  oldgrads.p = -[dc_dsm(:); ...
		 dc_dnetm.w1(:); dc_dnetm.w2(:); ...
		 dc_dnetm.b1(:); dc_dnetm.b2(:); ...
		 dc_dtnetm.w1(:); dc_dtnetm.w2(:); ...
		 dc_dtnetm.b1(:); dc_dtnetm.b2(:)];
  oldgrads.grad = grad;
end

step.s = dc_dsm;
step.net = dc_dnetm;
step.tnet = dc_dtnetm;
[sources, net, tnet, status] = cubic_search(...
    sources, net, tnet, newc, restc, step, ...
    fs, tfs, data, params, status.t0, status, missing, notimedep);

if status.debug >= 2,
  Nh = size(status.kls, 2);
  status.history.t0(Nh) = status.t0;
  status.history.linestep(Nh) = sqrt(oldgrads.norm) * status.t0;
end


function [sources, net, tnet, newc, restc, status] = update_variances(...
    sources0, net0, tnet0, dcp_dsvn, fs, tfs, params, dcp_dnetv, dcp_dtnetv, ...
    newac, data, oldcost, status, clamped, missing, notimedep)
% UPDATE_VARIANCES  Update the variances of the network and the sources

epsilon = 1e-6;
epsilon2 = 1e-10;
sources = sources0;
net = net0;
tnet = tnet0;

if ~isfield(status, 'varalpha'),
  status.varalpha = 1;
end

newc = inf;
alpha = min([1, 10*sqrt(status.varalpha)]);

if status.updatesrcvars >= 0,
  step.s = .5 ./ max(dcp_dsvn, .45 ./ sources.nvar);
end

step.ac = newac;

if status.updatenet >= 0,
  step.net.w1 = .5 ./ max(dcp_dnetv.w1, .45 ./ net.w1.var);
  step.net.w2 = .5 ./ max(dcp_dnetv.w2, .45 ./ net.w2.var);
  step.net.b1 = .5 ./ max(dcp_dnetv.b1, .45 ./ net.b1.var);
  step.net.b2 = .5 ./ max(dcp_dnetv.b2, .45 ./ net.b2.var);
end

if status.updatetnet >= 0,
  step.tnet.w1 = .5 ./ max(dcp_dtnetv.w1, .45 ./ tnet.w1.var);
  step.tnet.w2 = .5 ./ max(dcp_dtnetv.w2, .45 ./ tnet.w2.var);
  step.tnet.b1 = .5 ./ max(dcp_dtnetv.b1, .45 ./ tnet.b1.var);
  step.tnet.b2 = .5 ./ max(dcp_dtnetv.b2, .45 ./ tnet.b2.var);
end

itercount = 0;
hcost(1) = oldcost;
halpha(1) = 0;

% Variation of backtracking linesearch (with quadratic interpolation)
while newc > (oldcost + epsilon),
  itercount = itercount + 1;
  if alpha < epsilon2,
    warning('Variance update failed');
    sources = sources0;
    net = net0;
    tnet = tnet0;
    newc = oldcost;
    break;
  end

  [sources, net, tnet] = update_var(...
      sources0, net0, tnet0, alpha, step, status);
    
  [newc, datac, restc, dync, netc] = kldiv(...
      [], [], sources, data, net, tnet, params, missing, notimedep, status);

  hcost(itercount + 1) = newc;
  halpha(itercount + 1) = alpha;
  
  if status.debug,
    fprintf('Variance update:             c=%.16g, alpha=%f\n', newc, alpha);
  end
  if itercount > 1,
    % After first step use quadratic interpolation
    x1 = halpha(1);
    x2 = halpha(itercount);
    x3 = halpha(itercount + 1);
    c1 = hcost(1);
    c2 = hcost(itercount);
    c3 = hcost(itercount + 1);
    newalpha = (x1.^2 .* (c2-c3) + x2.^2 .* (c3-c1) + x3.^2 .* (c1-c2)) ./ ...
	(2*(x1 .* (c2-c3) + x2 .* (c3-c1) + x3 .* (c1-c2)));
    alpha = max([min([newalpha .5 * alpha]) .1 * alpha]);
  else
    alpha = .1 * alpha;
  end
end

status.varalpha = min([10*alpha 1]);

function [A C] = pop(A, B),
% POP  Implements stack style pop operation with reshaping
%
%  [A, C] = pop(A, B)
%
%  Pop elements from A to create a size B matrix C
n = prod(size(B));
C = reshape(A(1:n), size(B));
A = A(n+1:end);
