/*

  This file is part of NetResponse algorithm. (C) Leo Lahti 2008-2010.

  This file is based on the Agglomerative Independent Variable Group
  Analysis package, Copyright (C) 2001-2007 Esa Alhoniemi, Antti
  Honkela, Krista Lagus, Jeremias Seppa, Harri Valpola, and Paul
  Wagner
  
  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2, or (at your option)
  any later version.
 
  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License (included in file License.txt in the
  program package) for more details.
*/

#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include <math.h>
#include "mex.h"
#include "matrix.h"
 
#define DIGAMMA_S 1e-5
#define DIGAMMA_C 8.5
#define DIGAMMA_S3 1.0/12
#define DIGAMMA_S4 1.0/120
#define DIGAMMA_S5 1.0/252
#define DIGAMMA_D1 -0.5772156649

#define PI 3.1415926535
#define POW2(x) ((x) * (x))

/************************************************************/
/* digamma function (by Antti Honkela)                      */

double 
digamma(double x) {
  double y = 0.0, r = 0.0, xn = x;

  if (xn <= 0) {
    return mxGetNaN();
  }
  
  if (xn <= DIGAMMA_S)
    y = DIGAMMA_D1 - 1.0 / xn;
  else {
    while (xn < DIGAMMA_C) {
      y  -= 1.0 / xn;
      xn += 1.0;
    }
    r = 1.0 / xn;
    y += log(xn) - .5 * r;
    r = POW2(r);
    y -= r * (DIGAMMA_S3 - r * (DIGAMMA_S4 - r*DIGAMMA_S5));
  }
 
  return y;
}                                                                              


/************************************************************/

void
compute_variance(int ncentroids, int dim1, double *Ksi_alpha, 
                 double *Ksi_beta, double **S2_x, double **Ksi_log) {
  register int i, ind, k;
  
  for (i = 0; i < ncentroids; i++) 
    for (k = 0; k < dim1; k++) {
      ind = k * ncentroids + i;
      S2_x[i][k]    = Ksi_beta[ind]/Ksi_alpha[ind];
      Ksi_log[i][k] = digamma(Ksi_alpha[ind])-log(Ksi_beta[ind]);
      
      if( S2_x[i][k] < 1e-100 ) S2_x[i][k] = 1e-100;
    }
  
  return;
}

/************************************************************/
void
compute_tempmat(long datalen, int dim1, int dim2, int ncentroids,
		double **Temp, double *data1, int **data2_int,
		double *Mu_bar, double *Mu_tilde, double **S2_x,
                double **Ksi_log, double ***U_hat_table, double *Ns,
		double implicit_noisevar, double *log_lambda) {
  register int i, k;
  long         ind, j,t;
  double term;
  
  for (i = 0; i < ncentroids; i++) {
    for (j = 0; j < datalen; j++) {
      Temp[i][j] = 0.0;
      for (k = 0; k < dim1; k++) {
	ind  = k * ncentroids + i;
	Temp[i][j] += ((Mu_tilde[ind]+POW2(data1[k*datalen + j]-Mu_bar[ind]) + implicit_noisevar)/
		       S2_x[i][k]) - Ksi_log[i][k];
      }
      Temp[i][j] /= 2.0;
    }
  }
  for(j=0;j<dim2;j++){
    for(i=0;i<ncentroids;i++){
      term=0.0;
      for(k=0;k<(int)(Ns[j]);k++){
	term += U_hat_table[j][i][k]; 
	U_hat_table[j][i][k]=digamma(U_hat_table[j][i][k]);
      }
      term=digamma(term);
      for (t=0;t<datalen;t++){
	Temp[i][t] += (term - U_hat_table[j][i][data2_int[j][t]]);
      }
    }
  }

  for (i = 0; i < ncentroids; i++) {
    for (j = 0; j < datalen; j++) {
      log_lambda[i * datalen + j] += -dim1*log(2*M_PI)/2 - Temp[i][j];
    }
  }
  return;
}


/***
function log_lambda = mk_log_lambda(data, hp_posterior, hp_prior, opts);
% log_lambda: N*K
% q(z_n=c|x_n) = lambda_n_c / sum_c lambda_n_c

[N,D] = size(data.given_data.data);
K = size(hp_posterior.Mu_bar, 1);

log_lambda = zeros(N,K);
for c=1:K
  E_log_p_of_z_given_other_z_c = ...
      psi(hp_posterior.gamma(1,c)) ...
      - psi(sum(hp_posterior.gamma(:,c),1)) ...
      + sum(psi(hp_posterior.gamma(2,[1:c-1])) - psi(sum(hp_posterior.gamma(:,[1:c-1]),1)), 2);
  log_lambda(:,c) = E_log_p_of_z_given_other_z_c;
end
***/
void log_p_of_z_given_other_z_c(int datalen, long ncentroids,
				double *post_gamma, double *log_lambda)
{
  register int c, i;
  double E_log_p;

  for (c=0; c<ncentroids; c++) {
    E_log_p = digamma(post_gamma[2*c]) - digamma(post_gamma[2*c] + post_gamma[2*c+1]);
    for (i=0; i<c; i++) {
      E_log_p += digamma(post_gamma[2*i+1]) - digamma(post_gamma[2*i] + post_gamma[2*i+1]);
    }
    for (i=0; i<datalen; i++) {
      log_lambda[c*datalen+i] = E_log_p;
    }
  }

  return;
}



void fix_lambda(int ncentroids, long datalen, double *prior_alpha, double *log_lambda)
{
  register int i;
  double correction;

  correction = log(1 - exp(digamma(*prior_alpha) - digamma(1 + *prior_alpha)));
  for (i=0; i<datalen; i++) {
    log_lambda[(ncentroids-1)*datalen + i] -= correction;
  }

  return;
}



void
allocate_memory(long datalen,int ncentroids,int dim1,int dim2,double ***S2_x,
                double ***Ksi_log, double ***Temp,
		double ****U_hat_table,int ***data2_int,double *Ns) {
  register int i,j;
  
  *Temp    = (double **)malloc(ncentroids * sizeof(double*));
  if (dim1) {
    *S2_x    = (double **)malloc(ncentroids * sizeof(double*));
    *Ksi_log = (double **)malloc(ncentroids * sizeof(double*));
  }
  if (dim2) {
    *U_hat_table=(double ***)malloc(dim2 * sizeof(double*));
    *data2_int=(int **)malloc(dim2 * sizeof(int*));
  }
  for (i = 0; i < ncentroids; i++) {
    (*Temp)[i]    = (double *)malloc(datalen * sizeof(double));
    if (dim1) {
      (*S2_x)[i]    = (double *)malloc(dim1 * sizeof(double));
      (*Ksi_log)[i] = (double *)malloc(dim1 * sizeof(double));
    }
  }
  for (j=0;j<dim2;j++){
    (*data2_int)[j]  = (int *)malloc(datalen * sizeof(int));
    (*U_hat_table)[j]=(double**)malloc(ncentroids * sizeof(double*));
    for (i=0;i<ncentroids;i++) {
      (*U_hat_table)[j][i] =(double *)malloc(((int)(Ns[j]))*sizeof(double));
    }
  }
    
  return;
}
 
/************************************************************/

void
free_memory(int ncentroids,int dim1, int dim2, double ***Temp, double ***W, 
            double ***S2_x, double ***Ksi_log,
		double ****U_hat_table, int ***data2_int) {
  register int i,j;
  for (j=0;j<dim2;j++){
    for (i=0;i<ncentroids;i++) {
      free((*U_hat_table)[j][i]);
    }
    free((*data2_int)[j]);
    free((*U_hat_table)[j]);
  }

  for (i = 0; i < ncentroids; i++) { 
    free((*Temp)[i]); 
    if (dim1){
      free((*S2_x)[i]); 
      free((*Ksi_log)[i]);
    }
  }
  
  free(*Temp);
  if (dim1) {
    free(*S2_x);
    free(*Ksi_log);
  }
  if (dim2) {
    free(*U_hat_table);
    free(*data2_int);
  }
  return;
}



/***
function log_lambda = mk_log_lambda(data, hp_posterior, hp_prior, opts);
% log_lambda: N*K
% q(z_n=c|x_n) = lambda_n_c / sum_c lambda_n_c

[N,D] = size(data.given_data.data);
K = size(hp_posterior.Mu_bar, 1);

log_lambda = zeros(N,K);
for c=1:K
  E_log_p_of_z_given_other_z_c = ...
      psi(hp_posterior.gamma(1,c)) ...
      - psi(sum(hp_posterior.gamma(:,c),1)) ...
      + sum(psi(hp_posterior.gamma(2,[1:c-1])) - psi(sum(hp_posterior.gamma(:,[1:c-1]),1)), 2);
  log_lambda(:,c) = E_log_p_of_z_given_other_z_c;
end

%%% IVGA-specific
types    = data.given_data.types;
X1       = data.given_data.data(:, (types == 1)); % real-valued
X2       = data.given_data.data(:, (types == 2)); % nominal
S        = data.given_data.S(types == 2);

[N,M1]=size(X1);
M2=size(X2,2);

Ksi_log   = (psi(hp_posterior.Ksi_alpha)-log(hp_posterior.Ksi_beta));
S2_x      = hp_posterior.Ksi_beta ./ hp_posterior.Ksi_alpha;

temp=zeros(K,N);
for j=1:M2,
  temp=temp + psi(hp_posterior.Uhat{j}(:, X2(:, j))) ...
       - repmat(psi(sum(hp_posterior.Uhat{j},2)),1,N);
end
for j=1:M1
  temp = temp - 0.5 * ...
	 gminus(...
	     grdivide(...
		 gplus(hp_posterior.Mu_tilde(:,j), ...
		       gminus(X1(:,j)', hp_posterior.Mu_bar(:,j)).^2), ...
		 S2_x(:,j)), ...
	     Ksi_log(:,j));
end

log_lambda = log_lambda - M1*log(2*pi)/2 + temp';
  
if isequal(opts.algorithm, 'vdp')
  log_lambda(:,end) = log_lambda(:,end) - log(1- exp(psi(hp_prior.alpha) - psi(1+hp_prior.alpha)));
end
***/


void
vdp_mk_log_lambda(double *Mu_mu, double *S2_mu, double *Mu_bar, double *Mu_tilde, 
		  double *Alpha_ksi, double *Beta_ksi, 
		  double *Ksi_alpha, double *Ksi_beta, 
		  double *post_gamma, double *log_lambda, double *prior_alpha,
		  double *U_p, mxArray *U_hat,
		  long datalen, int dim1, int dim2, double *data1, double *data2, 
		  double *Ns, int ncentroids, 
		  double implicit_noisevar) {
  register long i, j, t;
  register int k;
  double  *U_hat_j;
  mxArray  *U_hat_j_mxarray;
  int          iter = 0;
  double       **W, **Temp, **S, **S2_x,**Ksi_log,***U_hat_table;
  int          **data2_int;

  allocate_memory(datalen, ncentroids, dim1,dim2, &S2_x, &Ksi_log, &Temp,
		  &U_hat_table,&data2_int,Ns );

  for (j=0;j<dim2;j++){
    for(t=0;t<datalen;t++)
      data2_int[j][t]=((int)(data2[j*datalen+t]))-1;
    U_hat_j_mxarray = mxGetCell(U_hat,  (int)j);
    U_hat_j=mxGetPr(U_hat_j_mxarray);
    for(i=0;i<ncentroids;i++){
      for(k=0;k<Ns[j];k++){
	U_hat_table[j][i][k]=U_hat_j[k*ncentroids+i];
      }
    }
  }
  
  if (dim1) 
    compute_variance(ncentroids, dim1, Ksi_alpha, Ksi_beta, S2_x, Ksi_log);

  log_p_of_z_given_other_z_c(datalen, ncentroids, post_gamma, log_lambda);

  compute_tempmat(datalen,dim1,dim2,ncentroids,Temp,data1,data2_int,
		  Mu_bar,Mu_tilde,S2_x,Ksi_log,U_hat_table,Ns,
		  implicit_noisevar, log_lambda);
    
  fix_lambda(ncentroids, datalen, prior_alpha, log_lambda);
    
  free_memory(ncentroids, dim1,dim2, &Temp, &W, &S2_x, &Ksi_log, &U_hat_table, &data2_int);
  return;
}





/************************************************************/
/* bridge function                                          */

void
mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
  long datalen, ind;
  int  i, j, dim1, dim2, c, no_of_fields = 14, iters, verbose, ncentroids;
  double *Mu_mu, *S2_mu, *Mu_bar, *Mu_tilde, 
    *Alpha_ksi, *Beta_ksi, *Ksi_alpha, *Ksi_beta, *U_p, *prior_alpha,
    *post_gamma, *log_lambda;
  double *data1, *run_iters;
  double *data2;
  mxArray *U_hat, *U_hat_out;
  mxArray *given_data, *X1, *X2;
  double *Ns;
  double implicit_noisevar;
  
  /******************** input variables ********************/
  
  /* data */

  given_data = mxGetField(prhs[0], 0, "given_data");
  if (!given_data)
    printf("Error: given_data\n");
  X1         = mxGetField(given_data, 0, "X1");
  if (!X1)
    printf("Error: X1.\n");
  data1      = mxGetPr(X1);
  dim1       = (int)mxGetN(X1);
  datalen    = (long)mxGetM(X1);
  X2         = mxGetField(given_data, 0, "X2");
  if (!X2)
    printf("Error: X2.\n");
  data2      = mxGetPr(X2);
  dim2       = (int)mxGetN(X2);
  Ns         = mxGetPr(mxGetField(given_data, 0, "realS")); 
  implicit_noisevar = mxGetScalar(mxGetField(prhs[3], 0, "implicit_noisevar"));

  /* initial values of model parameters */

  if (dim1) {
    Mu_mu     = mxGetPr(mxGetField(prhs[2], 0, "Mu_mu"));
    S2_mu     = mxGetPr(mxGetField(prhs[2], 0, "S2_mu"));
    Mu_bar    = mxGetPr(mxGetField(prhs[1], 0, "Mu_bar"));
    Mu_tilde  = mxGetPr(mxGetField(prhs[1], 0, "Mu_tilde"));
    Alpha_ksi = mxGetPr(mxGetField(prhs[2], 0, "Alpha_ksi"));
    Beta_ksi  = mxGetPr(mxGetField(prhs[2], 0, "Beta_ksi"));
    Ksi_alpha = mxGetPr(mxGetField(prhs[1], 0, "Ksi_alpha"));
    Ksi_beta  = mxGetPr(mxGetField(prhs[1], 0, "Ksi_beta"));
  }
  if (dim2) {
    U_p       = mxGetPr(mxGetField(prhs[2], 0, "U_p"));
    U_hat     =         mxGetField(prhs[1], 0, "Uhat");
  }
  prior_alpha = mxGetPr(mxGetField(prhs[2], 0, "alpha"));
  post_gamma = mxGetPr(mxGetField(prhs[1], 0, "gamma"));

  ncentroids = mxGetM(mxGetField(prhs[1], 0, "Mu_bar"));

  /******************** output variables ********************/
  /* values of model parameters after iteration */

  plhs[0]    = mxCreateDoubleMatrix(datalen, ncentroids, mxREAL);
  log_lambda = mxGetPr(plhs[0]);

  vdp_mk_log_lambda(Mu_mu, S2_mu, Mu_bar, Mu_tilde, 
		    Alpha_ksi, Beta_ksi, Ksi_alpha, Ksi_beta, 
		    post_gamma, log_lambda, prior_alpha,
		    U_p, U_hat,
		    datalen, dim1, dim2, data1, data2, 
		    Ns, ncentroids, implicit_noisevar);
  
  return;
}
 
/************************************************************/




