//
// This file is a part of the Bayes Blocks library
//
// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
// Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2, or (at your option)
// any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License (included in file License.txt in the
// program package) for more details.
//
// $Id: PythonSaver.cc 7 2006-10-26 10:26:41Z ah $

#include "config.h"
#ifdef WITH_PYTHON
#include <Python.h>
#define __PYTHON_H_INCLUDED__

#ifdef WITH_NUMPY
#include "numpy/arrayobject.h"
#endif
#ifdef WITH_NUMERIC
#define NO_IMPORT_ARRAY
#define PY_ARRAY_UNIQUE_SYMBOL __py_numeric_array__
#include "Numeric/arrayobject.h"
#endif
#endif

#include "PythonSaver.h"
#ifdef WITH_PYTHON
#include <iostream>

PythonSaver::PythonSaver(bool _usearray) : usearray(_usearray)
{
  root_node = 0;
  last.push_back(NONE);
  saved = false;
}

PythonSaver::~PythonSaver()
{
  if (!saved)
    SaveIt();
  // delete root_node;
}


void PythonSaver::SaveIt()
{
  cout << "I would now save the net to a file if I just knew how..." << endl;
  // save(filename.c_str(), varname.c_str(), *root_node);
  saved = true;
}


void PythonSaver::StartEnumeratedContainer(int size, string name)
{
  switch (last.back()) {
  case NONE:
    root_node = PyList_New(0);
    open_enum_containers.push_back(root_node);
    last.push_back(ENUMER);
    break;
  default:
    open_enum_containers.push_back(PyList_New(0));
    last.push_back(ENUMER);
  }
}

void PythonSaver::StartNamedContainer(string name)
{
  switch (last.back()) {
  case NONE:
    root_node = PyDict_New();
    open_named_containers.push_back(root_node);
    last.push_back(NAMED);
    break;
  default:
    open_named_containers.push_back(PyDict_New());
    last.push_back(NAMED);
  }
}

void PythonSaver::CloseEnumeratedContainer(string name)
{
  (void)last.pop_back();
  PyObject *closer = open_enum_containers.back();
  open_enum_containers.pop_back();

  switch (last.back()) {
  case NONE:
    break;
  case NAMED:
    SetNamedArray(name, closer);
    break;
  case ENUMER:
    SetArray(closer);
  }
}

void PythonSaver::CloseNamedContainer(string name)
{
  (void)last.pop_back();
  PyObject *closer = open_named_containers.back();
  open_named_containers.pop_back();

  switch (last.back()) {
  case NONE:
    break;
  case NAMED:
    SetNamedArray(name, closer);
    break;
  case ENUMER:
    SetArray(closer);
  }
}

void PythonSaver::StartNode(string type)
{
  StartNamedContainer(type);
  SetNamedString("type", type);
}

void PythonSaver::CloseNode(string type)
{
  CloseNamedContainer(type);
}


void PythonSaver::SetNamedArray(string name, PyObject *val)
{
  PyObject *cont = open_named_containers.back();
  PyObject *str = PyString_FromString(name.c_str());

  if (PyDict_SetItem(cont, str, val))
    throw std::runtime_error("Trouble: PythonSaver::SetNamedArray");
  Py_DECREF(str);
}

void PythonSaver::SetArray(PyObject *val)
{
  PyObject *cont = open_enum_containers.back();

  if (PyList_Append(cont, val))
    throw std::runtime_error("Trouble: PythonSaver::SetArray");
}


PyObject *PythonSaver::DumpDFlags(const DFlags f)
{
  PyObject *res = PyDict_New();

  PyDict_SetItemString(res, "mean", PyInt_FromLong(f.mean));
  PyDict_SetItemString(res, "var", PyInt_FromLong(f.var));
  PyDict_SetItemString(res, "ex", PyInt_FromLong(f.ex));

  return res;
}

PyObject *PythonSaver::DumpDSSet(const DSSet f)
{
  PyObject *res = PyDict_New();

  PyDict_SetItemString(res, "mean", PyFloat_FromDouble(f.mean));
  PyDict_SetItemString(res, "var", PyFloat_FromDouble(f.var));
  PyDict_SetItemString(res, "ex", PyFloat_FromDouble(f.ex));

  return res;
}

#ifdef WITH_NUMERIC
PyObject *PythonSaver::DumpDV_array(const DV f)
{
  int dims[1];
  dims[0] = f.size();
  PyArrayObject *res 
    = (PyArrayObject *)PyArray_FromDims(1, dims, PyArray_DOUBLE);

  for (size_t i = 0; i < f.size(); i++) {
    *(double *)(res->data + i*res->strides[0]) = f[i];
  }

  return PyArray_Return(res);
}
#else
PyObject *PythonSaver::DumpDV_array(const DV f)
{
  return DumpDV_list(f);
}
#endif

PyObject *PythonSaver::DumpDV_list(const DV f)
{
  PyObject *res = PyList_New(f.size());

  for (size_t i = 0; i < f.size(); i++)
    PyList_SetItem(res, i, PyFloat_FromDouble(f[i]));

  return res;
}

PyObject *PythonSaver::DumpDVSet(const DVSet f)
{
  PyObject *res = PyDict_New();

  PyDict_SetItemString(res, "mean", DumpDV(f.mean));
  PyDict_SetItemString(res, "var", DumpDV(f.var));
  PyDict_SetItemString(res, "ex", DumpDV(f.ex));

  return res;
}

PyObject *PythonSaver::DumpDVH(const DVH f)
{
  PyObject *res = PyDict_New();

  PyDict_SetItemString(res, "scalar", DumpDSSet(f.scalar));
  if(f.vec) {
    PyDict_SetItemString(res, "vec", DumpDVSet(*f.vec));
  }

  return res;
}

PyObject *PythonSaver::DumpDD(const DD f)
{
  PyObject *res = PyDict_New();

  PyDict_SetItemString(res, "val", DumpDV(*(f.GetDV())));

  return res;
}

PyObject *PythonSaver::DumpVDD(const VDD f)
{
  PyObject *res = PyList_New(f.size());

  for (size_t i = 0; i < f.size(); i++) {
    PyList_SetItem(res, i, DumpDD(f[i]));
  }

  return res;
}

PyObject *PythonSaver::DumpIntV(const IntV f)
{
  PyObject *res = PyList_New(f.size());

  for (size_t i = 0; i < f.size(); i++)
    PyList_SetItem(res, i, PyInt_FromLong(f[i]));

  return res;
}

#endif // WITH_PYTHON
