function [sources, net, tnet, oldgrads, status] = update_natgrad(...
    sources, net, tnet, dcp_dsm, dcp_dsvn, fs, tfs, params, ...
    dcp_dnetm, dcp_dnetv, dcp_dtnetm, dcp_dtnetv, newac, dcp_ac, ...
    data, oldc, status, oldgrads, clamped, missing, notimedep)
% UPDATE_NATGRAD  Perform a line search to find optimal step length
%     and update all parameters based on an approximation of the natural 
%     conjugate gradient

% Copyright (C) 1999-2006 Antti Honkela, Harri Valpola,
% Xavier Giannakopoulos and Matti Tornio
%
% This package comes with ABSOLUTELY NO WARRANTY; for details
% see License.txt in the program package.  This is free software,
% and you are welcome to redistribute it under certain conditions;
% see License.txt for details.

x0 = struct('s', sources, 'net', net, 'tnet', tnet);
newvars = struct('s', .5 ./ max(dcp_dsvn, .45 ./ sources.nvar), ...
		 'w1', .5 ./ max(dcp_dnetv.w1, .45 ./ net.w1.var), ...
		 'w2', .5 ./ max(dcp_dnetv.w2, .45 ./ net.w2.var), ...
		 'b1', .5 ./ max(dcp_dnetv.b1, .45 ./ net.b1.var), ...
		 'b2', .5 ./ max(dcp_dnetv.b2, .45 ./ net.b2.var), ...
		 'tw1', .5 ./ max(dcp_dtnetv.w1, .45 ./ tnet.w1.var), ...
		 'tw2', .5 ./ max(dcp_dtnetv.w2, .45 ./ tnet.w2.var), ...
		 'tb1', .5 ./ max(dcp_dtnetv.b1, .45 ./ tnet.b1.var), ...
		 'tb2', .5 ./ max(dcp_dtnetv.b2, .45 ./ tnet.b2.var), ...
		 'ac', newac);
if status.updatesrcvars < 0,
  newvars.s = sources.nvar;
  newvars.ac = sources.ac;
end

if status.updatenet < 0,
  newvars.w1 = net.w1.var;
  newvars.w2 = net.w2.var;
  newvars.b1 = net.b1.var;
  newvars.b2 = net.b2.var;
end
if status.updatetnet < 0,
  newvars.tw1 = tnet.w1.var;
  newvars.tw2 = tnet.w2.var;
  newvars.tb1 = tnet.b1.var;
  newvars.tb2 = tnet.b2.var;
end

fargs = struct('data', data, 'params', params, 'status', status, ...
               'missing', missing, 'notimedep', notimedep);

if ~isfield(status, 'varalpha'),
  status.varalpha = 1;
end

%x, alpha, c] = bisection_search(x0, newvars, ...
%				    min([1, 2*sqrt(status.varalpha)]), ...
%				    @search_helper_var, fargs);


[x, alpha, c] = backtracking_search(x0, newvars, ...
				    min([1, 2*sqrt(status.varalpha)]), ...
				    @search_helper_var, fargs);

sources = x.s;
net = x.net;
tnet = x.tnet;

[newc, datac, restc, dync, netc] = kldiv(...
    [], [], sources, data, net, tnet, params, missing, notimedep, status);

if status.updatesrcs < 0,
  dc_dsm = zeros(size(dcp_dsm));
else
  dc_dsm = -dcp_dsm;
  % Set gradient for the clamped sources to zeros
  dc_dsm(find(clamped)) = 0;
end

if status.updatenet < 0,
  dc_dnetm.w1 = zeros(size(dcp_dnetm.w1));
  dc_dnetm.b1 = zeros(size(dcp_dnetm.b1));
  dc_dnetm.w2 = zeros(size(dcp_dnetm.w2));
  dc_dnetm.b2 = zeros(size(dcp_dnetm.b2));
else
  dc_dnetm.w1 = -dcp_dnetm.w1;
  dc_dnetm.b1 = -dcp_dnetm.b1;
  dc_dnetm.w2 = -dcp_dnetm.w2;
  dc_dnetm.b2 = -dcp_dnetm.b2;
end

if status.updatetnet < 0,
  dc_dtnetm.w1 = zeros(size(dcp_dtnetm.w1));
  dc_dtnetm.b1 = zeros(size(dcp_dtnetm.b1));
  dc_dtnetm.w2 = zeros(size(dcp_dtnetm.w2));
  dc_dtnetm.b2 = zeros(size(dcp_dtnetm.b2));
else
  dc_dtnetm.w1 = -dcp_dtnetm.w1;
  dc_dtnetm.b1 = -dcp_dtnetm.b1;
  dc_dtnetm.w2 = -dcp_dtnetm.w2;
  dc_dtnetm.b2 = -dcp_dtnetm.b2;
end

% Collect the gradients and variances of the sources and weights
grad = [dc_dsm(:); ...
	dc_dnetm.w1(:);  dc_dnetm.w2(:); ...
	dc_dnetm.b1(:);  dc_dnetm.b2(:); ...
	dc_dtnetm.w1(:); dc_dtnetm.w2(:); ...
	dc_dtnetm.b1(:); dc_dtnetm.b2(:)];
var = [sources.var(:); ...
       net.w1.var(:);  net.w2.var(:); ...
       net.b1.var(:);  net.b2.var(:); ...
       tnet.w1.var(:); tnet.w2.var(:); ...
       tnet.b1.var(:); tnet.b2.var(:)];


% Norm in Riemannian space
gn = sum(var.^(-1).*grad.^2);

% Natural gradient
f = var .* grad;
f_temp = f;
fn = sum(var.^(-1).*f.^2);

% Extract the individual components of the gradient
[f dc_dsm] = pop(f, dc_dsm);
[f dc_dnetm.w1] = pop(f, dc_dnetm.w1);
[f dc_dnetm.w2] = pop(f, dc_dnetm.w2);
[f dc_dnetm.b1] = pop(f, dc_dnetm.b1);
[f dc_dnetm.b2] = pop(f, dc_dnetm.b2);
[f dc_dtnetm.w1] = pop(f, dc_dtnetm.w1);
[f dc_dtnetm.w2] = pop(f, dc_dtnetm.w2);
[f dc_dtnetm.b1] = pop(f, dc_dtnetm.b1);
[f dc_dtnetm.b2] = pop(f, dc_dtnetm.b2);

f = f_temp;

if (oldgrads.norm ~= 0) & isfield(oldgrads, 'grad') & ...
      isfield(oldgrads, 'f') & ...
      all(size(oldgrads.s) == size(dc_dsm)) & ...
      all(size(oldgrads.net.w1) == size(dc_dnetm.w1)) & ...
      all(size(oldgrads.tnet.w1) == size(dc_dtnetm.w1)) & ...
      all(size(oldgrads.grad) == size(grad)),
  % Powell-Beale restarts
  if abs((oldgrads.f'*(var.^(-1).*f))) >= .2*fn && ...
	status.cgreset > 10,
    fprintf('Resetting CG (Powell-Beale restart)\n');
    oldgrads.s = zeros(size(sources));
    oldgrads.net = netgrad_zeros(net);
    oldgrads.tnet = netgrad_zeros(tnet);
    oldgrads.norm = 0;
    oldgrads.grad = zeros(size(grad));
    beta = 0;
    status.cgreset = 0;
    oldgrads.cgreset(size(status.kls, 2)) = true;
  else
    % Polak-Ribire formula
    beta = f' * (var.^(-1).*(f-oldgrads.f)) / oldgrads.fnorm;
    beta = max([0 beta]);

    % Fletcher-Reeves formula (gives inferior results in most cases)
    %beta = fn / oldgrads.fnorm;
    
    if status.debug,
      fprintf('Beta=%f  ratio=%f\n',beta,abs((oldgrads.f'*(var.^(-1).*f)))/fn);
    end
    status.cgreset = status.cgreset + 1;
  end

  dc_dsm       = dc_dsm       + beta * oldgrads.s;
  dc_dnetm.w1  = dc_dnetm.w1  + beta * oldgrads.net.w1;
  dc_dnetm.w2  = dc_dnetm.w2  + beta * oldgrads.net.w2;
  dc_dnetm.b1  = dc_dnetm.b1  + beta * oldgrads.net.b1;
  dc_dnetm.b2  = dc_dnetm.b2  + beta * oldgrads.net.b2;
  dc_dtnetm.w1 = dc_dtnetm.w1 + beta * oldgrads.tnet.w1;
  dc_dtnetm.w2 = dc_dtnetm.w2 + beta * oldgrads.tnet.w2;
  dc_dtnetm.b1 = dc_dtnetm.b1 + beta * oldgrads.tnet.b1;
  dc_dtnetm.b2 = dc_dtnetm.b2 + beta * oldgrads.tnet.b2;
end

% Store the gradient and state information for the next iteration
oldgrads.norm = gn;
oldgrads.fnorm = fn;
oldgrads.s = dc_dsm;
oldgrads.net = dc_dnetm;
oldgrads.tnet = dc_dtnetm;
oldgrads.grad = grad;
oldgrads.var = var;
oldgrads.f = f;

step.s    = dc_dsm;
step.net  = dc_dnetm;
step.tnet = dc_dtnetm;
step.restc = restc;

[x, status.t0] = linesearch(x, step, status.t0, @search_helper_mean, ...
                            fargs, newc);

sources = x.s;
net = x.net;
tnet = x.tnet;

%[sources, net, tnet, status] = cubic_search(...
%    sources, net, tnet, newc, restc, step, ...
%    fs, tfs, data, params, status.t0, status, missing, notimedep);


function [x, c] = search_helper_mean(x, step, lambda, d),
% SEARCH_HELPER_MEAN  Search helper function for updating means

x.s.e = x.s.e + lambda * step.s;
x.net.w1.e = x.net.w1.e + lambda * step.net.w1;
x.net.b1.e = x.net.b1.e + lambda * step.net.b1;
x.net.w2.e = x.net.w2.e + lambda * step.net.w2;
x.net.b2.e = x.net.b2.e + lambda * step.net.b2;
x.tnet.w1.e = x.tnet.w1.e + lambda * step.tnet.w1;
x.tnet.b1.e = x.tnet.b1.e + lambda * step.tnet.b1;
x.tnet.w2.e = x.tnet.w2.e + lambda * step.tnet.w2;
x.tnet.b2.e = x.tnet.b2.e + lambda * step.tnet.b2;
c = kldiv([], [], x.s, d.data, x.net, x.tnet, d.params, ...
          d.missing, d.notimedep, d.status);


function [x, c] = search_helper_var(x, newvar, lambda, d),
% SEARCH_HELPER_VAR  Search helper function for updating variances

x.s.nvar = exp(lambda * log(newvar.s) + (1-lambda) * log(x.s.nvar));
x.net.w1.var = exp(lambda*log(newvar.w1)+(1-lambda)*log(x.net.w1.var));
x.net.w2.var = exp(lambda*log(newvar.w2)+(1-lambda)*log(x.net.w2.var));
x.net.b1.var = exp(lambda*log(newvar.b1)+(1-lambda)*log(x.net.b1.var));
x.net.b2.var = exp(lambda*log(newvar.b2)+(1-lambda)*log(x.net.b2.var));
x.tnet.w1.var = exp(lambda*log(newvar.tw1)+(1-lambda)*log(x.tnet.w1.var));
x.tnet.w2.var = exp(lambda*log(newvar.tw2)+(1-lambda)*log(x.tnet.w2.var));
x.tnet.b1.var = exp(lambda*log(newvar.tb1)+(1-lambda)*log(x.tnet.b1.var));
x.tnet.b2.var = exp(lambda*log(newvar.tb2)+(1-lambda)*log(x.tnet.b2.var));
%s.ac(:,2:end) = alpha * step.ac(:,2:end) + (1-alpha)*s0.ac(:,2:end);
x.s.ac = lambda * newvar.ac + (1-lambda)*x.s.ac;
x.s = updatevar(x.s);
c = kldiv([], [], x.s, d.data, x.net, x.tnet, d.params, ...
          d.missing, d.notimedep, d.status);


function [A C] = pop(A, B),
% POP  Implements a stack style pop operation with reshaping
%
%  Returns matrix C with shape and size of matrix B from the vector
%  stack A.
%
%  [A, C] = pop(A, B)
%
%  Pop elements from vector A to create a size B matrix C
n = prod(size(B));
C = reshape(A(1:n), size(B));
A = A(n+1:end);
