// this little utility can be used for training an ELM-based skin detector

#include "random.hpp"
#include "util.hpp"
#include "ELMClassifier.hpp"
#include "ColourSpace.hpp"
#include <stack>

void printHelp(const std::string& progName) {
  std::cout << "usage:" << std::endl
            << progName << " [-v] [--mode wc|bw] [--sampling-mode uniform|balanced] <frames.txt> <gts.txt> "
            << "<N neurons>" 
            << std::endl
            << std::endl
            << "The first file contains a list of filenames (one per line) "
            << "where each file contains to an input training sample image. "
            << "The second file contains a corresponding list of ground truths,"
            << " again one per line. The number of lines in both files should "
            << "match, and every input image should be of equal size with its "
            << "corresponding ground truth image. The third parameter should "
            << "be the number of neurons."
            << std::endl
            << "The mode parameter sets how the ground truth mask should be "
            << "interpreted: in bw mode, black pixels are considered non-skin "
            << "and white pixels are considered skin. In wc mode, white pixels "
            << "are considered non-skin and coloured pixels skin."
            << std::endl
            << "The sampling mode parameter specifies how the input data "
            << "vectors are sampled in case there is more data than can fit in "
            << "the main memory. Uniform sampling treats all samples equally. "
            << "balanced sampling drops samples in proportion to their "
            << "prevalence, so that the samples chosen for training will be "
            << "equally distrituted among classes"
            << std::endl    
            << std::endl
            << "-v enables verbose mode" << std::endl;
}

std::vector<std::string> fillVector(std::ifstream& ifs) {
  std::vector<std::string> vec;
  std::string line;
  while (std::getline(ifs, line))
    vec.push_back(line);
  return vec;
}

enum class Parameter {
  MODE,
  POSITIONAL_FRAMES,
  POSITIONAL_GROUND_TRUTH,
  POSITIONAL_NNEURONS,
  VERBOSE,
  SAMPLING_MODE
};

enum class Mode {
  BLACK_WHITE,
  WHITE_COLOUR
};

enum class SamplingMode {
  UNIFORM,
  BALANCED
};

static const cv::Vec3b WHITE_PIXEL(255,255,255);
static const cv::Vec3b BLACK_PIXEL(255,255,255);

int main(int argc, char* argv[]) {
  slmotion::random::seed();

  std::stack<Parameter> parameterStack;
  parameterStack.push(Parameter::POSITIONAL_NNEURONS);
  parameterStack.push(Parameter::POSITIONAL_GROUND_TRUTH);
  parameterStack.push(Parameter::POSITIONAL_FRAMES);
    
  if (argc < 4) {
    printHelp(argv[0]);
    return 0;
  }

  bool verbose = false;

  int argumentNumber = 1;
  std::vector<std::string> frames;
  std::vector<std::string> gts;
  size_t nNeurons = 0;
  Mode mode = Mode::BLACK_WHITE;
  SamplingMode samplingMode = SamplingMode::UNIFORM;
  while (!parameterStack.empty() || argumentNumber < argc) {
    if (argumentNumber >= argc) {
      std::cerr << argv[0] << ": ERROR: too few arguments." << std::endl;
      return 1;
    }

    if (argv[argumentNumber] == std::string("--mode")) { 
      parameterStack.push(Parameter::MODE);
    }
    else if (argv[argumentNumber] == std::string("--sampling-mode")) {
      parameterStack.push(Parameter::SAMPLING_MODE);
    }
    else if (argv[argumentNumber] == std::string("-v")) {
      verbose = true;
    }
    else {
      std::string argument = argv[argumentNumber];
      Parameter p = parameterStack.top();
      if (p == Parameter::POSITIONAL_FRAMES) {
        std::ifstream framesFile(argv[argumentNumber]);
        if (!framesFile.good()) {
          std::cerr << "Could not read " << argv[argumentNumber] << std::endl;
          return -1;
        }
      
        frames = fillVector(framesFile);
      }
      else if (p == Parameter::POSITIONAL_GROUND_TRUTH) {
        std::ifstream gtsFile(argv[argumentNumber]);
        if (!gtsFile.good()) {
          std::cerr << "Could not read " << argv[argumentNumber] << std::endl;
          return -1;
        }
        gts = fillVector(gtsFile);
      
        if (gts.size() != frames.size()) {
          std::cerr << "Mismatch between the number of training images and ground "
                    << "truths!" << std::endl;
          return -1;
        }
      }
      else if (p == Parameter::POSITIONAL_NNEURONS) {
        nNeurons = atoi(argv[argumentNumber]);
      }
      else if (p == Parameter::MODE) {
        if (argument == "wc")
          mode = Mode::WHITE_COLOUR;
        else if (argument == "bw")
          mode = Mode::BLACK_WHITE;
        else {
          std::cerr << argv[0] << ": ERROR: \"" << argument 
                    << "\" is an invalid mode!" << std::endl;
          return 1;
        }
      }
      else if (p == Parameter::SAMPLING_MODE) {
        if (argument == "uniform")
          samplingMode = SamplingMode::UNIFORM;
        else if (argument == "balanced")
          samplingMode = SamplingMode::BALANCED;
        else {
          std::cerr << argv[0] << ": ERROR: \"" << argument 
                    << "\" is an invalid mode!" << std::endl;
          return 1;
        }
      }

      else {
        std::cerr << argv[0] << ": ERROR: \"" << (int)p 
                  << "\" is an unhandled parameter! This is probably a bug." 
                  << std::endl;
        return 1;
      }
      parameterStack.pop();
    }
    ++argumentNumber;
  }




  // read images
  // std::vector<cv::Mat> frameImgs(frames.size());
  // std::vector<cv::Mat> gtImgs(frames.size());
  if (verbose)
    std::cerr << "Counting number of samples..." << std::endl;
  size_t N = 0; // number of samples
  size_t posN = 0; // number of positive samples
  size_t negN = 0; // number of negative samples
  for (size_t i = 0; i < frames.size(); ++i) {
    if (verbose) 
      std::cerr << "\rReading files " << i+1 << "/" << frames.size() << std::flush;
    cv::Mat frameImg = cv::imread(frames[i]);
    cv::Mat gtImg = cv::imread(gts[i]);
    if (frameImg.cols != gtImg.cols ||
        frameImg.rows != gtImg.rows) {
      std::cerr << "Mismatch between image sizes with " << frames[i] << " and " 
                << gts[i] << std::endl;
      return -1;
    }
    N += frameImg.total();
    if (frameImg.type() != CV_8UC3) {
      std::cerr << frames[i] << " is not an RGB image!" << std::endl;
      return -1;
    }
    if (gtImg.type() != CV_8UC3) {
      std::cerr << gts[i] << " is not an RGB image!" << std::endl;
      return -1;
    }

    for (int u = 0; u < gtImg.rows; ++u) {
      for (int v = 0; v < gtImg.cols; ++v) {
        bool pos;
        if (mode == Mode::BLACK_WHITE) {
          pos = gtImg.at<cv::Vec3b>(u,v) == cv::Vec3b(255,255,255);
        }
        else if (mode == Mode::WHITE_COLOUR) {
          pos = gtImg.at<cv::Vec3b>(u,v) != cv::Vec3b(255,255,255);
        }
        else {
          std::cerr << "INVALID MODE!" << std::endl;
          return -1;
        }
        posN += pos ? 1 : 0;
        negN += !pos ? 1 : 0;
      }
    }
  }

  assert(N == posN + negN);

  if (verbose)  {
    std::cerr << std::endl;
    std::cerr << "Total number of samples: " << N << std::endl;
    std::cerr << "Positive sample count: " << posN << std::endl;
    std::cerr << "Negative sample count: " << negN << std::endl;
  }

  size_t MAXSIZE = 1024L*1024L*1024L*4L;
  size_t MAXSAMPLES = MAXSIZE/(sizeof(float)*nNeurons);
  cv::Mat samples, desiredOutputs;

  if (samplingMode == SamplingMode::UNIFORM) {
    size_t skip = 1;
    if (N > MAXSAMPLES) {
      skip = N / MAXSAMPLES;
      skip = skip > 0 ? skip : 1;
      N = N/skip;
      if (verbose)
        std::cerr << "Limiting to " << N << " samples. Taking every " << skip 
                  << "th sample." << std::endl;
    }

    // convert images to samples
    samples = cv::Mat(N, 3, CV_32FC1);
    desiredOutputs = cv::Mat(N, 1, CV_32FC1);
    int sampleNr = 0;
    int includedSamples = 0;
    for (size_t k = 0; k < frames.size(); ++k) {
      if (verbose)
        std::cerr << "\rProcessing sample " << includedSamples + 1 << "/" << N << 
          " from file " << k+1 << "/" << frames.size() << std::flush;
      cv::Mat frameImg = cv::imread(frames[k]);
      cv::Mat gtImg = cv::imread(gts[k]);
      // slmotion::convertColourInPlace(gtImg, CV_BGR2GRAY);
      assert(frameImg.type() == CV_8UC3);
      assert(gtImg.type() == CV_8UC3);
      assert(frameImg.size() == gtImg.size());

      for (int i = 0; i < frameImg.rows; ++i) {
        for (int j = 0; j < frameImg.cols; ++j) {
          if (sampleNr % skip == 0 && includedSamples < static_cast<int>(N)) {
            cv::Vec3b& v = frameImg.at<cv::Vec3b>(i,j); //frameImgs[k].at<cv::Vec3b>(i,j);
            for (int l = 0; l < 3; ++l)
              samples.at<float>(includedSamples, l) = v[l]; 

            if (mode == Mode::BLACK_WHITE) {
              if (gtImg.at<cv::Vec3b>(i,j) == WHITE_PIXEL) 
                desiredOutputs.at<float>(includedSamples,0) = 1.0;
              else if (gtImg.at<cv::Vec3b>(i,j) == BLACK_PIXEL)
                desiredOutputs.at<float>(includedSamples,0) = -1.0;
              else {
                std::cerr << argv[0] << ": ERROR: mode is set to be Black/White but"
                          << " file \"" << gts[k] << "\" contains a pixel with "
                          << "intensity value " << gtImg.at<uchar>(i,j) << " at"
                          << " (" << j << "," << i << ")" << std::endl;
                return 1;
              }
            }
            else if (mode == Mode::WHITE_COLOUR) {
              if (gtImg.at<cv::Vec3b>(i,j) == WHITE_PIXEL) 
                desiredOutputs.at<float>(includedSamples,0) = -1.0;
              else 
                desiredOutputs.at<float>(includedSamples,0) = 1.0;
            }
            ++includedSamples;
          }
          ++sampleNr;
        }
      }
    }
    if (verbose) {
      std::cerr << std::endl;
      std::cerr << "Processed " << includedSamples << " samples." << std::endl;
    }

    if (skip == 1)
      assert(sampleNr == static_cast<int>(N));
    assert(includedSamples == static_cast<int>(N));
  }
  else if (samplingMode == SamplingMode::BALANCED) {
    cv::Mat positiveSamples(posN, 3, CV_32FC1);
    cv::Mat negativeSamples(negN, 1, CV_32FC1);
    if (verbose) {
      std::cerr << "Reading all samples..." << std::endl;
    }

    int sampleNr = 0;
    int posSampleNr = 0;
    int negSampleNr = 0;
    for (size_t k = 0; k < frames.size(); ++k) {
      if (verbose)
        std::cerr << "\rProcessing sample " << sampleNr + 1 << "/" << N << 
          " from file " << k+1 << "/" << frames.size() << std::flush;
      cv::Mat frameImg = cv::imread(frames[k]);
      cv::Mat gtImg = cv::imread(gts[k]);
      //slmotion::convertColourInPlace(gtImg, CV_BGR2GRAY);
      assert(frameImg.type() == CV_8UC3);
      assert(gtImg.type() == CV_8UC3);
      assert(frameImg.size() == gtImg.size());

      for (int i = 0; i < frameImg.rows; ++i) {
        for (int j = 0; j < frameImg.cols; ++j) {
          cv::Vec3b& v = frameImg.at<cv::Vec3b>(i,j);
          bool pos;
          if (mode == Mode::BLACK_WHITE) 
            pos = gtImg.at<cv::Vec3b>(i,j) == WHITE_PIXEL;
          else if (mode == Mode::WHITE_COLOUR) 
            pos = gtImg.at<cv::Vec3b>(i,j) != WHITE_PIXEL;
          else {
            std::cerr << "INVALID MODE!" << std::endl;
            return -1;
          }

          if (pos) {
            for (int l = 0; l < 3; ++l) 
              positiveSamples.at<float>(posSampleNr, l) = v[l];
            ++posSampleNr;
          }
          else {
            for (int l = 0; l < 3; ++l) 
              negativeSamples.at<float>(negSampleNr, l) = v[l];
            ++negSampleNr;
          }

          ++sampleNr;
        }
      }
    }    
    assert(posSampleNr + negSampleNr == sampleNr);
    assert(negSampleNr == negN);
    assert(posSampleNr == posN);

    if (verbose) 
      std::cerr << std::endl
                << "Processed " << sampleNr << " samples (" << posSampleNr 
                << " positive, " << negSampleNr << " negative)" << std::endl;
    
    // shuffle
    posN = std::min(MAXSAMPLES/2, posN);
    negN = std::min(MAXSAMPLES/2, negN);
    if (posN != negN)
      negN = posN = std::min(posN, negN);

    if (verbose) 
      std::cerr << "Shuffling samples..." << std::endl
                << "Creating index vectors... " << std::flush;
          
    std::vector<int> posIndices(posSampleNr);
    for (size_t i = 0; i < posIndices.size(); ++i) 
      posIndices[i] = i;
    std::vector<int> negIndices(negSampleNr);
    for (size_t i = 0; i < negIndices.size(); ++i) 
      negIndices[i] = i;


    if (verbose)
      std::cerr << "Done." << std::endl
                << "Shuffling indices... " << std::flush;

    std::random_shuffle(posIndices.begin(), posIndices.end());    
    std::random_shuffle(negIndices.begin(), negIndices.end());    
    if (verbose)
      std::cerr << "Done." << std::endl
                << "Taking " << posN << " positive samples and "
                << negN << " negative samples... ";

    assert(posN == negN);
    samples = cv::Mat(posN + negN, 3, CV_32FC1);
    for (size_t i = 0; i < posN; ++i) {
      assert(i < samples.rows);
      assert(i < posIndices.size());
      assert(posIndices[i] < positiveSamples.rows);
      for (int k = 0; k < 3; ++k) {
        samples.at<float>(i, k) = positiveSamples.at<float>(posIndices[i],k);
      }
    }

    for (size_t i = 0; i < negN; ++i) {
      assert(posN + i < samples.rows);
      assert(i < samples.rows);
      assert(i < negIndices.size());
      assert(negIndices[i] < negativeSamples.rows);
      for (int k = 0; k < 3; ++k) {
        samples.at<float>(posN+i, k) = negativeSamples.at<float>(negIndices[i],k);
      }
    }

    desiredOutputs = cv::Mat(posN + negN, 1, CV_32FC1);
    for (int i = 0; i < posN; ++i) {
      assert(i < desiredOutputs.rows);
      desiredOutputs.at<float>(i,0) = 1.0;
    }
    // desiredOutputs.rowRange(0,posN) = 1.0;
    for (int i = posN; i < posN+negN; ++i) {
      assert(i < desiredOutputs.rows);
      desiredOutputs.at<float>(i,0) = -1.0;
    }
    // desiredOutputs.rowRange(posN,posN+negN) = -1.0;
    if (verbose)
      std::cerr << "Done." << std::endl;
  }
  else {
    std::cerr << "INVALID SAMPLING MODE" << std::endl;
    return -1;
  }
  

  slmotion::CvElmClassifier elm;
  if (verbose)
    std::cerr << "Training elm..." << std::endl;
  elm.train(samples, desiredOutputs, nNeurons, &tanhf);

  std::cout << elm.toString() << std::endl;

  return 0;
}
