#include "DataSet.h"
#include "Group.h"
#include "Tree.h"
#include "utils.h"
#include "SamplingUtils.h"
#include <fstream>
#include <string>
#include <sstream>
#include <ctime>
#include <map>
#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>

using namespace std;

DataSet::DataSet(int nClients, int nFeatures, int treeDepth, int burnIn, int nIter, int nIterGamma, ClientInitType clientInitType,
				double k0, double alpha, double beta, double aGamma, double bGamma, double aV, double bV,
				string exprFn, string clientsFn, string featuresFn, string resultsDir, int nFeatureScans)
:nClients_(nClients), nFeatures_(nFeatures), treeDepth_(treeDepth), burnIn_(burnIn), nIter_(nIter), nIterGamma_(nIterGamma), 
nFeatureScans_(nFeatureScans), clientInitType_(clientInitType), 
k0_(k0), alpha_(alpha), beta_(beta), aGamma_(aGamma), bGamma_(bGamma), aV_(aV), bV_(bV), resultsDir_(resultsDir)
{		
	Y_ = new double[nFeatures * nClients]; // Expression matrix.
	parseData(exprFn, clientsFn, featuresFn); // Import expression data.		
}

DataSet::~DataSet()
{	
	delete [] Y_;	
}

// Parse expression data and client/feature labels.
void DataSet::parseData(string exprFn, string clientsFn, string featuresFn)
{	
	int i, j; 	
	// Import expression data.
	ifstream exprFile(exprFn.c_str(), ios::in);			
	for(i = 0; i < nFeatures_; i++)
	  for(j = 0; j < nClients_; j++)
	    exprFile >> Y_[getIndex(i, j, nFeatures_)];
	// Parse client labels.
	ifstream clientFile(clientsFn.c_str(), ios::in);
	for(i = 0; i < nClients_; i++)
		clientFile >> indexToClientLabel_[i];
	// Parse feature labels.
	ifstream featureFile(featuresFn.c_str(), ios::in);
	for(i = 0; i < nFeatures_; i++)
		featureFile >> indexToFeatureLabel_[i];	
}

void DataSet::writeResults(double *logProbGamma, double *logProbNcrp, double *logProbFeatures, double *logProbExpr, int *numNodes, 
		double *gammaSamples, double *accRateClients, double *accRateFeatures, Tree *mode)
{
	int i;
	ostringstream stem, probFn, gammaFn;
	stem << resultsDir_; // Stem for results filenames.
	// Write matrix of gamma values. Each column corresponds to a sequence of gamma values sampled during a single Gibbs sampler run, 
	// in the context of an "outer" Gibbs sampler iteration.
	gammaFn << stem.str() + "gamma.txt";	
	printMatrix<double>(gammaSamples, nIterGamma_, nIter_, gammaFn.str());
	// Write log-probability information as a tab-delimited file.
	probFn << stem.str() << "log_prob.txt";
	ofstream probFile(probFn.str().c_str());
	probFile << "Iter" << "\t";
	probFile << "log-prob. gamma" << "\t";	
	probFile << "log-prob. Ncrp" << "\t"; 
	probFile << "log-prob. features" << "\t";
	probFile << "log-prob. expr." << "\t";	
	probFile << "Number of nodes" << "\t";		
	probFile << "Acc. rate for clients" << "\t";
	probFile << "Acc. rate for features" << endl;
	for(i = 0; i < nIter_; i++){
		probFile << i << "\t";
		probFile << logProbGamma[i] << "\t";
		probFile << logProbNcrp[i] << "\t" << logProbFeatures[i] << "\t" << logProbExpr[i] << "\t";
		probFile << numNodes[i] << "\t";
		probFile << accRateClients[i] << "\t" << accRateFeatures[i] << endl;
	}
	// Write mode-related files.
	mode->writeToFile(stem.str(), &indexToClientLabel_, &indexToFeatureLabel_);
}

Tree * DataSet::runSamplerAux(double *logProbGamma, double *logProbNcrp, double *logProbFeatures, double *logProbExpr, int *numNodes, 
		double *gammaSamples, double *accRateClients, double *accRateFeatures)
{
	bool varHasChanged; // Used for computing acceptance rates.		
	// The order in which variables are sampled varies randomly across Gibbs scans. The variables below store that order.
	int *clients = new int[nClients_];
	int *features = new int[nFeatures_ * nFeatureScans_];
	int *varType = new int[nClients_ + nFeatures_ * nFeatureScans_]; // Type of variable (client or feature) being sampled at a given time.
	int clientVarInd, featureVarInd, currClient, currFeature;
	double maxLogProb, lpGamma, lpNcrp, lpFeatures, lpExpr;	
	// Initialize tree.	
	Tree *tree = new Tree(nClients_, nFeatures_, treeDepth_, Y_, clientInitType_, nIterGamma_, k0_, alpha_, beta_, aGamma_, bGamma_, aV_, bV_);	
	Tree *mode = new Tree(tree); // Initial mode as copy of initial tree.
	lpGamma = tree->logProbGamma();	
	lpNcrp = tree->logProbNcrp(tree->root_);
	lpFeatures = tree->logProbFeatures();
	lpExpr = tree->logProbExpr(Y_);
	maxLogProb = lpGamma + lpNcrp + lpFeatures + lpExpr;
	int i, j;
	// Initialize variable sampling order. This order is permuted at the start of every Gibbs scan.  
	for(i = 0; i < nClients_; i++){
		clients[i] = i; // The i-th client variable to be sampled.
		varType[i] = VAR_CLIENT;
	}
	for(i = 0; i < nFeatureScans_; i++)
		for(j = 0; j < nFeatures_; j++){
			features[i * nFeatures_ + j] = j;
			varType[i * nFeatures_ + j] = VAR_FEATURE;	
		}	
	// Actual sampling procedure.
	for(i = 0; i < burnIn_ + nIter_; i++){		
		clientVarInd = featureVarInd = 0; // Auxiliary variables used for accessing the current client/feature.		
		accRateClients[i] = accRateFeatures[i] = 0.0; 
		// Random scan order.
		gsl_ran_shuffle(SamplingUtils::r_, clients, nClients_, sizeof(int)); // Clients sampled in random order.
		gsl_ran_shuffle(SamplingUtils::r_, features, nFeatures_ * nFeatureScans_, sizeof(int)); // Features sampled in random order.
		gsl_ran_shuffle(SamplingUtils::r_, varType, nClients_ + nFeatures_ * nFeatureScans_, sizeof(int)); // Var. type choice is random.						
		for(j = 0; j < nClients_ + nFeatures_ * nFeatureScans_; j++){ // For each variable in the random scan...
			if(varType[j] == VAR_CLIENT){				
				currClient = clients[clientVarInd];
				clientVarInd++;
				varHasChanged = tree->sampleClient(currClient, Y_);
				if(varHasChanged)
					accRateClients[i]++;						
			}else{ // varType[j] == VAR_FEATURE.				
				currFeature = features[featureVarInd];
				featureVarInd++;
				varHasChanged = tree->sampleFeature(currFeature, Y_);
				if(varHasChanged)
					accRateFeatures[i]++;				
			}
		}
		accRateClients[i] /= nClients_;
		accRateFeatures[i] /= nFeatures_ * nFeatureScans_;
		// At the end of the scan, sample the concentration parameter.
		// Provide current column of 'gammaSamples' as input ('gammaSamples' is in column-major order).		
		tree->sampleGamma(gammaSamples + i * nIterGamma_);		
		// Get current log-probabilities.
		logProbGamma[i] = tree->logProbGamma();		
		logProbNcrp[i] = tree->logProbNcrp(tree->root_);
		logProbFeatures[i] = tree->logProbFeatures();
		logProbExpr[i] = tree->logProbExpr(Y_);				
		numNodes[i] = tree->nNodes();
		// If we're paste the burn-in stage and a new mode has been found.
		if(i > burnIn_ && logProbGamma[i] + logProbNcrp[i] + logProbFeatures[i] + logProbExpr[i] > maxLogProb){		
			maxLogProb = logProbGamma[i] + logProbNcrp[i] + logProbFeatures[i] + logProbExpr[i];
			delete mode;
			mode = new Tree(tree); // Current mode is a copy of current tree.
		}		
		// Print information on current iteration.
		cout << "----------------------------" << "Iteration " << i << "----------------------------" << endl;				
		cout << "Log-prob. gamma: " << logProbGamma[i] << endl;
		cout << "Log-prob. Ncrp: " << logProbNcrp[i] << endl;
		cout << "Log-prob. features: " << logProbFeatures[i] << endl;
		cout << "Log-prob. expr: " << logProbExpr[i] << endl;
		cout << "Current gamma: " << tree->gamma_ << endl;
		cout << "Acceptance rate for clients: " << accRateClients[i] << endl;
		cout << "Acceptance rate for features: " << accRateFeatures[i] << endl;
		cout << "Number of nodes: " << numNodes[i] << endl;
		cout << "----------------------------" << endl;		
	}	
	delete tree;
	delete [] clients;
	delete [] features;
	delete [] varType;
	return mode;
}

void DataSet::runSampler()
{
	int totalNSamples = burnIn_ + nIter_;	 
	double *logProbGamma = new double[totalNSamples];	
	double *logProbNcrp = new double[totalNSamples];
	double *logProbFeatures = new double[totalNSamples];
	double *logProbExpr = new double[totalNSamples];			
	int *numNodes = new int[totalNSamples];
	double *gammaSamples = new double[totalNSamples * nIterGamma_];	 
	double *accRateClients = new double[totalNSamples];
	double *accRateFeatures = new double[totalNSamples];
	Tree *mode = runSamplerAux(logProbGamma, logProbNcrp, logProbFeatures, logProbExpr, numNodes, gammaSamples, accRateClients, accRateFeatures);							
	writeResults(logProbGamma, logProbNcrp, logProbFeatures, logProbExpr, numNodes, gammaSamples, accRateClients, accRateFeatures, mode);
	delete mode; // The function 'runSamplerAux' allocates memory for a Tree object, which must be manually freed.	
	delete [] logProbGamma;	
	delete [] logProbNcrp;
	delete [] logProbFeatures;
	delete [] logProbExpr;	
	delete [] numNodes;
	delete [] gammaSamples;	
	delete [] accRateClients;
	delete [] accRateFeatures;
}
