function [sources, net, tnet, params, status, missing, clamped, notimedep, fs, tfs] = ndfa(data, varargin),
% NDFA  Run nonlinear dynamical factor analysis
%
%  [sources, net, tnet, params, status, missing, clamped, notimedep, x_reco, s_pred] = NDFA(data, otherargs...)
%  runs the NDFA algorithm for given data (different samples in
%  different columns, different channels in different rows. Use
%  NaNs for missing values) and returns the resulting sources, 
%  net, tnet, parameters, status information and reconstructed 
%  observations.
%
%  If the number of return values expected is only one, the first
%  four values will be returned in a single stucture that can be
%  further passed as argument to the algorithm.
%
%  The additional arguments can be given as a structure with
%  the field name indicating the name of the argument, or as a
%  list as in NDFA(data, 'name1', val1, 'name2', val2, ...),
%  or as a combination of these as in
%  NDFA(data, struct, 'name1', val, 'name2', val2, ...).
%  In the latter case the values given in the list take precedence
%  over the ones specified in the structure.
%
% --------------------------------------------------------------------
%  ACCEPTED ARGUMENTS
%
%  The recognised arguments are as follows:
%
%  'sources'            The values of sources returned by previous a
%                       run of the algorithm.  This option can be used
%                       to continue a previous simulation.
%  'initsources'        Initial values for the sources to be used.
%  'searchsources'      Number of sources to use, to be initialised
%                       with PCA
%  'initclamped'        Initial values for clamped sources (use
%                       NaNs for missing values). This option can
%                       be used to supply control signals or
%                       otherwise known sources. The clamped
%                       sources are concatenated with the normal
%                       sources specified by 'sources',
%                       'initsources' or 'searchsources' argument.
%
%  'net'                The values of observation network weights
%                       returned by a previous run of the algorithm.
%                       This option can be used to continue a previous
%                       simulation.
%  'tnet'               The values of temporal network weights
%                       returned by a previous run of the algorithm.
%                       This option can be used to continue a previous
%                       simulation.
%  'hidneurons'         Number of hidden neurons to use in the
%                       observation MLP.
%  'thidneurons'        Number of hidden neurons to use in the
%                       temporal MLP.
%  'nonlin'             Nonlinear activation function to use in the
%                       MLPs.  Default value 'tanh' is currently the
%                       only supported value.
%
%  'params'             Structure of hyperparameter values returned by
%                       a previous run of the algorithm.  This option
%                       can be used to continue a previous simulation.
%
%  'status'             Structure of status information returned by a
%                       previous run of the algorithm.  This option
%                       can be used to continue a previous simulation.
%
%  'clamped'            The information on clamped sources returned by a
%                       previous run of the algorithm.  This option can be 
%                       used to continue a previous simulation.
%  'notimedep'          Sources which should have no time dependence with
%                       their successor. Can be used to partition the data
%                       into multiple parts with no time
%                       dependencies between.
%
%  'iters'              Number of iterations to run the algorithm.
%                       The default value is 300.
%                       Use value 0 to only evaluate the reconstructed
%                       observations and cost function value.
%  'runtime'            Maximum runtime (in seconds, default Inf).
%                       If set to a finite value, overrides the
%                       iteration count above.
%  'embed'              Number of iterations during which embedded 
%                       sources are used.
%  'freeinitial'        Should the sources at first time instant have
%                       a zero mean prior (value 0) or a hierarchical
%                       prior automatically adapted to the correct
%                       value (value 1, default).
%  'approximation'      Nonlinearity approximation scheme to use.
%                       Supported values are 'hermite' (default) and
%                       'taylor'.
%  'updatealg'          Update algorithm.  Supported values are
%                       'ncg' (default, natural conjugate gradient),
%                       'ng', 'cg', 'old', 'grad', 'natgraddiag', and
%                       'natgradfull'.
%  'verbose'            Print out a summary of all simulation
%                       parameters at the beginning of simulation.
%                       The default value is 1 (yes), other possible
%                       values are 0 (no) and 2 (always).  With 1
%                       (yes).  The value is reset in the returned
%                       structure so further calls will not reprint
%                       the summary.
%  'debug'              Print out additional debug information for
%                       each iteration. Possible values are 0 (no
%                       debug information), 1 (normal information) 
%                       and 2 (extended information).
%  'epsilon'            Epsilon for stopping criterion.  The default
%                       value is 1e-6.
%  'nolearning'         Prevents updates to the networks or the 
%                       parameters, only sources are updated using 
%                       the preexisting mappings. Can be used to
%                       infer the sources with given observations
%                       and mappings.
%
% --------------------------------------------------------------------
%  RELATIONS BETWEEN ARGUMENTS
%
%  Exactly one of the arguments 'sources', 'initsources' and
%  'searchsources' must be set.
%
%  Either 'net' or 'hidneurons' must be set.  If 'net' is set,
%  'hidneurons' is not honoured.  Similarly either 'tnet' or
%  'thidneurons' must be set.  If 'tnet' is set, 'thidneurons' is not
%  honoured.
%
%
% --------------------------------------------------------------------
%  STOPPING CRITERIA
%
%  The iteration is stopped if the designated number of iterations is
%  reached, or the total number of iterations is greater than 400 and
%  a) the cost function increases for 10 iterations in a row; or
%  b) the cost function decreases by less than epsilon (see above)
%  on each iteration for 200 iterations in a row.
%
%
% --------------------------------------------------------------------
%  EXAMPLES
%
%  [sources, net, tnet, params, status] = ...
%      NDFA(data, 'searchsources', 5, 'hidneurons', 30, 'thidneurons', 20);
%            Extract 5 nonlinear factors from data using an
%            observation MLP with 30 hidden neurons and temporal MLP
%            with 20 hidden neurons and zero source initialisation.
%  
%  result = NDFA(data, 'initsources', my_s, 'hidneurons', 30, 'thidneurons', 20, 'iters', 50);
%            Extract nonlinear factors from data using an observation
%            MLP with 30 hidden neurons and temporal MLP with 20
%            hidden neurons and custom initialisation given by my_s
%            using 50 iterations of the algorithm.
%
%  result = NDFA(data, result, 'iters', 500);
%            Continue the previous simulation for 500 more iterations.
%
%  result = NDFA(data, 'searchsources', 6, 'initclamped', my_u, 'hidneurons', 30, 'thidneurons', 20);
%            Extract 6 nonlinear factors from data with clamped sources
%            my_u using an observation MLP with 30 hidden neurons and
%            temporal MLP with 20 hidden neurons.
%
%  result = NDFA(data, 'searchsources', 4, 'hidneurons', 30, 'thidneurons', 20, 'notimedep', 500);
%            Extract 4 nonlinear factors from two part data with the 
%            first part containing 500 samples using an observation 
%            MLP with 30 hidden neurons and a temporal MLP with 20 
%            hidden neurons.

% Copyright (C) 2002-2005 Harri Valpola, Antti Honkela and Matti Tornio.
%
% This package comes with ABSOLUTELY NO WARRANTY; for details
% see License.txt in the program package.  This is free software,
% and you are welcome to redistribute it under certain conditions;
% see License.txt for details.

% Read the arguments
args = readargs(varargin);

% Status
if ~isfield(args, 'status'),
  [status.iters, args] = getargdef(args, 'iters', 300);
  [status.runtime, args] = getargdef(args, 'runtime', Inf);
  if isfinite(status.runtime),
    status.iters = Inf;
  end
  status.prune.hidneurons = 0;
  status.prune.thidneurons = 0;
  status.prune.sources = 0;
  status.prune.iters = 0;
  status.history = {};
  [status.embed.iters, args] = getargdef(args, 'embed', 0);
  [status.approximation, args] = ...
      getargdef(args, 'approximation', 'hermite');
  [status.updatealg, args] = getargdef(args, 'updatealg', 'ncg');
  [status.epsilon, args] = getargdef(args, 'epsilon', 1e-6);
  [status.verbose, args] = getargdef(args, 'verbose', 1);
  [status.debug, args] = getargdef(args, 'debug', 0);
  [status.freeinitial, args] = getargdef(args, 'freeinitial', 1);
  
  status.kls = [];
  status.cputime = [cputime];
  status.version = 1.1;
  
  % How many iterations to wait before starting to update these values
  status.updatenet = 0;
  status.updatetnet = 0;
  if strcmp(status.updatealg, 'old'),
    [status.updatesrcs, args] = getargdef(args, 'updatesrcs', -50);
    [status.updatesrcvars, args] = getargdef(args, 'updatesrcvars', -50);
    [status.updateparams, args] = getargdef(args, 'updateparams', -100);
  else
    [status.updatesrcs, args] = getargdef(args, 'updatesrcs', -20);
    [status.updatesrcvars, args] = getargdef(args, 'updatesrcvars', -20);
    [status.updateparams, args] = getargdef(args, 'updateparams', -50);
  end
  status.t0 = 1;
else
  status = args.status;
  args = rmfield(args, 'status');
  % Support for obsolote models
  if ~isfield(status, 'version') || ...
     ~isnumeric(status.version) || ...
     (status.version < 1.1),
    status = convert_obsolote(status);
  end
  
  [status.iters, args] = getargdef(args, 'iters', status.iters);
  [status.runtime, args] = getargdef(args, 'runtime', status.runtime);
  if isfinite(status.runtime),
    status.iters = Inf;
  end
  [status.embed.iters, args] = ...
      getargdef(args, 'embed', status.embed.iters);
  [status.prune.sources, args] = ...
      getargdef(args, 'prune', status.prune.sources);
  [status.prune.hidneurons, args] = ...
      getargdef(args, 'prune', status.prune.hidneurons);
  [status.prune.thidneurons, args] = ...
      getargdef(args, 'prune', status.prune.thidneurons);
  [status.prune.iters, args] = ...
      getargdef(args, 'prune', status.prune.iters);
  [status.approximation, args] = ...
      getargdef(args, 'approximation', status.approximation);
  [status.updatealg, args] = ...
      getargdef(args, 'updatealg', status.updatealg);
  [status.epsilon, args] = getargdef(args, 'epsilon', status.epsilon);
  [status.verbose, args] = getargdef(args, 'verbose', status.verbose);
  [status.debug, args] = getargdef(args, 'debug', status.debug);
end

switch status.updatealg,
 case 'natgrad'
  status.updatealg = 'ncg';
 case 'natgraddiag'
  status.updatealg = 'ncgdiag';
 case 'natgradfull'
  status.updatealg = 'ncgfull';
 case 'conjgrad'
  status.updatealg = 'cg';
end

if status.verbose,
  if isempty(status.kls),
    fprintf('Starting an NDFA simulation with parameters:\n');
  else
    fprintf('Continuing an NDFA simulation with parameters:\n');
    fprintf('Number of iterations so far: %d\n', length(status.kls));
  end
  fprintf('Number of signals: %d\n', size(data, 1));
  fprintf('Number of samples: %d\n', size(data, 2));
  switch status.approximation,
   case 'taylor'
    fprintf('Using Taylor approximation for the nonlinearity.\n');
   case 'hermite'
    fprintf('Using Gauss-Hermite approximation for the nonlinearity.\n');
   otherwise
    fprintf('Using an unknown approximation for the nonlinearity.\n');
  end
  switch status.updatealg,
   case 'old'
    fprintf('Using the old heuristic NDFA update algorithm.\n');
   case 'cg'
    fprintf('Using the conjugate gradient update algorithm.\n');
   case 'ncg'
    fprintf('Using the natural conjugate gradient update algorithm for means.\n');
   case 'ncgdiag'
    fprintf('Using approximate natural conjugate gradient update algorithm with diagonal covariance for means.\n');
   case 'ng'
    fprintf('Using the natural gradient update algorithm for means.\n');
   case 'ncgfull'
    fprintf('Using the natural conjugate gradient update algorithm.\n');
   case 'grad'
    fprintf('Using a gradient update algorithm with line searches.\n');
   otherwise
    fprintf('Using an unknown update algorithm.\n');
  end
end

% Initialise the sources
if ~isfield(args, 'sources'),
  if isfield(args, 'initsources'),
    if status.verbose,
      fprintf('Using %d pre-initialised sources.\n', ...
 	      size(args.initsources, 1));
    end
    initsources = args.initsources;
    sources = acprobdist_alpha(...
	initsources, .0001 * ones(size(initsources)));
    args = rmfield(args, 'initsources');
  else
    if ~isfield(args, 'searchsources')
      error('Either sources, initsources or searchsources must be set!')
    end
    
    if status.verbose,
      fprintf('Using %d sources initialised with PCA.\n', ...
	      args.searchsources);
    end
    [data_em, s_mean] = embed(data, [], args.searchsources);
    zs = ones(size(s_mean));
    sources = acprobdist_alpha(s_mean, .0001*zs);
    status.embed.iters = -abs(status.embed.iters);
    if status.embed.iters,
      if status.verbose,
        fprintf('Embedding data.\n');
      end
      status.embed.datadim = size(data, 1);
      status.embed.timedim = size(data, 2);
      data = data_em;
    end
    args = rmfield(args, 'searchsources');
  end
else
  if status.verbose,
    fprintf('Using %d previously used sources.\n', size(args.sources, 1));
  end
  sources = args.sources;
  if status.embed.iters < 0,
    if status.verbose,
      fprintf('Embedding data to continue a previous simulation.\n');
    end
    data = embed(data, [], size(sources, 1));
  end
  args = rmfield(args, 'sources');
end

% Modify sources to data length if necessary
[Nx T] = size(data);
[Ns Ts] = size(sources);

if Ts > T,
  if status.verbose,
    fprintf('Trimming sources to data length.\n');
  end
  sources = sources(:,1:T);
elseif Ts < T,
  if status.verbose,
    fprintf('Padding sources to data length.\n');
  end
  ns = ones(size(Ns, T - Ts));
  newsources = acprobdist_alpha( ...
     repmat(sources(:,end).e, 1, T - Ts), ns * .0001);
  sources = [sources newsources];
end

% Initialise the clamped sources
if isfield(args, 'initclamped'),
  control = args.initclamped;
  [Nu Tu] = size(control);
  % Remove old control values
  if isfield(status, 'controlchannels'),
    sources = sources(1:(Ns - Nu),:);
    Ns = Ns - Nu;
  end

  controls = Ns + (1:Nu);

  % Check that data and control lengths match
  if T ~= Tu,
    if status.embed.iters,
      control = [control nan*ones(Nu, T - Tu)];
      Tu = T;
    else
      error('Data and control vectors must have equal number of columns!');
    end
  end

  % Force control into acprobdist_alpha
  if ~isa(control, 'acprobdist_alpha'),
    control = acprobdist_alpha(control, .0001 * ones(Nu, T));
  end

  % Find missing control values and set them to zeros
  missingcontrol = sparse(isnan(control.e));
  clamped = sparse([false(Ns, T); ~missingcontrol]);
  control.e(find(missingcontrol)) = 0;

  % Add control variables to sources
  sources = [sources; control];
  status.controlchannels = Nu;
  Ns = Ns + Nu;
  args = rmfield(args, 'initclamped');
  if isfield(args, 'clamped'),
    args = rmfield(args, 'clamped');
  end
else
  [clamped, args] = getargdef(args, 'clamped', sparse(false(Ns, T)));
  controls = [];
  if ~isfield(status, 'controlchannels'),
    status.controlchannels = 0;
  end
end

% Modify info on clamped sources to data length
Tms = size(clamped, 2);
if Tms > T,
  clamped = clamped(1:Ns, 1:T);
elseif Tms < T,
  clamped = [clamped repmat(clamped(:,end), 1, Tms - T)];
end

% Initialise ignored time dependencies
if isfield(args, 'notimedep'),
  if ~isempty(args.notimedep),
     % If the number of columns doesn't match sample count, use matrix as
     % indices
    if (size(args.notimedep, 2) ~= size(sources, 2) - 1) & T > 1,
      notimedep = sparse(zeros(1, T - 1));
      notimedep(1,args.notimedep) = 1;
    else
      notimedep = args.notimedep;
    end
    % If only one row is provided, use it for all channels    
    if size(notimedep, 1) == 1,
      notimedep = repmat(notimedep, Ns, 1);
    end
  else
    notimedep = [];
  end
  args = rmfield(args, 'notimedep');
else
  notimedep = [];
end

% Initialise missing data values
missing = sparse(isnan(data));
data(find(missing)) = 0;

% Initialise constraints for data and control
if isfield(args, 'constraints'),
  if ~isempty(args.constraints),
    status.constraints = args.constraints;
  end
  args = rmfield(args, 'constraints');
end

% Initialise the observation network
if ~isfield(args, 'net'),
  if ~isfield(args, 'hidneurons'),
    error('Either net or hidneurons must be set!')
  end
  if status.verbose,
    fprintf('Initialising a new observation MLP network with %d hidden neurons.\n', args.hidneurons);
  end
  net = createnet_alpha(Ns, args.hidneurons, Nx, ...
			'tanh', 1, 1, 1, 1, .01, .01);

  net.b2.e = net.b2.e + mean(data, 2);
  args = rmfield(args, 'hidneurons');
  [net.identity, args] = getargdef(args, 'observationmapping', false);
else
  if status.verbose,
    fprintf('Using a previously used observation MLP network with %d hidden neurons.\n', size(args.net.w1, 1));
  end
  net = args.net;

  args = rmfield(args, 'net');

  % Special observation network flags
  net.identity = getargdef(net, 'identity', false);
  [net.identity, args] = getargdef(args, 'observationmapping', net.identity);
end

% If identity mapping is used for observation mapping, we should
% not waste time updating sources or the mapping
if net.identity,
  status.updatesrcs = -inf;
  status.updatesrcvars = -inf;
  status.updatenet = -inf;
end

% Initialise the temporal network
if ~isfield(args, 'tnet'),
  if ~isfield(args, 'thidneurons'),
    error('Either tnet or thidneurons must be set!')
  end
  if status.verbose,
    fprintf('Initialising a new temporal MLP network with %d hidden neurons.\n', args.thidneurons);
  end
  tnet = createnet_alpha(size(sources, 1), args.thidneurons, size(sources, 1), ...
			'tanh', 1, 1, .1, .01, .01, .01);
  args = rmfield(args, 'thidneurons');
else
  if status.verbose,
    fprintf('Using a previously used temporal MLP network with %d hidden neurons.\n', size(args.tnet.w1, 1));
  end
  tnet = args.tnet;
  args = rmfield(args, 'tnet');
end

% Initialise parameters and hyperparameters
if ~isfield(args, 'params'),
  params.net.w2var = probdist(zeros(1, size(net.b1, 1)), ...
			      .5 / size(data, 1) * ones(1, size(net.b1, 1)));
  params.tnet.w2var = probdist(zeros(1, size(tnet.b1, 1)), ...
			       .5 / size(data, 1) * ones(1, size(tnet.b1, 1)));
  params.noise = probdist(.5 * log(.1) * ones(size(data, 1), 1), ...
			  .5 / size(data, 2) * ones(size(data, 1), 1));
  params.src = probdist(zeros(size(sources, 1), 1), ...
			.5 / size(data, 2) * ones(size(sources, 1), 1));

  params.hyper.net.w2var = nlfa_inithyper(0, .1, 0, .1);
  params.hyper.net.b1 = nlfa_inithyper(0, .1, 0, .1);
  params.hyper.net.b2 = nlfa_inithyper(0, .1, 0, .1);
  params.hyper.tnet.w2var = nlfa_inithyper(0, .1, 0, .1);
  params.hyper.tnet.b1 = nlfa_inithyper(0, .1, 0, .1);
  params.hyper.tnet.b2 = nlfa_inithyper(0, .1, 0, .1);

  params.hyper.noise = nlfa_inithyper(0, .1, 0, .1);
  params.hyper.src = nlfa_inithyper(0, .1, 0, .1);

  params.prior.net.w2var = nlfa_initprior(0, log(100), 0, log(100));
  params.prior.net.b1 = nlfa_initprior(0, log(100), 0, log(100));
  params.prior.net.b2 = nlfa_initprior(0, log(100), 0, log(100));
  params.prior.tnet.w2var = nlfa_initprior(0, log(100), 0, log(100));
  params.prior.tnet.b1 = nlfa_initprior(0, log(100), 0, log(100));
  params.prior.tnet.b2 = nlfa_initprior(0, log(100), 0, log(100));
  params.prior.noise = nlfa_initprior(0, log(100), 0, log(100));
  params.prior.src = nlfa_initprior(0, log(100), 0, log(100));
else
  params = args.params;
  args = rmfield(args, 'params');
end

% No learning is done, the networks are not updated and the sources can be
% updated from the beginning
if isfield(args, 'nolearning'),
  status.updatenet = -inf;
  status.updatetnet = -inf;
  status.updateparams = 0;
  status.updatesrcs = 0;
  status.updatesrcvars = 0;
  args = rmfield(args, 'nolearning');
end

% Check that all the parameters were valid
otherargs = fieldnames(args);
if length(otherargs) > 0,
  fprintf('Warning: ndfa: unused arguments:\n');
  fprintf(' %s\n', otherargs{:});
end

if status.verbose == 1,
  status.verbose = 0;
end

status.cgreset = 0;

% Do the actual iteration
[sources, net, tnet, params, status, fs, tfs] = ...
ndfa_iter(data, sources, net, tnet, params, status, missing, clamped, notimedep);

% If only one return value is expected, pack everything to it
if nargout == 1,
  val.sources = sources;
  val.clamped = clamped;
  val.notimedep = notimedep;
  val.net = net;
  val.tnet = tnet;
  val.params = params;
  val.status = status;
  sources = val;
end


function hyper = nlfa_inithyper(mm, mv, vm, vv)
% Copyright (C) 2002 Harri Valpola and Antti Honkela.
%
% This package comes with ABSOLUTELY NO WARRANTY; for details
% see License.txt in the program package.  This is free software,
% and you are welcome to redistribute it under certain conditions;
% see License.txt for details.

hyper.mean = probdist(mm, mv);
hyper.var = probdist(vm, vv);


function prior = nlfa_initprior(mm, mv, vm, vv)
% Copyright (C) 2002 Harri Valpola and Antti Honkela.
%
% This package comes with ABSOLUTELY NO WARRANTY; for details
% see License.txt in the program package.  This is free software,
% and you are welcome to redistribute it under certain conditions;
% see License.txt for details.

prior.mean.mean = probdist(mm, 0);
prior.mean.var = probdist(mv, 0);
prior.var.mean = probdist(vm, 0);
prior.var.var = probdist(vv, 0);


function status = convert_obsolote(status)
% CONVERT_OBSOLOTE  Convert a structure returned by a older version
%
%  status = CONVERT_OBSOLOTE(status)
%
%  Convert a result structure of an older version of NDFA to the
%  current format.
%
% Copyright (C) 2004-2006 Matti Tornio.
%
% This package comes with ABSOLUTELY NO WARRANTY; for details
% see License.txt in the program package.  This is free software,
% and you are welcome to redistribute it under certain conditions;
% see License.txt for details.

status.controlchannels = getargdef(status, 'controlchannels', 0);
status.embed = getargdef(status, 'embed', struct);
status.embed.iters = getargdef(status.embed, 'iters', 0);
status.prune = getargdef(status, 'prune', struct);
status.prune.sources = getargdef(status.prune, 'sources', 0);
status.prune.hidneurons = getargdef(status.prune, 'hidneurons', 0);
status.prune.thidneurons = getargdef(status.prune, 'thidneurons', 0);
status.prune.iters = getargdef(status.prune, 'iters', 0);
status.approximation = getargdef(status, 'approximation', ...
					 'hermite');
status.updatealg = getargdef(status, 'updatealg', 'conjgrad');
status.epsilon = getargdef(status, 'epsilon', 1e-6);
status.verbose = getargdef(status, 'verbose', 1);
status.debug = getargdef(status, 'debug', 0);
status.t0 = getargdef(status, 't0', 1);
status.history = getargdef(status, 'history', {});
