// -*- C++ -*-  $Id: KeypointFilter.C,v 1.13 2014/01/16 15:18:15 jorma Exp $
//
// Copyright 2009-2014 PicSOM Development Group <picsom@cis.hut.fi>
// Aalto University School of Science and Technology
// Department of Information and Computer Science
// P.O.BOX 15400, FI-00076 Aalto, FINLAND
//

#include "KeypointFilter.h"

string KeypointFilter_C_vcid =
  "@(#)$Id: KeypointFilter.C,v 1.13 2014/01/16 15:18:15 jorma Exp $";

// ---------------------------------------------------------------------

bool ClusterFilter::LoadFilter(int size) {
  
  if (debug)
    cout << "ClusterFilter::LoadFilter(): Loading keypoint filter "
	 << "from file " << name << endl;

  ifstream infile(name.c_str());
  if (!infile) {
    cerr << "ERROR: ClusterFilter::LoadFilter(): File " << name 
	 << " not found" << endl;
    exit(-1);
  }

  if (size==0) {
    cerr << "ERROR: ClusterFilter::LoadFilter(): " 
	 << "Descriptor size must be specified" << endl;
    exit(-1);
  }

  size_t nclusters;
  infile >> nclusters;
  if (threshold >= nclusters) {
    cerr << "ERROR: ClusterFilter::LoadFilter(): Not enough clusters(" << nclusters 
	 << ") in " << name << " for threshold=" 
	 << threshold << endl;
    exit(-1);
  }

  codebook = new cv::Mat(nclusters, size, CV_32F);

  for (size_t i=0; i<nclusters; i++)
    for (int j=0; j<codebook->cols; j++)
      infile >> codebook->at<float>(i, j);

  if (debug)
    cout << "ClusterFilter::LoadFilter(): Successfully loaded " << nclusters
	 << " clusters" << endl;

  return true;
}

// ---------------------------------------------------------------------

void ClusterFilter::FilterDescriptors(const vector<cv::KeyPoint>&, 
				      const cv::Mat &descs) {

  double tt = (double)cvGetTickCount();
  keypoint_ok.clear();

  if (debug)
    cout << "ClusterFilter::FilterDescriptors(): Filtering " << descs.rows
	 << " descriptors using " << name << " with threshold=" 
	 << threshold << endl;

  static const cv::Mat I = cv::Mat::eye(descs.cols, descs.cols, CV_32F);  
  cv::Mat descv(descs.cols, 1, CV_32F);
  //float *descv_ptr = descv.ptr<float>(0);
  
  //CvSeqReader desc_reader;
  //cvStartReadSeq(descs, &desc_reader);

  size_t accepted = 0, rejected = 0;
  for (int i = 0; i < descs.rows; i++) {
    //const float* descriptor = (const float*)desc_reader.ptr;
    //CV_NEXT_SEQ_ELEM(desc_reader.seq->elem_size, desc_reader);
    //memcpy(descv_ptr, descriptor, 64*sizeof(float));
    descv = descs.row(i);

    cv::Size cbsize = codebook->size();
    double mindist = 99999999999.9;
    int mincl = -1;
    for (int j=0; j<cbsize.height; j++) {
      cv::Mat filtv = codebook->row(j);
      double dist = cv::Mahalanobis(descv, filtv.t(), I);
      if (dist<mindist) {
	mindist = dist;
	mincl = j;
      }
    }
    bool kp_ok = (mincl >= (int)threshold);
    keypoint_ok.push_back(kp_ok);
    kp_ok ? ++accepted : ++rejected;

    if (debug>1)
      cout << "ClusterFilter::FilterDescriptors(): Minimum distance=" << mindist
	   << " with cluster=" << mincl << " : " 
	   << (kp_ok ? "ACCEPTED" : "REJECTED") << endl;
  }

  if (debug) {
    cout << "ClusterFilter::FilterDescriptors(): Filtering done, accepted: " 
	 << accepted << ", rejected: " << rejected << endl;
    tt = (double)cvGetTickCount() - tt;
    printf("ClusterFilter::FilterDescriptors(): Total time = %gms\n", 
	   tt/(cvGetTickFrequency()*1000.));
  }

}

// ---------------------------------------------------------------------

#ifdef USE_LIBSVM

bool SVMFilter::LoadFilter(int) {
  const string msg = "SVMFilter::LoadFilter(): ";
  
  if (debug)
    cout << msg << "Loading SVM keypoint filter from file " << name << endl;

  model = svm_load_model(name.c_str());
  if (!model) {
    cerr << "ERROR: " << msg << "svm_load_model() failed" << endl;
    exit(-1);
  }
  
  if (svm_check_probability_model(model)==0) {
    cerr << "ERROR: " << msg 
	 << "Model does not support probability estimates" << endl;
    exit(-1);
  }

  svm_type = svm_get_svm_type(model);
  nr_class = svm_get_nr_class(model);

  if (debug) {
    cout << msg << "Successfully loaded SVM model of type=" << svm_type
	 << ", nr_class=" << nr_class << endl;
  }

  if (threshold == 0)
    svmthreshold = -1.0;
  else
    svmthreshold = threshold/100.0;

  return true;
}

// ---------------------------------------------------------------------

void SVMFilter::FilterDescriptors(const vector<cv::KeyPoint>&, 
				  const cv::Mat &descs) {

  double tt = (double)cvGetTickCount();
  keypoint_ok.clear();

  if (debug)
    cout << "SVMFilter::FilterDescriptors(): Filtering " << descs.rows
	 << " descriptors using " << name << " with svmthreshold=" 
	 << svmthreshold << endl;

  static const cv::Mat I = cv::Mat::eye(descs.cols, descs.cols, CV_32F);  
  cv::Mat descv(descs.cols, 1, CV_32F);
  //float *descv_ptr = descv.ptr<float>(0);
  
  //CvSeqReader desc_reader;
  //cvStartReadSeq(descs, &desc_reader);

  size_t accepted = 0, rejected = 0;
  for (int i = 0; i < descs.rows; i++) {
    //const float* descriptor = (const float*)desc_reader.ptr;
    //CV_NEXT_SEQ_ELEM(desc_reader.seq->elem_size, desc_reader);
    //memcpy(descv_ptr, descriptor, 64*sizeof(float));
    descv = descs.row(i);

    double prob_estimates[2];
    svm_node* x_space = ConvertVector(descv);

    svm_predict_probability(model, x_space, &prob_estimates[0]);

    delete [] x_space;
    
    bool kp_ok =  svmthreshold < 0.0 ? 
      true : (prob_estimates[1] >= svmthreshold);
    keypoint_ok.push_back(kp_ok);
    kp_ok ? ++accepted : ++rejected;

    if (debug>1)
      cout << "SVMFilter::FilterDescriptors(): Probability estimates: " 
	   << prob_estimates[0] << "," << prob_estimates[1] << " : "
	   << (kp_ok ? "ACCEPTED" : "REJECTED") << endl;
  }
  
  if (debug) {
    cout << "SVMFilter::FilterDescriptors(): Filtering done, accepted: " 
	 << accepted << ", rejected: " << rejected << endl;
    tt = (double)cvGetTickCount() - tt;
    printf("SVMFilter::FilterDescriptors(): Total time = %gms\n", 
	   tt/(cvGetTickFrequency()*1000.));
  }
  
}

// ---------------------------------------------------------------------

struct svm_node* SVMFilter::ConvertVector(const cv::Mat &v) {
  int l = v.rows;
  struct svm_node* x_space = new svm_node[1+l];

  int i;
  for (i=0; i<l; i++) {
    x_space[i].index = i+1;
    x_space[i].value = v.at<float>(i,0);
  
  }
  x_space[i].index = -1;
  return x_space;
}

#endif

// ---------------------------------------------------------------------

bool RandomFilter::LoadFilter(int /*size*/) { return true; }

// ---------------------------------------------------------------------

void RandomFilter::FilterDescriptors(const vector<cv::KeyPoint>&, 
				     const cv::Mat &descs) {

  double tt = (double)cvGetTickCount();
  keypoint_ok.clear();
  
  if (debug)
    cout << "RandomFilter::FilterDescriptors(): Filtering " << descs.rows
	 << " descriptors with threshold=" << threshold << endl;

  size_t accepted = 0, rejected = 0;
  for (int i = 0; i < descs.rows; i++) {
    int rnd = rand()%100;
    bool kp_ok = (rnd >= (int)threshold);
    keypoint_ok.push_back(kp_ok);
    kp_ok ? ++accepted : ++rejected;
    
    if (debug>1)
      cout << "RandomFilter::FilterDescriptors(): Random number generated: " 
	   << rnd << " : "
	   << (kp_ok ? "ACCEPTED" : "REJECTED") << endl;
  }
  
  if (debug) {
    cout << "RandomFilter::FilterDescriptors(): Filtering done, accepted: " 
	 << accepted << ", rejected: " << rejected << endl;
    tt = (double)cvGetTickCount() - tt;
    printf("RandomFilter::FilterDescriptors(): Total time = %gms\n", 
	   tt/(cvGetTickFrequency()*1000.));
  }
  
}

// ---------------------------------------------------------------------

bool HessianFilter::LoadFilter(int /*size*/) { return true; }

// ---------------------------------------------------------------------

void HessianFilter::FilterDescriptors(const vector<cv::KeyPoint> &keyps, 
				      const cv::Mat &) {

  double tt = (double)cvGetTickCount();
  keypoint_ok.clear();

  const string msg = "HessianFilter::FilterDescriptors(): ";

  float hessianthreshold = -1.0;
  if (keyps.size() > threshold) {
    vector<float> hessians(keyps.size());
    for (size_t i = 0; i < keyps.size(); i++) {
      // CvSURFPoint* r = (CvSURFPoint*)cvGetSeqElem(keyps, i);
      // hessians.at(i) = r->hessian;
      hessians.at(i) = keyps.at(i).response;
    }
#ifndef _MSC_VER
    sort(hessians.begin(), hessians.end(), greater<float>());
#else
    cerr << "ERROR: " << msg 
	 << "THIS FUNCTIONALITY CURRENTLY DISABLED FOR VISUAL STUDIO" 
	 << endl;
#endif //_MSC_VER
    hessianthreshold = hessians.at(threshold);
  }

  if (debug)
    cout << msg << "Filtering " << keyps.size()
	 << " descriptors with threshold=" << threshold 
	 << ", hessianthreshold=" << hessianthreshold << endl;     

  size_t accepted = 0, rejected = 0;
  for (size_t i = 0; i < keyps.size(); i++) {
    //CvSURFPoint* r = (CvSURFPoint*)cvGetSeqElem(keyps, i);
    const cv::KeyPoint &kp = keyps.at(i);

    bool kp_ok = (keyps.at(i).response > hessianthreshold);
    keypoint_ok.push_back(kp_ok);
    kp_ok ? ++accepted : ++rejected;
    
    if (debug>1)
      cout << msg << "Keypoint's hessian: " << kp.response << " : "
	   << (kp_ok ? "ACCEPTED" : "REJECTED") << endl;
  }
  
  if (debug) {
    cout << msg << "Filtering done, accepted: " 
	 << accepted << ", rejected: " << rejected << endl;
    tt = (double)cvGetTickCount() - tt;
    printf("HessianFilter::FilterDescriptors(): Total time = %gms\n", 
	   tt/(cvGetTickFrequency()*1000.));
  }
  
}

// ---------------------------------------------------------------------
// ---------------------------------------------------------------------
