function OutStruct = netresponse(datamatrix, network, varargin)

%  Global modeling of network responses.
%
%  res = netsign(dataset, network, ['arg1', val1, 'arg2', val2, ...])
%  
%  INPUT:
%
%  dataset: 'samples x genes' matrix
%  network: binary network matrix
%
%  This function returns a set of subnetworks and their estimated
%  context-specific co-expression patterns.
%
%  res.costs are the cost function values at each state
%  res.moves has the indices of groups joined at each state in its columns
%  res.groupings holds the groupings at each level of the hierarchy
%  res.models has compressed representations of the models from each step
%
%  The following optional keyword arguments are accepted:
%    'impl_noise' is the variance of the implicit noise (default: 0)
%    'update_hyperparams'
%                 specifies if the hyperparameters are updated in
%                 DP mixture modeller                   (default: no = 0)
%    'maxiters'   is the maximum number of iterations in classic mixture model
%                 training                              (default: 1000)
%    'initcomps'  is the initial number of mixture components for
%                 new classic models                    (default: 2)
%    'maxcomps'   is the maximum number of mixture components for
%                 classic models                        (default: 100)
%    'maxsubnetsize' is the maximum allowed subnetwork size (default: 20)
%
% Copyright (C) 2008-2010 Leo Lahti
%
% This program is free software; you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation; either version 2, or (at your option)
% any later version.
%
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
% GNU General Public License for more details.
%
% This file is based on the Variational Dirichlet Process
% Gaussian Mixture Model implementation,
% Copyright (C) 2007 Kenichi Kurihara. All rights reserved.
% and
% AIVGA Agglomerative Independent Variable Group Analysis package (v. 1.0)
% Copyright (C) 2001-2007 Esa Alhoniemi, Antti Honkela, Krista Lagus,
% Jeremias Seppa, Harri Valpola, and Paul Wagner
% For more details on AIVGA, see http://www.cis.hut.fi/projects/ivga/

%  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
%  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
%  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
%  FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
%  COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
%  INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
%  (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
%  SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
%  HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
%  STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
%  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
%  OF THE POSSIBILITY OF SUCH DAMAGE.



% Read the arguments
if (length(varargin) == 1) & (isstruct(varargin{1})),
  args = varargin{1};
elseif (mod(length(varargin), 2) ~= 0),
  if ~(isstruct(varargin{1})),
    error('Keyword arguments should appear in pairs');
  else
    args = varargin{1};
    for k=2:2:length(varargin),
      if ~ischar(varargin{k})
        error('Keyword argument names must be strings.');
      end
      eval(sprintf('args.%s = varargin{k+1};', varargin{k}));
    end
  end
else
  for k=1:length(varargin),
    if iscell(varargin{k}) && numel(varargin{k}) > 1,
      varargin{k} = {varargin{k}};
    end
  end
  args = struct(varargin{:});
end

% Get the values
[params.maxiter, args] = getargdef(args, 'maxiters', 1000);
[params.implicit_noisevar, args] = getargdef(args, 'impl_noise', 0);
[params.update_hyperparams, args] = getargdef(args, 'update_hyperparams', 0);
[params.initcomps, args] = getargdef(args, 'initcomps', 2);
[params.maxcomps, args] = getargdef(args, 'maxcomps', 100);
[params.maxsubnetsize, args] = getargdef(args, 'maxsubnetsize', 20);

% Report unused arguments
otherargs = fieldnames(args);
if length(otherargs) > 0,
  fprintf('Warning: unused arguments:\n');
  fprintf(' %s\n', otherargs{:});
end

if exist('vdp_mk_log_lambda', 'file') == 3 && ...
      exist('vdp_mk_hp_posterior', 'file') == 3 && ...
      exist('vdp_sumlogsumexp', 'file') == 3 && ...
      exist('vdp_softmax', 'file') == 3,
  params.usemex = 1;
else
  warning('MEX files not found, try using COMPILE to create them and speed up the algorithm a lot');
  params.usemex = 0;
end


%
% Format the data
%

dataset = format_data(datamatrix)

%
% Initialize
%


     [N,dim] = size(dataset.data);
     Nlog = log(N);
     dim0    = dim;
     C = 0;
     H = ones(1, dim)*Inf; % Store costs
     % diagonal contains number of parameters for corresponding model
     % off-diagonal tells number of parameters for the joint model
     % initial costs for the independent and joint models
     models    = cell(dim);
     costs = ones(dim)*Inf;
     Nparams = ones(dim)*Inf; 
     bic_ind   = ones(dim)*Inf;
     bic_joint = ones(dim)*Inf;
     delta   = ones(dim)*Inf;
     groupinghist    = {};   

%
% compute cost for each variable
%

for k = 1:dim
  [cost model]  = cost_a(dataset, k, params);
  H(k) = cost;  % -cost is lower bound for log(P(D|H))
  modelhist{k} = model;
  models{k,k}  = model;
  Nparams(k,k) = model.Nparams; % save number of parameters for this model
  bic = Nparams(k,k)*Nlog - 2*(-H(k)); % BIC for this model
  C       = C + bic;  % Total cost, previously: C + H(k)
  fprintf('\rComputing model for variable %4d/%4d...', k, dim);
end
Clist = [C];
fprintf('done.\n');

%
% compute costs for combined variable pairs
%

for a = 1:(dim-1)

  fprintf('\rComputing delta values for variable %4d/%4d...', a, dim);

  for b = (a+1):dim
    models{a,b} = NaN;

    if network(a,b) % Require that the combined groups are connected in the network

      [cost m] = cost_a(dataset, [a b], params);

      models{a,b} = m;
      costs(a,b) = cost; % Store cost for the joint model. 
      Nparams(a,b) = m.Nparams;       % number of parameters in the joint model

      % Compute BIC-value for two independent subnets vs. joint model 
      % Negative free energy (-cost) is (variational) lower bound for P(D|H)
      % Use it as an approximation for P(D|H)
      % Cost for the indpendent and joint models
      bic_ind(a,b)   = (Nparams(a,a) + Nparams(b,b))*Nlog - 2*(-H(a)-H(b)); % -cost is sum of two independent models (H: appr. log-likelihoods)
      bic_joint(a,b) = Nparams(a,b)*Nlog - 2*(-costs(a,b));

      % Replace AIVGA delta with BIC change
      % Note that this is ok since also BIC is additive
      % change (increase) of the total BIC 
      delta(a, b)  = bic_joint(a,b) - bic_ind(a,b);     % change (increase) of the total cost

      % Symmetric
      costs(b,a) = costs(a,b); 
      Nparams(b,a) = Nparams(a,b);
      bic_ind(b,a) = bic_ind(a,b);
      bic_joint(b,a) = bic_joint(a,b);
      delta(b,a)   = delta(a,b);

    end    
    models{b,a} = models{a,b};
  end

  G{a}        = a;
  delta(a, a) = Inf;

end
a = dim;
G{a}        = a;
delta(a, a) = Inf;

fprintf('done.\n');



%
% Compute the actual groupings 
%

move_cost_hist  = [0; 0; C];
groupinghist{1} = {G};   

%if off-diagonal is not zero there are still unconnected linked groups
for j = 2:dim0
  fprintf('\rMinimum delta %f',min(min(delta))) 

  % if there are still groups sharing a link and
  % improvement can be obtained by combining groups (there are connected items that have delta<0)

  %
  % Identify the best group pair
  %

  if (sum(sum(network))-sum(diag(network)))>0 & sum(sum((delta.*network)<0))>0, 
    fprintf('\rCombining groups, %4d group(s) left...', dim - 2); 

    %
    % Identify the best neighbor pair in the network
    %

    a=0;
    b=0;
    mindelta = Inf;
    for i = 1:dim
      %Check neighborgs
      iNeighs = network(i, :);
      if sum(network(i,[i+1:dim]))>0
        %Identify best neighbor
        neighInds = find(iNeighs);
        [x ztmp] = min(delta(i,neighInds));
        %Return ztmp to original index domain
        z = neighInds(ztmp);

        % require also x<0 since otherwise combining groups is 
        % worse than keeping them separate
        if (x < 0 & x < mindelta & network(i,z) & length([G{z} G{i}])<=params.maxsubnetsize)
          mindelta = x;     % NOTE: here x from delta refers to BIC change
	  b = z;            % we search for the indices (a,b) of the groups
	  a = i;            % that are joined next ..
        end
      end
    end

    temp                 = sort([a b]);
    a                    = temp(1);
    b                    = temp(2);
 
    %
    % Store results
    %

    C = C + mindelta;  
    Clist = [Clist C]; % TEMP
    move_cost_hist       = [move_cost_hist [a; b; C]];
    modelhist{dim0 +j-1} = models{a,b};


    % put the new group to a's place
    % done only for those variables for which this is needed.
    % For others, put Inf on the a neighborgs, see further in the code

    H(a) = costs(a, b); 

    % combine a and b in the network, remove self-link a-a
    network(a,:) = (network(a,:) | network(b,:)); % combine neighborghs for a and b
    network(:,a) = network(a,:);
    network(a,a) = 0;

    % number of parameters for the subnets
    m = models{a,b};
    Nparams(a,a) = m.Nparams;

    bic_ind(a,a) = bic_joint(a,b); 

    % remove group b
    % (a<b always; see above sort)

    removed_group_vars  = G{b};
    G    = [G(1:b-1) G(b+1:dim)];
    G{a} = sort([G{a} removed_group_vars]);
    groupinghist{j} = G;

    keepIndices = [1:(b-1) (b+1):dim];

    network = network(keepIndices,keepIndices); 

    % remove the merged group

    H(b) = [];

    bic_ind(:,b)   = [];
    bic_ind(b,:)   = [];
 
    bic_joint(b,:) = [];
    bic_joint(:,b) = [];

    costs(b,:) = [];
    costs(:,b) = [];
 
    Nparams(b,:) = []; 
    Nparams(:,b) = []; 

    delta(:, b) = [];
    delta(b, :) = [];

    % Infinite joint costs etc with a for groups not linked to a
    % Note that for Nparams we need also a-a information

    % update dimensionality - NOTE: only after the network update!
    dim = dim - 1;

    nona = [(1:(a-1)) (a+1):dim];

    Nparams(a,nona) = Inf;
    Nparams(nona,a) = Inf;

    bic_ind(a,nona) = Inf;
    bic_ind(nona,a) = Inf;

    % Diagonal not needed anywhere, put Inf just to quarantee:

    costs(a,:) = Inf;
    costs(:,a) = Inf;

    bic_joint(a,:) = Inf;
    bic_joint(:,a) = Inf;

    delta(a,:) = Inf;
    delta(:,a) = Inf;

    % Compute new joint models for a and its neighborghs
    % remove previous group b

    % New joint models table
    models{a,a} = models{a,b};  
    models2 = cell(dim);
    cnt1 = 0;
    for k1=keepIndices
      cnt1 = cnt1 + 1;
      cnt2 = 0;
      for k2=keepIndices
        cnt2=cnt2+1;
        m = models{k1,k2};
        models2{cnt1,cnt2} = m; 
      end
    end
    models = models2;
 
    % NOTE: Links from a to other nodes need still be updated
    % and replaced by Inf for those that are not linked

    % Compute new joint costs etc

    for i = 1:dim  % not run at last iteration, update "distances"

      % NOTE: a<i always

      % don't compute combined model unless they are neighborghs
      models{a,i} = NaN;

      % compute combined model only if a and i are linked
      if (network(a,i) & length([G{a} G{i}])<=params.maxsubnetsize)

        [cost m] = cost_a(dataset, sort([G{a} G{i}]), params);

        models{a, i} = m;

        % Store cost for the joint model. 
        costs(a,i) = cost; 

        % number of parameters in the joint model
        Nparams(a,i) = m.Nparams;

        % Compute BIC-cost value for two independent subnets vs. joint model 
        % Number of parameters in the 'two independent subnets' model
        % Negative free energy (-cost) is (variational) lower bound for P(D|H)
        % Use it as an approximation for P(D|H)
        % Cost for the indpendent and joint models

	% -cost is sum of two independent models (H: appr. log-likelihoods)
        bic_ind(a,i)   = (Nparams(a,a) + Nparams(i,i))*Nlog - 2*(-H(a)-H(i)); 
        bic_joint(a,i) = Nparams(a,i)*Nlog - 2*(-costs(a,i));

        % change (increase) of the total BIC (cost)
        delta(a, i)  = bic_joint(a,i) - bic_ind(a,i);     

        % Symmetric
        costs(i,a)     = costs(a,i); 
        Nparams(i,a)   = Nparams(a,i);
        bic_ind(i,a)   = bic_ind(a,i);
        bic_joint(i,a) = bic_joint(a,i);
        delta(i, a)    = delta(a,i);

      end

      %Symmetric
      models{i,a}    = models{a,i};

    end

  else % no groups with links any more.
    fprintf('\rNo groups having links any more or no improvement possible on level %i',j);
    break;
  end
end

fprintf('\ndone.\n');

OutStruct.cost=move_cost_hist(3,:);
OutStruct.moves=move_cost_hist(1:2,:);
OutStruct.groupings=groupinghist;
OutStruct.models=modelhist;
OutStruct.Clist=Clist; % TEMP
OutStruct.BIC=diag(bic_ind);;
OutStruct.Nparams=Nparams;
OutStruct.H=H;


function [val, args] = getargdef(args, name, default),
  if isfield(args, name),
    eval(sprintf('val = args.%s;', name));
    args = rmfield(args, name);
  else
    val = default;
  end



function [flg, value] = isarg(varargin)

% [flg, value] = isarg(Argumentlist, Argument)
% [flg, value] = isarg(Argument)
%
% INPUT ARGUMENTS
%
%  Argumentlist   a cell array of arguments
%  Argument       argument
%
% OUTPUT ARGUMENTS
%
%  flg            0 or 1
%  value          value of the argument
%
% PURPOSE
%
%  Returns 1 if string argument is present in the Argumentlist. If
%  there are two output arguments, the next item in the Argumentlist
%  is also returned.
%
%  If argument is ignored, Argumentlist is read from variable
%  'varargin' in the caller's workspace.
%
%  If more than one occurence of Argument is found and there are two
%  output arguments, then the last one's value is returned.

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

if nargin == 1
  errstr = 'error(''Variable ''''varargin'''' does not exist in caller''''s workspace'')';
  Argumentlist = evalin('caller', 'varargin', errstr);
  Argument     = varargin{1}; 
else
  Argumentlist = varargin{1};
  Argument     = varargin{2};
end

if ~iscell(Argumentlist), Argumentlist = cellstr(Argumentlist); end

la  = length(Argumentlist);
flg = 0;

for i = la:-1:1
  A = Argumentlist{i};
  if ischar(A) & strcmp(A, Argument)
    flg = 1;
    if nargout == 2
      if i < la
	value = Argumentlist{i+1}; 
      else
	error(sprintf('No value for argument ''%s'' in the argument list',A));
      end      
    end
    break;
  end
end

if flg == 0 & nargout == 2, value = []; end


function result = isvector_( x )

% isvector() replacement for compatibility with older Matlabs.

result = ndims(x) == 2 && ...
         any( size(x) == 1 );

return;


function dataset = format_data (datamatrix)
     dataset = [];
     dataset.data = datamatrix;
     dataset.types = ones(1,size(dataset.data,2));
     dataset.X1 = dataset.data(:, (dataset.types == 1));
     dataset.X2 = dataset.data(:, (dataset.types == 2));
     dataset.S  = zeros(1,length(dataset.types));
     dataset.realS = dataset.S(dataset.types == 2);


function [cost, modelpack] = cost_a(D, vars, params)

  newd = D;
  newd.data = newd.data(:, vars);
  newd.types = newd.types(vars);
  newd.S  = newd.S(vars);
  newd.realS = newd.S(newd.types == 2);
  newd.X1 = newd.data(:, (newd.types == 1));
  newd.X2 = newd.data(:, (newd.types == 2));


  res = vdp_mixt(newd, struct('update_hyperparams', params.update_hyperparams, 'implicit_noisevar', params.implicit_noisevar, 'usemex', params.usemex));

  cost = res.free_energy;

  % Negative free energy is (variational) lower bound for P(D|H)
  % Use it as an approximation for P(D|H)

  % First check number of free parameters (k)
  d = length(vars);  % dimensionality

  % number of (real) mixture components (nonempty components only!)
  k = sum(sum(res.hp_posterior.q_of_z)>0); 

  % number of parameters (d-dimensional centroid + diagonal variance matrix + component weight for each k)
  % Used to penalize model complexity with a BIC-inspired term
  res.Nparams = k*(2*d + 1); 

  modelpack = res;





function [results] = vdp_mixt(given_data, opts)

start_time = clock;
if nargin == 1
  opts = struct();
end
if ~ isfield(opts, 'algorithm')
  % algorithm can be one of 'vdp', 'bj', 'cdp', 'csb' and 'non_dp'
  % vdp : variational DP
  % bj : Blei and Jordan
  % cdp : collapsed Dirichlet prior
  % csb : collapsed stick-breaking
  % non_dp : variational Bayes for Gaussian mixtures
  opts.algorithm = 'vdp';
end
if ~ isfield(opts, 'do_sort')
  opts.do_sort = '0';
end
if ~ isfield(opts, 'get_q_of_z')
  opts.get_q_of_z = 0;
end
if ~ isfield(opts, 'get_log_likelihood')
  opts.get_log_likelihood = 0;
end
if ~ isfield(opts, 'threshold')
  opts.threshold = 1.0e-5;
end
if ~ isfield(opts, 'initial_depth')
  opts.initial_depth = 3;
end
if ~ isfield(opts, 'initial_K')
  opts.initial_K = 1;
end
if ~ isfield(opts, 'ite')
  opts.ite = inf;
end
if ~ isfield(opts, 'do_split')
  opts.do_split = 0;
end
if ~ isfield(opts, 'do_merge')
  opts.do_merge = 0;
end
if ~ isfield(opts, 'do_greedy')
  opts.do_greedy = 1;
end
if ~ isfield(opts, 'max_target_ratio')
  opts.max_target_ratio = 0.5;
end
if ~ isfield(opts, 'init_of_split')
  % 'pc', 'rnd', 'rnd_close' or 'close_f'
  opts.init_of_split = 'pc'; % PCA initialization for splitting
  %opts.init_of_split = 'rnd';
end
if isfield(opts, 'seed')
  rand('state', opts.seed);
else
  seed = rand('state');
  results.seed = seed;
end
if ~ isfield(opts, 'implicit_noisevar')
  opts.implicit_noisevar = 0;
end
if ~isfield(opts, 'update_hyperparams')
  opts.update_hyperparams = 0;
end
if ~isfield(opts, 'quiet')
  opts.quiet = 1;
end

data.given_data = given_data;

% the hyperparameters of priors
hp_prior = mk_hp_prior(data, opts);

if isfield(opts, 'hp_posterior')
  if opts.get_q_of_z
    results.q_of_z = mk_q_of_z(data, opts.hp_posterior, hp_prior, opts);
  end
  if opts.get_log_likelihood
    results.log_likelihood = mk_log_likelihood(data, opts.hp_posterior, hp_prior, opts);
  end
  return
end

if isfield(opts, 'q_of_z')
  q_of_z = opts.q_of_z;
else
  q_of_z = rand_q_of_z(data, opts.initial_K, opts);
end

[hp_posterior, hp_prior] = mk_hp_posterior(data, q_of_z, hp_prior, opts, 1);
[free_energy, hp_posterior, hp_prior, data] = greedy(data, hp_posterior, hp_prior, opts);

results.algorithm = opts.algorithm;
results.elapsed_time = etime(clock, start_time);
results.free_energy = free_energy;
results.hp_prior = hp_prior;
results.hp_posterior = hp_posterior;
results.K = size(hp_posterior.Mu_bar, 1);
results.opts = opts;
if opts.get_q_of_z
  results.q_of_z = mk_q_of_z(data, hp_posterior, hp_prior, opts);
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function [free_energy, hp_posterior, hp_prior, data] = greedy(data, hp_posterior, hp_prior, opts);
free_energy = mk_free_energy(data, hp_posterior, hp_prior, opts);
disp_status(free_energy, hp_posterior, opts);
while 1
  if ~opts.quiet,
    disp('finding the best one....')
  end
  [new_free_energy, new_hp_posterior, new_data, c] = find_best_splitting(data, ...
                                                    hp_posterior, ...
                                                    hp_prior, opts);
  if c == -1
    break
  end
  if ~opts.quiet,
    disp(['finding the best one.... done.  component ' num2str(c) ' was split.'])
    disp_status(new_free_energy, new_hp_posterior, opts);
  end
  [new_free_energy, new_hp_posterior, new_hp_prior, new_data] = ...
      update_posterior2(new_data, ...
			new_hp_posterior, ...
			hp_prior, opts, opts.ite, 1, 1);
  if ~isfinite(new_free_energy)
    error('Free energy is not finite, please consider adding implicit noise or not updating the hyperparameters')
  end
  if free_energy_improved(free_energy, new_free_energy, 0, opts) == 0
    break
  end
  free_energy = new_free_energy;
  hp_posterior = new_hp_posterior;
  hp_prior = new_hp_prior;
  data = new_data;
end
disp_status(free_energy, hp_posterior, opts);


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [free_energy, hp_posterior, data, c] = find_best_splitting(data, ...
                                                  hp_posterior, ...
                                                  hp_prior, opts);
c_max = 10;
K = size(hp_posterior.Mu_bar, 1);
candidates = find(hp_posterior.Nc>2);
if isempty(candidates)
  c = 1;
end
q_of_z = mk_q_of_z(data, hp_posterior, hp_prior, opts);
new_free_energy = ones(1,max(candidates))*inf;
%%%%%%%%%%%%%%%%%%%%%
fc = mk_E_log_q_p_eta(data, hp_posterior, hp_prior, opts);
log_lambda = mk_log_lambda(data, hp_posterior, hp_prior, opts);
%%%%%%%%%%%%%%%%%%%%%
for c = candidates(1:min(c_max, length(candidates)))
  if ~opts.quiet,
    disp(['splitting ' num2str(c) '...'])
  end
  [new_data(c), new_q_of_z, info] = split(c, data, q_of_z, hp_posterior, hp_prior, opts);
  %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
  new_c = info.new_c;
  relating_n = find(sum(new_q_of_z(:,[c new_c]),2) > 0.5);
  if isempty(relating_n)
    continue
  end
  new_K = size(new_q_of_z, 2);
  sub_q_of_z = new_q_of_z(relating_n, [c new_c new_K]);
  %%% IVGA-specific
  sub_data.given_data = new_data(c).given_data;
  sub_data.given_data.data = new_data(c).given_data.data(relating_n, :);
  sub_data.given_data.X1 = new_data(c).given_data.X1(relating_n, :);
  sub_data.given_data.X2 = new_data(c).given_data.X2(relating_n, :);

  %%% /IVGA-specific
  sub_hp_posteior = mk_hp_posterior(sub_data, sub_q_of_z, hp_prior, opts, 0);
  [sub_f, sub_hp_posteior, dummy1, dummy2, sub_q_of_z] = ...
      update_posterior2(sub_data, ...
			sub_hp_posteior, ...
			hp_prior, opts, 10, 0, 0);
  if size(sub_q_of_z,2) < 3
    continue
  end
  if length(find(sum(sub_q_of_z,1)<1.0e-10)) > 1
    continue
  end
  new_log_lambda = log_lambda;
  sub_log_lambda = mk_log_lambda(new_data(c), sub_hp_posteior, hp_prior, opts);
  insert_indices = [c new_c new_K:(new_K+size(sub_q_of_z,2)-3)];
  new_log_lambda(:,insert_indices) = sub_log_lambda;
  new_fc = fc;
  new_fc(insert_indices) = mk_E_log_q_p_eta(sub_data, sub_hp_posteior, hp_prior, opts);
  new_free_energy(c) = mk_free_energy(new_data(c), sub_hp_posteior, hp_prior, opts, new_fc, new_log_lambda);
  new_q_of_z(relating_n,:) = 0;
  new_q_of_z(relating_n,insert_indices) = sub_q_of_z;
  new_q_of_z_cell{c} = new_q_of_z;
end
[free_energy, c] = min(new_free_energy);
if isinf(free_energy)
  c = -1;
  return
end
data = new_data(c);
[hp_posterior, hp_prior] = mk_hp_posterior(data, new_q_of_z_cell{c}, hp_prior, opts, 1);


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [new_data, new_q_of_z, info] = split(c, data, q_of_z, ...
                                              hp_posterior, hp_prior, opts)
% q_of_z: N*K
new_data = data;
if isequal(opts.init_of_split, 'pc')  % principal eigenvector
  arg1_data = new_data.given_data;
  %size(arg1_data.data)

  %dir = divide_by_principal_component(arg1_data, ...
  %                                    hp_posterior.B{c}/hp_posterior.eta(c), ...
  %                                    hp_posterior.m(:,c));
  
  %Center data / LL
  mydc = (arg1_data.data - repmat(mean(arg1_data.data), size(arg1_data.data,1), 1));
  principal_component = princomp(mydc);
  % project on the first principal component to get assignments
  dir = mydc*principal_component(:,1); 

  q_of_z_c1 = zeros(size(q_of_z,1),1);
  q_of_z_c2 = q_of_z(:,c);
  I = find(dir>=0);
  q_of_z_c1(I) = q_of_z(I,c);
  q_of_z_c2(I) = 0;
else
  q_of_z_c = q_of_z(:,c);
  if isequal(opts.init_of_split, 'rnd')  % random
    r = rand(size(q_of_z,1),1);
  elseif isequal(opts.init_of_split, 'rnd_close')  % make close clusters
    r = 0.5 + (rand(size(q_of_z,1),1)-0.5)*0.01;
  elseif isequal(opts.init_of_split, 'close_f')  % one is almost zero.
    r = 0.98 + rand(size(q_of_z,1),1)*0.01;
  else
    init_of_split = opts.init_of_split
    error('unknown option')
  end
  q_of_z_c1 = q_of_z_c.*r;
  q_of_z_c2 = q_of_z_c.*(1-r);
end
new_q_of_z = zeros(size(q_of_z,1), size(q_of_z,2)+1);
new_q_of_z(:,[1:end-2 end]) = q_of_z;
new_q_of_z(:,c) = q_of_z_c1;
new_c = size(new_q_of_z, 2) - 1;
new_q_of_z(:,new_c) = q_of_z_c2;
info.new_c = new_c;
%[new_data(c), new_q_of_z, info] = split(c, data, q_of_z, hp_posterior, hp_prior, opts);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [free_energy, hp_posterior, hp_prior, data, q_of_z] = update_posterior2(data, hp_posterior, hp_prior, opts, ite, do_sort, do_hyperparams);
% update q_of_z: N*K
if ~opts.quiet,
  disp(['### updating posterior ...'])
end
free_energy = inf;
if nargin < 5
  ite = inf;
end
if nargin < 6
  do_sort = 1;
end
i = 0;
last_Nc = 0;
start_sort = 0;
while 1
  i = i+1;
  [new_free_energy, log_lambda] = mk_free_energy(data, hp_posterior, hp_prior, opts);
  if ~isfinite(new_free_energy)
    error('Free energy is not finite, please consider adding implicit noise or not updating the hyperparameters')
  end
  disp_status(new_free_energy, hp_posterior, opts);
  if (~isinf(ite) && i>=ite) || ...
        (isinf(ite) && free_energy_improved(free_energy, new_free_energy, 0, opts) == 0)
    free_energy = new_free_energy;
    if do_sort && opts.do_sort && ~ start_sort
      start_sort = 1;
    else
      break
    end
  end
  last_Nc = hp_posterior.Nc;
  free_energy = new_free_energy;
  [q_of_z, data] = mk_q_of_z(data, hp_posterior, hp_prior, opts, log_lambda);
  % check if the last component is small enough
  if isequal(opts.algorithm, 'vdp') & sum(q_of_z(:,end)) > 1.0e-20
    q_of_z(:,end+1) = 0;
  end
  if start_sort
    q_of_z = sort_q_of_z(data, q_of_z, opts);
  end
  if isequal(opts.algorithm, 'vdp') & sum(q_of_z(:,end-1)) < 1.0e-10
    q_of_z(:,end-1) = [];
  end
  [hp_posterior, hp_prior] = mk_hp_posterior(data, q_of_z, hp_prior, opts, do_hyperparams);
end
% disp_status(free_energy, hp_posterior, opts);
if ~opts.quiet,
  disp(['### updating posterior ... done.'])
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [q_of_z, I] = sort_q_of_z(data, q_of_z, opts);
if ~opts.quiet,
  disp('sorting...')
end
Nc = sum(q_of_z, 1); % 1*K
if isequal(opts.algorithm, 'vdp')
  [dummy,I] = sort(Nc(1:end-1), 2, 'descend');
  I(end+1) = length(Nc);
else
  [dummy,I] = sort(Nc, 2, 'descend');
end
q_of_z = q_of_z(:,I);
if ~opts.quiet,
  disp('sorting... done.')
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function disp_status(free_energy, hp_posterior, opts);
if opts.quiet,
  return;
end
if isequal(opts.algorithm, 'vdp')
  Nc = hp_posterior.true_Nc;
else
  Nc = hp_posterior.Nc;
end
%fprintf('F=%.4f;   Nc=[', free_energy);
%fprintf('%0.5g ', Nc);
%fprintf('];\n');
disp(['F=' num2str(free_energy) ...
      ';    Nc=[' num2str(Nc, ' %0.5g ') ...
      '];'])


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function bool = free_energy_improved(free_energy, new_free_energy, warn_when_increasing, opts);
diff = new_free_energy - free_energy;
if abs(diff/free_energy) < opts.threshold
  bool = 0;
elseif diff > 0
  if warn_when_increasing
    if abs(diff/free_energy) > 1.0e-3
      error(['the free energy increased.  the diff is ' num2str(new_free_energy-free_energy)])
    else
      warning(['the free energy increased.  the diff is ' num2str(new_free_energy-free_energy)])
    end
  end
  bool = 0;
elseif diff == 0
  bool = 0
else
  bool = 1;
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function direction = divide_by_principal_component(data, covariance, mean);
N = size(data, 2);
if size(data,1) <= 16
  [V,D] = eig(covariance);
  [eig_val, principal_component_i] = max(diag(D));
  principal_component = V(:,principal_component_i);
else
  [principal_component,eig_val] = power_method(covariance);
end
direction = sum((data - repmat(mean, 1, N)).*repmat(principal_component, 1, N), 1);


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function q_of_z = rand_q_of_z(data, K, opts);
% q_of_z: N*K
N = size(data.given_data.data, 1);
if isequal(opts.algorithm, 'vdp')
  q_of_z = zeros(N, K+1);
else
  q_of_z = zeros(N, K);
end
q_of_z(:,1:K) = rand(N, K);
q_of_z = normalize(q_of_z, 2);


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [hp_posterior, hp_prior] = mk_hp_posterior(data, q_of_z, hp_prior, opts, do_hyperparams);
% the last component of q_of_z represents the infinite rest of components
% the last component is the prior.
% q_of_z: N*K
% q_of_z(:,end) is the rest of responsibilities.

if isequal(opts.algorithm, 'vdp') && opts.usemex
  hp_posterior = vdp_mk_hp_posterior(data, q_of_z, hp_prior, opts);
  if opts.update_hyperparams && do_hyperparams, 
    K  = size(q_of_z, 2);
    M1 = size(data.given_data.X1, 2);
    
    % Update some priors .. (and calculate Ksi_Log) :
    hp_prior.Mu_mu=(1/K)*sum(hp_posterior.Mu_bar,1);
    hp_prior.S2_mu=(1/K)*sum(hp_posterior.Mu_tilde + ...
			     ((hp_posterior.Mu_bar-repmat(hp_prior.Mu_mu,K,1)).^2),1);
    T=repmat(K,1,M1)./sum(hp_posterior.Ksi_alpha./hp_posterior.Ksi_beta,1);
    Ksi_log=psi(hp_posterior.Ksi_alpha)-log(hp_posterior.Ksi_beta);
    hp_prior.Alpha_ksi=hp_prior.Alpha_ksi ...
	+repmat(0.5,1,M1)./(psi(hp_prior.Alpha_ksi)-log(hp_prior.Alpha_ksi)) ...
	+repmat(0.5,1,M1)./(repmat(1/K,1,M1).*sum(-Ksi_log,1) - log(T));
  
    hp_prior.Beta_ksi=hp_prior.Alpha_ksi.*T;
  end
  return;
end

threshold_for_N = 1.0e-200;
K = size(q_of_z, 2);
[N,D] = size(data.given_data.data);
if isequal(opts.algorithm, 'vdp')
  true_Nc = sum(q_of_z, 1); % 1*K
  q_of_z(:,end) = 0;
end
Nc = sum(q_of_z, 1); % 1*K

%%% IVGA-specific
types    = data.given_data.types;
X1       = data.given_data.X1; % real-valued
X2       = data.given_data.X2; % nominal

[N,M1]=size(X1);
M2=size(X2,2);


if M2
  S = data.given_data.realS;

  for j=1:M2,
    for k=1:S(j),
      hp_posterior.Uhat{j}(:,k)= hp_prior.U_p(j) + q_of_z'*(X2(:,j)==k); % matrix prod. :) 
    end
  end
end
if M1
  S2_muM=repmat(hp_prior.S2_mu,K,1);
  FM=repmat(Nc',1,M1);

  S2_x = repmat(hp_prior.Beta_ksi ./ hp_prior.Alpha_ksi, K, 1);
  hp_posterior.Mu_bar = (S2_x .* repmat(hp_prior.Mu_mu,K,1) + S2_muM .* ...
			 (q_of_z'*X1))./(S2_x + S2_muM.*FM);
  hp_posterior.Mu_tilde = (S2_x.*S2_muM)./(S2_x + S2_muM.*FM);

  hp_posterior.Ksi_alpha = repmat(hp_prior.Alpha_ksi,K,1)+0.5*FM;
  for j=1:M1
    for k=1:K,
      hp_posterior.Ksi_beta(k,j)=hp_prior.Beta_ksi(j) + 0.5 * ...
	  sum(q_of_z(:, k).*(hp_posterior.Mu_tilde(k,j) ...
			     +(X1(:,j)-hp_posterior.Mu_bar(k,j)).^2 ...
			     + opts.implicit_noisevar));
    end

  end
  S2_x=hp_posterior.Ksi_beta ./ hp_posterior.Ksi_alpha;
  hp_posterior.Mu_bar = (S2_x .* repmat(hp_prior.Mu_mu,K,1) + S2_muM .* ...
			 (q_of_z'*X1))./(S2_x + S2_muM.*FM);
  hp_posterior.Mu_tilde = (S2_x.*S2_muM)./(S2_x + S2_muM.*FM);

  if opts.update_hyperparams && do_hyperparams,
    % Update some priors .. (and calculate Ksi_Log) :
    hp_prior.Mu_mu=(1/K)*sum(hp_posterior.Mu_bar,1);
    hp_prior.S2_mu=(1/K)*sum(hp_posterior.Mu_tilde + ...
			     ((hp_posterior.Mu_bar-repmat(hp_prior.Mu_mu,K,1)).^2),1);
    T=repmat(K,1,M1)./sum(hp_posterior.Ksi_alpha./hp_posterior.Ksi_beta,1);
    Ksi_log=psi(hp_posterior.Ksi_alpha)-log(hp_posterior.Ksi_beta);
    hp_prior.Alpha_ksi=hp_prior.Alpha_ksi ...
	+repmat(0.5,1,M1)./(psi(hp_prior.Alpha_ksi)-log(hp_prior.Alpha_ksi)) ...
	+repmat(0.5,1,M1)./(repmat(1/K,1,M1).*sum(-Ksi_log,1) - log(T));
  
    hp_prior.Beta_ksi=hp_prior.Alpha_ksi.*T;
  end
else
  hp_posterior.Mu_bar = zeros(K, 0);
end

%%% /IVGA-specific

if isequal(opts.algorithm, 'vdp')
  % gamma: 2*K
  hp_posterior.gamma = zeros(2,K);
  hp_posterior.gamma(1,:) = 1 + true_Nc;
  hp_posterior.gamma(2,:) = hp_prior.alpha + sum(true_Nc) - cumsum(true_Nc,2);
elseif isequal(opts.algorithm, 'bj')
  hp_posterior.gamma = zeros(2,K-1);
  hp_posterior.gamma(1,:) = 1 + Nc(1:K-1);
  hp_posterior.gamma(2,:) = hp_prior.alpha + sum(Nc) - cumsum(Nc(1:K-1),2);
elseif isequal(opts.algorithm, 'non_dp') | isequal(opts.algorithm, 'cdp')
  hp_posterior.tilde_alpha = hp_prior.alpha/K + Nc;
elseif isequal(opts.algorithm, 'csb')
  1;
else
  error('unknown algorithm')
end

hp_posterior.Nc = Nc; 
if isequal(opts.algorithm, 'vdp')
  hp_posterior.true_Nc = true_Nc;
end
hp_posterior.q_of_z = q_of_z;
%keyboard


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function fc = mk_E_log_q_p_eta(data, hp_posterior, hp_prior, opts);
% returns E[ log q(eta)/p(eta) ]_q
% fc : 1 by K
%D = size(hp_posterior.m, 1);

% IVGA-specific
types    = data.given_data.types;
X1       = data.given_data.X1; % real-valued
X2       = data.given_data.X2; % nominal
S        = data.given_data.realS;

[N,M1]=size(X1);
M2=size(X2,2);

K = size(hp_posterior.Mu_bar, 1);

l_codebook = - M1/2 * ones(K, 1);

for j=1:M2,
  l_codebook = l_codebook - (gammaln(hp_prior.U_p(j)*S(j))) + ...
      S(j)*gammaln(hp_prior.U_p(j)) + ...
      gammaln(sum(hp_posterior.Uhat{j},2)) - sum(gammaln(hp_posterior.Uhat{j}), 2);
  l_codebook=l_codebook + sum((hp_posterior.Uhat{j}-hp_prior.U_p(j)) .* ...
			      (psi(hp_posterior.Uhat{j}) - ...
			       repmat(psi(sum(hp_posterior.Uhat{j},2)),1,S(j))), 2);
end

if M1
  Ksi_log   = (psi(hp_posterior.Ksi_alpha)-log(hp_posterior.Ksi_beta));

  for j=1:M1,
    l_codebook=l_codebook + ...
	.5 * (log(hp_prior.S2_mu(j) ./ hp_posterior.Mu_tilde(:, j)) + ...
	      ((hp_posterior.Mu_bar(:, j) - hp_prior.Mu_mu(j)).^2 + ...
	       hp_posterior.Mu_tilde(:, j)) ./ hp_prior.S2_mu(j)) + ...
	gammaln(hp_prior.Alpha_ksi(j)) - ...
	gammaln(hp_posterior.Ksi_alpha(:, j)) + ...
	hp_posterior.Ksi_alpha(:, j).*log(hp_posterior.Ksi_beta(:, j)) - ...
	hp_prior.Alpha_ksi(j) .* log(hp_prior.Beta_ksi(j)) ...
	+ (hp_posterior.Ksi_alpha(:, j) - hp_prior.Alpha_ksi(j)) .* Ksi_log(:, j) ...
	+ (hp_prior.Beta_ksi(j) - hp_posterior.Ksi_beta(:, j)) ...
	.* hp_posterior.Ksi_alpha(:, j) ./ hp_posterior.Ksi_beta(:, j);
  end
end

fc = l_codebook';

%%% /IVGA-specific


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [free_energy, log_lambda] = mk_free_energy(data, hp_posterior, ...
                                                  hp_prior, opts, ...
                                                  fc, log_lambda);
if nargin == 4
  fc = mk_E_log_q_p_eta(data, hp_posterior, hp_prior, opts); % 1*K
  log_lambda = mk_log_lambda(data, hp_posterior, hp_prior, opts); % N*K
end
[N,K] = size(log_lambda);
if isequal(opts.algorithm, 'vdp') || isequal(opts.algorithm, 'bj')
  % note when bj,  if full hp_posterior is given, len_gamma = K - 1.
  len_gamma = size(hp_posterior.gamma, 2);
  if isequal(opts.algorithm, 'bj') && len_gamma ~= K - 1
    error('invalid length')
  end
  E_log_p_of_V = ...
      gammaln(sum(hp_posterior.gamma, 1)) ...
      - gammaln(1+hp_prior.alpha) ...
      - sum(gammaln(hp_posterior.gamma), 1) ...
      + gammaln(hp_prior.alpha) ...
      + ((hp_posterior.gamma(1,:)-1) ...
         .*(psi(hp_posterior.gamma(1,:))-psi(sum(hp_posterior.gamma,1)))) ...
      + ((hp_posterior.gamma(2,:)-hp_prior.alpha) ...
         .*(psi(hp_posterior.gamma(2,:))-psi(sum(hp_posterior.gamma,1))));
  extra_term = sum(E_log_p_of_V);
elseif isequal(opts.algorithm, 'non_dp')
  E_log_p_of_pi = ...
      sum(gammaln(hp_prior.alpha/K) ...
          - gammaln(hp_posterior.tilde_alpha) ...
          + hp_posterior.Nc.*(psi(hp_posterior.tilde_alpha) ...
                              - psi(sum(hp_posterior.tilde_alpha)))) ...
      + gammaln(sum(hp_prior.alpha+size(data,2))) - gammaln(hp_prior.alpha);
  extra_term = E_log_p_of_pi;
else
  error('unknown algorithm')
end
if opts.usemex
  free_energy = extra_term + sum(fc) - vdp_sumlogsumexp(log_lambda);
else
  free_energy = extra_term + sum(fc) - sum(log_sum_exp(log_lambda, 2), 1);
end



%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function log_lambda = mk_log_lambda(data, hp_posterior, hp_prior, opts);
% log_lambda: N*K
% q(z_n=c|x_n) = lambda_n_c / sum_c lambda_n_c

if isequal(opts.algorithm, 'vdp') && opts.usemex
  log_lambda = vdp_mk_log_lambda(data, hp_posterior, hp_prior, opts);
  return;
end

if isequal(opts.algorithm, 'vdp')
  if abs(hp_posterior.gamma(2,end) - hp_prior.alpha) > 1.0e-5
    hp_posterior.gamma(2,end)
    hp_prior.alpha
    diff = hp_prior.alpha - hp_posterior.gamma(2,end)
    error('must be alpha')
  end
end

[N,D] = size(data.given_data.data);
K = size(hp_posterior.Mu_bar, 1);

%psi_sum = sum(psi( repmat(hp_posterior.eta+1,D,1) - repmat([1:D]',1,K)*0.5 ), 1); % 1*K
log_lambda = zeros(N,K);
for c=1:K
  if isequal(opts.algorithm, 'vdp')
    E_log_p_of_z_given_other_z_c = ...
        psi(hp_posterior.gamma(1,c)) ...
        - psi(sum(hp_posterior.gamma(:,c),1)) ...
        + sum(psi(hp_posterior.gamma(2,[1:c-1])) - psi(sum(hp_posterior.gamma(:,[1:c-1]),1)), 2);
  elseif isequal(opts.algorithm, 'bj')
    if c < K
      E_log_p_of_z_given_other_z_c = ...
          psi(hp_posterior.gamma(1,c)) ...
          - psi(sum(hp_posterior.gamma(:,c),1)) ...
          + sum(psi(hp_posterior.gamma(2,[1:c-1])) - psi(sum(hp_posterior.gamma(:,[1:c-1]),1)), 2);
    else
      E_log_p_of_z_given_other_z_c = sum(psi(hp_posterior.gamma(2,[1:c-1])) ...
                                         - psi(sum(hp_posterior.gamma(:,[1:c-1]),1)), 2);
    end
  elseif isequal(opts.algorithm, 'non_dp')
    % E[log pi] ; pi is the weight of mixtures.
    E_log_p_of_z_given_other_z_c = psi(hp_posterior.tilde_alpha(c)) ...
        - psi(sum(hp_posterior.tilde_alpha));
  elseif isequal(opts.algorithm, 'csb') | isequal(opts.algorithm, 'cdp')
    E_log_p_of_z_given_other_z_c = E_log_p_of_z_given_other_z(:,c)';
  else
    error('unknown algorithm')
  end
  log_lambda(:,c) = E_log_p_of_z_given_other_z_c;
end
%%% IVGA-specific
types    = data.given_data.types;
X1       = data.given_data.X1; % real-valued
X2       = data.given_data.X2; % nominal
S        = data.given_data.realS;

[N,M1]=size(X1);
M2=size(X2,2);

if M1,
  Ksi_log   = (psi(hp_posterior.Ksi_alpha)-log(hp_posterior.Ksi_beta));
  S2_x      = hp_posterior.Ksi_beta ./ hp_posterior.Ksi_alpha;
end

temp=zeros(K,N);
for j=1:M2,
  temp=temp + psi(hp_posterior.Uhat{j}(:, X2(:, j))) ...
       - repmat(psi(sum(hp_posterior.Uhat{j},2)),1,N);
end
for j=1:M1
  for k=1:K
    temp(k, :) = temp(k, :) - 0.5*((hp_posterior.Mu_tilde(k,j)+ ...
				    (X1(:,j)'- hp_posterior.Mu_bar(k,j)).^2) ...
				   ./ S2_x(k, j) ...
				   - Ksi_log(k, j));
  end

end

log_lambda = log_lambda - M1*log(2*pi)/2 + temp';
  

%%% /IVGA-specific

if isequal(opts.algorithm, 'vdp')
  log_lambda(:,end) = log_lambda(:,end) - log(1- exp(psi(hp_prior.alpha) - psi(1+hp_prior.alpha)));
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [q_of_z, data, log_lambda] = mk_q_of_z(data, hp_posterior, hp_prior, opts, log_lambda);
% q_of_z: N*K
if nargin == 4
  log_lambda = mk_log_lambda(data, hp_posterior, hp_prior, opts);
end
if opts.usemex
  q_of_z = vdp_softmax(log_lambda);
else
  q_of_z = exp(normalizeln(log_lambda, 1));
end


%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function hp_prior = mk_hp_prior(data, opts)

types    = data.given_data.types;
X1       = data.given_data.X1; % real-valued
X2       = data.given_data.X2; % nominal
S        = data.given_data.realS;
[n dim1] = size(X1);
dim2     = size(X2, 2);


Mean          = sum(X1, 1)/n;
Var           = sum((X1 - repmat(Mean,n,1)).^2, 1)/n; % variance
  
% priors for distribution of codebook vectors Mu ~ N(Mu_Mu, S2_Mu)..

hp_prior.Mu_mu = Mean;
hp_prior.S2_mu = Var;
if dim2
  hp_prior.U_p = .5 * ones(1, dim2);
else
  hp_prior.U_p = Inf;
end

% priors for data variance Ksi ~ Gamma(Alpha_ksi, Beta_ksi)

hp_prior.Alpha_ksi = .01 * ones(1, dim1);
hp_prior.Beta_ksi  = .01 * ones(1, dim1);

if isfield(opts, 'alpha')
  hp_prior.alpha = opts.alpha;
else
  hp_prior.alpha = 1;
end


function m=normalize(m,dim)
% return m normalized with 'dimension'
%
% e.g. 0
% m : i by j by k by ...
% sum(normalize(m,2), 2) -> ones(i, 1, k, ...)
%

dims = ones(1, ndims(m));
dims(dim) = size(m, dim);
m = m ./ repmat(sum(m, dim), dims);



function val = log_sum_exp(x,dim,y);
[x_max, i] = max(x, [], dim);
dims = ones(1, ndims(x));
dims(dim) = size(x, dim);
x = x - repmat(x_max, dims);
val = x_max + log(sum(exp(x), dim));




function M =normalizeln(M ,dimension);
  M = lpt2lcpt(M, dimension);


function lcpt=lpt2lcpt(lpt,dimension);

the_other_dimension=-(dimension-1.5)+1.5;
lpt=permute(lpt,[dimension,the_other_dimension]);
% now we can calculate as if dimension=1.

log_sum_exp_lpt = log_sum_exp(lpt,2); % Mx1
lcpt = lpt - repmat(log_sum_exp_lpt,1,size(lpt,2));

lcpt=permute(lcpt,[dimension,the_other_dimension]);

