//
// This file is a part of the Bayes Blocks library
//
// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
// Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2, or (at your option)
// any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License (included in file License.txt in the
// program package) for more details.
//
// $Id: DecayCounter.cc 7 2006-10-26 10:26:41Z ah $

#include "config.h"
#ifdef WITH_PYTHON
#include <Python.h>
#define __PYTHON_H_INCLUDED__
#endif

#include "Templates.h"
#include "DecayCounter.h"
#include <iostream>

DecayCounter *DecayCounter::GlobalLoader(NetLoader *loader)
{
  DecayCounter *dc = new TraditionalDecayCounter();
  dc->Load(loader);
  return dc;
}


TraditionalDecayCounter::TraditionalDecayCounter(double r)
{
  samples = 1;
  ratio = r;
  cumsum = 0;
  decay = 0;
}

double TraditionalDecayCounter::StepTime()
{
  double x;

  samples++;
  cumsum++;
  x = (samples * ratio - 1) / cumsum;
  if (x < 0.99)
    x = exp(x-1);
  cumsum *= x;
  decay = x;
  return x;
}

void TraditionalDecayCounter::Save(NetSaver *saver)
{
  saver->SetNamedDouble("samples", samples);
  saver->SetNamedDouble("cumsum", cumsum);
  saver->SetNamedDouble("ratio", ratio);
  saver->SetNamedDouble("decay", decay);
}

void TraditionalDecayCounter::Load(NetLoader *loader)
{
  loader->GetNamedDouble("samples", samples);
  loader->GetNamedDouble("cumsum", cumsum);
  loader->GetNamedDouble("ratio", ratio);
  loader->GetNamedDouble("decay", decay);
}

#ifdef WITH_PYTHON
PythonDecayCounter::PythonDecayCounter(PyObject *stepfunc)
{
  if (!PyCallable_Check(stepfunc))
    throw TypeException("DecayCounter function not callable");
  step_function = stepfunc;
  Py_INCREF(step_function);
}

PythonDecayCounter::~PythonDecayCounter()
{
  Py_DECREF(step_function);
}

double PythonDecayCounter::StepTime()
{
  PyObject *val;

  val = PyObject_CallObject(step_function, 0);
  
  if (!PyFloat_Check(val))
    throw TypeException("DecayCounter return type not float");
  decay = PyFloat_AsDouble(val);
  Py_DECREF(val);

  return decay;
}

void PythonDecayCounter::Save(NetSaver *saver)
{
}

void PythonDecayCounter::Load(NetLoader *loader)
{
}
#endif  // WITH_PYTHON
