function cost = evalcost(data, varargin),
% EVALCOST  Evaluate the cost of an NFA model
%
%  cost = EVALCOST(data, otherargs...)
%  returns the cost function value for the given NLFA model.
%  The supported arguments are a subset of those accepted by NLFA,
%  those that are relevant to the problem at hand.
%
%  One additional value for 'approximation' is supported, namely
%  'mc' which uses Monte Carlo sampling to evaluate the cost.
%
%  See also NLFA.

% Copyright (C) 1999-2004 Antti Honkela, Harri Valpola,
% and Xavier Giannakopoulos.
%
% This package comes with ABSOLUTELY NO WARRANTY; for details
% see License.txt in the program package.  This is free software,
% and you are welcome to redistribute it under certain conditions;
% see License.txt for details.

% Read the arguments
if (length(varargin) == 1) & (isstruct(varargin{1})),
  args = varargin{1};
elseif (mod(length(varargin), 2) ~= 0),
  if ~(isstruct(varargin{1})),
    error('Keyword arguments should appear in pairs');
  else
    args = varargin{1};
    for k=2:2:length(varargin),
      if ~ischar(varargin{k})
	error('Keyword argument names must be strings.');
      end
      eval(sprintf('args.%s = varargin{k+1};', varargin{k}));
    end
  end
else
  args = struct(varargin{:});
end

%args

if ~isfield(args, 'status'),
  [status.approximation, args] = getargdef(args, 'approximation', 'hermite');
else
  status = args.status;
  args = rmfield(args, 'status');
  if isfield(args, 'approximation'),
    status.approximation = args.approximation;
    args = rmfield(args, 'approximation');
  end
end

if ~isfield(args, 'sources'),
  error('Sources must be set!');
else
  sources = args.sources;
  args = rmfield(args, 'sources');
end

if ~isfield(args, 'net'),
  error('Net must be set!')
else
  net = args.net;
  args = rmfield(args, 'net');
end

if ~isfield(args, 'params'),
  error('Params must be set!');
else
  params = args.params;
  args = rmfield(args, 'params');
end

otherargs = fieldnames(args);
if length(otherargs) > 0,
  fprintf('Warning: nlfa: unused arguments:\n');
  fprintf(' %s\n', otherargs{:});
end

% Do the actual thing...
cost = ...
    really_eval_cost(data, sources, net, params, status);



function [val, args] = getargdef(args, name, default),

if isfield(args, name),
  % val = args.(name);
  eval(sprintf('val = args.%s;', name));
  args = rmfield(args, name);
else
  val = default;
end





function cost = really_eval_cost(data, sources, net, params, status)

nsampl = size(data, 2);

fs = probdist(zeros(size(data)), ones(size(data)));

% Do feedforward calculations
if strcmp(status.approximation, 'mc'),
  x = mc_feedfw(sources, net);
  fs = probdist(x.e, x.var);
else
  x = feedfw(sources, net, status.approximation);
  fs = probdist(x{4}.e, x{4}.var);
end

cost = kl_static(net, params) + kl_batch(fs, sources, data, params);



function x = mc_feedfw(s, net)
% MC_FEEDFW  Monte Carlo feedfw

npoints = 400;
sp = repmat(s.e, [1, 1, npoints]) + ...
     repmat(sqrt(s.var), [1, 1, npoints]) .* ...
     randn(size(s.e, 1), size(s.e, 2), npoints);

xp = mlpfw_mc(sp, net);

x.e = mean(xp, 3);

xd = (xp - repmat(x.e, [1, 1, npoints])).^2;

x.var = mean(xd, 3);


function x = mlpfw_mc(s, net),
% MLPFW_MC  Sample through an MLP

filler = ones(1, size(s, 2));

x = zeros(size(net.w2.e, 1), size(s, 2), size(s, 3));

for k = 1:size(s, 3),
  temp1 = (net.w1.e+randn(size(net.w1.e)).*sqrt(net.w1.var))*s(:, :, k) + ...
	  (net.b1.e+randn(size(net.b1.e)).*sqrt(net.b1.var)) * filler;
  temp2 = feval(net.nonlin, temp1);
  x(:, :, k) = (net.w2.e+randn(size(net.w2.e)).*sqrt(net.w2.var))*temp2 + ...
      (net.b2.e+randn(size(net.b2.e)).*sqrt(net.b2.var)) * filler;
end
