#include "KLTTracker.hpp"
#include "BlackBoardDumpWriter.hpp"
#include "SkinDetector.hpp"
#include "ColourSpaceConverter.hpp"
#include "util.hpp"
#include <opencv2/opencv.hpp>
#include <algorithm>
#include <list>


using cv::Mat;
using cv::Scalar;
using std::vector;
using cv::Point2f;
using cv::Vec6f;
using std::cerr;
using std::endl;
using cv::Size;
using cv::TermCriteria;
using std::set;
using std::list;
using boost::program_options::value;



namespace slmotion {
  extern int debug;



  static KLTTracker DUMMY(true);

  KLTTracker::KLTTracker(bool) : Component(true) {
    BlackBoardDumpWriter::registerAnyWriter<std::set<TrackedPoint> >();
  /**
   * Format:
   * [uint64_t][float][float][char]
   *     |        |      |     |
   *     |        |      |     +---Body part identity
   *     |        |      +---y coordinate
   *     |        +---x coordinate
   *     +---id
   */
    BlackBoardDumpWriter::registerSizeComputer<TrackedPoint>([](const TrackedPoint&) {
        return sizeof(uint64_t) + sizeof(char) + 2*sizeof(float);
      });
    BlackBoardDumpWriter::registerDumbWriter<TrackedPoint>([](std::ostream& ofs, 
                                                              const TrackedPoint& data) {
        uint64_t id = data.getId();
        cv::Point2f p(static_cast<cv::Point2f>(data));
        float coord[2] = { p.x, p.y };
        char partName = data.getPartName();
        ofs.write(reinterpret_cast<const char*>(&id), sizeof(id));
        ofs.write(reinterpret_cast<const char*>(coord), sizeof(coord));
        ofs.write(reinterpret_cast<const char*>(&partName), sizeof(partName));
      });

    BlackBoardDumpReader::registerUnDumper<TrackedPoint>([](BlackBoardDumpReader* const w, std::istream& ifs) {
        uint64_t id;
        float coord[2];
        char partName;
        w->dumbRead(ifs, id);
        w->dumbRead(ifs, coord);
        w->dumbRead(ifs, partName);
        return TrackedPoint(id, cv::Point2f(coord[0], coord[1]),
                            static_cast<BodyPart::PartName>(partName));
      });

  }
  

  TrackedPoint::point_id_type_t TrackedPoint::uniqueIdCounter = 1;
  


    /**
     * Attempts to locate corners in the greyscale image within the area of
     * the mask if the mask is set (i.e. it is not empty)
     *
     * @param gsimg Greyscale input image
     * @param mask Binary mask, or an empty matrix
     *
     * @return a vector of detected corners
     */
  vector<Point2f> KLTTracker::findCorners(const cv::Mat& gsImg, 
                                          const cv::Mat& mask) {
    assert(gsImg.type() == CV_8UC1);
    assert(mask.empty() || (mask.size() == gsImg.size() &&
                            mask.type() == gsImg.type()));
    
    vector<Point2f> corners;
    goodFeaturesToTrack(gsImg, corners, maxPoints, qualityLevel, 
                        minDistance, mask);
    
    cornerSubPix(gsImg, corners, Size(5, 5), Size(-1, -1), 
                 TermCriteria(TermCriteria::COUNT + TermCriteria::EPS,
                              20, 0.03));

    return corners;
  }



  std::set<TrackedPoint> KLTTracker::initialise(const Mat& gsframe, const Mat& mask) {
    vector<Point2f> corners = findCorners(gsframe, mask);
    
    vector<uchar> featuresFound(corners.size(), 1);
    vector<float> errs(corners.size(), 0);
    removeInvalidPoints(corners, featuresFound, errs, mask);

    // convert points to identified tracked points
    std::set<TrackedPoint> trackedPoints;
    for (size_t i = 0; i < corners.size(); ++i)
      if (featuresFound[i])
        trackedPoints.insert(TrackedPoint(corners[i]));

    return trackedPoints;
  }


  
    void KLTTracker::purgeAndReplacePoints(const cv::Mat& img,
                                           std::vector<cv::Point2f>& points,
                                           vector<uchar>& featuresFound, 
                                           vector<float>& errors,
                                           const cv::Mat& mask) {
      vector<cv::Point2f> newCorners = findCorners(img, mask);
      points.insert(points.end(), newCorners.begin(), newCorners.end());
      featuresFound.insert(featuresFound.end(), newCorners.size(), 1);
      errors.insert(errors.end(), newCorners.size(), 0);
      removeInvalidPoints(points, featuresFound, errors, mask);
    }



#if 0
  void KLTTracker::replaceLostPoints(const Mat& img,
                                       const vector<Vec6f>* originalEig, 
                                       vector<Point2f>& points,
                                       vector<uchar>& featuresFound,
                                       const vector<uchar>& oldFeaturesFound,
                                       const Mat& mask) {
    vector<Point2f> tempPoints;
  
    goodFeaturesToTrack(img, tempPoints, ST_MAX_POINTS_UPPERBOUND,
                        qualityLevel, minDistance, mask);
  
    if (originalEig) {
      // Assume preserve_old_points is on
      Mat eig2;
      cornerEigenValsAndVecs(img, eig2, 3, 3);
    
      // First remove any invalid points
      for (size_t i = 0; i < points.size(); i++) {
        if (featuresFound[i]) {
          const Point2f& point = points[i];
          if (point.x < 0 || point.y  < 0 || point.x >= img.size().width 
              || point.y >= img.size().height) {
            featuresFound[i] = 0;
            continue;
          }
        
          const Vec6f& new_eig = eig2.at<Vec6f>(point);
        
          // a value computed as follow:
          // assuming l1 and l2 are the minimal eigenvalues
          // at the new location and o1 and o2 at the original location
          // then, error = (o1 - l1)^2 + (o2 - l2)^2
          const float error_value = 
            ((*originalEig)[i][0] - new_eig[0]) *
            ((*originalEig)[i][0] - new_eig[0]) +
            ((*originalEig)[i][1] - new_eig[1]) *
            ((*originalEig)[i][1] - new_eig[1]);
        
          if (error_value > maxFoundEigenError)
            featuresFound[i] = 0;
        }
      }
    
      // Then attempt to rediscover them:
      // Compute the 'error value' for each new corner found
      // and pick the best one, assuming it goes above the limit
      int best_index;
      float best_error;
      Point2f point;
      float error_value;
    
      for (size_t i = 0; i < points.size(); i++) {
        if (!featuresFound[i]) {
          best_index = -1;
          best_error = FLT_MAX;
          for (size_t j = 0; j < tempPoints.size(); j++) {
            point = tempPoints[j];
            Vec6f& new_eig = eig2.at<Vec6f>(point);
            error_value = ((*originalEig)[i][0] - new_eig[0]) *
              ((*originalEig)[i][0] - new_eig[0]) +
              ((*originalEig)[i][1] - new_eig[1]) *
              ((*originalEig)[i][1] - new_eig[1]);	
            if (error_value < maxNewEigenError &&
                error_value < best_error) {
              best_index = j;
              best_error = error_value;
            }
          }
        
          if (best_index >= 0) {
            points[i] = tempPoints[best_index];
            featuresFound[i] = 1;
            if (slmotion::debug > 3) {
              point = points[i];
              cerr << "Rediscovered " << i << " at (" << point.x << "," <<
                point.y << " with error value " << best_error << endl;
            }
          }
        }
      }
    }
    else {
      // go through all newly found points of interest and see if they can be
      // put somewhere
      bool isGood;
      // double min_distance_sq = minDistance * minDistance;
      // double dx, dy;
      size_t j, rc = 0;
    
      /*
        if (::debug)
        fprintf(stderr, "Replacing lost features...\n");*/
    
      for (size_t i = 0; i < tempPoints.size(); i++) {
        // find a possible location and store in j
        isGood = false;
        const Point2f& point1 = tempPoints[i];
        for (j = 0; static_cast<int>(j) < maxPoints; j++) {
          if (j < points.size() && !featuresFound[j]
              && !oldFeaturesFound[j]) {
            isGood = true;
            break;
          }
          else if (j == points.size() && static_cast<int>(j) < maxPoints) {
            isGood = true; // in this case, we extend the amount of points
            break;
          }
        }
      
        if (static_cast<int>(j) == maxPoints) {
          if (slmotion::debug > 2)
            cerr << "Maximum number of points reached" << endl;
          break; // the array is full
        }
      
        if (!isGood)
          continue;
      
        // Found a suitable candidate,
        // just make sure it's far enough
        // from any other points
        for (size_t k = 0; k < points.size(); k++) {
          if (!featuresFound[k])
            continue;
          else {
            const Point2f& point2 = points[k];
            // dx = point2.x - point1.x;
            // dy = point2.y - point1.y;
            // printf("%i %i\n", point2.x, point1.x);
          
            // if (dx*dx + dy*dy < min_distance_sq) {
            if (norm(point2 - point1) < minDistance) {
              isGood = false;
              break;
            }
          } 
        }
      
        if (isGood) {
          if (j == points.size()) {
            points.push_back(tempPoints[i]);
            featuresFound.push_back(1);
          
            if (slmotion::debug > 2)
              cerr << "Increasing point count to " << points.size() << endl;
          }
          else {
            points[j] = tempPoints[i];
            featuresFound[j] = 1;
          }
          rc++;
        }
      }
      if (slmotion::debug > 2) {
        cerr << "Replaced " << rc << " points" << endl;
      }
    }
  }
#endif
  namespace {
    /**
     * Removes the point at the given index i if it is invalid. Returns true
     * if the feature was removed
     */
    bool removeIfInvalid(vector<Point2f>& corners,
                         vector<uchar>& featuresFound,
                         std::vector<float>& errors,
                         const cv::Mat& mask,
                         bool removeNonMaskedFeatures,
                         size_t i, double maxFrameError) {
      if (removeNonMaskedFeatures && !mask.empty()) {
        if (corners[i].x >= mask.size().width ||
            corners[i].y >= mask.size().height || 
            corners[i].x < 0 || corners[i].y < 0 ||
            mask.at<uchar>(corners[i]) == 0) {
          if (debug > 2) 
            cerr << "Removed point " << corners[i] << " because of mask" 
                 << endl;
          
          featuresFound[i] = 0;
          return true;
        }
      }

      if (errors[i] > maxFrameError) {
        if (debug > 2) 
          cerr << "Removed point because of error" << endl;

        featuresFound[i] = 0;
        return true;
      }

      return false;      
    }



    /**
     * Removes any too close neighbour points that occur after the point at
     * the given index
     *
     * @return The number of points removed
     */
    int removeTooCloseNeighbours(const vector<Point2f>& corners,
                                 vector<uchar>& featuresFound,
                                 size_t i, int minDistance) {
      int rc = 0;
      for (size_t j = i + 1; j < corners.size(); j++) 
        if (featuresFound[j]) 
          if (norm(corners[i] - corners[j]) < minDistance) 
            featuresFound[j] = 0, ++rc;
          
      return rc;
    }

  }



  void KLTTracker::removeInvalidPoints(vector<Point2f>& corners,
                                       vector<uchar>& featuresFound,
                                       std::vector<float>& errors,
                                       const cv::Mat& mask) {
    assert(corners.size() == featuresFound.size());
    assert(errors.size() == corners.size());
    assert(mask.empty() || mask.type() == CV_8UC1);
    int rc = 0;
    int invalidPointCount = 0;
    int tooCloseNeighbourCount = 0;

    // Finally, remove features that are too near other features
    // or that are otherwise invalid
    for (size_t i = 0; i < corners.size(); i++) {
      if (featuresFound[i]) {
        if (removeIfInvalid(corners, featuresFound, errors, mask, 
                            removeNonMaskedFeatures, i, maxFrameError)) {
          ++rc;
          ++invalidPointCount;
        }
        else {
          int neighBours = removeTooCloseNeighbours(corners, featuresFound,
                                                    i, minDistance);
          rc += neighBours;
          tooCloseNeighbourCount += neighBours;
        }
      }
      
    }
    if (slmotion::debug > 1)
      cerr << "Removed " << rc << " points: " << tooCloseNeighbourCount 
           << " were too near one another and " << invalidPointCount
           << " were otherwise invalid" << endl;
  }
  
  
  
  set<TrackedPoint> KLTTracker::track(const Mat& frame, 
                                      const Mat& previousFrame,
                                      const set<TrackedPoint>& oldPoints,
                                      const Mat& mask) {
    assert(frame.type() == CV_8UC1);
    assert(previousFrame.type() == frame.type());
    assert(previousFrame.size() == frame.size());
    assert(mask.empty() || 
           (mask.type() == frame.type() && mask.size() == frame.size()));

    vector<Point2f> previousPoints;
    insert_transform(oldPoints, previousPoints,
                     [](const TrackedPoint& p) {
                       return static_cast<Point2f>(p);
                     });
    vector<Point2f> nextPoints;
    vector<uchar> status;
    vector<float> err;

    calcOpticalFlowPyrLK(previousFrame, frame, previousPoints, 
                         nextPoints, status, err);

    purgeAndReplacePoints(frame, nextPoints, status, err, mask);

    std::set<TrackedPoint> newPoints;
    auto it = oldPoints.begin();
    // update old points
    for (size_t i = 0; i < oldPoints.size(); ++i, ++it)
      if (status[i])
        newPoints.insert(it->relocate(nextPoints[i]));
    // add new points but do not let the number of points overgrow the
    // maximum amount allowed

    for (size_t i = oldPoints.size(); i < nextPoints.size() &&
           static_cast<int>(newPoints.size()) < maxPoints; ++i) 
      if (status[i])
        newPoints.insert(TrackedPoint(nextPoints[i]));
    
    return newPoints;
  }



  std::set<TrackedPoint> KLTTracker::updateBodyPartIdentities(const std::set<TrackedPoint>& points,
                                            frame_number_t frnumber) {
    std::set<TrackedPoint> updatedPoints;
    BlackBoardPointer<std::list<BodyPart> > currentBodyPartsPtr = getBlackBoard().get<list<BodyPart>>(frnumber, BODYPARTCOLLECTOR_BLACKBOARD_ENTRY);
    std::list<BodyPart>& currentBodyParts = *currentBodyPartsPtr;

    std::set<TrackedPoint>* previousPoints = NULL;
    if (frnumber > 0 && 
        getBlackBoard().has(frnumber-1, KLTTRACKER_BLACKBOARD_TRACKED_POINTS_ENTRY))
      previousPoints = &*getBlackBoard().get<std::set<TrackedPoint>>(frnumber-1, KLTTRACKER_BLACKBOARD_TRACKED_POINTS_ENTRY);

    for (auto it = points.cbegin(); it != points.cend(); ++it) {
      double bestd, d;
      bestd = DBL_MAX;
      BodyPart::PartName b = BodyPart::UNKNOWN;
      for (list<BodyPart>::iterator bt = currentBodyParts.begin();
           bt != currentBodyParts.end(); ++bt) {
        d = bt->distanceTo(static_cast<Point2f>(*it));
        if (d < bestd) {
          bestd = d;
          b = bt->getIdentity();
        }
      }

      if (previousPoints) {
        // Compare body part identities with those from the previous frame
        // and associate ambiguous body parts with non-ambiguous ones
        auto jt = previousPoints->find(*it);
        if (jt != previousPoints->end())
          BodyPart::associateWithBetterIdentity(jt->getPartName(), b);
      }
     
      // do this the ugly way since set always returns const iterators
      TrackedPoint p = *it;
      p.setPartName(b);
      updatedPoints.insert(p);
    }

    return updatedPoints;
  }



  void KLTTracker::process(frame_number_t frameNumber) {
    // check if skin detection results are available and set the mask
    // if they are. Otherwise, use an empty matrix.
    const Mat* skinMask = nullptr;
    BlackBoardPointer<cv::Mat> skinMaskBbPtr;
    cv::Mat emptyMat;
    if (getBlackBoard().has(frameNumber, SKINDETECTOR_BLACKBOARD_MASK_ENTRY)) {
      skinMaskBbPtr = getBlackBoard().get<cv::Mat>(frameNumber, 
                                                   SKINDETECTOR_BLACKBOARD_MASK_ENTRY);
      skinMask = &*skinMaskBbPtr;
    }
    else
      skinMask = &emptyMat;

    // ensure that source images are available
    if (!getBlackBoard().has(frameNumber, COLOURSPACECONVERTER_BLACKBOARD_GSIMAGE_ENTRY))
      throw KLTTrackerException("No greyscale image for the requested frame was found on black board!");

    BlackBoardPointer<Mat> greyScaleImage = getBlackBoard().get<cv::Mat>(frameNumber, COLOURSPACECONVERTER_BLACKBOARD_GSIMAGE_ENTRY);

    std::set<TrackedPoint> trackedPoints; // result set

    // initialisation is done always in the first frame, otherwise if
    // the previous frame had not been tracked
    if (frameNumber == 0 || !getBlackBoard().has(frameNumber-1, KLTTRACKER_BLACKBOARD_TRACKED_POINTS_ENTRY)) {
      if (slmotion::debug > 1)
        cerr << "Initialising Lucas-Kanade feature tracker..." << endl;
      trackedPoints = initialise(*greyScaleImage, *skinMask);
    }
    else {
      if (!getBlackBoard().has(frameNumber-1, COLOURSPACECONVERTER_BLACKBOARD_GSIMAGE_ENTRY))
        throw KLTTrackerException("No greyscale image for the requested frame was found on black board!");
        
      BlackBoardPointer<Mat> previousGreyScaleImage = getBlackBoard().get<cv::Mat>(frameNumber-1, COLOURSPACECONVERTER_BLACKBOARD_GSIMAGE_ENTRY);
      BlackBoardPointer<std::set<TrackedPoint> > oldPoints = getBlackBoard().get<std::set<TrackedPoint>>(frameNumber-1, KLTTRACKER_BLACKBOARD_TRACKED_POINTS_ENTRY);
      trackedPoints = track(*greyScaleImage, *previousGreyScaleImage, 
                            *oldPoints, *skinMask);
    }

    // update body part identity information
    trackedPoints = updateBodyPartIdentities(trackedPoints, frameNumber);

    getBlackBoard().set(frameNumber, KLTTRACKER_BLACKBOARD_TRACKED_POINTS_ENTRY, trackedPoints);
  }



  bool KLTTracker::processRangeImplementation(frame_number_t first, 
                                              frame_number_t last,
                                              UiCallback* uiCallback) {
    for (size_t i = first; i < last; ++i) {
      if (slmotion::debug > 1) {
        cerr << i + 1 - first << '/' << last - first;
      }
      this->process(i);
      if (uiCallback != NULL && !(*uiCallback)(100.*(i+1-first)/(last-first)))
        return false;
    }
    return true;
  }



  boost::program_options::options_description KLTTracker::getConfigurationFileOptionsDescription() const {
    boost::program_options::options_description opts;
    opts.add_options()
      ("KLTTracker.removenonmaskedfeatures", 
       value<bool>()->default_value(true),
       "Remove features that have been tracked to a location that is "
       "outside mask boundaries as detected by the skin detector")
      ("KLTTracker.maxframeerror", value<double>()->default_value(1000),
       "Maximal acceptable error tracking error between two consequent frames.")
      ("KLTTracker.gfquality", value<double>()->default_value(0.01),
       "Sets the quality level for finding good features")
      ("KLTTracker.gfmindistance", value<double>()->default_value(3),
       "Sets the minimum distance between good features")
      ("KLTTracker.maxpoints", value<int>()->default_value(1000),
       "Sets the maximum number of points to track")
      ("KLTTracker.maxmove", value<int>()->default_value(30),
       "Sets the maximum number of pixels that a feature can move between "
       "two consecutive frames.");
    return opts;
  }



  Component* KLTTracker::createComponentImpl(const boost::program_options::variables_map& configuration, BlackBoard* blackBoard, FrameSource* frameSource) const {
    KLTTracker featureTracker(blackBoard, frameSource);
    if (configuration.count("KLTTracker.removenonmaskedfeatures"))
      featureTracker.setRemoveNonMaskedFeatures(configuration["KLTTracker.removenonmaskedfeatures"].as<bool>());

    if (configuration.count("KLTTracker.maxframeerror"))
      featureTracker.setMaxFrameError(configuration["KLTTracker.maxframeerror"].as<double>());

    if (configuration.count("KLTTracker.gfquality"))
      featureTracker.setQualityLevel(configuration["KLTTracker.gfquality"].as<double>());

    if (configuration.count("KLTTracker.gfmindistance"))
      featureTracker.setMinDistance(configuration["KLTTracker.gfmindistance"].as<double>());

    if (configuration.count("KLTTracker.maxpoints"))
      featureTracker.setMaxPoints(configuration["KLTTracker.maxpoints"].as<int>());

    if (configuration.count("KLTTracker.maxmove"))
      featureTracker.setMaxMove(configuration["KLTTracker.maxmove"].as<int>());

    return new KLTTracker(featureTracker);
  }
}
