//
// This file is a part of the Bayes Blocks library
//
// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
// Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2, or (at your option)
// any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License (included in file License.txt in the
// program package) for more details.
//
// $Id: MatlabSaver.cc 5 2006-10-26 09:44:54Z ah $

#include "MatlabSaver.h"
#include <sstream>
#include <iostream>
#ifdef WITH_MATLAB

#ifdef __OLD_MATLAB__
mxArray *mxCreateDoubleScalar(const double val)
{
  mxArray *p = mxCreateDoubleMatrix(1, 1, mxREAL);
  *mxGetPr(p) = val;

  return p;
}
#endif

MatlabSaver::MatlabSaver(string filename, string varname)
{
  this->filename = filename;
  this->varname = varname;
  root_node = NULL;
  last.push_back(NONE);
  saved = false;
}

MatlabSaver::~MatlabSaver()
{
  if (!saved)
    SaveIt();
  mxDestroyArray(root_node);
  root_node = NULL;
}


void MatlabSaver::SaveIt()
{
  MATFile *fp;

  fp = matOpen(filename.c_str(), "w");
  if (! fp)
    throw MatlabException("Error in opening .mat file for writing!");

#ifndef __OLD_MATLAB__
  if (matPutVariable(fp, varname.c_str(), root_node))
    throw MatlabException("Error in writing to .mat file!");
#else
  mxSetName(root_node, varname.c_str());
  if (matPutArray(fp, root_node))
    throw MatlabException("Error in writing to .mat file!");
#endif

  if (matClose(fp))
    throw MatlabException("Error in closing .mat file!");

  saved = true;
}


void MatlabSaver::StartEnumeratedContainer(int size, string name)
{
  switch (last.back()) {
  case NONE:
    root_node = mxCreateCellMatrix(1, size);
    if (! root_node)
      throw MatlabException("mxCreateCellMatrix");

    open_enum_containers.push_back(root_node);
    enum_container_indices.push_back(0);
    last.push_back(ENUMER);
    break;
  default:
    mxArray *newnode = NULL;
    newnode = mxCreateCellMatrix(1, size);
    if (! newnode)
      throw MatlabException("mxCreateCellMatrix");

    open_enum_containers.push_back(newnode);
    enum_container_indices.push_back(0);
    last.push_back(ENUMER);
  }
}

void MatlabSaver::StartNamedContainer(string name)
{
  switch (last.back()) {
  case NONE:
    root_node = mxCreateStructMatrix(1, 1, 0, NULL);
    if (! root_node)
      throw MatlabException("StartNamedContainer: mxCreateStructMatrix");

    open_named_containers.push_back(root_node);
    last.push_back(NAMED);
    break;
  default:
    mxArray *newnode = mxCreateStructMatrix(1, 1, 0, NULL);
    if (! newnode)
      throw MatlabException("StartNamedContainer: mxCreateStructMatrix");

    open_named_containers.push_back(newnode);
    last.push_back(NAMED);
  }
}

void MatlabSaver::CloseEnumeratedContainer(string name)
{
  (void)last.pop_back();
  (void)enum_container_indices.pop_back();
  mxArray *closer = open_enum_containers.back();
  open_enum_containers.pop_back();

  switch (last.back()) {
  case NONE:
    root_node = closer;
    break;
  case NAMED:
    SetNamedArray(name, closer);
    break;
  case ENUMER:
    SetArray(closer);
  }
}

void MatlabSaver::CloseNamedContainer(string name)
{
  (void)last.pop_back();
  mxArray *closer = open_named_containers.back();
  open_named_containers.pop_back();

  switch (last.back()) {
  case NONE:
    root_node = closer;
    break;
  case NAMED:
    SetNamedArray(name, closer);
    break;
  case ENUMER:
    SetArray(closer);
  }
}

void MatlabSaver::StartNode(string type)
{
  StartNamedContainer(type);
  SetNamedString("type", type);
}

void MatlabSaver::CloseNode(string type)
{
  CloseNamedContainer(type);
}


void MatlabSaver::SetNamedArray(string name, mxArray *val)
{
  mxArray *cont = open_named_containers.back();
  int fieldnum;

  fieldnum = mxAddField(cont, name.c_str());
  if (fieldnum < 0)
    throw MatlabException("mxAddField");

  mxSetFieldByNumber(cont, 0, fieldnum, val);
}

void MatlabSaver::SetArray(mxArray *val)
{
  mxArray *cont = open_enum_containers.back();

  mxSetCell(cont, enum_container_indices.back()++, val);
}


mxArray *MatlabSaver::DumpDFlags(const DFlags f)
{
  mxArray *mat;
  const char* names[3] = {"mean", "var", "ex"};

  mat = mxCreateStructMatrix(1, 1, 3, names);
  if (! mat)
    throw MatlabException("mxCreateStructMatrix");

  mxSetField(mat, 0, "mean", mxCreateDoubleScalar(f.mean));
  mxSetField(mat, 0, "var",  mxCreateDoubleScalar(f.var));
  mxSetField(mat, 0, "ex",   mxCreateDoubleScalar(f.ex));

  return mat;
}

mxArray *MatlabSaver::DumpDSSet(const DSSet f)
{
  mxArray *mat;
  const char* names[3] = {"mean", "var", "ex"};

  mat = mxCreateStructMatrix(1, 1, 3, names);
  if (! mat)
    throw MatlabException("mxCreateStructMatrix");

  mxSetField(mat, 0, "mean", mxCreateDoubleScalar(f.mean));
  mxSetField(mat, 0, "var",  mxCreateDoubleScalar(f.var));
  mxSetField(mat, 0, "ex",   mxCreateDoubleScalar(f.ex));

  return mat;
}

mxArray *MatlabSaver::DumpDV(const DV f)
{
  double *temp;
  mxArray *mat;

  mat = mxCreateDoubleMatrix(1, f.size(), mxREAL);
  temp = mxGetPr(mat);

  for (size_t i=0; i<f.size(); i++)
    temp[i] = f[i];

  return mat;
}

mxArray *MatlabSaver::DumpDVSet(const DVSet f)
{
  mxArray *mat;
  const char* names[3] = {"mean", "var", "ex"};

  mat = mxCreateStructMatrix(1, 1, 3, names);
  if (! mat)
    throw MatlabException("mxCreateStructMatrix");

  mxSetField(mat, 0, "mean", DumpDV(f.mean));
  mxSetField(mat, 0, "var",  DumpDV(f.var));
  mxSetField(mat, 0, "ex",   DumpDV(f.ex));

  return mat;
}

mxArray *MatlabSaver::DumpDVH(const DVH f)
{
  mxArray *mat;
  const char* names[2] = {"scalar", "vec"};

  mat = mxCreateStructMatrix(1, 1, 2, names);
  if (! mat)
    throw MatlabException("mxCreateStructMatrix");

  mxSetField(mat, 0, "scalar", DumpDSSet(f.scalar));
  mxSetField(mat, 0, "vec",  DumpDVSet(*f.vec));

  return mat;
}

mxArray *MatlabSaver::DumpDD(const DD f)
{
  mxArray *mat;
  const char* names[1] = {"val"};

  mat = mxCreateStructMatrix(1, 1, 1, names);
  if (! mat)
    throw MatlabException("mxCreateStructMatrix");

  mxSetField(mat, 0, "val", DumpDV(*(f.GetDV())));

  return mat;
}

mxArray *MatlabSaver::DumpVDD(const VDD f)
{
  mxArray *res = mxCreateCellMatrix(1, f.size());

  if (! res)
    throw MatlabException("mxCreateCellMatrix");

  for (size_t i = 0; i < f.size(); i++) {
    mxSetCell(res, (int)i, DumpDD(f[(int)i]));
  }

  return res;
}

mxArray *MatlabSaver::DumpIntV(const IntV f)
{
  mxArray *res = mxCreateDoubleMatrix(1, f.size(), mxREAL);

  if (! res)
    throw MatlabException("mxCreateDoubleMatrix");

  double *p = mxGetPr(res);

  for (size_t i = 0; i < f.size(); i++) 
    p[i] = f[i];

  return res;
}

mxArray *MatlabSaver::DumpLabel(const Label f)
{
  string str = f;
  size_t i, j;
  int val;

  if ((i = str.find("(")) == str.npos) {
    return mxCreateString(str.c_str());
  }
  
  //cout << str << " -> ";
  while ((j = str.find(",", i)) != str.npos) {
    val = atoi(str.substr(i+1, j-i-1).c_str());
    ostringstream ss;
    ss << val + 1;
    str.replace(i+1, j-i-1, ss.str());
    i = str.find(",", i)+1;
  }
  
  if ((j = str.find(")", i)) != str.npos) {
    val = atoi(str.substr(i+1, j-i-1).c_str());
    ostringstream ss;
    ss << val + 1;
    str.replace(i+1, j-i-1, ss.str());
  }
  //cout << str << endl;

  return mxCreateString(str.c_str());
}
#endif // WITH_MATLAB
