#include "Asm.hpp"
#include "math.hpp"
#include "Visualiser.hpp"
#include "LinePointIterator.hpp"
#include "util.hpp"

using cv::Mat;
using cv::Point;
using cv::Scalar;
using cv::Point2d;
using std::vector;
using std::cerr;
using std::endl;

/**
 * Length for the path to look new landmarks for
 */
static const double SEARCH_PATH_LENGTH = 20;

/**
 * Minimum percentage of the mask area that must be covered by the fitted 
 * model
 */
static const double MINIMUM_COVERED_AREA = 0.7;
static const double MAXIMUM_EXTRA_AREA = 0.3;

namespace slmotion {
  using namespace math;

  extern int debug;



  bool Asm::Instance::operator==(const Asm::Instance& other) const {
    return *this->pdm == *other.pdm && equal(this->bt, other.bt) &&
      this->pose == other.pose && equal(this->landmarks, other.landmarks);
  }



  static Point2d findBottommostPointInMask(const Mat& initialMask) {
    for (int i = initialMask.rows-1; i >= 0; --i) {
      for (int j = initialMask.cols-1; j >= 0; --j) {
        if (initialMask.at<uchar>(i,j) > 0) {
          return Point2d(j, i); 
        }
      }
    }

    return Point2d(-1,-1);
  }



  /**
   * Applies simple pose parameters
   */
  static Mat applyPoseParameters(const Mat& shape,
                                 const Pdm::PoseParameter& pose) {
    Mat newShape(shape.rows, 1, CV_64FC1);

    for (int i = 0; i < shape.rows; i += 2) {
      double x = shape.at<double>(i,0);
      double y = shape.at<double>(i+1,0);
      newShape.at<double>(i,0) = x*pose.scale*cos(pose.theta) - 
        y*pose.scale*sin(pose.theta) + pose.tx;
      newShape.at<double>(i+1,0) = x*pose.scale*sin(pose.theta) +
        y*pose.scale*cos(pose.theta) + pose.ty;
    }

    return newShape;
  }


  /**
   * Given pose parameters and anchor location, creates an actual shape
   */
  static Mat createPoseAndAnchor(const Mat& shape, 
                                 Pdm::PoseParameter pose,
                                 const cv::Point2d& anchor,
                                 Pdm::PoseParameter* outPose = NULL) {
    assert(shape.type() == CV_64FC1);
    assert(shape.cols == 1 && shape.rows % 2 == 0);
    Mat newShape = applyPoseParameters(shape, pose);
    Point2d landmarkZero(newShape.at<double>(0, 0),
                         newShape.at<double>(1, 0));

    Point2d offset = anchor - landmarkZero;
    pose.tx = offset.x;
    pose.ty = offset.y;

    if (outPose) 
      *outPose = pose;

    newShape = applyPoseParameters(shape, pose);

    // for (int i = 0; i < newShape.rows; i += 2) {
    //   newShape.at<double>(i, 0) -= offset.x;
    //   newShape.at<double>(i+1, 0) -= offset.y;
    // }

    return newShape;
  }



  /**
   * Draws a polygon whose corners are located at the given landmark points,
   * and returns the shape as a binary image matrix of given size
   */
  static Mat shapeToMatrix(const cv::Mat& shape, 
                           const cv::Size& frameSize) {
    assert(shape.type() == CV_64FC1 && shape.cols == 1 &&
           shape.rows > 1 && shape.rows % 2 == 0);
    Mat binImg(frameSize, CV_8UC1, Scalar::all(0));
    vector<Point> polyPoints(shape.rows/2);
    for (size_t i = 0; i < polyPoints.size(); ++i) 
      polyPoints[i] = Point(shape.at<double>(2*i, 0),
                            shape.at<double>(2*i+1, 0));

    const Point* p = &polyPoints[0];
    int npoints = polyPoints.size();
    fillPoly(binImg, &p, &npoints, 1, Scalar::all(255));
    return binImg;
  }



  static void getInitialShape(const Mat& initialMask, const Pdm* pdm,
                              const Mat& inFrame, Mat& outInitialShape,
                              Point2d& outAnchorGuess, 
                              Pdm::PoseParameter& initialPose) {
    // locate the bottommost point
    Point2d anchorGuess = findBottommostPointInMask(initialMask);

    // If the mask is empty, use the default
    if (anchorGuess == Point2d(-1, -1))
      anchorGuess = pdm->getMeanAnchor();

    assert(anchorGuess.x >= 0 && anchorGuess.y >= 0);

    // estimate initial rotation
    double bestTheta = 0;
    // number of pixels at the intersection of the mean shape, rotated by 
    // the best value theta
    int numPixelsInIntersectionAtBestTheta = 0; 
    for (double theta = 0; theta <= 2*PI; theta += PI/20) {
      Mat shape = createPoseAndAnchor(pdm->generateShape(Mat()), 
                                      { theta, 1, 0, 0 }, anchorGuess);

      Mat temp = shapeToMatrix(shape, inFrame.size());
      temp = cv::min(temp, initialMask);
      int numPixelsAtIntersection = cv::countNonZero(temp);
      if (numPixelsInIntersectionAtBestTheta < numPixelsAtIntersection) {
        numPixelsInIntersectionAtBestTheta = numPixelsAtIntersection;
        bestTheta = theta;
      }
    }
    
    initialPose = { bestTheta, 1, 0, 0};
    Mat meanShape = pdm->generateShape(Mat());
    outInitialShape = createPoseAndAnchor(pdm->generateShape(Mat()),
                                          initialPose,
                                          anchorGuess,
                                          &initialPose);
    // initialPose.tx = anchorGuess.x - meanShape.at<double>(0,0);
    // initialPose.ty = anchorGuess.y - meanShape.at<double>(1,0);
    // initialPose.tx = anchorGuess.x - outInitialShape.at<double>(0,0);
    // initialPose.ty = anchorGuess.y - outInitialShape.at<double>(1,0);
    outAnchorGuess = anchorGuess;
  }



  static Mat getGradientImage(const Mat& inFrame) {
    const static Mat LOGKERNEL = (cv::Mat_<double>(5, 5) <<
                                  0, 0, -1, 0, 0,
                                  0, -1, -2, -1, 0,
                                  -1, -2, 16, -2, -1,
                                  0, -1, -2, -1, 0,
                                  0, 0, -1, 0, 0);
    Mat greyImg(inFrame.size(), CV_8UC1);
    cv::cvtColor(inFrame, greyImg, CV_BGR2GRAY);
    Mat gradientImg;
    cv::filter2D(greyImg, gradientImg, CV_64F, LOGKERNEL);

    Mat temp = cv::abs(gradientImg);
    double gradMax, gradMin;
    cv::minMaxLoc(temp, &gradMin, &gradMax);
    temp = (temp - gradMin) / (gradMax - gradMin);
    cv::minMaxLoc(temp, &gradMin, &gradMax);

    // threshold(temp, temp, 0.5, 1, cv::THRESH_TOZERO);
    // cv::threshold does not work with CV_64F
    for (int i = 0; i < temp.rows; ++i)
      for (int j = 0; j < temp.cols; ++j)
        if (temp.at<double>(i, j) < 0.1)
          temp.at<double>(i,j) = 0;

    return temp;
  }



  static Mat findTargetLandmarks(const cv::Mat& shape,
                                 const cv::Mat& gradImg) {
    Mat temp = gradImg.clone();
    cv::Mat newShape(shape.size(), shape.type());

    assert(newShape.rows % 2 == 0);
    assert(newShape.cols == 1);
    assert(newShape.type() == CV_64FC1);

    for (int i = 0; i < shape.rows; i += 2) {
      // makes sure the coordinates are not too far off
      auto sanitise = [&](Point2d& p) {
        p.x = std::max(p.x, 0.);
        p.x = std::min(p.x, gradImg.cols-1.);
        p.y = std::max(p.y, 0.);
        p.y = std::min(p.y, gradImg.rows-1.);
      };

      Point2d pCurr(shape.at<double>(i,0), shape.at<double>(i+1,0));
      sanitise(pCurr);

      // preceeding and succeeding landmarks
      int iNext = math::mod(i+2, shape.rows);
      Point2d pNext(shape.at<double>(iNext,0), shape.at<double>(iNext+1,0));
      int iPrev = math::mod(i-2, shape.rows);
      Point2d pPrev(shape.at<double>(iPrev,0), shape.at<double>(iPrev+1,0));

      // the unit differential
      Point2d pDiff = pNext - pPrev;
      pDiff = pDiff * (1./sqrt(pDiff.x*pDiff.x + pDiff.y*pDiff.y));
      std::swap(pDiff.x, pDiff.y);
      pDiff.x *= -1;

      // in case |pDiff| = 0, remove the NaNs
      if (pDiff.x != pDiff.x || pDiff.y != pDiff.y)
        pDiff = Point2d(0,0);

      double currentGradient = gradImg.at<double>(pCurr);
      Point2d newTarget = pCurr;

      Point2d pStart = pCurr - (SEARCH_PATH_LENGTH/2)*pDiff;
      Point2d pEnd = pCurr + (SEARCH_PATH_LENGTH/2)*pDiff;

      sanitise(pStart);
      sanitise(pEnd);

      LinePointIterator it(pStart, pEnd);
      for (size_t j = 0; j < it.length(); ++j, ++it) {
        if (gradImg.at<double>(*it) > currentGradient) {
          currentGradient = gradImg.at<double>(*it);
          newTarget = *it;
          assert(gradImg.at<double>(newTarget) == currentGradient);
        }
      }
          
      newShape.at<double>(i,0) = newTarget.x;
      newShape.at<double>(i+1,0) = newTarget.y;
    }

    return newShape;
  }



  /**
   * Returns the number of non-zero pixels in the binary image covered by
   * the instance
   */
  static int getCoveredPixels(const cv::Mat& binaryImage, 
                              const Asm::Instance& instance) {
    Mat temp1(binaryImage.size(), CV_8UC1, Scalar::all(0));
    instance.drawLine(temp1, CV_FILLED, cv::Scalar::all(255));
    Mat temp2;
    cv::min(temp1, binaryImage, temp2);

    int r = cv::countNonZero(temp2);
    
    return r;
  }



  /**
   * Returns the number of pixels covered by the instance but not present 
   * in the binary image
   */
  static int getExtraPixels(const cv::Mat& binaryImage, 
                              const Asm::Instance& instance) {
    Mat temp1(binaryImage.size(), CV_8UC1, Scalar::all(0));
    instance.drawLine(temp1, CV_FILLED, cv::Scalar::all(255));
    Mat temp2 = Mat(binaryImage.size(), CV_8UC1, Scalar::all(255)) - binaryImage;
    Mat temp3;
    cv::min(temp1, temp2, temp3);

    int r = cv::countNonZero(temp3);
    
    return r;
  }

  /**
   * Returns true if the old pose is "good enough" wrt. the initial mask
   */
  static bool isOldPoseGood(const cv::Mat& initialMask,
                            const Asm::Instance* oldInstance) {
    double coveredPixels = getCoveredPixels(initialMask, *oldInstance);
    double totalPixels = cv::countNonZero(initialMask);
    double extraPixels = getExtraPixels(initialMask, *oldInstance);
    return coveredPixels / totalPixels < MINIMUM_COVERED_AREA ||
      extraPixels/totalPixels > MAXIMUM_EXTRA_AREA;
  }

  Asm::Instance Asm::fitImpl(const cv::Mat& inFrame,
                             const cv::Mat& initialMask,
                             const Asm::Instance* oldInstance) const { 
    assert(initialMask.type() == CV_8UC1);

    Mat shape;
    // initialise the shape vector
    Pdm::PoseParameter oldPose;
    // the b vector such that shape ~ meanshape + b
    Mat shapeB(nComponents, 1, CV_64FC1, Scalar::all(0));

    if (!oldInstance) {
      Point2d anchorGuess;
      getInitialShape(initialMask, pdm.get(), inFrame, shape, 
                      anchorGuess, oldPose);
      // std::cout << "anchor guess: " << anchorGuess << std::endl;
      // std::cout << "first landmark: " << shape.at<double>(0,0) << "x" 
      //           << shape.at<double>(1,0) << std::endl;
      // std::cout << "pose: " << oldPose << std::endl;
      // cv::Mat temp6(initialMask.size(), CV_8UC3);
      // cv::cvtColor(initialMask, temp6, CV_GRAY2BGR);
      // cv::circle(temp6, anchorGuess, 3, cv::Scalar(0,255,0), CV_FILLED);
      // cv::imshow("", temp6);
      // cv::waitKey(0);
    }
    else if (isOldPoseGood(initialMask, oldInstance)) {
        Point2d anchorGuess;
        getInitialShape(initialMask, pdm.get(), inFrame, shape, 
                        anchorGuess, oldPose);
    }
    else {
      oldPose = oldInstance->getPose();
      shape = oldInstance->getLandmarks();
      shapeB = oldInstance->getBt();      
    }

    auto fitShapeAndPose = [this](const cv::Mat& inFrame, 
				  const Pdm& pdm,
				  cv::Mat& shapeB, cv::Mat& shape, 
				  Pdm::PoseParameter& oldPose) {
      Mat gradImg = getGradientImage(inFrame);
      shape = transform(pdm.generateShape(shapeB), oldPose);
      const Mat& eigenValues = pdm.getEigenValues();
     
      for (size_t i = 0; i < 30; ++i) {
        Mat temp(gradImg.clone());
        
        Mat targetLandmarks(findTargetLandmarks(shape, gradImg));
        Mat targetShape;
        Pdm::PoseParameter newPose;
        Mat displacement = targetLandmarks - shape;
        align(targetLandmarks, pdm.generateShape(shapeB), newPose);
        Mat deltaTranslation(shape.size(), shape.type());
        double dtx = newPose.tx - oldPose.tx;
        double dty = newPose.ty - oldPose.ty;
        for (int i = 0; i < deltaTranslation.rows; i += 2) {
          deltaTranslation.at<double>(i, 0) = dtx;
          deltaTranslation.at<double>(i+1, 0) = dty;
        }
        Mat approximateDisplacement = 
          transform(shape + displacement - deltaTranslation,
                    Pdm::PoseParameter {-newPose.theta, 1./(newPose.scale), 
                        0, 0}) - shape;
        
        shapeB = shapeB + pdm.generateApproximateShapeDifference(approximateDisplacement, nComponents, maxShapeParameterDeviation);

        // limit values
        for (int i = 0; i < shapeB.rows; ++i) {
          double limit = sqrt(eigenValues.at<double>(i)) *
            maxShapeParameterDeviation;
          shapeB.at<double>(i) = std::max(std::min(shapeB.at<double>(i), limit), -limit);
        }
        
        shape = transform(pdm.generateShape(shapeB), newPose);
        
        oldPose = newPose;

        // Asm::Instance instance { this->pdm, shapeB, oldPose, shape };
        // Mat temp5(inFrame.clone());
        // instance.drawLine(temp5, 1, cv::Scalar(0,255,0));
        // cv::imshow("", temp5);
        // cv::waitKey(0);
      }
    };
   
    fitShapeAndPose(inFrame, *pdm, shapeB, shape, oldPose);

    // ensure that the limits are respected
    assert(shapeB.rows == static_cast<int>(nComponents));
    for (unsigned int i = 0; i < nComponents; ++i) {
      assert (shapeB.at<double>(i) >= -maxShapeParameterDeviation*std::sqrt(this->pdm->getEigenValues().at<double>(i)));
      assert (shapeB.at<double>(i) <= maxShapeParameterDeviation*std::sqrt(this->pdm->getEigenValues().at<double>(i)));
    }

    Asm::Instance instance { this->pdm, shapeB, oldPose, shape };
    // if (isOldPoseGood(initialMask, instance)) 
    //   return instance;

    Point2d anchorGuess;
    getInitialShape(initialMask, pdm.get(), inFrame, shape, 
                    anchorGuess, oldPose);
    fitShapeAndPose(inFrame, *pdm, shapeB, shape, oldPose);
    
    
    // Mat meanShape { pdm->generateShape(Mat()) };
    // shapeB = pdm->generateApproximateShapeDifference(shape - meanShape,
    //                                                  nComponents);

    Asm::Instance instance2 { this->pdm, shapeB, oldPose, shape };

    Asm::Instance instance3 = getCoveredPixels(initialMask, instance) >
      getCoveredPixels(initialMask, instance2) ? instance : instance2;

    // shape = instance3.getLandmarks();
    // if (debug) {
    //   Mat target = inFrame.clone();
    //   for (int i = 0; i < shape.rows; i += 2) {
    //     circle(target, cv::Point(shape.at<double>(i,0),
    //                              shape.at<double>(i+1,0)), 3, 
    //            cv::Scalar::all(255));
    //   }
    //   cv::imshow("", target);
    //   cv::waitKey(0);
    // }

    return instance3;

  }



  Asm::Instance& Asm::Instance::operator=(const Asm::Instance& other) {
    if (this != &other) {
      Instance temp(other);
      this->pdm = temp.pdm;
      this->bt = temp.bt;
      this->pose = temp.pose;
      this->landmarks = temp.landmarks;
    }
    return *this;
  }


  
  cv::Point2d Asm::Instance::computeCentroid() const {
    assert(this->getLandmarks().rows % 2 == 0 && this->getLandmarks().cols == 1);
    assert(this->getLandmarks().type() == CV_64FC1);
    cv::Point2d centroid(0,0);
    for (int i = 0; i < this->getLandmarks().rows; i += 2) {
      centroid.x += this->getLandmarks().at<double>(i,0);
      centroid.y += this->getLandmarks().at<double>(i+1,0);
    }
    return centroid * (1./ static_cast<double>(this->getLandmarks().rows/2));
  }

  void Asm::Instance::getMinSize(cv::Size& outMinSize, cv::Point2d& outTranslation) const {
    assert(landmarks.cols == 1 && landmarks.rows % 2 == 0 &&
           landmarks.rows > 0);
    double x = landmarks.at<double>(0,0);
    double y = landmarks.at<double>(1,0);
    cv::Point2d tl(x,y), br(x,y);

    for (int i = 0; i < landmarks.rows; i += 2) {
      x = landmarks.at<double>(i, 0);
      y = landmarks.at<double>(i+1, 0);
      if (x < tl.x)
        tl.x = x;
      if (y < tl.y)
        tl.y = y;
      if (x > br.x)
        br.x = x;
      if (y > br.y)
        br.y = y;
    }
    outMinSize = cv::Size(std::ceil(br.x) - std::floor(tl.x) + 1,
                          std::ceil(br.y) - std::floor(tl.y) + 1);
    outTranslation = cv::Point2d(-tl.x, -tl.y);
  }

  double intersectionArea(const Asm::Instance& a,
                          const Asm::Instance& b) {
    cv::Size minSizeA, minSizeB, minSize;
    cv::Point2d translateA, translateB;
    a.getMinSize(minSizeA, translateA);
    b.getMinSize(minSizeB, translateB);
    minSize.width = std::max(minSizeA.width, minSizeB.width);
    minSize.height = std::max(minSizeA.height, minSizeB.height);
    cv::Mat temp(minSize, CV_8UC1, cv::Scalar::all(0));
    cv::Mat temp2(minSize, CV_8UC1, cv::Scalar::all(0));
    a.drawLine(temp, CV_FILLED, cv::Scalar::all(255), translateA);
    // cv::imshow("", temp);
    // cv::waitKey(0);
    b.drawLine(temp2, CV_FILLED, cv::Scalar::all(255), translateB);
    // cv::imshow("", temp2);
    // cv::waitKey(0);
    cv::Mat temp3;
    cv::min(temp, temp2, temp3); 
    // cv::imshow("", temp3);
    // cv::waitKey(0);
    // assert(false && "NOT IMPLEMENTED YET");
    return cv::countNonZero(temp3);
  }

  double computeOrientation(const Asm::Instance& inst) {
    cv::Size minSize;
    cv::Point2d translate;
    inst.getMinSize(minSize, translate);

    cv::Mat landmarks(inst.getLandmarks());

    assert(landmarks.rows % 2 == 0 && landmarks.rows > 0 &&
           landmarks.cols == 1 && landmarks.type() == CV_64FC1);
    cv::Point2d bottomPoint(landmarks.at<double>(0,0),
                            landmarks.at<double>(1,0));
    cv::Point2d topPoint(bottomPoint);
    for (int i = 2; i < landmarks.rows; i += 2) {
      cv::Point2d p(landmarks.at<double>(i, 0),
                    landmarks.at<double>(i, 1));
      if (cv::norm(p-bottomPoint) > cv::norm(topPoint-bottomPoint))
        topPoint = p;
    }

    // the y axis must be inverted
    return std::atan2(bottomPoint.y - topPoint.y,
                      topPoint.x-bottomPoint.x);
  }

  void Asm::Instance::drawLine(cv::Mat& target, int thickness, 
                               const cv::Scalar& colour, 
                               const cv::Point2d& translation) const {
    std::vector<cv::Point> points;
    for (int i = 0; i < landmarks.rows; i += 2) 
      points.push_back(cv::Point2d(landmarks.at<double>(i,0),
                                 landmarks.at<double>(i+1,0)) + 
                       translation);
    
    const cv::Point* p = &points[0];
    int i = points.size();

    if (thickness == CV_FILLED)
      fillPoly(target, &p, &i, 1, colour);
    else
      polylines(target, &p, &i, 1, true, colour, thickness);
  }



  const Asm::Instance Asm::EMPTY([] {
      std::shared_ptr<Pdm> pdm(new Pdm(1, cv::Mat::zeros(2, 1, CV_64FC1),
                                       cv::Mat::zeros(2, 2, CV_64FC1), 
                                       cv::Mat::zeros(2, 1, CV_64FC1),
                                       cv::Point2d(0,0)));
      Asm::Instance instance(pdm,
                             cv::Mat::zeros(2, 1, CV_64FC1),
                             Pdm::PoseParameter {0, 0, 0, 0},
                             cv::Mat::zeros(2, 1, CV_64FC1));
      return instance;
    }());
}
