#include "BinClassifier.hpp"
#include "Classifier.hpp"
#include <cassert>
#include <cmath>

using std::vector;
using std::cout;
using std::cerr;
using std::endl;
using std::string;
using std::map;
using cv::Mat;

namespace slmotion {
  void BinClassifier::train(const vector<vector<float> > &negSamples,
			    const vector<vector<float> > &posSamples,
			    const vector<bool> *dimMask, 
			    const vector<float> *negWeights,
			    const vector<float> *posWeights) {
    assert(negWeights==NULL || negWeights->size()==negSamples.size());
    assert(posWeights==NULL || posWeights->size()==posSamples.size());
    
    
    size_t dim=posSamples[0].size();
    vector<size_t> activeDim;
    
    for(size_t d=0;d<dim;d++)
      if(dimMask==NULL || (*dimMask)[d]) activeDim.push_back(d);
    
    dim=activeDim.size();
    
    
    cerr << "active dimensionality: " << dim << endl;
    
    cerr << "active components: " << endl;
    
    for(size_t d=0;d<dim;d++){
      cerr << activeDim[d]<< " ";
    }
    cerr << endl;
    
    
    
    
    // pointers to data vectors
    vector<const vector<vector<float> > *> trainSamples(2);
    
    trainSamples[0] = &negSamples;
    trainSamples[1] = &posSamples;
    
    vector<size_t> classSizes(2);
    
    for(size_t cls=0;cls<2;cls++)
      classSizes[cls]=trainSamples[cls]->size();
    
    if(znorm)
      estimateZNorm(negSamples,posSamples,dimMask);
    
    // set the quantiser bins, if used

    if (useQuantiser) {
      if (quantiserMode==TSSOMQUANT) {
	if (!tssomtrained)
	  traintssom(negSamples,posSamples,dimMask);	
      } 
      else if (quantiserMode==VECTORQUANT) {
	// select codebook with k-means

	size_t K=200;
	if (quantiserBinCount.size() > 0)
	  K=quantiserBinCount[0];
	
	vector<vector<float> > samples;
	
	vector<float> samplevec(dim);
	
	vector<float> samplefrac(2);
	samplefrac[0]=0.002;
	samplefrac[1]=0.01;

	for (size_t cls=0; cls < 2; cls++)
	  for (size_t i=0; i < classSizes[cls]; i++) {
	    float r=rand();
	    r /= RAND_MAX;

	    if (r < samplefrac[cls] ) {
	      for (size_t d=0; d < dim; d++) {
		samplevec[d]=(*trainSamples[cls])[i][activeDim[d]];
		if (znorm) {
		  samplevec[d] -= mean[d]; 
		  samplevec[d] /= sqrt(variance[d]);
		}
	      }
	      samples.push_back(samplevec);
	    }
	  }

	vector<size_t> labels;
	
	size_t tolerance=10;

	cerr << "Sampled " << samples.size() << " examples for codebook determination" << endl;

	kmeans(samples,NULL,labels,K,tolerance,&codebook);
      }
      else {
      // expand the bin count specification if it's not complete

      if (quantiserBinCount.size() == 0) {
	quantiserBinCount=vector<size_t>(dim,3); // default value
      }
      else if(quantiserBinCount.size() == 1) {
	size_t nbins=quantiserBinCount[0];
	quantiserBinCount=vector<size_t>(dim,nbins);
      }


      quantiserBinUpperLimits= vector<vector<float> >(dim);
      for (size_t d=0; d < dim; d++){

	quantiserBinUpperLimits[d]=vector<float>(quantiserBinCount[d]-1);

	switch(quantiserMode){
	case CONSTANT:

	  // determine the spread of values in dim. d and
	  // divide the interval into equal parts
	  {
	    float minval=negSamples[0][activeDim[d]];
	    float maxval=minval;
	    
	    for(size_t cls=0;cls<2;cls++)
	      for(size_t i=0;i<classSizes[cls];i++){
		float val=(*trainSamples[cls])[i][activeDim[d]];
		if(znorm){
		  val -= mean[d];
		  val /= sqrt(variance[d]);
		}
		if(val<minval) minval=val;
		if(val>maxval) maxval=val;
	      }
	    
	    float binWidth=maxval-minval;
	    binWidth /= quantiserBinCount[d];
	    
	    for(size_t bin=0;bin<quantiserBinCount[d]-1;bin++)
	      quantiserBinUpperLimits[d][bin]=minval+(bin+1)*binWidth; 
	    
	  }
	  //	  cerr << "Dim " << (int)d <<": limits ";
	  //	  for(size_t bin=0;bin<quantiserBinCount[d]-1;bin++)
	    //	    cerr << " " << quantiserBinUpperLimits[d][bin]; 
	  //	  cerr << endl;
	  break;
	case ADAPTIVE:
	  throw string("ADAPTIVE quantiser mode not yet implemented");
	  break;
	default:
	  throw string("unknown quantiser mode");
	}


      }
      }


    } // if(useQuantiser)

    // now count the number of occurrences of each training vector type,
    // using quantiser if instructed
    
    vector<map<vector<float>,float> *> countMaps(2);
    vector<const vector<float> *> weights(2);

    
    countMaps[0]=&negCount;
    countMaps[1]=&posCount;

    weights[0]=negWeights;
    weights[1]=posWeights;

    vector<float> newSample(dim);
    vector<float> qSample;

    size_t tgtsize=150000;

    for(size_t cls=0;cls<2;cls++){
      countMaps[cls]->clear();

      size_t incr=classSizes[cls]/tgtsize;
      if(incr<1) incr=1;

      float samplingweight=1.0/incr;

      for(size_t i=0;i<classSizes[cls];i+=incr){

	const vector<float> *sptr=&((*trainSamples[cls])[i]);
	
	if(dimMask){
	  for(size_t d=0;d<dim;d++)
	    newSample[d]=(*sptr)[activeDim[d]];
	  sptr=&newSample;
	}

	if(useQuantiser){
	  quantise(*sptr,qSample);
	  sptr=&qSample;
	}
	

	(*countMaps[cls])[*sptr]+= samplingweight*((weights[cls]!=NULL)?(*weights[cls])[i]:1);
      }
    }

    cerr << "finished counting" << endl;
    cerr << posCount.size() << " positive and " << negCount.size() << " negative sample types." << endl;

    aprioriprob=classSizes[1];
    aprioriprob /= classSizes[0]+classSizes[1];

    // in case of tssoms, visualise the count maps

    if (useQuantiser && quantiserMode == TSSOMQUANT) {

      TsSomCodebook* bottom = tssomlevels[tssomlevels.size()-1];
      
      int mag=10;

      float discount=3;

      Mat vismat(bottom->h*mag,bottom->w*mag,cv::DataType<float>::type);
      Mat discmat(bottom->h*mag,bottom->w*mag,cv::DataType<float>::type);
      Mat classprob(bottom->h,bottom->w,cv::DataType<float>::type);
      Mat discprob(bottom->h,bottom->w,cv::DataType<float>::type);

	for(int y=0;y<bottom->h;y++)
	  for(int x=0;x<bottom->w;x++){

	    vector<float> idx(1,x+y*bottom->w);

	    float val=posCount[idx];
	    if(negCount[idx])
	      val /= negCount[idx]+posCount[idx];

	    float discval=posCount[idx];
	    discval /= negCount[idx]+posCount[idx]+discount;

	    classprob.at<float>(y,x)=val; 
	    discprob.at<float>(y,x)=discval; 

// 	    cerr << "Map unit ["<<x<<","<<y<<"]: posCount="<<posCount[idx]
// 		 << " negCount=" << negCount[idx] << " -> discval="
// 		 << discval << endl;

	    for(int dy=0;dy<mag;dy++)
	      for(int dx=0;dx<mag;dx++){
		int i=mag*y+dy;
		int j=mag*x+dx;

		vismat.at<float>(i,j)=val;
		discmat.at<float>(i,j)=discval;
	      }
	  }

	//		cv::imshow("class probability on bottom level", vismat);
	//	cv::imshow("discounted class probability on bottom level", discmat);
	//	cv::waitKey(0);


	// spread class probability heuristically to neighbours on the map


	bool leftmask[]={0,0,0,1,1,1,0,0};
	bool rightmask[]={1,1,0,0,0,0,0,1};
	bool upmask[]={0,0,0,0,0,1,1,1};
	bool downmask[]={0,1,1,1,0,0,0,0};
	
	bool nwmask[]={0,0,0,0,1,1,1,0};
	bool nemask[]={1,0,0,0,0,0,1,1};
	bool semask[]={1,1,1,0,0,0,0,0};
	bool swmask[]={0,0,1,1,1,0,0,0};
	
	int spreadcount=0;

	Mat classprob2=discprob.clone();

	for(int y=1;y<bottom->h-1;y++)
	  for(int x=1;x<bottom->w-1;x++){
	    float centerprob=discprob.at<float>(y,x);

	    float spreadstrength=0;

	    spreadstrength=fmax(spreadstrength,
				determineStrengthUnderMasks(discprob,y,x,leftmask,rightmask));
	    spreadstrength=fmax(spreadstrength,
				determineStrengthUnderMasks(discprob, y,x,upmask,downmask));
	    spreadstrength=fmax(spreadstrength,
				determineStrengthUnderMasks(discprob, y,x,nwmask,semask));
	    spreadstrength=fmax(spreadstrength,
				determineStrengthUnderMasks(discprob, y,x,nemask,swmask));

	    if(spreadstrength>centerprob){

	      //	      cerr << "spreading w/ strength" << spreadstrength << endl;

	      spreadcount++;
	      classprob2.at<float>(y,x)=0.3*discprob.at<float>(y,x)+0.7*spreadstrength;
	    }
	  }

	cerr << "spread probability to " << spreadcount << " map units." << endl;

	tssomprob=vector<float>(bottom->h*bottom->w);

	for(int y=0;y<bottom->h;y++)
	  for(int x=0;x<bottom->w;x++){

	    tssomprob[x+y*bottom->w]=classprob2.at<float>(y,x);
	    //	    cerr << "setting tssomprob["<<x+y*bottom->w<<"] = " << tssomprob[x+y*bottom->w] << endl;

	    for(int dy=0;dy<mag;dy++)
	      for(int dx=0;dx<mag;dx++){
		int i=mag*y+dy;
		int j=mag*x+dx;
		vismat.at<float>(i,j)=classprob2.at<float>(y,x);
	      }
	  }

       cv::imshow("heuristically spread class probability on bottom level", vismat);
	cv::waitKey(0);




    }






    predictioncache.clear();

 }

void  BinClassifier::traintssom(const vector<vector<float> > &negSamples,
				const vector<vector<float> > &posSamples,
				const vector<bool> *dimMask){

  size_t dim=posSamples[0].size();
  vector<size_t> activeDim;
  
  for(size_t d=0;d<dim;d++)
    if(dimMask==NULL || (*dimMask)[d]) activeDim.push_back(d);
  
  dim=activeDim.size();

// pointers to data vectors
  vector<const vector<vector<float> > *> trainSamples(2);
    
  trainSamples[0] = &negSamples;
  trainSamples[1] = &posSamples;

  DataSet dat;
  TsSomCodebook cod;
  dat.dim=dim;
  vector<float> toadd(dim);

  // sample targets

  vector<size_t> sampletgt(2,500000);

  for (size_t cls=0; cls < 1; cls++) {
    for (size_t i=0; i < trainSamples[cls]->size(); i++) {
      // randomint(n) is in the range of [0,n-1], so the cast should be valid (to 
      // suppress the warning)
      if (static_cast<size_t>(randomint(trainSamples[cls]->size())) < sampletgt[cls]) {
	for (size_t d=0; d < dim; d++){
	  toadd[d]=(*trainSamples[cls])[i][activeDim[d]];
	  if(znorm){
	    toadd[d] -= mean[d];
	    toadd[d] /= sqrt(variance[d]);
	  }
	}
	dat.vec.push_back(toadd);
      }
    }
  }

  int nlevels=3;
	  
  int w=4,h=4,level,dmin=0;
	
  vector<float> alpha(3,0.1);
  vector<int> nit(3);
  nit[0]=10;
  nit[1]=10;
  nit[2]=7;
  vector<int> d(4);
  d[0]=2;
  d[1]=2;
  d[2]=3;

  for(level=0;level<nlevels;level++){
    tssomlevels.push_back(new TsSomCodebook);
  }

  bool torus=false;
  
  for(level=0;level<nlevels;level++){
    
    if(level==nlevels-1) dmin=1;
    
    
    cout << "Training ts-som level "<<level<<endl;
    
    tssomlevels[level]->torus=torus;
    
    if(level==0){
      tssomlevels[level]->randominit(w,h,dat);
    }
    else{
      tssomlevels[level]->upper=tssomlevels[level-1];
      tssomlevels[level]->initfromabove();
    }
	  
    cout << "initialisation finished" << endl;
    cout << "w="<<tssomlevels[level]->w<<" h="<<tssomlevels[level]->h<<endl;

    // 	  cerr << "initial map values:" << endl;

    // 	  for(int i=0;i<tssomlevels[level]->w*tssomlevels[level]->h;i++){
    // 	    cerr << "unit " << i << endl;
    // 	    for(int d=0;d<tssomlevels[level]->dim;d++)
    // 	      cerr << tssomlevels[level]->units[i][d] << " ";
    // 	    cerr << endl;
    // 	  }
	  
	  
    // construct the training region
    
    int n_upper;
    n_upper= (level)?tssomlevels[level-1]->w*tssomlevels[level-1]->h:1;
    
    vector<TrainingRegion> r(n_upper);
    
    int nd=dat.vec.size();
    for(int i=0;i<nd;i++){
      //cout << "data point "<<i<<endl;
      int bmu=(level)?tssomlevels[level-1]->tssom_findbmu(dat.vec[i]):0;
      // cout << "found bmu "<<bmu<<endl;
      if(r[bmu].data.empty()){
	if(level)
	  r[bmu].limits=tssomlevels[level-1]->expanddown(bmu);
	else{
	  r[bmu].limits.x1=r[bmu].limits.y1=0;
	  r[bmu].limits.x2=w-1;
	  r[bmu].limits.y2=h-1;
	}
	
	r[bmu].current_data=0;
      }
      r[bmu].data.push_back(i);
    }
	  
    cout << n_upper << " training regions constructed" << endl;
	  
    // iterate training
    //    for(int region=0;region<n_upper;region++)
    //  r[region].dump();
    
    
    int n_rounds = nit[level]*nd/n_upper;
    
    for(int rnd=0; rnd<n_rounds; rnd++){
      if(rnd%1000==0)
	cout << "training iteration "<<rnd<<"/"<<n_rounds<<endl;
      float fraction=1-((float)rnd)/n_rounds;
      float alpha_t = alpha[level]*fraction;
      int d_t = (int)(d[level]*fraction);
      if(d_t<dmin) d_t=dmin;
      for(int region=0;region<n_upper;region++){
	//	  cout << "training region "<<region<<":"<<endl;
	if(!r[region].data.empty())
	  tssomlevels[level]->regiontrainiteration(r[region],dat,alpha_t,d_t);
	else{
	  // cout << "training region "<<region<<" has no data"<<endl;
	}
      }
    }
    w *= 4;
    h *= 4;
  } // for level

  // show the levels of the trained map
  
  if(true){
    for(int level=0;level<nlevels;level++){
      
      TsSomCodebook* bottom=tssomlevels[level];

      if(bottom->dim != 3) continue;
      
      int mag=10;
      
      Mat levelmat(bottom->h*mag,bottom->w*mag,cv::DataType<cv::Vec3b>::type);
      
      bool hsv=false;
      bool hsvtrig=false;
      float pii=3.141592654;
      vector<float> valf(3);
      
      for(int y=0;y<bottom->h;y++)
	for(int x=0;x<bottom->w;x++){
	  cv::Vec3b val;
	  
	  
	  for(size_t d=0;d<3;d++){
	    valf[d] =bottom->units[x+y*bottom->w][d];
	    if(znorm){
	      valf[d] *= sqrt(variance[d]);
	      valf[d] += mean[d];
	    }
	    
	  }
	  
	  if(hsv){
	    
	    
	    if(hsvtrig){
	      valf=hsvtrig2hsv(valf);
	    }
	    
	    valf[0]*=180/(2*pii);
	    valf[1]*=255;
	    
	    // 	      cerr << "scaling hsv values from map: ";
	    // 	      for(size_t dim=0;dim<3;dim++)
	    // 		cerr << bottom->units[x+y*bottom->w][dim] << " -> " << (float)valf[dim] << "  ";
	    // 	      cerr << endl;
	    
	  } else{
	    // for bgr do nothing
	  }
	  
	  for(size_t d=0;d<3;d++)
	    val[d]=valf[d];
	  
	  for(int dy=0;dy<mag;dy++)
	    for(int dx=0;dx<mag;dx++){
	      int i=mag*y+dy;
	      int j=mag*x+dx;
	      
	      levelmat.at<cv::Vec3b>(i,j)=val;
	    }
	}
      
      if(hsv)
	cvtColor(levelmat,levelmat,CV_HSV2BGR);
      
      	cv::imshow("trained bottom level map", levelmat);
      	cv::waitKey(0);
    }
  }

  tssomtrained=true;

}
    

  float BinClassifier::predict(const std::vector<float>  &sample,
				      const vector<bool> *dimMask){

    const vector<float> *sptr=&sample;
    size_t dim=sample.size();
    
    vector<float> newSample;
    vector<float> qSample;
    
    if(dimMask){
      for(size_t d=0;d<dim;d++)
	if((*dimMask)[d])
	  newSample.push_back(sample[d]);
      sptr=&newSample;
    }

    if(!use_cache || predictioncache.count(sample)==0){

    if(useQuantiser){
      quantise(*sptr,qSample);
      sptr=&qSample;
    }





    float ret;

    if(useQuantiser&&quantiserMode==TSSOMQUANT){
      ret=tssomprob[(*sptr)[0]];
      //      cerr << "binclassifier predicting p="<<ret<<endl;
    }
    else{

      size_t pcount=posCount[*sptr];
      size_t ncount=negCount[*sptr];
      
      if(pcount>0){
	ret=pcount;
	ret /= pcount+ncount;
	
      }
      else 
	ret = (ncount>0)? 0.0 : aprioriprob;
    }
      //        cerr << "counts: pos " << pcount << " neg " << ncount << " -> p " << ret << endl;
    if(use_cache) predictioncache[sample]= ret;
    else return ret;
    }

    return predictioncache[sample];

  }


  void BinClassifier::quantise(const std::vector<float>  &src, 
			  std::vector<float>  &dst)
  {
    vector<float> norm;
    const vector<float> *vec=&src;
    if(znorm){
      norm=src;
      size_t dim=src.size();
      for(size_t d=0;d<dim;d++){
	norm[d] -= mean[d];
	norm[d] /= sqrt(variance[d]);
      }
      vec=&norm;
    }

    if(quantiserMode==TSSOMQUANT){

      dst=vector<float>(1,tssomlevels[tssomlevels.size()-1]->tssom_findbmu(*vec));
      
    }
    else if(quantiserMode==VECTORQUANT){

      size_t minidx=0;
      float mindist=sqrdist(*vec,codebook[0],-1);

      for(size_t k=1;k<codebook.size();k++){
	float d=sqrdist(*vec,codebook[k],mindist);
	if(d<mindist){
	  mindist=d;
	  minidx=k;
	}
      }

      dst=vector<float>(1,minidx);

    } else{
    
    size_t dim=vec->size();
    dst=*vec;

    for(size_t d=0;d<dim;d++){
      // quantise each dimension separately
      size_t bin;
      for(bin=0;
	  bin<quantiserBinUpperLimits[d].size() && 
	    (*vec)[d]>quantiserBinUpperLimits[d][bin];
	  bin++);

      // now determined bin \in [0,quantiserBinUpperLimits[d].size()]
      dst[d]=bin;
    }

//     cerr << "quantiser: ";
//     for(size_t d=0;d<dim;d++)
//       cerr << " " << src[d];

//     cerr << " ->";
//     for(size_t d=0;d<dim;d++)
//       cerr << " " << dst[d];
//     cerr << "endl";  
    }
  }
}
