#include "Tree.h"
#include "FeatureNode.h"
#include "Group.h"
#include "SamplingUtils.h"
#include "utils.h"
#include <iostream>
#include <fstream>
#include <gsl/gsl_randist.h>
#include <gsl/gsl_sf_gamma.h>
#include <map>
#include <set>
#include <queue>
#include <stack>
#include <vector>
#include <cmath>
#include <ctime>

using namespace std;

Tree::Tree()
{
	root_ = NULL;	
}
						  
Tree::Tree(int nClients, int nFeatures, int treeDepth, double* expr, ClientInitType clientInitType, int nIterGamma,  
		double k0, double alpha, double beta, double aGamma, double bGamma, double aV, double bV)
:nClients_(nClients), nFeatures_(nFeatures), treeDepth_(treeDepth), nIterGamma_(nIterGamma), k0_(k0), alpha_(alpha), beta_(beta), 
aGamma_(aGamma), bGamma_(bGamma), aV_(aV), bV_(bV)
{	
	if(treeDepth_ < 2)
		cerr << "Tree: Invalid tree depth." << endl;
	else		
		addClients(expr, clientInitType); // Build tree; randomly instantiate concentration parameter.	
}

// This copy constructor does NOT copy the 'nodeToTempPath_' structure.
Tree::Tree(Tree* tree):
nClients_(tree->nClients_), nFeatures_(tree->nFeatures_), treeDepth_(tree->treeDepth_), nIterGamma_(tree->nIterGamma_),  
k0_(tree->k0_), alpha_(tree->alpha_), beta_(tree->beta_), aGamma_(tree->aGamma_), bGamma_(tree->bGamma_), 
aV_(tree->aV_), bV_(tree->bV_), gamma_(tree->gamma_)
{
	unsigned int i;
	FeatureNode *node, *newNode;
	map<int, Group*>::iterator mit; // iterator for the map from features to null groups.
	set<FeatureNode*>::iterator nit; // Node iterator.
	map<FeatureNode*, FeatureNode*> oldToNew;
	oldToNew[NULL] = NULL; // Used for the "parent" of the root (= null).
	queue<FeatureNode*> q;
	q.push(tree->root_);
	while(!q.empty()){
		node = q.front();
		q.pop();
		newNode = new FeatureNode(node); // Copy node.
		oldToNew[node] = newNode;
		nodeToParent_[newNode] = oldToNew[tree->nodeToParent_[node]];
		nodeToChildren_[newNode] = new set<FeatureNode*>;
		nodeIsLeaf_[newNode] = tree->nodeIsLeaf_[node];
		nodeToLevel_[newNode] = tree->nodeToLevel_[node];
		if(nodeToParent_[newNode] != NULL){ // If node != root
			// Edge length doesn't apply to root and its parent (null).
			edgeLength_[pair<FeatureNode*,FeatureNode*>(nodeToParent_[newNode], newNode)] = 
				tree->edgeLength_[pair<FeatureNode*, FeatureNode*>(tree->nodeToParent_[node], node)];		
			// Don't "connect" NULL and the root.
			nodeToChildren_[nodeToParent_[newNode]]->insert(newNode);
		}
		for(nit = tree->nodeToChildren_[node]->begin(); nit != tree->nodeToChildren_[node]->end(); nit++)
			q.push(*nit);
	}
	root_ = oldToNew[tree->root_];			
	for(i = 0; i < tree->clientToLeafNode_.size(); i++)
		clientToLeafNode_[i] = oldToNew[tree->clientToLeafNode_[i]];
}

void Tree::destructorAux(FeatureNode* node)
{
	set<FeatureNode*>::iterator nit; // Iterator over child nodes.	
	for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++)
		destructorAux(*nit);
	delete nodeToChildren_[node];
	delete node;
}

// Does not deal with the 'nodeToTempPath_' data structure.
Tree::~Tree()
{	 
	if(root_ != NULL)
		destructorAux(root_);	
}

////////////////////////////////////
//
// Tree initialization and editing.
//
////////////////////////////////////

// Auxiliary function to Tree::onePerLeafAddClients and Tree::ncrpAddClients. Initializes the root.
void Tree::initRoot(double* expr)
{
	int i;
	set<int> includedClients, newFeatures, nullFeatures;
	for(i = 0; i < nClients_; i++)
		includedClients.insert(i);		
	for(i = 0; i < nFeatures_; i++)
		nullFeatures.insert(i);		
	root_ = new FeatureNode(&includedClients, &newFeatures, &nullFeatures, expr, nFeatures_, k0_, aV_, bV_);
	nodeToParent_[root_] = NULL;
	nodeToChildren_[root_] = new set<FeatureNode*>;	
	nodeIsLeaf_[root_] = false;	
	nodeToLevel_[root_] = 0;
}

// Each client goes to a different leaf node.
void Tree::onePerLeafAddClients(double* expr)
{
	int i;
	initRoot(expr);
	for(i = 0; i < nClients_; i++){
		addTemporaryPath(i, root_, expr);
		commitTemporaryPath(i, root_, expr);
	}
}

// Partition clients according to the nCRP.
void Tree::ncrpAddClients(double* expr)
{	
	int i,j,k, nChildren, outcome;
	double *prob;
	FeatureNode * parent, * child;
	set<int>::iterator fit; // Feature iterator.
	set<FeatureNode*>::iterator nit; // Node iterator.
	map<int,FeatureNode*> indexToChildNode;		
	initRoot(expr);
	for(i = 0; i < nClients_; i++){
		parent = root_;
		for(j = 1; j < treeDepth_; j++){ // Sample a node at the j-th level, given the (j-1)-th level parent.
			indexToChildNode.clear();
			nChildren = nodeToChildren_[parent]->size();		
			prob = new double[nChildren + 1];						
			for(k = 0, nit = nodeToChildren_[parent]->begin(); nit != nodeToChildren_[parent]->end(); nit++, k++){
				prob[k] = (*nit)->nClients();				
				indexToChildNode.insert(pair<int, FeatureNode*> (k, *nit));
			}
			prob[nChildren] = gamma_;
			outcome = SamplingUtils::discreteSample(nChildren + 1, prob); // Prob is unnormalized, discreteSample handles that.
			delete [] prob;
			if(outcome < nChildren){ // Choose an already existing node.
				child = indexToChildNode[outcome];	
				child->addClient(i, expr);				
				parent = child;
				if(nodeIsLeaf_[child])
					clientToLeafNode_[i] = child;
			}else{ // Create a new node.
				addTemporaryPath(i, parent, expr);
				commitTemporaryPath(i, parent, expr);
				break;				
			}
		}
	}
}

void Tree::addClients(double* expr, ClientInitType clientInitType)
{
	gamma_ = gsl_ran_gamma(SamplingUtils::r_, aGamma_, 1/bGamma_); // Sample gamma_ from its prior density.		
	switch(clientInitType){	
		case INIT_ONE_PER_LEAF:			
			onePerLeafAddClients(expr);
			break;
		case INIT_NCRP:			
			ncrpAddClients(expr);
			break;
		default:
			cerr << "addClients: Invalid clientInitType: " << clientInitType << endl;
	}
	// Adjust gamma_ to the newly formed tree, by sampling from its posterior density.	
	sampleGamma(NULL); // The NULL input parameter means that no intermediary gamma_ values will be recorded.
}

// Auxiliary function for Tree::removeClient. Recursive. Bottom-up exploration. 
void Tree::removeClientAux(int client, FeatureNode *node)
{		
	bool clientExisted;
	int nOldClientsInNode = node->nClients();				
	set<int>::iterator fit; // Feature iterator.
	FeatureNode *parent = nodeToParent_[node];		
	clientExisted = node->removeClient(client);
	if(nOldClientsInNode == 1 && clientExisted){ // The node was deleted during the call to 'node->removeClient'.		
		nodeToParent_.erase(node);
		nodeToChildren_[parent]->erase(node);
		delete nodeToChildren_[node]; // Assumes all child nodes have already been deleted.
		nodeToChildren_.erase(node);
		edgeLength_.erase(pair<FeatureNode*, FeatureNode*> (parent, node));
		nodeIsLeaf_.erase(node);
		nodeToLevel_.erase(node);
		nodeToTempPath_.erase(node);
	}
	if(parent != NULL) // Move upwards until client has been removed from the root.
		removeClientAux(client, parent);
}

// Remove a client from the tree. This function does *not* decrement the value of 'nClients'.
void Tree::removeClient(int client)
{		
	removeClientAux(client, clientToLeafNode_[client]);	
	clientToLeafNode_.erase(client);	
}

// Similar to 'addTemporaryPath', with the difference that the features in the new path created here are a 
// copy of the features in the old path. 
double Tree::addReplicatedTemporaryPath(int client, FeatureNode* node, double *expr, stack<FeatureNode*> oldPath)
{
	int currLevel;
	double out = 0.0, edgeLength;
	set<int> includedClients, newFeatures, nullFeatures;
	set<int>::iterator fit; // Feature iterator.
	FeatureNode *parent, *nodeInOldPath;	
	if(nodeIsLeaf_[node]) // When the node is a leaf, this function has nothing to do.		
		return 0.0;	
	includedClients.insert(client);
	nodeToTempPath_[node] = new queue<FeatureNode*>;		
	parent = node;
	currLevel = nodeToLevel_[node] + 1; // Level for the child nodes of 'node'.
	while(!oldPath.empty()){
		nodeInOldPath = oldPath.top();		
		oldPath.pop();
		if(nodeInOldPath == node){ 
			// 'node' is in the path from the root to the deepest non-singleton node in the old path.
			// Notice the following: 'nodeInOldPath' cannot be a leaf node, otherwise 'node' would also be a leaf node 
			// and so the function would not reach this point. This implies that there is at least one more node in 'oldPath'.			
			while(!oldPath.empty() && nodeToLevel_.find(oldPath.top()) != nodeToLevel_.end()){
				// Replicate features from nodes in the old path until reaching the last non-singleton node.	
				nodeInOldPath = oldPath.top();				
				oldPath.pop();			
				// Replicate features.
				parent = new FeatureNode(&includedClients, &(nodeInOldPath->newFeatures_), &(nodeInOldPath->nullFeatures_),
										 expr, nFeatures_, k0_, aV_, bV_);
				nodeToTempPath_[node]->push(parent); // Push to last place of queue.
				// Compute the log-prob. of the expression data for the current client & new groups.
				for(fit = parent->newFeatures_.begin(); fit != parent->newFeatures_.end(); fit++)
					out += parent->newFeatureToGroup_[*fit]->logJointProb(); // Notice that this group has a single client.
				currLevel++;
			}
		}
	}
	// For the nodes further down, generate feature values from their prior distribution.	
	while(currLevel < treeDepth_){
		edgeLength = gsl_ran_beta(SamplingUtils::r_, alpha_, beta_); // Generate edge length.
		// Generate feature values.
		newFeatures.clear();
		nullFeatures = parent->nullFeatures_;
		for(fit = nullFeatures.begin(); fit != nullFeatures.end(); fit++)
			if(gsl_ran_bernoulli(SamplingUtils::r_, edgeLength) == 1)
				newFeatures.insert(*fit);
		for(fit = newFeatures.begin(); fit != newFeatures.end(); fit++)
			nullFeatures.erase(*fit); // Remove the new features from the set of null features.
		parent = new FeatureNode(&includedClients, &newFeatures, &nullFeatures, expr, nFeatures_, k0_, aV_, bV_);
		nodeToTempPath_[node]->push(parent); // Push to last place of queue.
		// Compute the log-prob. of the expression data for the current client & new groups.
		for(fit = parent->newFeatures_.begin(); fit != parent->newFeatures_.end(); fit++)
			out += parent->newFeatureToGroup_[*fit]->logJointProb(); // Notice that this group has a single client.
		currLevel++;
	}
	// Compute the log-prob. of the expression data for the current client & null features.		
	for(fit = parent->nullFeatures_.begin(); fit != parent->nullFeatures_.end(); fit++) // 'parent' is now a leaf node.
		out += stdNormLogProb(expr[getIndex(*fit, client, nFeatures_)]);			
	return out;
}

// Creates a set of nodes representing a partial path starting from the input variable 'node',
// but does not fill out any of the structures in the tree apart from nodeToTempPath.
// In order to fill out those structures, the function 'commitTemporaryPath' must be called.
//
// This function allocates memory, which is automatically freed when calling either 
// 'commitTemporaryPath' or 'cancelTemporaryPath'. 
//
// Returns the log-probability of the expression data for the client, restricted to 
// the features that are inactivated at 'node' (but that may become activated in the 
// newly created path).
double Tree::addTemporaryPath(int client, FeatureNode *node, double *expr)
{
	int currLevel;
	double out = 0.0, edgeLength;
	set<int> includedClients, newFeatures, nullFeatures;
	set<int>::iterator fit; // Feature iterator.
	FeatureNode *parent;	
	if(nodeIsLeaf_[node]) // When the node is a leaf, this function has nothing to do.		
		return 0.0;
	includedClients.insert(client);
	nodeToTempPath_[node] = new queue<FeatureNode*>;
	parent = node;
	currLevel = nodeToLevel_[node] + 1; // Level for the child nodes of 'node'.
	while(currLevel < treeDepth_){		
		edgeLength = gsl_ran_beta(SamplingUtils::r_, alpha_, beta_); // Generate edge length.
		// Generate feature values.
		newFeatures.clear();
		nullFeatures = parent->nullFeatures_;
		for(fit = nullFeatures.begin(); fit != nullFeatures.end(); fit++)
			if(gsl_ran_bernoulli(SamplingUtils::r_, edgeLength) == 1)
				newFeatures.insert(*fit);
		for(fit = newFeatures.begin(); fit != newFeatures.end(); fit++)
			nullFeatures.erase(*fit); // Remove the new features from the set of null features.
		parent = new FeatureNode(&includedClients, &newFeatures, &nullFeatures, expr, nFeatures_, k0_, aV_, bV_);
		nodeToTempPath_[node]->push(parent); // Push to last place of queue.
		// Compute the log-prob. of the expression data for the current client & **new** groups.
		for(fit = parent->newFeatures_.begin(); fit != parent->newFeatures_.end(); fit++)
			out += parent->newFeatureToGroup_[*fit]->logJointProb(); // Notice that this group has a single client.
		currLevel++;
	}
	// Compute the log-prob. of the expression data for the current client & **null** features.		
	for(fit = parent->nullFeatures_.begin(); fit != parent->nullFeatures_.end(); fit++) // 'parent' is now a leaf node.
		out += stdNormLogProb(expr[getIndex(*fit, client, nFeatures_)]);			
	return out;
}

// Free one of the temporary path structures allocated when sampling a client path.
void Tree::cancelTemporaryPath(FeatureNode *node)
{
	FeatureNode *tempNode;
	// First condition: If node is leaf, then this function has nothing to do.
	// Second condition: If the node has been already commited, then it won't appear in the 'nodeToTempPath_' map, and so this 
	// function won't have anything to do either.
	if(nodeIsLeaf_[node] || nodeToTempPath_.find(node) == nodeToTempPath_.end())
		return;
	while(!nodeToTempPath_[node]->empty()){
		tempNode = nodeToTempPath_[node]->front();
		nodeToTempPath_[node]->pop();
		delete tempNode;
	}
	delete nodeToTempPath_[node];
	nodeToTempPath_.erase(node);
}

// Insert a temporary path in the tree data structures. This function does *not* increment nClients_.
void Tree::commitTemporaryPath(int client, FeatureNode *node, double *expr)
{
	int currLevel = nodeToLevel_[node] + 1;
	set<int> includedClients;	
	set<int>::iterator fit; // Feature iterator.
	FeatureNode *parent = node, *currNode;		
	includedClients.insert(client);
	if(!nodeIsLeaf_[node]){				
		while(!nodeToTempPath_[node]->empty()){
			currNode = nodeToTempPath_[node]->front();
			nodeToTempPath_[node]->pop();
			nodeToParent_[currNode] = parent;
			nodeToChildren_[parent]->insert(currNode);
			nodeToChildren_[currNode] = new set<FeatureNode*>;
			nodeToLevel_[currNode] = currLevel;
			nodeIsLeaf_[currNode] = (currLevel == treeDepth_ - 1);
			if(nodeIsLeaf_[currNode])
				clientToLeafNode_[client] = currNode;
			parent = currNode;
			currLevel++;
		}
	}else // Node is leaf.
		clientToLeafNode_[client] = node;
	// Add client to the nodes that were not temporary.
	// If the client has already been added, the calls to 'addClient' have no effect.
	parent = node;
	while(parent != NULL){
		parent->addClient(client, expr);
		parent = nodeToParent_[parent];
	}	
	delete nodeToTempPath_[node];
	nodeToTempPath_.erase(node);	 
}

///////////////////////////////
//
// Sampling.
//
///////////////////////////////

// Sample from the posterior of gamma_ using an auxiliary variable scheme.
// If 'gammaValues' != NULL, store all intermediate gamma_ samples.
// This function assumes gamma_ contains a pre-existing, positive value. 
void Tree::sampleGamma(double *gammaValues)
{
	int i, sumS;
	int totalNChildNodes = nNodes() - 1; // All nodes but root. Amounts to the total number of clusters created by the Dir. processes.
	double sumLogW;
	FeatureNode* node;
	queue<FeatureNode*> q; 	
	map<FeatureNode*, int> nodeToNClients;
	computeNodeToNClients(&nodeToNClients, root_);	
	set<FeatureNode*>::iterator nit; // Node iterator.
	for(i = 0; i < nIterGamma_; i++){
		sumS = 0;
		sumLogW = 0.0;
		q.push(root_);
		while(!q.empty()){
			node = q.front();
			q.pop();
			if(nodeIsLeaf_[node])
				continue;
			sumLogW += log(gsl_ran_beta(SamplingUtils::r_, gamma_ + 1, nodeToNClients[node]));
			sumS += gsl_ran_bernoulli(SamplingUtils::r_, (nodeToNClients[node] / gamma_) / (1 + nodeToNClients[node] / gamma_));					
	 		for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++)
	 			q.push(*nit);
		}
		// After sampling all W's and S's, sample a new value for gamma_.
		gamma_ = gsl_ran_gamma(SamplingUtils::r_, aGamma_ + totalNChildNodes - sumS, 1 / (bGamma_ - sumLogW));		
		if(gammaValues != NULL) // If we are to store intermediary gamma_ values.
			gammaValues[i] = gamma_;
	}
}

// Auxiliary function for Tree::sampleClient. 
// Recursive; returns an integer identifier for the current node.
int Tree::samplePathKeptIn(int client, double *expr, double *prob, FeatureNode *node, int nodeIndex,   
							map<int, FeatureNode*>* indexToNode, double accLogPathProb, double accLogExprProb, stack<FeatureNode*> oldPath)
{
	int childNodeIndex;
	double childLogPathProb, nodeLogExprProb;
	set<FeatureNode*>::iterator nit; // Node iterator. 
	set<int>::iterator fit;	// Feature iterator.
	(*indexToNode)[nodeIndex] = node; // Associate node number to actual node.
	// Features that switch to 1 in the current node. 
	for(fit = node->newFeatures_.begin(); fit != node->newFeatures_.end(); fit++)
		accLogExprProb += node->newFeatureToGroup_[*fit]->logPredProb(expr[getIndex(*fit, client, nFeatures_)]);	
	if(nodeIsLeaf_[node]){ // Recursion base case.
		for(fit = node->nullFeatures_.begin(); fit != node->nullFeatures_.end(); fit++)
			accLogExprProb += stdNormLogProb(expr[getIndex(*fit, client, nFeatures_)]);
		prob[nodeIndex] = exp(accLogPathProb + accLogExprProb);
		return nodeIndex;
	}	
	nodeLogExprProb = accLogExprProb + addReplicatedTemporaryPath(client, node, expr, oldPath); 
	prob[nodeIndex] = exp(accLogPathProb + log(gamma_) - log(node->nClients() + gamma_) + nodeLogExprProb);	
	childNodeIndex = nodeIndex;
	for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++){
		childLogPathProb = accLogPathProb + log((*nit)->nClients()) - log(node->nClients() + gamma_);
		childNodeIndex = samplePathKeptIn(client, expr, prob, *nit, childNodeIndex + 1, indexToNode, childLogPathProb, accLogExprProb, oldPath);
	}
	return childNodeIndex;
}

// Sample a client allocation. Return true if the new path is different from the old one; return false otherwise. 
bool Tree::sampleClient(int client, double *expr)
{
	stack<FeatureNode*> oldPath = clientPath(client); // "Old" path for this client.
	double* prob; // Will store the probability of choosing each path.
	int numNodes, selectedIndex, i;
	map<int,FeatureNode*> indexToNode; // Map node number to actual node.
	removeClient(client); // Remove client from the tree, but only temporarily.
	numNodes = nNodes();
	prob = new double[numNodes];
	// The samplePathKeptIn function allocates a number of temporary paths, which must be freed afterwards.	
	samplePathKeptIn(client, expr, prob, root_, 0, &indexToNode, 0.0, 0.0, oldPath);
	selectedIndex = SamplingUtils::discreteSample(numNodes, prob);
	// The client is automatically re-inserted when committing the temporary path.
	commitTemporaryPath(client, indexToNode[selectedIndex], expr); 					
	for(i = 0; i < numNodes; i++) // Delete auxiliary structures.
		// Calling this function on a temporary path that has been committed has no effect.		
		cancelTemporaryPath(indexToNode[i]);
	delete [] prob;
	return pathHasChanged(oldPath, indexToNode[selectedIndex]);
}

// Sample the activation status of a feature at a given node. 
// Computes the log-ratio of maintaining vs. switching the feature value, when it is currently set to one.
// Returns true if the new activation status is different than the previous one; returns false otherwise.
bool Tree::sampleFeatureAuxPositive(FeatureNode* node, int feature, double *expr){

	bool valueHasChanged;
	double logRatioPrior, logRatioExpr;
	int nPlus = node->nNewFeatures() - 1, nMinus = node->nNullFeatures();
	set<FeatureNode*>::iterator nit; // Node iterator.
	set<int>::iterator cit; // Client iterator.
	map<FeatureNode*,Group*> childToGroup;
	// Log-ratio for P(feature at node = 1) vs. P(feature at node = 0).
	logRatioPrior = log(alpha_ + nPlus) - log(beta_ + nMinus);
	// Log-ratio for (feature at child nodes = 1 | feature at node = 1) vs. (feature at child nodes = 1 | feature at node = 0).
	for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++){
		nPlus = (*nit)->newFeatures_.size();
		nMinus = (*nit)->nullFeatures_.size();
		logRatioPrior -= log(alpha_ + nPlus) - log(alpha_ + beta_ + nPlus + nMinus);		
	}	
	// Log-ratio for expression data.
	logRatioExpr = node->newFeatureToGroup_[feature]->logJointProb();
	if(nodeIsLeaf_[node]){
		for(cit = node->includedClients_.begin(); cit != node->includedClients_.end(); cit++)
			logRatioExpr -= stdNormLogProb(expr[getIndex(feature, *cit, nFeatures_)]);		
		if(SamplingUtils::sampleFromLogRatio(logRatioPrior + logRatioExpr) == 0){ // Switch feature to 0.
			valueHasChanged = true;
			node->newFeatures_.erase(feature);
			delete node->newFeatureToGroup_[feature];
			node->newFeatureToGroup_.erase(feature);
			node->nullFeatures_.insert(feature);
		}else{
			valueHasChanged = false;
		}
	}else{ // Node is not a leaf.
		for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++){
			childToGroup.insert(pair<FeatureNode*,Group*> 
				(*nit, new Group(&((*nit)->includedClients_), feature, nFeatures_, expr, k0_, aV_, bV_)));
			logRatioExpr -= childToGroup[*nit]->logJointProb();
		}
		if(SamplingUtils::sampleFromLogRatio(logRatioPrior + logRatioExpr) == 1){ // Keep feature at 1.
			valueHasChanged = false;
			for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++)
				delete childToGroup[*nit];
		}else{ // Switch feature to 0.		
			valueHasChanged = true;
			node->newFeatures_.erase(feature);
			delete node->newFeatureToGroup_[feature];
			node->newFeatureToGroup_.erase(feature);
			node->nullFeatures_.insert(feature);
			for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++){
				// Make each child node point to its own group.
				(*nit)->newFeatures_.insert(feature);				
				(*nit)->newFeatureToGroup_.insert(pair<int,Group*>(feature, childToGroup[*nit]));
				// Notice that before this function was called, for all child nodes the feature was not contained neither in 
				// 'newFeatures_' nor in 'nullFeatures_', because it had been set to one at the parent node.	
			}
		}
	}
	return valueHasChanged;
}

// Sample the activation status of a feature at a given node. 
// Computes the log-ratio of maintaining vs. switching the feature value, when it is currently set to zero.
// Returns true if the new activation status is different than the previous one; returns false otherwise. 
bool Tree::sampleFeatureAuxNegative(FeatureNode* node, int feature, double *expr){
	
	bool valueHasChanged;
	double logRatioPrior, logRatioExpr = 0.0;	
	int nPlus = node->newFeatures_.size(), nMinus = node->nullFeatures_.size() - 1;		
	Group* group = NULL;
	set<int>::iterator cit; // Client iterator.
	set<FeatureNode*>::iterator nit; // Node iterator.
	map<FeatureNode*,Group*> childToGroup;		
	// Log-ratio for P(feature at node = 0) vs. P(feature at node = 1).
	logRatioPrior = log(beta_ + nMinus) - log(alpha_ + nPlus);
	// Log-ratio for (feature at child nodes = 1| feature at node = 0) vs. (feature at child nodes = 1 | feature at node = 1).
	for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++){
		nPlus = (*nit)->newFeatures_.size() - 1;
		nMinus = (*nit)->nullFeatures_.size();
		logRatioPrior += log(alpha_ + nPlus) - log(alpha_ + beta_ + nPlus + nMinus);		
	}
	// Log-ratio for expression data.		
	if(nodeIsLeaf_[node])
		for(cit = node->includedClients_.begin(); cit != node->includedClients_.end(); cit++)
			logRatioExpr += stdNormLogProb(expr[getIndex(feature, *cit, nFeatures_)]);		
	else
		for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++)
			logRatioExpr += (*nit)->newFeatureToGroup_[feature]->logJointProb();
	// Hypothetical group corresponding to switching the feature to 1 at the input node. 	
	group = new Group(&(node->includedClients_), feature, nFeatures_, expr, k0_, aV_, bV_);
	logRatioExpr -= group->logJointProb();
	// Sample activation status.
	if(SamplingUtils::sampleFromLogRatio(logRatioPrior + logRatioExpr) == 1){ // Maintain feature = 0.
		valueHasChanged = false;
		delete group;
	}else{ // Switch feature to 1.
		valueHasChanged = true;
		if(nodeIsLeaf_[node]){
			node->nullFeatures_.erase(feature);			
			node->newFeatures_.insert(feature);
			node->newFeatureToGroup_[feature] = group;
		}else{ // Node is not a leaf.
			for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++){
				// Delete induced groups from the children.
				(*nit)->newFeatures_.erase(feature);
				delete (*nit)->newFeatureToGroup_[feature];
				(*nit)->newFeatureToGroup_.erase(feature);
			}
			node->nullFeatures_.erase(feature);
			node->newFeatures_.insert(feature);
			node->newFeatureToGroup_.insert(pair<int,Group*>(feature, group));
		}
	}
	return valueHasChanged;
}

// Recursive auxiliary function for Tree::randomFeatureNode. 
// Check which nodes have a non-zero probability for the two possible values (0/1) of the current feature.
// Return the result in the input/output variable 'nodes'.
void Tree::randomFeatureNodeAux(FeatureNode* node, int feature, vector<FeatureNode*>* nodes){	
	bool allOne;
	set<FeatureNode*>::iterator nit; // Node iterator.
	if(node->newFeatures_.find(feature) != node->newFeatures_.end()){ // If feature is 1...
		// Nodes under this node are forced to have the feature equal to one, so it's not worth sampling the feature from them.
		nodes->push_back(node);
		return; // Recursion base case.
	}
	// Reaching this point in the function entails that the feature is 0 for the current node ('node').
	// Check if all child nodes have the feature equal to 1.
	// If that is the case, then it's possible to obtain 0 or 1 for the feature in the current node. 
	// Otherwise, the feature has to stay at 0 in the current node.
	allOne = true;
	for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++)
		if((*nit)->newFeatures_.find(feature) == (*nit)->newFeatures_.end()){ // Feature is 0 at child node.
			allOne = false;
			break;
		}
	if(allOne)
		nodes->push_back(node);	
	// Recurse to children.
	for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++)
		randomFeatureNodeAux(*nit, feature, nodes);
}

// Choose a random node for the current feature such that both feature values (0/1) have non-zero probability. 
FeatureNode* Tree::randomFeatureNode(int feature){
	vector<FeatureNode*> nodes;
	set<FeatureNode*>::iterator nit; // Node iterator.
	for(nit = nodeToChildren_[root_]->begin(); nit != nodeToChildren_[root_]->end(); nit++)
		randomFeatureNodeAux(*nit, feature, &nodes);
	// Choose one node at random. Note: there is always at least one possible node.	
	return nodes[SamplingUtils::uniformSample(nodes.size())];
}

// Sample the activation value of a given feature at a random node.
// Return true if the sampling process changes the value of the feature; return false otherwise.
bool Tree::sampleFeature(int feature, double *expr)
{		
	bool featureValueHasChanged;
	// Choose random node for the current feature, with the condition that 
	// the node must have non-zero probability for both values 0 and 1.
	// Notice that we can always find such a node. 	
	FeatureNode* node = randomFeatureNode(feature);
	if(node->newFeatures_.find(feature) != node->newFeatures_.end()) // Feature is currently set to 1.	
		featureValueHasChanged = sampleFeatureAuxPositive(node, feature, expr); 
	else // Feature is currently set to 0.
		featureValueHasChanged = sampleFeatureAuxNegative(node, feature, expr);	
	return featureValueHasChanged;
}

///////////////////////////////
//
// Log-probability computations.
//
///////////////////////////////

// Return the prior log-probability of gamma_. 
// Inversion of bGamma_ is required because GSL uses exp(-x/b) in the formula 
// for computing the pdf of a gamma distribution at point x, while we use exp(-x * b). 
double Tree::logProbGamma()
{
	return log(gsl_ran_gamma_pdf(gamma_, aGamma_, 1.0 / bGamma_));
}

// Return the log-probability of the client tree structure. Depth-first recursive.
double Tree::logProbNcrp(FeatureNode* node)
{
	double out = 0.0;	
	set<FeatureNode*>::iterator nit; // Node iterator.
	 // Recursion base case.
	if(nodeIsLeaf_[node])
		return 0.0;
	// Compute contribution of current node.	
	out += nodeToChildren_[node]->size() * log(gamma_);
	out -= gsl_sf_lngamma(node->nClients() + gamma_);	
	for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++)
		out += gsl_sf_lngamma((*nit)->nClients());		
	// Recurse to child nodes.
	for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++)
		out += logProbNcrp(*nit);	
	return out;
}

// Return the log-probability of the feature activation patterns in a single node.
// This is an auxiliary function for Tree::logProbFeatures().
double Tree::logProbFeaturesSingleNode(FeatureNode* node)
{
	int nPlus = node->newFeatures_.size(); // # features that switched to one.
	int nMinus = node->nullFeatures_.size(); // # features that were kept at zero.
	double out = 0.0;	
	// Log-prob. of generating feature values.
	out += gsl_sf_lngamma(nPlus + alpha_) + gsl_sf_lngamma(nMinus + beta_);
	out -= gsl_sf_lngamma(nPlus + nMinus + alpha_ + beta_);
	out -= gsl_sf_lngamma(alpha_) + gsl_sf_lngamma(beta_);
	out += gsl_sf_lngamma(alpha_ + beta_);
	return out;
}

// P(Z | C, gamma_). Implemented as breadth-first search.
double Tree::logProbFeatures()
{
	double out = 0.0;
	set<FeatureNode*>::iterator nit; // Node iterator.
	queue<FeatureNode*> q;
	FeatureNode* node;
	// Initialize queue with the root's children.
	for(nit = nodeToChildren_[root_]->begin(); nit != nodeToChildren_[root_]->end(); nit++)
		q.push(*nit);
	// Browse tree in breath-first manner.
	while(!q.empty()){
		node = q.front();
		q.pop();
		out += logProbFeaturesSingleNode(node);
		for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++)
			q.push(*nit);
	}
	return out;
}

// Return the log-probability of the expression data Y_ given all other variables. 
double Tree::logProbExpr(double *expr)
{
	int i;
	double out = 0.0;
	set<int>::iterator fit; // Feature iterator.	
	set<FeatureNode*>::iterator nit; // Node iterator.
	queue<FeatureNode*> q;
	FeatureNode* node;
	// Explore features that switch to 1 and corresponding induced groups.
	q.push(root_);
	while(!q.empty()){
		node = q.front();
		q.pop();
		for(fit = node->newFeatures_.begin(); fit != node->newFeatures_.end(); fit++)
			out += node->newFeatureToGroup_[*fit]->logJointProb();
		for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++)
			q.push(*nit);
	}
	// Explore null features.
	for(i = 0; i < nClients_; i++){
		node = clientToLeafNode_[i];
		for(fit = node->nullFeatures_.begin(); fit != node->nullFeatures_.end(); fit++)
			out += stdNormLogProb(expr[getIndex(*fit, i, nFeatures_)]);	
	}
	return out;
}

///////////////////////////////
//
// Etc.
//
///////////////////////////////

// Return the number of nodes in the tree.
int Tree::nNodes()
{
	return nodeToParent_.size(); // May also use nodeToChildren_.size().
}

// Return the number of leaf nodes in the tree.
int Tree::nLeafNodes()
{
	set<FeatureNode*> leafNodes; // A set contains only unique items. 
	map<int, FeatureNode*>::iterator mit; // Map iterator. 
	for(mit = clientToLeafNode_.begin(); mit != clientToLeafNode_.end(); mit++)
		leafNodes.insert((*mit).second);
	return leafNodes.size();	
}

void Tree::computeNodeToNChildren(map<FeatureNode*, int> *nodeToNChildren)
{
	set<FeatureNode*>::iterator nit; // Node iterator.
	queue<FeatureNode*> q;
	FeatureNode* node;
	q.push(root_);
	while(!q.empty()){
		node = q.front();
		q.pop();
		(*nodeToNChildren)[node] = nodeToChildren_[node]->size();
		for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++)
			q.push(*nit);	
	}
}

// Recursive.
int Tree::computeNodeToNClients(map<FeatureNode*, int> *nodeToNClients, FeatureNode* node)
{
	map<int, FeatureNode*>::iterator mit; // Iterator to map from client to leaf node.
	set<FeatureNode*>::iterator nit; // Node iterator.
	int n = 0;
	if(nodeIsLeaf_[node])
		for(mit = clientToLeafNode_.begin(); mit != clientToLeafNode_.end(); mit++){
			if((*mit).second == node)			
				n++;			
		}
	else
		for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++)
			n += computeNodeToNClients(nodeToNClients, *nit);
	(*nodeToNClients)[node] = n;
	return n;
}

// Recursive.
void Tree::clientPathAux(int client, FeatureNode *node, stack<FeatureNode*> *s)
{	
	s->push(node);
	if(nodeToParent_[node] != NULL) // The recursion stops when the root is found.
		clientPathAux(client, nodeToParent_[node], s);
}

// Return a stack of nodes, representing the path taken by the client from the root to a leaf node. 
// The node in the top position (i.e. the node obtained when calling stack.top() for the first time) is the root node.  
stack<FeatureNode*> Tree::clientPath(int client)
{
	stack<FeatureNode*> s;
	clientPathAux(client, clientToLeafNode_[client], &s);
	return s;
}

// Check if the old and the new paths are the same. 
bool Tree::pathHasChanged(stack<FeatureNode*> oldPath, FeatureNode *newNode)
{
	FeatureNode *parent = oldPath.top(); // Always the root.	
	FeatureNode *node = NULL;
	oldPath.pop();
	// Get the deepest node in the "old" path that still exists in the tree.
	// This tells you that, in the old path, all the nodes below 'node' included only the current client, so they were 
	// deleted when the client was removed from the tree.
	while(!oldPath.empty()){
		node = oldPath.top();
		oldPath.pop();
		if(nodeToLevel_.find(node) == nodeToLevel_.end()){ // 'node' does not exist in the tree.
			// The parent of 'node' is the deepest non-singleton node in the old path.
			node = parent;
			break;
		}
		parent = node;
	}
	// 'newNode' is the node in the new path such that all nodes below it contain only the current client. 
	// This node existed before in the tree. If it is the same as 'node', then the only change that occurred between 
	// the old and the new path was the substitution of "singleton" (i.e. with a single client) nodes by other singleton nodes, 
	// that is, the new and old paths are the same.
	return newNode != node;
}

void Tree::estimateEdgeLengths()
{
	int nPlus, nMinus;
	double len;
	FeatureNode *parent;
	set<FeatureNode*>::iterator nit; // Node iterator.
	queue<FeatureNode*> q;
	q.push(root_);
	while(!q.empty()){
		parent = q.front();
		q.pop();
		for(nit = nodeToChildren_[parent]->begin(); nit != nodeToChildren_[parent]->end(); nit++){
			q.push(*nit); // Add child node to queue.
			nPlus = (*nit)->nNewFeatures();
			nMinus = (*nit)->nNullFeatures();
			len = (nPlus + alpha_) / (nPlus + nMinus + alpha_ + beta_);
			edgeLength_[pair<FeatureNode*, FeatureNode*> (parent, *nit)] = len;
		}
	}
}

void Tree::writeToFile(string stem, map<int,string> *indexToClientLabel, map<int,string> *indexToFeatureLabel)
{	
	int i;
	string edgesFn = stem + "tree_edges.txt"; // Map parent nodes to child nodes and respective edge lengths.	
	string featuresFn = stem + "tree_features.txt"; // Map features to nodes.
	string clientsFn = stem + "tree_clients.txt"; // Map clients to leaf nodes.	
	ofstream edgesFile(edgesFn.c_str());
	ofstream featuresFile(featuresFn.c_str());
	ofstream clientsFile(clientsFn.c_str());	
	queue<FeatureNode*> q;
	FeatureNode* node;
	set<FeatureNode*>::iterator nit; // Node iterator.
	set<int>::iterator fit; // Feature iterator.	
	estimateEdgeLengths();
	// Map each node to a unique integer identifier, in breadth-first order.
	map<FeatureNode*, int> nodeToIndex;
	int id = 0; // Current integer identifier.
	q.push(root_);
	while(!q.empty()){
		node = q.front();
		q.pop();
		nodeToIndex[node] = id;
		id++;
		for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++)
			q.push(*nit);
	}	
	// Go through nodes again, in breadth-first order, in order to write edge length and feature activation files.
	edgesFile << "Parent" << "\t" << "Child" << "\t" << "Edge length" << endl;	
	featuresFile << "Feature\tNode" << endl;
	q.push(root_);
	while(!q.empty()){
		node = q.front();
		q.pop();
		for(fit = node->newFeatures_.begin(); fit != node->newFeatures_.end(); fit++)
			featuresFile << (*indexToFeatureLabel)[*fit] << "\t" << nodeToIndex[node] << endl;		
		for(nit = nodeToChildren_[node]->begin(); nit != nodeToChildren_[node]->end(); nit++){
			q.push(*nit);
			edgesFile << nodeToIndex[node] << "\t";
			edgesFile << nodeToIndex[*nit] << "\t";
			edgesFile << edgeLength_[pair<FeatureNode*, FeatureNode*> (node, *nit)] << endl;						
		}	
	}	
	// Clients file.	
	clientsFile << "Client" << "\t" << "Leaf node" << endl;
	for(i = 0; i < nClients_; i++)
		clientsFile << (*indexToClientLabel)[i] << "\t" << nodeToIndex[clientToLeafNode_[i]] << endl;
}
