#include "Group.h"
#include "SamplingUtils.h"
#include "utils.h"
#include <cmath>
#include <iostream>
#include <utility>
#include <gsl/gsl_randist.h>
#include <gsl/gsl_sf_gamma.h>

using namespace std;

// Cannot be used to build a node with 0 clients.
Group::Group(const set<int> * includedClients, int feature, int nFeatures, double* Y, double k0, double a, double b)
: includedClients_(*includedClients), feature_(feature), nFeatures_(nFeatures), k0_(k0), a_(a), b_(b)
{	
	set<int>::iterator cit; // Client iterator.
	if(includedClients->size() == 0)
		cerr << "Group: Cannot build group with 0 clients." << endl;	
	else
		for(cit = includedClients->begin(); cit != includedClients->end(); cit++)
			clientToExpr_[*cit] = Y[getIndex(feature, *cit, nFeatures)];					
	recomputeQuadFormIncluded_ = recomputeMeanIncluded_ = true;	
}

// Copy constructor.
Group::Group(Group *group): 
includedClients_(group->includedClients_), clientToExpr_(group->clientToExpr_), 
feature_(group->feature_), nFeatures_(group->nFeatures_), k0_(group->k0_), a_(group->a_), b_(group->b_)
{
	recomputeQuadFormIncluded_ = recomputeMeanIncluded_ = true;	
}

Group::~Group(void)
{
}

// Return the number of clients associated with the group.
int Group::nClients()
{
	return includedClients_.size();
}

// If the client already exists, do nothing.
void Group::addClient(int client, double value)
{
	bool clientIsIncluded = includedClients_.find(client) != includedClients_.end();	  	
	if(clientIsIncluded) // If the client already exists, do nothing.
	   	return;
	includedClients_.insert(client);		
	recomputeQuadFormIncluded_ = recomputeMeanIncluded_ = true; 		 			
	clientToExpr_.insert(pair<int, double> (client, value));
}

// Returns whether the client was previously included in the group.
// If the number of clients drops to zero, delete group.
bool Group::removeClient(int client)
{	
	bool clientWasIncluded = includedClients_.erase(client);		
	if(!clientWasIncluded) // If the client is not associated with the group, then do nothing.
		return false;	
	recomputeQuadFormIncluded_ = recomputeMeanIncluded_ = true;	
	clientToExpr_.erase(client);
	if(nClients() == 0)
		delete this;
	return true;
}

// mean(Y).
double Group::meanIncluded()
{
	static double res = 0.0;
	double out = 0.0;
	set<int>::iterator cit; // Client iterator.
	if(recomputeMeanIncluded_){
		if(nClients() == 0){
			out = 0.0;
		}else{
			for(cit = includedClients_.begin(); cit != includedClients_.end(); cit++)
				out += clientToExpr_[*cit];
			out /= nClients();
		}
		res = out;
		recomputeMeanIncluded_ = false;
	}
	return res;		
}

// Y' \Sigma^{-1} Y . Works even when the number of included clients is 0.
double Group::quadFormIncluded()
{
	static double res = 0.0;
	double out = 0.0, n = (double) nClients();
	set<int>::iterator i,j; // Client iterators.
	if(recomputeQuadFormIncluded_){
		for(i = includedClients_.begin(); i != includedClients_.end(); i++){
			out += pow(clientToExpr_[*i], 2.0);
			for(j = includedClients_.begin(); j != includedClients_.end(); j++)				
				out -= clientToExpr_[*i] * clientToExpr_[*j] / (1/k0_ + n);
		}
		res = out;
		recomputeQuadFormIncluded_ = false;
	}
	return res;
}

// When the number of included clients is 0, returns 0.0.
double Group::logJointProb()
{	
	double out = 0.0;	
	double n = (double) nClients();
	if(n == 0) // No included clients.
		return 0.0;
	out += a_ * log(b_) + gsl_sf_lngamma(a_ + n/2) - gsl_sf_lngamma(a_);
	out -= n/2 * log(2*M_PI) + 0.5 * log(1 + n * k0_);		
	out += (-n/2 - a_) * log(b_ + quadFormIncluded() / 2);		
	return out;
}

// Works even when the number of included clients is 0.
double Group::logPredProb(double value)
{
	double out = 0.0, n = (double) nClients();	
	double r = a_ + n/2, s = b_ + quadFormIncluded() / 2;
	double m = n / (1 / k0_ + n) * meanIncluded();
	double lambda = (n + 1 / k0_ + 1) / (n + 1 / k0_);
	out += r * log(s) + gsl_sf_lngamma(r + 0.5) - gsl_sf_lngamma(r);
	out -= 0.5 * (log(2 * M_PI) + log(lambda));
	out += (-r - 0.5) * log(s + pow(value - m, 2.0) / (2 * lambda));
	return out;
}
