// -*- C++ -*-

//
// This file is a part of the Bayes Blocks library
//
// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
// Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2, or (at your option)
// any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License (included in file License.txt in the
// program package) for more details.
//
// $Id: Net.h 7 2006-10-26 10:26:41Z ah $

#ifndef NET_H
#define NET_H

#ifdef WITH_PYTHON
#ifndef __PYTHON_H_INCLUDED__
#error "Python.h must be included before this file.  (It must also be included before any system include files.)"
#endif // not __PYTHON_H_INCLUDED__
#endif // WITH_PYTHON

#include "Templates.h"
#include "Saver.h"
#include "Loader.h"
#include "Decay.h"
#include <string>
#include <map>

class Node;
class Variable;
class OnLineDelay;
class Proxy;

class DecayCounter;

typedef vector<Node *> NodeVector;
typedef NodeVector::iterator NodeIterator;
typedef vector<Variable *> VariableVector;
typedef VariableVector::iterator VariableIterator;
typedef VariableVector::reverse_iterator VariableRIterator;

class Net
{
public:
  Net(size_t ti=1);
  ~Net();

  double Cost();
  void NotifyDeath(Node *ptr) {deadnodes.push_back(ptr);}
  void CleanUp();

  void UpdateAll();
  void UpdateTimeInd();
  void UpdateTimeDep();
  void SaveAllStates();
  void SaveAllSteps();
  void RepeatAllSteps(double alpha);
  void ClearAllStatesAndSteps();
  void StepTime();
  void ResetTime();
  size_t Time() {return t;}

  void Save(NetSaver *saver);
  void SaveToXMLFile(string fname, bool debugsave=false);
  void SaveNodeToXMLFile(string fname, Node *node, bool debugsave=false);
  void SaveToMatFile(string, string, bool debugsave=false);
#ifdef WITH_PYTHON
  PyObject *SaveToPyObject(bool debugsave=false);
#endif  // WITH_PYTHON

#ifndef BUILDING_SWIG_INTERFACE
  Net( NetLoader * loader );
#endif

  void AddVariable(Variable *ptr, Label label);
  Variable *GetVariable(Label label) { return variableindex[label]; }
  Variable *GetVariableByIndex(int i) { return variables[i]; }
  size_t VariableCount() { return variables.size(); }

  void AddNode(Node *ptr, Label label) {
    nodes.push_back(ptr);
    nodeindex[label] = ptr;
  }
  Node *GetNode(Label label) { return nodeindex[label]; }
  Node *GetNodeByIndex(int i) { return nodes[i]; }
  size_t NodeCount() { return nodes.size(); }

  void AddOnLineDelay(OnLineDelay *ptr, Label label) {
    oldelays.push_back(ptr);
    oldelayindex[label] = ptr;
  }
  OnLineDelay *GetOnLineDelay(Label label) { return oldelayindex[label]; }
  OnLineDelay *GetOnLineDelayByIndex(int i) { return oldelays[i]; }
  size_t OnLineDelayCount() { return oldelays.size(); }

  void AddProxy(Proxy *ptr, Label label) {
    proxies.push_back(ptr);
    proxyindex[label] = ptr;
  }
  Proxy *GetProxy(Label label) { return proxyindex[label]; }
  Proxy *GetProxyByIndex(int i) { return proxies[i]; }
  size_t ProxyCount() { return proxies.size(); }
  bool ConnectProxies();
  void SortNodes();
  void CheckStructure();

  Label GetNextLabel(Label label);

  void SetLabelconst(int i) { labelconst = i; }

  double GetOldCost() { return oldcost; }
  void SetOldCost(double c) { oldcost = c; }
  void SetDebugLevel(int l) { debuglevel = l; }
  int GetDebugLevel() { return debuglevel; }
  double Decay();
  void SetDecayCounter(DecayCounter *d);
  DecayCounter *dc;

  bool RegisterDecay(Decayer *d, string hook);
  bool UnregisterDecay(Decayer *d);
  bool UnregisterDecayFromHook(Decayer *d, string hook);
  bool ProcessDecayHook(string hook);
  void SetSumNKeepUpdated(bool keepupdated);


  bool HasVariableGroup(Label group);
  void DefineVariableGroup(Label group);
  size_t NumGroupVariables(Label group);
  Variable *GetGroupVariable(Label group, size_t index);
  void UpdateGroup(Label group);
  double CostGroup(Label group);


  bool HasTimeIndexGroup(Label group);
  void DefineTimeIndexGroup(Label group, IntV &indices);
  void EnableTimeIndexGroup(Label group);
  void DisableTimeIndexGroup(Label group);

private:
  void RemoveVariableFromGroups(Variable *ptr);
  void AddVariableToGroups(Variable *ptr);

  void RemovePtr(Node *ptr);

  size_t t;
  int labelconst;
  int debuglevel;
  map<Label, Node *> nodeindex;
  map<Label, Variable *> variableindex;
  map<Label, OnLineDelay *> oldelayindex;
  map<Label, Proxy *> proxyindex;
  NodeVector nodes, deadnodes;
  VariableVector variables;
  vector<OnLineDelay *> oldelays;
  vector<Proxy *> proxies;
  multimap<string, Decayer *> decay_hooks;
  double oldcost;
  bool sumnkeepupdated;

  map<Label, VariableVector *> variablegroups;

  map<Label, IntV *> timeindexgroups;
  IntV *activetimeindexgroup;
};

Net * LoadFromMatFile( string fname, string varname );
#ifdef WITH_PYTHON
Net *CreateNetFromPyObject(PyObject *obj);
#endif
 
#endif // NET_H
