/*

  This file is part of NetResponse algorithm. (C) Leo Lahti 2008-2010.

  This file is based on the Agglomerative Independent Variable Group
  Analysis package, Copyright (C) 2001-2007 Esa Alhoniemi, Antti
  Honkela, Krista Lagus, Jeremias Seppa, Harri Valpola, and Paul
  Wagner
 
  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2, or (at your option)
  any later version.
 
  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License (included in file License.txt in the
  program package) for more details.
*/

#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include <math.h>
#include "mex.h"
#include "matrix.h"
 
#define DIGAMMA_S 1e-5
#define DIGAMMA_C 8.5
#define DIGAMMA_S3 1.0/12
#define DIGAMMA_S4 1.0/120
#define DIGAMMA_S5 1.0/252
#define DIGAMMA_D1 -0.5772156649

#define PI 3.1415926535
#define POW2(x) ((x) * (x))

/************************************************************/
/* digamma function (by Antti Honkela)                      */

double 
digamma(double x) {
  double y = 0.0, r = 0.0, xn = x;

  if (xn <= 0) {
    return mxGetNaN();
  }
  
  if (xn <= DIGAMMA_S)
    y = DIGAMMA_D1 - 1.0 / xn;
  else {
    while (xn < DIGAMMA_C) {
      y  -= 1.0 / xn;
      xn += 1.0;
    }
    r = 1.0 / xn;
    y += log(xn) - .5 * r;
    r = POW2(r);
    y -= r * (DIGAMMA_S3 - r * (DIGAMMA_S4 - r*DIGAMMA_S5));
  }
 
  return y;
}                                                                              



void compute_nc(int ncentroids, long datalen, double *true_Nc,
		double *q_of_z, double *Nc)
{
  register int i, j;

  for (i=0; i<ncentroids; i++) {
    true_Nc[i] = 0.0;
    for (j=0; j<datalen; j++) {
      true_Nc[i] += q_of_z[i*datalen + j];
    }
    Nc[i] = true_Nc[i];
  }
  Nc[ncentroids-1] = 0.0;
  for (j=0; j<datalen; j++) {
    q_of_z[(ncentroids-1)*datalen+j] = 0;
  }

  return;
}


void
update_centroids(long datalen, int ncentroids, int dim1, int dim2,
		 double *data1, int **data2_int,
                 double *Nc, double *q_of_z, double *Mu_mu, double *S2_mu,
                 double *Mu_bar, double *Mu_tilde,
		 double *Ksi_alpha, double *Ksi_beta, double *Alpha_ksi,
		 double *Beta_ksi, double implicit_noisevar,
		 double *U_p, double ***U_hat_table, double *Ns) {
  register int i, k;
  register long ind, j, t;
  double term, term2, term3, s2x, s2x_new;
  
  
  if (dim1){
    for (k = 0; k < dim1; k++) {
      s2x = Beta_ksi[k] / Alpha_ksi[k];
      for (i = 0; i < ncentroids; i++) {
	term = 0.0;
	ind  = k * ncentroids + i;
	for (j = 0; j < datalen; j++)
	  term += q_of_z[i*datalen + j] * data1[k * datalen + j];
	term2         = s2x + S2_mu[k] * Nc[i];
	Mu_bar[ind]   = ((s2x * Mu_mu[k]) + (S2_mu[k] * term))/term2;
	Mu_tilde[ind] = (s2x * S2_mu[k])/term2;
	Ksi_alpha[ind] = Alpha_ksi[k] + 0.5 * Nc[i];
	term3 = 0.0;
	for (j = 0; j < datalen; j++)
	  term3 += q_of_z[i*datalen + j] *
	    (Mu_tilde[ind] + POW2(data1[k * datalen + j] - Mu_bar[ind]) +
	     implicit_noisevar);
	Ksi_beta[ind] = Beta_ksi[k] + 0.5 * term3;

	s2x_new = Ksi_beta[ind] / Ksi_alpha[ind];
	term2         = s2x_new + S2_mu[k] * Nc[i];
	Mu_bar[ind]   = ((s2x_new * Mu_mu[k]) + (S2_mu[k] * term))/term2;
	Mu_tilde[ind] = (s2x_new * S2_mu[k])/term2;
      }
    }
  }

  for(j=0;j<dim2;j++){
    for(i=0;i<ncentroids;i++){   
      for(k=0;k<(int)(Ns[j]);k++)
      	U_hat_table[j][i][k]=U_p[j];
      for(t=0;t<datalen;t++){	     /***************/
	U_hat_table[j][i][data2_int[j][t]] += q_of_z[i*datalen+t];
      }
    }
  }  
  return;
}


void update_gamma(int ncentroids, double *true_Nc, double *prior_alpha,
		  double *post_gamma)
{
  register int i, j;
  double ncsum, nccumsum;

  ncsum = 0.0;
  for (i=0; i<ncentroids; i++) {
    ncsum += true_Nc[i];
  }
  nccumsum = 0.0;
  for (i=0; i<ncentroids; i++) {
    nccumsum += true_Nc[i];
    post_gamma[2*i] = 1 + true_Nc[i];
    post_gamma[2*i+1] = *prior_alpha + ncsum - nccumsum;
  }

  return;
}



void
allocate_memory(long datalen,int ncentroids,int dim2,
                double ****U_hat_table,int ***data2_int,double *Ns) {
  register int i,j;
  
  if (dim2) {
    *U_hat_table=(double ***)malloc(dim2 * sizeof(double*));
    *data2_int=(int **)malloc(dim2 * sizeof(int*));
  }
  for (j=0;j<dim2;j++){
    (*data2_int)[j]  = (int *)malloc(datalen * sizeof(int));
    (*U_hat_table)[j]=(double**)malloc(ncentroids * sizeof(double*));
    for (i=0;i<ncentroids;i++) {
      (*U_hat_table)[j][i] =(double *)malloc(((int)(Ns[j]))*sizeof(double));
    }
  }
    
  return;
}
 
/************************************************************/

void
free_memory(int ncentroids, int dim2, 
	    double ****U_hat_table, int ***data2_int) {
  register int i,j;
  for (j=0;j<dim2;j++){
    for (i=0;i<ncentroids;i++) {
      free((*U_hat_table)[j][i]);
    }
    free((*data2_int)[j]);
    free((*U_hat_table)[j]);
  }

  if (dim2) {
    free(*U_hat_table);
    free(*data2_int);
  }
  return;
}


void
vdp_mk_hp_posterior(double *Mu_mu, double *S2_mu, double *Mu_bar, double *Mu_tilde, 
		    double *Alpha_ksi, double *Beta_ksi, 
		    double *Ksi_alpha, double *Ksi_beta, 
		    double *post_gamma, double *prior_alpha,
		    double *U_p, mxArray *U_hat,
		    long datalen, int dim1, int dim2, double *data1, double *data2, 
		    double *Ns, int ncentroids, 
		    double implicit_noisevar, double *q_of_z,
		    double *Nc, double *true_Nc) {
  register long i, j, t;
  register int k;
  double  *U_hat_j;
  mxArray  *U_hat_j_mxarray;
  int          iter = 0;
  double       ***U_hat_table;
  int          **data2_int;

  allocate_memory(datalen, ncentroids, dim2,
		  &U_hat_table,&data2_int,Ns );

  for (j=0;j<dim2;j++){
    for(t=0;t<datalen;t++)
      data2_int[j][t]=((int)(data2[j*datalen+t]))-1;
  }

  compute_nc(ncentroids, datalen, true_Nc, q_of_z, Nc);
  
  update_centroids(datalen, ncentroids, dim1, dim2,
		   data1, data2_int,
		   Nc, q_of_z, Mu_mu, S2_mu,
		   Mu_bar, Mu_tilde, 
		   Ksi_alpha, Ksi_beta, Alpha_ksi,
		   Beta_ksi, implicit_noisevar,
		   U_p, U_hat_table, Ns);

  update_gamma(ncentroids, true_Nc, prior_alpha, post_gamma);

  for (j=0;j<dim2;j++){
    U_hat_j_mxarray=mxCreateDoubleMatrix(ncentroids,(int)(Ns[j]),mxREAL);
    U_hat_j=mxGetPr(U_hat_j_mxarray);
    mxSetCell(U_hat, (int)j, U_hat_j_mxarray); 
    for (i=0;i<ncentroids;i++)
      for (k=0;k<(int)(Ns[j]);k++)
	U_hat_j[k*ncentroids+i]=U_hat_table[j][i][k];
  }

  free_memory(ncentroids, dim2, &U_hat_table, &data2_int);
  return;
}





/************************************************************/
/* bridge function                                          */

void
mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
  long datalen, ind;
  int  i, j, dim1, dim2, c, no_of_post_fields = 9, no_of_prior_fileds=6,
    iters, verbose, ncentroids;
  double *Mu_mu, *S2_mu, *Mu_bar, *Mu_tilde, 
    *Alpha_ksi, *Beta_ksi, *Ksi_alpha, *Ksi_beta, *U_p, *prior_alpha,
    *post_gamma;
  double *data1, *run_iters;
  double *data2;
  mxArray *U_hat, *U_hat_out;
  mxArray *given_data, *X1, *X2;
  double *Ns;
  double *q_of_z_in, *q_of_z, *Nc, *true_Nc;
  double implicit_noisevar;
  const char *posterior_fields[]={"Mu_bar","Mu_tilde",
				  "Ksi_alpha","Ksi_beta",
				  "gamma","Nc","true_Nc","q_of_z","Uhat"};
  const char *prior_fields[]={"Mu_mu","S2_mu",
			      "Alpha_ksi","Beta_ksi",
			      "alpha","U_p"};
  
  /******************** input variables ********************/
  
  /* data */

  given_data = mxGetField(prhs[0], 0, "given_data");
  if (!given_data)
    printf("Error: given_data\n");
  X1         = mxGetField(given_data, 0, "X1");
  if (!X1)
    printf("Error: X1.\n");
  data1      = mxGetPr(X1);
  dim1       = (int)mxGetN(X1);
  datalen    = (long)mxGetM(X1);
  X2         = mxGetField(given_data, 0, "X2");
  if (!X2)
    printf("Error: X2.\n");
  data2      = mxGetPr(X2);
  dim2       = (int)mxGetN(X2);
  Ns         = mxGetPr(mxGetField(given_data, 0, "realS")); 
  implicit_noisevar = mxGetScalar(mxGetField(prhs[3], 0, "implicit_noisevar"));

  /* initial values of model parameters */

  Mu_mu     = mxGetPr(mxGetField(prhs[2], 0, "Mu_mu"));
  S2_mu     = mxGetPr(mxGetField(prhs[2], 0, "S2_mu"));
  Alpha_ksi = mxGetPr(mxGetField(prhs[2], 0, "Alpha_ksi"));
  Beta_ksi  = mxGetPr(mxGetField(prhs[2], 0, "Beta_ksi"));
  U_p       = mxGetPr(mxGetField(prhs[2], 0, "U_p"));
  prior_alpha = mxGetPr(mxGetField(prhs[2], 0, "alpha"));

  q_of_z_in  = mxGetPr(prhs[1]);
  ncentroids = mxGetN(prhs[1]);

  /******************** output variables ********************/
  /* values of model parameters after iteration */

  plhs[0]    = mxCreateStructMatrix(1, 1, no_of_post_fields, posterior_fields);
  mxSetField(plhs[0], 0, posterior_fields[0], mxCreateDoubleMatrix(ncentroids, dim1, mxREAL)); 
  mxSetField(plhs[0], 0, posterior_fields[1], mxCreateDoubleMatrix(ncentroids, dim1, mxREAL));
  mxSetField(plhs[0], 0, posterior_fields[2], mxCreateDoubleMatrix(ncentroids, dim1, mxREAL));
  mxSetField(plhs[0], 0, posterior_fields[3], mxCreateDoubleMatrix(ncentroids, dim1, mxREAL));
  mxSetField(plhs[0], 0, posterior_fields[4], mxCreateDoubleMatrix(2, ncentroids, mxREAL));    
  mxSetField(plhs[0], 0, posterior_fields[5], mxCreateDoubleMatrix(1, ncentroids, mxREAL));    
  mxSetField(plhs[0], 0, posterior_fields[6], mxCreateDoubleMatrix(1, ncentroids, mxREAL));  
  mxSetField(plhs[0], 0, posterior_fields[7], mxCreateDoubleMatrix(datalen, ncentroids, mxREAL));
  mxSetField(plhs[0], 0, posterior_fields[8], mxCreateCellMatrix(1, dim2));
  Mu_bar     = mxGetPr(mxGetField(plhs[0], 0, posterior_fields[0]));      
  Mu_tilde   = mxGetPr(mxGetField(plhs[0], 0, posterior_fields[1]));      
  Ksi_alpha  = mxGetPr(mxGetField(plhs[0], 0, posterior_fields[2]));      
  Ksi_beta   = mxGetPr(mxGetField(plhs[0], 0, posterior_fields[3]));      
  post_gamma = mxGetPr(mxGetField(plhs[0], 0, posterior_fields[4]));
  Nc         = mxGetPr(mxGetField(plhs[0], 0, posterior_fields[5]));      
  true_Nc    = mxGetPr(mxGetField(plhs[0], 0, posterior_fields[6]));      
  q_of_z     = mxGetPr(mxGetField(plhs[0], 0, posterior_fields[7]));
  U_hat      =   mxGetField(plhs[0], 0, posterior_fields[8]);

  for (i=0; i<datalen*ncentroids; i++) {
    q_of_z[i] = q_of_z_in[i];
  }

  vdp_mk_hp_posterior(Mu_mu, S2_mu, Mu_bar, Mu_tilde, 
		      Alpha_ksi, Beta_ksi, Ksi_alpha, Ksi_beta, 
		      post_gamma, prior_alpha,
		      U_p, U_hat,
		      datalen, dim1, dim2, data1, data2, 
		      Ns, ncentroids, implicit_noisevar, q_of_z, Nc, true_Nc);
  
  return;
}
 
/************************************************************/




