//
// This file is a part of the Bayes Blocks library
//
// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
// Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2, or (at your option)
// any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License (included in file License.txt in the
// program package) for more details.
//
// $Id: MatlabLoader.cc 5 2006-10-26 09:44:54Z ah $

#include "Templates.h"
#include "MatlabLoader.h"
#ifdef WITH_MATLAB
#include <matrix.h>
#include <mat.h>
#include <iostream>
#include <sstream>


MatlabLoader::MatlabLoader( string _fname, string _varname )
{
  fname = _fname;
  varname = _varname;
  root_node = NULL;
  last.push_back(NONE);
}

MatlabLoader::~MatlabLoader()
{
  mxDestroyArray(root_node);
}

void MatlabLoader::LoadIt(void)
{
  MATFile *fp;

  fp = matOpen(fname.c_str(), "r");
  if (! fp)
    throw MatlabException("Error in loading .mat file!");

#ifndef __OLD_MATLAB__
  root_node = matGetVariable(fp, varname.c_str());
#else
  root_node = matGetArray(fp, varname.c_str());
#endif
  
  if (! root_node)
    throw MatlabException("Bad variable name when loading .mat file!");

  matClose(fp);
}

int MatlabLoader::StartEnumeratedContainer(string name)
{
  mxArray *temp=NULL;

  switch (last.back()) {
  case NONE:
      open_cont.push_back(root_node);
      break;

  case ENUMER:
    temp = mxGetCell(open_cont.back(), ec_ix.back()++);
    if (! temp)
      throw MatlabException("StartEnumeratedContainer: mxGetCell");

    open_cont.push_back(temp);
    break;

  case NAMED:
    temp = mxGetField(open_cont.back(), 0, name.c_str());
    if (! temp) {
      cout << "Open_cont.back() = " << open_cont.back() << endl;
      cout << "IsStruct: " << mxIsStruct(open_cont.back()) << endl;
      const char *fname;  int i=0;
      while ((fname = mxGetFieldNameByNumber(open_cont.back(), i++)))
	cout << "Fieldname: " << fname << endl;
      throw MatlabException("StartEnumeratedContainer: mxGetField " + name);
    }

    open_cont.push_back(temp);
    break;
  }
  ec_ix.push_back(0);
  last.push_back(ENUMER);
  return 1;
}

int MatlabLoader::StartNamedContainer(string name)
{
  mxArray *temp=NULL;

  switch (last.back()) {
  case NONE:
    open_cont.push_back(root_node);
    break;

  case ENUMER:
    temp = mxGetCell(open_cont.back(), ec_ix.back()++);
    if (! temp)
      throw MatlabException("StartNamedContainer: mxGetCell");

    open_cont.push_back(temp);
    
    break;

  case NAMED:
    temp = mxGetField(open_cont.back(), 0, name.c_str());
    if (! temp)
      throw MatlabException("StartNamedContainer: mxGetField");

    open_cont.push_back(temp);

    break;
  }
  last.push_back(NAMED);
  return 1;
}

void MatlabLoader::CloseEnumeratedContainer(string name)
{
  (void)last.pop_back();
  (void)ec_ix.pop_back();

  open_cont.pop_back();
}

void MatlabLoader::CloseNamedContainer(string name)
{
  (void)last.pop_back();

  open_cont.pop_back();
}

int MatlabLoader::StartNode(string & type)
{
  if( !StartNamedContainer("") )
    return 0; // ???
  
  GetNamedString("type", type);
  return 1;
}

void MatlabLoader::CloseNode(string type)
{
  CloseNamedContainer(type);
}

bool MatlabLoader::GetNamedArray(string name, mxArray ** mar)
{
  *mar = mxGetField(open_cont.back(), 0, name.c_str());
  if (! *mar)
    return 0;
  // throw MatlabException("GetNamedArray: mxGetField " + name);

  return 1;
}

bool MatlabLoader::GetArray(mxArray ** mar)
{
  if (ec_ix.back() >= mxGetN(open_cont.back()))
    return 0;

  *mar = mxGetCell(open_cont.back(), ec_ix.back()++);
  if (! *mar)
    throw MatlabException("GetArray: mxGetCell");

  return 1;
}


string MatlabLoader::ToString(mxArray *mar)
{
  char *temp = mxArrayToString(mar);

  string res = string(temp);
  mxFree(temp);

  return res;
}

void MatlabLoader::ToDFlags(mxArray *mar, DFlags & val)
{
  val.mean = ToBool( mxGetField(mar, 0, "mean"));
  val.var  = ToBool( mxGetField(mar, 0, "var"));
  val.ex   = ToBool( mxGetField(mar, 0, "ex"));
}

void MatlabLoader::ToDSSet(mxArray *mar, DSSet & val)
{
  val.mean = ToDouble( mxGetField(mar, 0, "mean"));
  val.var  = ToDouble( mxGetField(mar, 0, "var"));
  val.ex   = ToDouble( mxGetField(mar, 0, "ex"));
}

void MatlabLoader::ToDV(mxArray *mar, DV & val)
{
  val.resize( mxGetN(mar) );

  double *pr = mxGetPr(mar);
  for (size_t i=0; i<val.size(); i++)
    val[i] = pr[i];
}

void MatlabLoader::ToDVSet(mxArray *mar, DVSet & val)
{
  ToDV( mxGetField(mar, 0, "mean"), val.mean );
  ToDV( mxGetField(mar, 0, "var"),  val.var );
  ToDV( mxGetField(mar, 0, "ex"),   val.ex );
}

void MatlabLoader::ToDVH(mxArray *mar, DVH & val)
{
  ToDSSet( mxGetField(mar, 0, "scalar"), val.scalar );
  if (mxGetFieldNumber(mar, "vec") >= 0) {
    if (val.vec)
      delete val.vec;
    val.vec = new DVSet();
    ToDVSet( mxGetField(mar, 0, "vec"), *val.vec );
  }
}

void MatlabLoader::ToDD(mxArray *mar, DD & val)
{
  DV  d;
  ToDV( mxGetField(mar, 0, "val"), d );

  val.Resize( d.size() );

  for ( size_t i = 0; i < val.size(); i++ ) {
    val.Set(i, d[i]);
  }
}

void MatlabLoader::ToVDD(mxArray *mar, VDD & val)
{
  size_t size;

  size = mxGetM(mar);

  val.Resize(size);

  for ( size_t i = 0; i < size; i++ ) {
    ToDD( mxGetCell(mar, i+1), val[i] );
  } 

}

void MatlabLoader::ToIntV(mxArray *mar, IntV & val)
{
  val.resize( mxGetN(mar) );
  double *p = mxGetPr(mar);

  for (size_t i = 0; i < val.size(); i++) {
    val[i] = (int)p[i];
  }
}

void MatlabLoader::ToLabel(mxArray *mar, Label & val)
{
  string str = ToString(mar);
  size_t i, j;
  int intval;

  if ((i = str.find("(")) == str.npos)
  {
      val = str;
      return;
  }
  
  //cout << str << " -> ";
  while ((j = str.find(",", i)) != str.npos) {
    intval = atoi(str.substr(i+1, j-i-1).c_str());
    ostringstream ss;
    ss << intval - 1;
    str.replace(i+1, j-i-1, ss.str());
    i = str.find(",", i)+1;
  }
  
  if ((j = str.find(")", i)) != str.npos) {
    intval = atoi(str.substr(i+1, j-i-1).c_str());
    ostringstream ss;
    ss << intval - 1;
    str.replace(i+1, j-i-1, ss.str());
  }
  //cout << str << endl;

  val = str;
  return;

}
#endif // WITH_MATLAB
