function [sources, net, params, status, fs] = ...
    nlfa_iter(data, sources, net, params, status)
% NLFA_ITER  Perform the NLFA iteration

% Copyright (C) 1999-2004 Antti Honkela, Harri Valpola,
% and Xavier Giannakopoulos.
%
% This package comes with ABSOLUTELY NO WARRANTY; for details
% see License.txt in the program package.  This is free software,
% and you are welcome to redistribute it under certain conditions;
% see License.txt for details.

nsampl = size(data, 2);

%nlfa_batches = 1:status.batch_size:nsampl;
%
%nlfa_batch = [nlfa_batches', [nlfa_batches(2:end)-1, nsampl]'];

iters_left = status.iters;

if ~strcmp(status.updatealg, 'old'),
  if isfield(status, 'oldgrads') && status.cgreset ~= -1,
    oldgrads = status.oldgrads;
  else
    fprintf('Resetting CG\n');
    oldgrads.net = netgrad_zeros(net);
    oldgrads.s = zeros(size(sources));
    oldgrads.norm = 0;
  end
end

while iters_left > 0

  dcp_dnetm = netgrad_zeros(net);
  dcp_dnetv = netgrad_zeros(net);
  fs = probdist(zeros(size(data)), ones(size(data)));

  newkls = kl_static(net, params);
  
  %  for k = 1:size(nlfa_batch, 1),
  %curbatch = nlfa_batch(k,1):nlfa_batch(k,2);
  curbatch = 1:nsampl;
  
  % Do feedforward calculations
  x = feedfw( sources(:, curbatch) , net, status.approximation);
  fs(:, curbatch) = probdist(x{4}.e, x{4}.var);

  % Calculate and possibly display current value of the cost function
  newkls = newkls + kl_batch(fs(:, curbatch), sources(:, curbatch), ...
			     data(:, curbatch), params);
    
  %if k == size(nlfa_batch, 1)
  fprintf('Iteration #%d: %f\n', size(status.kls, 2), newkls);
  if isnan(newkls),
    iters_left = 0;
    %if size(nlfa_batch, 1) == 1,
    fprintf('Cost is NaN, bailing out...\n');
    return
    %end
  end

  if (size(status.kls, 2) > 400 && ...
      ((min(diff(status.kls(end-10:end))) > 0) || ...
       (min(diff(status.kls(end-200:end))) > -status.epsilon))),
    fprintf('The iteration appears to have converged, bailing out...\n');
    iters_left = 0;
  end
  
  status.kls = [status.kls newkls];
  status.cputime = [status.cputime cputime];
  %end

  % Calculate partial derivatives for parameters to adapt
  [dcp_dsm, dcp_dsv, newdcp_dnetm, newdcp_dnetv] =...
      feedback(x, net, sources(:, curbatch), data(:, curbatch), ...
	       params.noise, status);

  [newdcp_dsm, newdcp_dsv] = ...
      feedback_srcpriors(sources(:, curbatch), params.src);

  dcp_dsm = dcp_dsm + newdcp_dsm;
  dcp_dsv = dcp_dsv + newdcp_dsv;

  dcp_dnetm = sum_structs(dcp_dnetm, newdcp_dnetm);
  dcp_dnetv = sum_structs(dcp_dnetv, newdcp_dnetv);

  [newdcp_dnetm, newdcp_dnetv] = ...
      feedback_netpriors(net, params.net, params.hyper.net);
  dcp_dnetm = sum_structs(dcp_dnetm, newdcp_dnetm);
  dcp_dnetv = sum_structs(dcp_dnetv, newdcp_dnetv);

  if strcmp(status.updatealg, 'old'),
    % Get new values for sources and alphas if appropriate
    if max([status.updatesrcs, status.updatesrcvars]) >= 0
      sources = probdist_alpha(sources);
      newsources = ...
          updatesources(sources(:, curbatch), dcp_dsm, dcp_dsv, x{4}, ...
                        params.src, params.noise);

      if status.updatesrcs < 0
        sources = ...
            probdist_alpha(sources.e(:, curbatch), newsources.var, ...
                           sources.malpha(:, curbatch), newsources.valpha, ...
                           sources.msign(:, curbatch), newsources.vsign);
      else
        sources = newsources;
      end
    end

    if status.updatenet >= 0
      net = updatenetwork(net, dcp_dnetm, dcp_dnetv);
    end

  else % new updatealg
    [sources, net, oldgrads, status] = update_everything(...
	sources, net, dcp_dsm, dcp_dsv, x{4}, params, dcp_dnetm, dcp_dnetv, ...
	data, newkls, status, oldgrads);
  
    if (status.cgreset > 0) && (mod(length(status.kls), status.cgreset) == 0),
      fprintf('Resetting CG\n');
      oldgrads.net = netgrad_zeros(net);
      oldgrads.s = zeros(size(sources));
      oldgrads.norm = 0;
    end
  end % updatealg
  
  if status.updatesrcs < 0
    status.updatesrcs = status.updatesrcs + 1;
    if (status.updatesrcs == 0) && (~strcmp(status.updatealg, 'old')),
      fprintf('Resetting CG\n');
      oldgrads.net = netgrad_zeros(net);
      oldgrads.s = zeros(size(sources));
      oldgrads.norm = 0;
    end
  end
  if status.updatesrcvars < 0
    status.updatesrcvars = status.updatesrcvars + 1;
  end

  if status.updatenet < 0
    status.updatenet = status.updatenet + 1;
  end

  % Update estimates for different parameters if appropriate
  if status.updateparams < 0
    status.updateparams = status.updateparams + 1;
    if (status.updateparams == 0) && (~strcmp(status.updatealg, 'old')),
      fprintf('Resetting CG\n');
      oldgrads.net = netgrad_zeros(net);
      oldgrads.s = zeros(size(sources));
      oldgrads.norm = 0;
    end
  else
    params.noise = estimatevars(probdist(fs.e-data, fs.var), ...
				params.hyper.noise, params.noise);
    params.src   = estimatevars(sources, params.hyper.src, params.src);
    params.net.w2var = estimatevars(net.w2, params.hyper.net.w2var, ...
				    params.net.w2var, 1);
    
    [params.hyper.net.w2var.mean, params.hyper.net.w2var.var] = ...
	estimatemeanvars(params.net.w2var, params.prior.net.w2var.mean, ...
		       params.prior.net.w2var.var, params.hyper.net.w2var.var);
    [params.hyper.noise.mean, params.hyper.noise.var] = ...
	estimatemeanvars(params.noise, params.prior.noise.mean, ...
		       params.prior.noise.var, params.hyper.noise.var, 1);
    [params.hyper.net.b1.mean, params.hyper.net.b1.var] = ...
	estimatemeanvars(net.b1, params.prior.net.b1.mean, ...
		       params.prior.net.b1.var, params.hyper.net.b1.var, 1);
    [params.hyper.net.b2.mean, params.hyper.net.b2.var] = ...
	estimatemeanvars(net.b2, params.prior.net.b2.mean, ...
		       params.prior.net.b2.var, params.hyper.net.b2.var, 1);
    [params.hyper.src.mean, params.hyper.src.var] = ...
	estimatemeanvars(params.src, params.prior.src.mean, ...
		       params.prior.src.var, params.hyper.src.var, 1);
  end

  if strcmp(status.updatealg, 'old'),
    if (size(sources, 1) > 1),
      [sources, net, params] = ...
	  scalesources(sources, net, params);
    end
  end
  
  iters_left = iters_left - 1;
end

if ~strcmp(status.updatealg, 'old'),
  status.oldgrads = oldgrads;
end

fs = probdist(zeros(size(data)), ones(size(data)));

newkls = kl_static(net, params);

% Do feedforward calculations
curbatch = 1:nsampl;
x = feedfw( sources(:, curbatch) , net, status.approximation);
fs(:, curbatch) = probdist(x{4}.e, x{4}.var);

newkls = newkls + kl_batch(fs(:, curbatch), sources(:, curbatch), ...
			   data(:, curbatch), params);

fprintf('Finally after %d iterations: %f\n', size(status.kls, 2), newkls);


function [dc_dsm, dc_dsv] = feedback_srcpriors(sources, srcparams)
% FEEDBACK_SRCPRIORS Calculate the contribution of source priors
%   to the gradients of the cost function with respect to source values

sourcevar = normalvar(srcparams);

nsampl = size(sources, 2);

temp = sourcevar * ones(1, nsampl);

dc_dsm = sources.e ./ temp;
dc_dsv = .5 ./ temp;


function [dc_dnetm, dc_dnetv] = feedback_netpriors(net, params, hypers)
% FEEDBACK_NETPRIORS Calculate the contribution of network priors
%   to the gradients of the cost function with respect to network weights

w1var = ones(1, size(net.w1, 2));
w2var = normalvar(params.w2var);

[dc_dnetm.w2, dc_dnetv.w2, dc_dnetm.b2, dc_dnetv.b2] = ...
    netgradsprior(net.w2, net.b2, w2var, hypers.b2);
[dc_dnetm.w1, dc_dnetv.w1, dc_dnetm.b1, dc_dnetv.b1] = ...
    netgradsprior(net.w1, net.b1, w1var, hypers.b1);


function [dcp_dwm, dcp_dwv, dcp_dbm, dcp_dbv] = ...
    netgradsprior(w, b, wprior, bprior)
% NETGRADSPRIOR Calculate the contribution of priors to partial
%   derivatives of kldiv with respect to network weights

wpvar = repmat(wprior, [size(w, 1) 1]);
bpexp = repmat(bprior.mean.e, size(b));
bpvar = repmat(normalvar(bprior.var), size(b));

dcp_dwm = w.e ./ wpvar;
dcp_dwv = .5 ./ wpvar;

dcp_dbm = (b.e - bpexp) ./ bpvar;
dcp_dbv = .5 ./ bpvar;


function grad = netgrad_zeros(net)

grad.w2 = zeros(size(net.w2));
grad.b2 = zeros(size(net.b2));
grad.w1 = zeros(size(net.w1));
grad.b1 = zeros(size(net.b1));


function s = sum_structs(s1, s2)
% SUM_STRUCTS  Add all the fields of two structures together

f = fieldnames(s1);
c1 = struct2cell(s1);
c2 = struct2cell(s2);
if size(c1) ~= size(c2)
  error('sum_structs: Structures must be of same type')
end

c = cell(size(c1));

for k=1:length(c1),
  c{k} = c1{k} + c2{k};
end

s = cell2struct(c, f, 1);
