/*

  This file is part of NetResponse algorithm. (C) Leo Lahti 2008-2010.

  This file is based on the Agglomerative Independent Variable Group
  Analysis package, Copyright (C) 2001-2007 Esa Alhoniemi, Antti
  Honkela, Krista Lagus, Jeremias Seppa, Harri Valpola, and Paul
  Wagner
 
  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2, or (at your option)
  any later version.
 
  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License (included in file License.txt in the
  program package) for more details.
*/

#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include <math.h>
#include "mex.h"
#include "matrix.h"
 
#define DIGAMMA_S 1e-5
#define DIGAMMA_C 8.5
#define DIGAMMA_S3 1.0/12
#define DIGAMMA_S4 1.0/120
#define DIGAMMA_S5 1.0/252
#define DIGAMMA_D1 -0.5772156649

#define PI 3.1415926535
#define POW2(x) ((x) * (x))

/************************************************************/
/* digamma function (by Antti Honkela)                      */

double 
digamma(double x) {
  double y = 0.0, r = 0.0, xn = x;

  if (xn <= 0) {
    return mxGetNaN();
  }
  
  if (xn <= DIGAMMA_S)
    y = DIGAMMA_D1 - 1.0 / xn;
  else {
    while (xn < DIGAMMA_C) {
      y  -= 1.0 / xn;
      xn += 1.0;
    }
    r = 1.0 / xn;
    y += log(xn) - .5 * r;
    r = POW2(r);
    y -= r * (DIGAMMA_S3 - r * (DIGAMMA_S4 - r*DIGAMMA_S5));
  }
 
  return y;
}                                                                              
/************************************************************/
void
compute_tempmat(long datalen, int dim1, int dim2, int ncentroids,
		double **Temp, double *data1, int **data2_int,
		double *Mu_bar, double *Mu_tilde, double **S2_x,
                double **Ksi_log, double ***U_hat_table, double *Ns,
		double implicit_noisevar) {
  register int i, k;
  long         ind, j,t;
  double term;
  
  for (i = 0; i < ncentroids; i++) {
    for (j = 0; j < datalen; j++) {
      Temp[i][j] = 0.0;
      for (k = 0; k < dim1; k++) {
	ind  = i * dim1 + k;
	Temp[i][j] += ((Mu_tilde[ind]+POW2(data1[j*dim1+k]-Mu_bar[ind]) + implicit_noisevar)/
		       S2_x[i][k]) - Ksi_log[i][k];
      }
      Temp[i][j] /= 2.0;
    }
  }
  for(j=0;j<dim2;j++){
    for(i=0;i<ncentroids;i++){
      term=0.0;
      for(k=0;k<(int)(Ns[j]);k++){
	term += U_hat_table[j][i][k]; 
	U_hat_table[j][i][k]=digamma(U_hat_table[j][i][k]);
      }
      term=digamma(term);
      for (t=0;t<datalen;t++){
	Temp[i][t] += (term - U_hat_table[j][i][data2_int[j][t]]);
      }
    }
  }
  return;
}

/************************************************************/

void
find_winners(long datalen, int ncentroids, double *logC, double **W) {
  double min_w, sum;
  register int i;
  register long j;
  
  for (j = 0; j < datalen; j++) {
    min_w = DBL_MAX;
    
    for (i = 0; i < ncentroids; i++) {  
      W[i][j] -= logC[i];              /*..*/
      if (W[i][j] < min_w) min_w = W[i][j]; 
    }
    
    sum = 0.0;
    for (i = 0; i < ncentroids; i++) {
      W[i][j] = exp(min_w-W[i][j]);
      sum     += W[i][j];
    }
    for (i = 0; i < ncentroids; i++) W[i][j] /= sum;
  }
  
  return;
}

/************************************************************/


void
update_frequencies(long datalen, int ncentroids, double C_prior, 
                   double *F, double **W, double *logC) {
  register int i;
  register long j;
  double term = 0.0;

  for(i = 0; i < ncentroids; i++) {
    F[i] = 0.0;
    for (j = 0; j < datalen; j++) F[i] += W[i][j];
    term += F[i];
  }
  term += ncentroids * C_prior;
  
  for (i = 0; i < ncentroids; i++) 
    logC[i] = -digamma(term) + digamma(F[i] + C_prior);

  return;
}

/************************************************************/

void
update_centroids(long datalen, int ncentroids, int dim1, int dim2,
		 double *data1, int **data2_int,
                 double *F, double **W, double *Mu_mu, double *S2_mu,
                 double *Mu_bar, double *Mu_tilde, double **S2_x,
		 double *U_p, double ***U_hat_table, double *Ns) {
  register int i, k;
  register long ind, j, t;
  double term, term2;
  
  
  if (dim1){
    for (i = 0; i < ncentroids; i++) {
      for (k = 0; k < dim1; k++) {
	term = 0.0;
	ind  = i * dim1 + k;
	for (j = 0; j < datalen; j++)
	  term += W[i][j] * data1[j * dim1 + k];
	term2         = S2_x[i][k] + S2_mu[k] * F[i];
	Mu_bar[ind]   = ((S2_x[i][k] * Mu_mu[k]) + (S2_mu[k] * term))/term2;
	Mu_tilde[ind] = (S2_x[i][k] * S2_mu[k])/term2;
      }
    }
  }
  for(j=0;j<dim2;j++){
    for(i=0;i<ncentroids;i++){   
      for(k=0;k<(int)(Ns[j]);k++)
      	U_hat_table[j][i][k]=U_p[j];
      for(t=0;t<datalen;t++){	     /***************/
	U_hat_table[j][i][data2_int[j][t]] += W[i][t];
      }
    }
  }  
  return;
}

/************************************************************/

void
compute_variance(int ncentroids, int dim1, double *Ksi_alpha, 
                 double *Ksi_beta, double **S2_x, double **Ksi_log) {
  register int i, ind, k;
  
  for (i = 0; i < ncentroids; i++) 
    for (k = 0; k < dim1; k++) {
      ind = i * dim1 + k;
      S2_x[i][k]    = Ksi_beta[ind]/Ksi_alpha[ind];
      Ksi_log[i][k] = digamma(Ksi_alpha[ind])-log(Ksi_beta[ind]);
      
      if( S2_x[i][k] < 1e-100 ) S2_x[i][k] = 1e-100;
    }
  
  return;
}

/************************************************************/

void
update_variances(long datalen, int ncentroids, int dim1, double *data1, 
                 double *F, double **W, double *Mu_bar, double *Mu_tilde,
                 double *Alpha_ksi, double *Beta_ksi, 
                 double *Ksi_alpha, double *Ksi_beta, double implicit_noisevar) {
  register int i, k;
  register long ind, j;
  
  for (k = 0; k < dim1; k++) 
    for (i = 0; i < ncentroids; i++) {
      ind = i * dim1 + k;
      Ksi_beta[ind]  = 0.0;
      for (j = 0; j < datalen; j++) 
        Ksi_beta[ind] += W[i][j] * (Mu_tilde[ind] + implicit_noisevar +
                                    POW2(Mu_bar[ind] - data1[j * dim1 + k]));
      Ksi_alpha[ind] = Alpha_ksi[k] + 0.5 * F[i];
        
      Ksi_beta[ind]  *= 0.5;
      Ksi_beta[ind]  += Beta_ksi[k];
    }
  
  return;
}

/************************************************************/
void
update_priors(int ncentroids, int dim, double *Mu_mu, double *S2_mu, 
              double *Mu_bar, double *Mu_tilde, double *Alpha_ksi, 
              double *Beta_ksi, double *Ksi_alpha, double *Ksi_beta) {

  register int i, ind, k;
  double term, term2;
  
  for (k = 0; k < dim; k++) { 
    Mu_mu[k] = 0.0;
    for (i = 0; i < ncentroids; i++) Mu_mu[k] += Mu_bar[i * dim + k];
    Mu_mu[k] /= ncentroids;
    
    S2_mu[k] = 0.0;
    for (i = 0; i < ncentroids; i++) {
      ind      = i * dim + k;
      S2_mu[k] += Mu_tilde[ind] + POW2(Mu_bar[ind] - Mu_mu[k]);
    }
    S2_mu[k] /= ncentroids;
    
    term = 0.0;
    for (i = 0; i < ncentroids; i++)
      term += Ksi_alpha[i * dim + k]/Ksi_beta[i * dim + k];
    term = ncentroids/term;
    
    term2 = 0.0;
    for (i = 0; i < ncentroids; i++) {
      ind  = i * dim + k;
      term2 += log(Ksi_beta[ind]) - digamma(Ksi_alpha[ind]);
    }
    term2 /= ncentroids;
    
    Alpha_ksi[k] += 0.5/(digamma(Alpha_ksi[k])-log(Alpha_ksi[k])) +
      0.5/(term2 - log(term));
    
    Beta_ksi[k] = Alpha_ksi[k] * term;
  }
  
  return;
}
/************************************************************/

void
allocate_memory(long datalen,int ncentroids,int dim1,int dim2,double ***S2_x,
                double ***Ksi_log, double ***W, double ***Temp,
		double ****U_hat_table,int ***data2_int,double *Ns) {
  register int i,j;
  
  *W       = (double **)malloc(ncentroids * sizeof(double*));
  *Temp    = (double **)malloc(ncentroids * sizeof(double*));
  if (dim1) {
    *S2_x    = (double **)malloc(ncentroids * sizeof(double*));
    *Ksi_log = (double **)malloc(ncentroids * sizeof(double*));
  }
  if (dim2) {
    *U_hat_table=(double ***)malloc(dim2 * sizeof(double*));
    *data2_int=(int **)malloc(dim2 * sizeof(int*));
  }
  for (i = 0; i < ncentroids; i++) {
    (*W)[i]       = (double *)malloc(datalen * sizeof(double));
    (*Temp)[i]    = (double *)malloc(datalen * sizeof(double));
    if (dim1) {
      (*S2_x)[i]    = (double *)malloc(dim1 * sizeof(double));
      (*Ksi_log)[i] = (double *)malloc(dim1 * sizeof(double));
    }
  }
  for (j=0;j<dim2;j++){
    (*data2_int)[j]  = (int *)malloc(datalen * sizeof(int));
    (*U_hat_table)[j]=(double**)malloc(ncentroids * sizeof(double*));
    for (i=0;i<ncentroids;i++) {
      (*U_hat_table)[j][i] =(double *)malloc(((int)(Ns[j]))*sizeof(double));
    }
  }
    
  return;
}
 
/************************************************************/

void
free_memory(int ncentroids,int dim1, int dim2, double ***Temp, double ***W, 
            double ***S2_x, double ***Ksi_log,
		double ****U_hat_table, int ***data2_int) {
  register int i,j;
  for (j=0;j<dim2;j++){
    for (i=0;i<ncentroids;i++) {
      free((*U_hat_table)[j][i]);
    }
    free((*data2_int)[j]);
    free((*U_hat_table)[j]);
  }

  for (i = 0; i < ncentroids; i++) { 
    free((*W)[i]); 
    free((*Temp)[i]); 
    if (dim1){
      free((*S2_x)[i]); 
      free((*Ksi_log)[i]);
    }
  }
  
  free(*Temp);
  free(*W);
  if (dim1) {
    free(*S2_x);
    free(*Ksi_log);
  }
  if (dim2) {
    free(*U_hat_table);
    free(*data2_int);
  }
  return;
}

/************************************************************/

/* WARNING: will fail if iterations == 1: W_final would be computed from
   uninitialized table W. */
int
vectquant(double *Mu_mu, double *S2_mu, double *Mu_bar, double *Mu_tilde, 
          double *Alpha_ksi, double *Beta_ksi, 
          double *Ksi_alpha, double *Ksi_beta, 
          double *C_prior, double *logC, double *F, double *W_final,
	  double *U_p, mxArray *U_hat, mxArray *U_hat_out,
          long datalen, int dim1, int dim2, double *data1, double *data2, 
	  double *Ns, int ncentroids, 
          int iterations, int verbose, double implicit_noisevar) {
  register long i, j, t;
  register int k;
  double  *U_hat_j;
  mxArray  *U_hat_j_mxarray;
  int          iter = 0;
  double       **W, **Temp, **S, **S2_x,**Ksi_log,***U_hat_table;
  int          **data2_int;
  
  allocate_memory(datalen, ncentroids, dim1,dim2, &S2_x, &Ksi_log, &W, &Temp,
		  &U_hat_table,&data2_int,Ns );

  for (j=0;j<dim2;j++){
    for(t=0;t<datalen;t++)
      data2_int[j][t]=((int)(data2[j*datalen+t]))-1;
    U_hat_j=mxGetPr(mxGetCell(U_hat,  (int)j));
    for(i=0;i<ncentroids;i++){
      for(k=0;k<Ns[j];k++){
	U_hat_table[j][i][k]=U_hat_j[k*ncentroids+i];
      }
    }
  }
  
  while (iter < iterations) {    
    
    if (dim1) 
      compute_variance(ncentroids, dim1, Ksi_alpha, Ksi_beta, S2_x, Ksi_log);

    compute_tempmat(datalen,dim1,dim2,ncentroids,Temp,data1,data2_int,
                    Mu_bar,Mu_tilde,S2_x,Ksi_log,U_hat_table,Ns, implicit_noisevar);
    
    if (iter && verbose) {
    }
    
    S = Temp; Temp = W; W = S;   
    
    find_winners(datalen, ncentroids, logC, W);
    
    update_frequencies(datalen, ncentroids, *C_prior, F, W, logC);
    update_centroids(datalen, ncentroids, dim1, dim2, data1, data2_int,
		     F, W, Mu_mu, S2_mu,
                     Mu_bar, Mu_tilde, S2_x, U_p, U_hat_table,Ns);
    
    if (dim1) {
      update_variances(datalen, ncentroids, dim1, data1, F,
		       W, Mu_bar, Mu_tilde,
		       Alpha_ksi, Beta_ksi, Ksi_alpha, Ksi_beta, implicit_noisevar);
      
      update_priors(ncentroids, dim1, Mu_mu, S2_mu, Mu_bar, Mu_tilde,
		    Alpha_ksi, Beta_ksi, Ksi_alpha, Ksi_beta);
    }
    ++iter;
  }  
  
  for (i = 0; i < ncentroids; i++)
    for (j = 0; j < datalen; j++)
      W_final[j * ncentroids + i] = W[i][j];                                      
  for (j=0;j<dim2;j++){
    U_hat_j_mxarray=mxCreateDoubleMatrix(ncentroids,(int)(Ns[j]),mxREAL);
    U_hat_j=mxGetPr(U_hat_j_mxarray);
    mxSetCell(U_hat_out, (int)j, U_hat_j_mxarray); 
    for (i=0;i<ncentroids;i++)
      for (k=0;k<(int)(Ns[j]);k++)
	U_hat_j[k*ncentroids+i]=U_hat_table[j][i][k];
  }
  free_memory(ncentroids, dim1,dim2, &Temp, &W, &S2_x, &Ksi_log, &U_hat_table, &data2_int);
  return iter+1;
}


/************************************************************/
/* bridge function                                          */

void
mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
  long datalen, ind;
  int  i, j, dim1, dim2, c, no_of_fields = 14, iters, verbose;
  double *Mu_mu_init, *S2_mu_init, *Mu_bar_init, *Mu_tilde_init, 
    *Alpha_ksi_init, *Beta_ksi_init, *Ksi_alpha_init, *Ksi_beta_init, 
    C_prior_init, *logC_init, *U_p_init;
  double *Mu_mu, *S2_mu, *Mu_bar, *Mu_tilde, 
    *Alpha_ksi, *Beta_ksi, *Ksi_alpha, *Ksi_beta, *C_prior, *logC, *F, *W, *U_p;
  double *data1, *run_iters;
  double *data2;
  mxArray *U_hat, *U_hat_out;
  double *Ns;
  double implicit_noisevar;
  const char *fields[]={"Mu_mu","S2_mu","Mu_bar","Mu_tilde",
                        "Alpha_ksi","Beta_ksi","Ksi_alpha","Ksi_beta",
                        "C_prior","logC","F","W","U_p","U_hat"};
  
  /******************** input variables ********************/
  
  /* data */

  data1          = mxGetPr(prhs[0]);
  dim1           = (int)mxGetM(prhs[0]);
  datalen       = (long)mxGetN(prhs[0]);
  data2          = mxGetPr(prhs[4]);
  dim2           = (int)mxGetN(prhs[4]);
  Ns             = mxGetPr(prhs[5]); 
  implicit_noisevar = mxGetScalar(prhs[6]);

  /* initial values of model parameters */

  Mu_mu_init     = mxGetPr(mxGetField(prhs[1], 0, fields[0]));
  S2_mu_init     = mxGetPr(mxGetField(prhs[1], 0, fields[1]));
  Mu_bar_init    = mxGetPr(mxGetField(prhs[1], 0, fields[2]));
  Mu_tilde_init  = mxGetPr(mxGetField(prhs[1], 0, fields[3]));
  Alpha_ksi_init = mxGetPr(mxGetField(prhs[1], 0, fields[4]));
  Beta_ksi_init  = mxGetPr(mxGetField(prhs[1], 0, fields[5]));
  Ksi_alpha_init = mxGetPr(mxGetField(prhs[1], 0, fields[6]));
  Ksi_beta_init  = mxGetPr(mxGetField(prhs[1], 0, fields[7]));
  C_prior_init   = mxGetScalar(mxGetField(prhs[1], 0, fields[8]));
  logC_init      = mxGetPr(mxGetField(prhs[1], 0, fields[9]));
  c              = (int)mxGetN(mxGetField(prhs[1], 0, fields[2]));
  U_p_init       = mxGetPr(mxGetField(prhs[1], 0, fields[12]));
  U_hat          =         mxGetField(prhs[1], 0, fields[13]);
  iters          = (int)mxGetScalar(prhs[2]); /* max no of iterations */
  verbose        = (int)mxGetScalar(prhs[3]); /* verbose on/off */

  /******************** output variables ********************/
  /* values of model parameters after iteration */


  plhs[0]         = mxCreateStructMatrix(1, 1, no_of_fields, fields);
  mxSetField(plhs[0], 0, fields[0], mxCreateDoubleMatrix(1, dim1, mxREAL)); /**/
  mxSetField(plhs[0], 0, fields[1], mxCreateDoubleMatrix(1, dim1, mxREAL)); /**/
  mxSetField(plhs[0], 0, fields[2], mxCreateDoubleMatrix(dim1, c, mxREAL)); 
  mxSetField(plhs[0], 0, fields[3], mxCreateDoubleMatrix(dim1, c, mxREAL));
  mxSetField(plhs[0], 0, fields[4], mxCreateDoubleMatrix(1, dim1,  mxREAL)); /**/
  mxSetField(plhs[0], 0, fields[5], mxCreateDoubleMatrix(1, dim1,  mxREAL)); /**/
  mxSetField(plhs[0], 0, fields[6], mxCreateDoubleMatrix(dim1, c, mxREAL));
  mxSetField(plhs[0], 0, fields[7], mxCreateDoubleMatrix(dim1, c, mxREAL));
  mxSetField(plhs[0], 0, fields[8], mxCreateDoubleMatrix(1, 1, mxREAL));    
  mxSetField(plhs[0], 0, fields[9], mxCreateDoubleMatrix(c, 1, mxREAL));    
  mxSetField(plhs[0], 0, fields[10], mxCreateDoubleMatrix(c, 1, mxREAL));  
  mxSetField(plhs[0], 0, fields[11], mxCreateDoubleMatrix(c, datalen, mxREAL));
  mxSetField(plhs[0], 0, fields[12], mxCreateDoubleMatrix(1,dim2,mxREAL)); /**/
  mxSetField(plhs[0], 0, fields[13], mxCreateCellMatrix(1, dim2));
  Mu_mu      = mxGetPr(mxGetField(plhs[0], 0, fields[0]));      
  S2_mu      = mxGetPr(mxGetField(plhs[0], 0, fields[1]));      
  Mu_bar     = mxGetPr(mxGetField(plhs[0], 0, fields[2]));      
  Mu_tilde   = mxGetPr(mxGetField(plhs[0], 0, fields[3]));      
  Alpha_ksi  = mxGetPr(mxGetField(plhs[0], 0, fields[4]));      
  Beta_ksi   = mxGetPr(mxGetField(plhs[0], 0, fields[5]));      
  Ksi_alpha  = mxGetPr(mxGetField(plhs[0], 0, fields[6]));      
  Ksi_beta   = mxGetPr(mxGetField(plhs[0], 0, fields[7]));      
  C_prior    = mxGetPr(mxGetField(plhs[0], 0, fields[8]));      
  logC       = mxGetPr(mxGetField(plhs[0], 0, fields[9]));      
  F          = mxGetPr(mxGetField(plhs[0], 0, fields[10]));      
  W          = mxGetPr(mxGetField(plhs[0], 0, fields[11]));
  U_p        = mxGetPr(mxGetField(plhs[0], 0, fields[12]));
  U_hat_out  =   mxGetField(plhs[0], 0, fields[13]);
  
  /* no of iterations that were run */

  plhs[2]    = mxCreateDoubleMatrix(1, 1, mxREAL);
  run_iters  = mxGetPr(plhs[2]);

  /*** copy initial values to output variables ***/

  
  for (i = 0; i < c; i++) {
    logC[i] = logC_init[i];
    for (j = 0; j < dim1; j++) { /** not run if dim1=0 **/
      ind = i * dim1 + j;
      Mu_bar[ind]    = Mu_bar_init[ind];
      Mu_tilde[ind]  = Mu_tilde_init[ind];
      Ksi_alpha[ind] = Ksi_alpha_init[ind];
      Ksi_beta[ind]  = Ksi_beta_init[ind];
    }
  }
  for (j = 0; j < dim1; j++) {
    Mu_mu[j]     = Mu_mu_init[j];
    S2_mu[j]     = S2_mu_init[j];
    Alpha_ksi[j] = Alpha_ksi_init[j];
    Beta_ksi[j]  = Beta_ksi_init[j];    
  }
  for (j = 0; j < dim2; j++) 
    U_p[j] = U_p_init[j];  /** currently also U_p=U_p_init would do **/
                           /** since U_p is not changed **/
  *C_prior = C_prior_init;

  /*** run the algorithm ***/
  
  *run_iters = 
    (double)vectquant(Mu_mu, S2_mu, Mu_bar, Mu_tilde, 
                      Alpha_ksi, Beta_ksi, Ksi_alpha, Ksi_beta, 
                      C_prior, logC, F, W, U_p, U_hat, U_hat_out,
                      datalen, dim1, dim2, data1, data2, Ns, c, iters, verbose, implicit_noisevar);
  
  
  return;
}
 
/************************************************************/




