//
// This file is a part of the Bayes Blocks library
//
// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
// Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2, or (at your option)
// any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License (included in file License.txt in the
// program package) for more details.
//
// $Id: PythonLoader.cc 7 2006-10-26 10:26:41Z ah $

#include "config.h"
#ifdef WITH_PYTHON
#include <Python.h>
#define __PYTHON_H_INCLUDED__

#ifdef WITH_NUMPY
#include "numpy/arrayobject.h"
#endif
#ifdef WITH_NUMERIC
#define NO_IMPORT_ARRAY
#define PY_ARRAY_UNIQUE_SYMBOL __py_numeric_array__
#include "Numeric/arrayobject.h"
#endif
#endif

#include "Templates.h"
#include "PythonLoader.h"
#ifdef WITH_PYTHON
#include <iostream>


PythonLoader::PythonLoader(PyObject *root)
{
  root_node = root;
  last.push_back(NONE);
}

PythonLoader::~PythonLoader()
{
  //delete root_node;
}

void PythonLoader::LoadIt(void)
{
  //load( fname.c_str(), varname.c_str(), root_node );
}

int PythonLoader::StartEnumeratedContainer(string name)
{
  PyObject *val;

  switch (last.back()) {
  case NONE:
    open_cont.push_back(root_node);
    break;

  case ENUMER:
    val = GetArray();
    if (!val)
      return 0;
    open_cont.push_back(val);
    break;

  case NAMED:
    val = GetNamedArray(name);
    if (!val)
      return 0;
    open_cont.push_back(val);
    break;
  }
  ec_ix.push_back(0);
  last.push_back(ENUMER);
  return 1;
}

int PythonLoader::StartNamedContainer(string name)
{
  PyObject *val;

  switch (last.back()) {
  case NONE: // Shouldn't happen...
    open_cont.push_back(root_node);
    break;

  case ENUMER:
    val = GetArray();
    if (!val)
      return 0;
    open_cont.push_back(val);
    break;

  case NAMED:
    val = GetNamedArray(name);
    if (!val)
      return 0;
    open_cont.push_back(val);
    break;
  }
  last.push_back(NAMED);
  return 1;
}

void PythonLoader::CloseEnumeratedContainer(string name)
{
  (void)last.pop_back();
  (void)ec_ix.pop_back();

  //delete open_cont.back();
  open_cont.pop_back();
}

void PythonLoader::CloseNamedContainer(string name)
{
  (void)last.pop_back();

  //delete open_cont.back();
  open_cont.pop_back();
}

int PythonLoader::StartNode(string & type)
{
  if (!StartNamedContainer(""))
    return 0; // ???
  
  GetNamedString("type", type);
  return 1;
}

void PythonLoader::CloseNode(string type)
{
  CloseNamedContainer(type);
}

PyObject *PythonLoader::GetNamedArray(string name)
{
  PyObject *str = PyString_FromString(name.c_str());
  PyObject *val = PyDict_GetItem(open_cont.back(), str);
  Py_DECREF(str);
  return val;
}

PyObject *PythonLoader::GetArray()
{
  if (ec_ix.back() >= PyList_Size(open_cont.back()))
    return 0;
  else
    return PyList_GetItem(open_cont.back(), ec_ix.back()++);
}

string PythonLoader::ToString(PyObject *obj)
{
  char *s = PyString_AsString(obj);
  return string(s);
}

void PythonLoader::ToDFlags(PyObject *obj, DFlags &val)
{
  val.mean = ToBool( PyDict_GetItemString(obj, "mean") );
  val.var  = ToBool( PyDict_GetItemString(obj, "var") );
  val.ex   = ToBool( PyDict_GetItemString(obj, "ex") );
}

void PythonLoader::ToDSSet(PyObject *obj, DSSet &val)
{
  val.mean = ToDouble( PyDict_GetItemString(obj, "mean") );
  val.var  = ToDouble( PyDict_GetItemString(obj, "var") );
  val.ex   = ToDouble( PyDict_GetItemString(obj, "ex") );
}

void PythonLoader::ToDV(PyObject *obj, DV & val)
{
#ifdef WITH_NUMERIC
  if (PyArray_Check(obj)) {
    ToDV_fromArray(obj, val);
    return;
  }
#endif
  ToDV_fromList(obj, val);
}

#ifdef WITH_NUMERIC
void PythonLoader::ToDV_fromArray(PyObject *obj, DV & val)
{
  PyArrayObject *o = (PyArrayObject *)obj;
  
  val.resize(o->dimensions[0]);

  for (size_t i = 0; i < val.size(); i++)
    val[i] = *(double *)(o->data + i*o->strides[0]);
}
#else
void PythonLoader::ToDV_fromArray(PyObject *obj, DV & val)
{
  throw TypeException("Saved net contains Numeric arrays but library compiled without Numeric.");
}
#endif

void PythonLoader::ToDV_fromList(PyObject *obj, DV & val)
{
  Py_ssize_t size = PyList_Size(obj);

  if (size < 0)
    size = 0;

  val.resize(size);

  for (size_t i = 0; i < val.size(); i++)
    val[i] = ToDouble( PyList_GetItem(obj, i) );
}

void PythonLoader::ToDVSet(PyObject *obj, DVSet & val)
{
  ToDV( PyDict_GetItemString(obj, "mean"), val.mean );
  ToDV( PyDict_GetItemString(obj, "var"),  val.var );
  ToDV( PyDict_GetItemString(obj, "ex"),   val.ex );
}

void PythonLoader::ToDVH(PyObject *obj, DVH & val)
{
  PyObject *temp;

  ToDSSet( PyDict_GetItemString(obj, "scalar"), val.scalar );

  temp = PyDict_GetItemString(obj, "vec");
  if (temp) {
    if (val.vec)
      delete val.vec;
    val.vec = new DVSet();
    ToDVSet(temp, *val.vec);
  }
}

void PythonLoader::ToDD(PyObject *obj, DD & val)
{
  DV  d;
  ToDV( PyDict_GetItemString(obj, "val"), d );

  val.Resize( d.size() );

  for (size_t i = 0; i < val.size(); i++)
    val[i] = d[i];
}

void PythonLoader::ToVDD(PyObject *obj, VDD & val)
{
  Py_ssize_t size = PyList_Size(obj);

  if (size < 0)
    size = 0;

  val.Resize(size);

  for (size_t i = 0; i < val.size(); i++ ) {
    ToDD( PyList_GetItem(obj, i), val[i] );
  }
}

void PythonLoader::ToIntV(PyObject *obj, IntV & val)
{
  Py_ssize_t size = PyList_Size(obj);

  if (size < 0)
    size = 0;

  for (size_t i = 0; i < val.size(); i++)
    val[i] = ToInt( PyList_GetItem(obj, i) );
}

#endif // WITH_PYTHON
