#include "HOGDescriptor.hpp"
#include "math.hpp"
#include "util.hpp"

namespace slmotion {
  HOGDescriptor::HOGDescriptor(const cv::Mat& m) {
    if (m.rows % 16 != 0 || m.cols % 16 != 0)
      throw std::invalid_argument("Input matrix dimensions must be divisible "
                                  "by 16!");
    hog = cv::HOGDescriptor(m.size(), cv::Size(16,16), cv::Size(8,8), 
                            cv::Size(8,8), 9);
    hog.compute(m, hogValues);

    assert(hog.getDescriptorSize() == hogValues.size());

    cellSize = hog.cellSize;
    assert(cellSize == cv::Size(8,8));

    cv::Size winSize = hog.winSize;
    assert(winSize == m.size());
    assert(winSize.width % cellSize.width == 0);
    nBlocksX = winSize.width / cellSize.width - 1;
    assert(winSize.height % cellSize.height == 0);
    nBlocksY = winSize.height / cellSize.height - 1;
    
    nBlocks = nBlocksX*nBlocksY;

    blockSize = hog.blockSize;
    assert(blockSize.width % cellSize.width == 0);
    assert(blockSize.height % cellSize.height == 0);
    cellsPerBlock = (blockSize.width / cellSize.width) * 
      (blockSize.height/cellSize.height);

    nBins = hog.nbins;

    assert((int)hog.getDescriptorSize() == nBins * cellsPerBlock * nBlocks);
    valuesPerBlock = nBins * cellsPerBlock;
  }
  


  std::vector<HOGDescriptor::Cell> HOGDescriptor::constructCells(const std::vector<float>& vals,
                                                                 const cv::Point& blockTl) const {
    assert((int)vals.size() == valuesPerBlock);
    assert(valuesPerBlock == cellsPerBlock * hog.nbins);
    int cellsX = hog.blockSize.width / hog.cellSize.width;
    int cellsY = hog.blockSize.height / hog.cellSize.height;
    std::vector<HOGDescriptor::Cell> cells(cellsPerBlock);
    for (int x = 0; x < cellsX; ++x) {
      for (int y = 0; y < cellsY; ++y) {
        int cellNr = x*cellsX + y;
        Cell& c = cells[cellNr];
        c.size = hog.cellSize;
        c.tl = cv::Point(blockTl.x + x*hog.cellSize.width,
                         blockTl.y + y*hog.cellSize.height);
        c.br = cv::Point(c.tl.x + hog.cellSize.width,
                         c.tl.y + hog.cellSize.height);
        c.bins = std::vector<float>(vals.begin() + cellNr*hog.nbins,
                                    vals.begin() + (cellNr+1)*hog.nbins);
      }
    }
    return cells;
  }

  

  HOGDescriptor::Block HOGDescriptor::getBlock(int x, int y) const {
    assert(x >= 0 && y >= 0 && x < nBlocksX && y < nBlocksY);
    Block b;
    int blockIdx = x * nBlocksY + y;
    std::vector<float> blockVals(hogValues.begin() + blockIdx*valuesPerBlock,
                                 hogValues.begin() + (blockIdx+1)*valuesPerBlock);

    cv::Point tl(x*hog.blockStride.width, y*hog.blockStride.height);
    cv::Point br(br.x = tl.x + hog.blockSize.width,
                 tl.y + hog.blockSize.height);

    std::vector<Cell> cells = constructCells(blockVals, tl);

    return Block { hog.blockSize, tl, br, cells };
  }



  void HOGDescriptor::drawOrientedGradients(cv::Mat& m, const cv::Point2d& p, 
                                            const std::vector<float>& grads,
                                            double scaleFactor) {
    assert(grads.size() == 9);
    double radsPerBin = math::PI/grads.size();
    for (size_t i = 0; i < grads.size(); ++i) {
      double rads = radsPerBin/2 + i*radsPerBin;
      cv::Point2d dir(cos(rads), sin(rads));
      double len = scaleFactor*grads[i];
      cv::line(m, p - len*dir, p + len*dir, cv::Scalar(0,255,0), 1);
    }
  }



  void HOGDescriptor::visualise(cv::Mat& m) const {
    if (m.size() != hog.winSize)
      throw std::invalid_argument("Wrong size matrix given as input!");
    for (int x = 0; x < nBlocksX; ++x) {
      for (int y = 0; y < nBlocksY; ++y) {
        Block b = getBlock(x,y);
        for (const Cell& c : b.cells) {
          cv::Point centroid = 0.5 * (c.tl + c.br);
          drawOrientedGradients(m, centroid, c.bins, 5.0);
        }
      }
    }
  }



  HOGDescriptor::ValueIdentifier HOGDescriptor::identify(int i) const {
    ValueIdentifier vi;
    int blockIdx = (i - i % valuesPerBlock) / valuesPerBlock;
    vi.blockXIdx = blockIdx / nBlocksY;
    vi.blockYIdx = blockIdx % nBlocksY;
    vi.cellIdx = (i % valuesPerBlock) / nBins;
    vi.tl.x = (vi.blockXIdx + vi.cellIdx / 2)*cellSize.width;
    vi.tl.y = (vi.blockYIdx + vi.cellIdx % 2)*cellSize.height;
    vi.br.x = vi.tl.x + cellSize.width;
    vi.br.y = vi.tl.y + cellSize.height;
    vi.binNumber = i % nBins;
    double radsPerBin = math::PI/nBins;
    vi.angle = radsPerBin/2 + vi.binNumber * radsPerBin;
    vi.centroid = 0.5 * (vi.tl + vi.br);
    return vi;
  }



  double HOGDescriptor::ValueIdentifier::distanceTo(const ValueIdentifier& that, double angleWeight) const {
    double d = cos(angle-that.angle);
    return cv::norm(centroid-that.centroid) + angleWeight * (1.0-d*d);
  };


  typedef HOGDescriptor::ValueIdentifier ValueIdentifier;
  
  static cv::Mat generateDistanceMatrixFromValueIdentifiers(double w, 
                                                            const std::vector<ValueIdentifier>& valsThis,
                                                            const std::vector<ValueIdentifier>& valsThat) {
    int N = valsThis.size();
    int M = valsThat.size();
    cv::Mat m(valsThis.size(), valsThat.size(), CV_32FC1, cv::Scalar::all(0));
    for (int i = 0; i < N; ++i) {
      const ValueIdentifier& vali = valsThis[i];
      for (int j = 0; j < M; ++j) {
        const ValueIdentifier& valj = valsThat[j];
        m.at<float>(i,j) = vali.distanceTo(valj, w);
      }
    }
    return m;
  }



  cv::Mat HOGDescriptor::generateDistanceMatrix(double w) const {
    int N = hogValues.size();
    cv::Mat m(N,N,CV_32FC1,cv::Scalar::all(0));
    // double radsPerBin = math::PI/nBins;

    // auto angle = [&](int i) {
    //   return radsPerBin/2 + (i % nBins) * radsPerBin;
    // };

    std::vector<ValueIdentifier> valis(N);
    for (int i = 0; i < N; ++i)
      valis[i] = identify(i);

    float* mPtr = m.ptr<float>();
    for (int i = 0; i < N; ++i) {
      const cv::Point2d& centroidi = valis[i].centroid;
      // double anglei = angle(i);
      double anglei = valis[i].angle;
      for (int j = 0; j < N; ++j) {
        const cv::Point2d& centroidj = valis[j].centroid;
        // double anglej = angle(j);
        // float d = cos(anglei-anglej);
        double anglej = valis[j].angle;
        double d = cos(anglei-anglej);
        *mPtr++ = cv::norm(centroidi-centroidj) + w * (1.0-d*d);
      }
    }
    return m;
  }



  cv::Mat HOGDescriptor::generateDistanceMatrix(double w, const HOGDescriptor& that) const {
    int N = hogValues.size();
    int M = that.hogValues.size();

    std::vector<ValueIdentifier> valsThis(N);
    for (int i = 0; i < N; ++i)
      valsThis[i] = identify(i);

    std::vector<ValueIdentifier> valsThat(M);
    for (int i = 0; i < M; ++i)
      valsThat[i] = that.identify(i);

    return generateDistanceMatrixFromValueIdentifiers(w, valsThis, valsThat);
  }



  void HOGDescriptor::generateTrimmedDistanceMatrix(double w, double threshold, 
                                                    size_t maxSize,
                                                    const HOGDescriptor& that, 
                                                    cv::Mat& outDm,
                                                    cv::Mat& outThisWeights,
                                                    cv::Mat& outThatWeights,
                                                    std::vector<int>* indicesThisPtr,
                                                    std::vector<int>* indicesThatPtr
                                                    ) const {
    if (w < 0.0)
      throw std::invalid_argument("Angle component weight must be non-negative!");

    if (threshold < 0.0)
      throw std::invalid_argument("Th threshold must be non-negative!");

    // weight index pairs
    std::vector< std::pair<float, size_t> > weightsThis(hogValues.size());
    std::vector< std::pair<float, size_t> > weightsThat(that.hogValues.size());
    for (size_t i = 0; i < hogValues.size(); ++i)
      weightsThis[i] = std::make_pair(hogValues[i], i);
    for (size_t i = 0; i < that.hogValues.size(); ++i)
      weightsThat[i] = std::make_pair(that.hogValues[i], i);

    // sort in descending order
    std::sort(weightsThis.begin(), weightsThis.end());
    std::reverse(weightsThis.begin(), weightsThis.end());
    std::sort(weightsThat.begin(), weightsThat.end());
    std::reverse(weightsThat.begin(), weightsThat.end());

    // compute cumulative proportional sums
    float totalThis = 0.0;
    float totalThat = 0.0;
    for (size_t i = 0; i < hogValues.size(); ++i)
      totalThis += weightsThis[i].first;
    for (size_t i = 0; i < that.hogValues.size(); ++i)
      totalThat += weightsThat[i].first;

    weightsThis[0].first /= totalThis;
    weightsThat[0].first /= totalThat;

    for (size_t i = 1; i < hogValues.size(); ++i)
      weightsThis[i].first = weightsThis[i].first / totalThis + weightsThis[i-1].first;
    for (size_t i = 1; i < that.hogValues.size(); ++i)
      weightsThat[i].first = weightsThat[i].first / totalThat + weightsThat[i-1].first;

    // erase extras
    auto pred = [&threshold](const std::pair<float,size_t>& v) {
      return v.first < threshold;
    };
    weightsThis.erase(std::find_if_not(weightsThis.begin(), weightsThis.end(), pred), 
                      weightsThis.end());
    weightsThat.erase(std::find_if_not(weightsThat.begin(), weightsThat.end(), pred), 
                      weightsThat.end());

    std::vector<int> indicesThis(std::min(weightsThis.size(), maxSize));
    std::vector<int> indicesThat(std::min(weightsThat.size(), maxSize));
    for (size_t i = 0; i < indicesThis.size(); ++i)
      indicesThis[i] = weightsThis[i].second;

    for (size_t i = 0; i < indicesThat.size(); ++i)
      indicesThat[i] = weightsThat[i].second;

    std::sort(indicesThis.begin(), indicesThis.end());    
    std::sort(indicesThat.begin(), indicesThat.end());
    
    outThisWeights = cv::Mat(indicesThis.size(), 1, CV_32FC1);
    float* p = outThisWeights.ptr<float>();
    for (size_t i = 0; i < indicesThis.size(); ++i)
      *p++ = hogValues[indicesThis[i]];

    outThatWeights = cv::Mat(indicesThat.size(), 1, CV_32FC1);
    p = outThatWeights.ptr<float>();
    for (size_t i = 0; i < indicesThat.size(); ++i)
      *p++ = that.hogValues[indicesThat[i]];

    std::vector<ValueIdentifier> valisThis(indicesThis.size());
    for (size_t i = 0; i < indicesThis.size(); ++i)
      valisThis[i] = identify(indicesThis[i]);

    std::vector<ValueIdentifier> valisThat(indicesThat.size());
    for (size_t i = 0; i < indicesThat.size(); ++i)
      valisThat[i] = that.identify(indicesThat[i]);

    outDm = generateDistanceMatrixFromValueIdentifiers(w, valisThis, valisThat);

    if (indicesThisPtr)
      *indicesThisPtr = std::move(indicesThis);
    if (indicesThatPtr)
      *indicesThatPtr = std::move(indicesThat);
  }



  HOGDescriptor HOGDescriptor::getBlockNormalisedCopy() const {
    HOGDescriptor copy(*this);
    auto it = copy.hogValues.begin();
    for (int b = 0; b < nBlocks; ++b) {
      double sum = 0.0;
      auto jt = it;
      for (int i = 0; i < valuesPerBlock; ++i)
        sum += *jt++;
      
      if (sum > 0)
        for (int i = 0; i < valuesPerBlock; ++i)
          *it++ /= sum;
      else
        it += valuesPerBlock;
    }
    assert(it == copy.hogValues.end());
    return copy;
  }
}
