// -*- C++ -*-

//
// This file is a part of the Bayes Blocks library
//
// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
// Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2, or (at your option)
// any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License (included in file License.txt in the
// program package) for more details.
//
// $Id: Node.h 7 2006-10-26 10:26:41Z ah $

#ifndef NODE_H
#define NODE_H

#include <string>
#include <map>
#include "Templates.h"
#include "Saver.h"
#include "Loader.h"
#include "Decay.h"
#include "Net.h"

class Node;

#ifndef BUILDING_SWIG_INTERFACE
typedef bool BOOLASOBJ;
#endif

enum partype_e {
  REAL_MV, REAL_ME, REAL_M, REALV_MV, REALV_ME, REALV_M, DISCRETE,
  DISCRETEV
};

class NodeBase
{
public:
  virtual ~NodeBase() { }

  virtual int ParIdentity(const Node *ptr) = 0;
  virtual size_t NumParents() = 0;
  virtual Node *GetParent(size_t i) = 0;
  virtual int RemoveParent(const Node *ptr) = 0;

protected:
  virtual void ReallyAddParent(Node *ptr) = 0;
  virtual bool ParReplacePtr(const Node *oldptr, Node *newptr) = 0;
};

class Node : public virtual NodeBase
{
public:
  friend Net::Net(NetLoader *loader);

  virtual ~Node() { }

  void NotifyDeath(Node *ptr, int verbose = 0);
  virtual void NotifyTimeType(int tt, int verbose = 0);
  void ReplacePtr(Node *oldptr, Node *newptr);
  void AddChild(Node *ptr) {children.push_back(ptr);}
protected:
  Node(Net *ptr, Label label);
#ifndef BUILDING_SWIG_INTERFACE
  Node(Net *ptr, NetLoader *loader, bool isproxy = 0);
#endif

  void AddParent(Node *ptr, bool really=true);

public:
  virtual bool GetReal(DSSet &val, DFlags req) { return false; }
  virtual void GradReal(DSSet &val, const Node *ptr) {}
  virtual bool GetRealV(DVH &val, DFlags req) {
    val.vec = 0; return GetReal(val.scalar, req); }
  virtual void GradRealV(DVSet &val, const Node *ptr) {}
#ifdef BUILDING_SWIG_INTERFACE
  virtual BOOLASOBJ GetDiscrete(DD *&val) { return false; }
#else
  virtual bool GetDiscrete(DD *&val) { return false; }
#endif
  virtual void GradDiscrete(DD &val, const Node *ptr) {}
  virtual bool GetDiscreteV(VDDH &val) {
    val.vec = 0; return GetDiscrete(val.scalar); }
  virtual void GradDiscreteV(VDD &val, const Node *ptr) {}

  virtual void Outdate(const Node *ptr) { OutdateChild(); }
  void CheckParent(size_t parnum, partype_e partype);

  bool ParReal(int i, DSSet &val, const DFlags req) {
    return GetParent(i)->GetReal(val, req);}
  bool ParRealV(int i, DVH &val, const DFlags req) {
    return GetParent(i)->GetRealV(val, req);}
  bool ParDiscrete(int i, DD *&val) {
    return GetParent(i)->GetDiscrete(val);}
  bool ParDiscreteV(int i, VDDH &val) {
    return GetParent(i)->GetDiscreteV(val);}
  void ChildGradReal(DSSet &val);
  void ChildGradRealV(DVSet &val);
  void ChildGradDiscrete(DD &val);
  void ChildGradDiscreteV(VDD &val);

  Label GetLabel() const { return label; }
  string GetIdent() const { return GetType() + " node " + GetLabel(); }
  Net *GetNet() const { return net; }
  virtual string GetType() const = 0;
  int TimeType() { return timetype; }

  int GetDying() { return dying; }
  void Die(int verbose = 0);
  void OutdateChild();

  virtual void Save(NetSaver *saver);

  size_t NumChildren() { return children.size(); }
  Node *GetChild(size_t i) {return i < children.size() ? children[i] : 0;}

  int GetPersist() { return persist; }
  void SetPersist(int p) { persist = p; }

protected:
  vector<Node *> children;
  Net *net;
  Label label;
  int persist, timetype;
  bool dying;
};

class NullParNode : public virtual NodeBase
{
public:
  virtual int ParIdentity(const Node *ptr) {return -1;}
  virtual size_t NumParents() { return 0; }
  virtual Node *GetParent(size_t i) {return 0;}
  virtual int RemoveParent(const Node *ptr) {return 0;}
protected:
  virtual void ReallyAddParent(Node *ptr) {return;}
  virtual bool ParReplacePtr(const Node *oldptr, Node *newptr) {return false;}
};

class UniParNode : public virtual NodeBase
{
private:
  Node *parent;
public:
  UniParNode(Node *p) : parent(p) {}
  virtual int ParIdentity(const Node *ptr) { return ptr==parent ? 0 : -1;}
  virtual size_t NumParents() { return parent!=0; }
  virtual Node *GetParent(size_t i) {return i==0 ? parent : 0;}
  virtual int RemoveParent(const Node *ptr);
protected:
  virtual void ReallyAddParent(Node *ptr) { parent = ptr; }
  virtual bool ParReplacePtr(const Node *oldptr, Node *newptr) {
    return (parent==oldptr) ? (parent=newptr) : false; }
};

class BiParNode : public virtual NodeBase
{
private:
  Node *parents[2];
public:
  BiParNode(Node *p1, Node *p2) { parents[0]=p1; parents[1]=p2; }
  virtual int ParIdentity(const Node *ptr);
  virtual size_t NumParents() { return parents[0] == 0 ? 0 :
      (parents[1] == 0 ? 1 : 2); }
  virtual Node *GetParent(size_t i) {return i < 2 ? parents[i] : 0;}
  virtual int RemoveParent(const Node *ptr);
protected:
  virtual void ReallyAddParent(Node *ptr);
  virtual bool ParReplacePtr(const Node *oldptr, Node *newptr);
};

class NParNode : public virtual NodeBase
{
private:
  vector<Node *> parents;
  map<const Node *, int> parent_inds;
public:
  NParNode(Node *p1=0, Node *p2=0, Node *p3=0, Node *p4=0, Node *p5=0) {
    if (p1) parents.push_back(p1);
    if (p2) parents.push_back(p2);
    if (p3) parents.push_back(p3);
    if (p4) parents.push_back(p4);
    if (p5) parents.push_back(p5);
  }
  virtual int ParIdentity(const Node *ptr);
  virtual size_t NumParents() { return parents.size(); }
  virtual Node *GetParent(size_t i) {return i < parents.size() ? parents[i] : 0;}
  virtual int RemoveParent(const Node *ptr);
protected:
  virtual void ReallyAddParent(Node *ptr) { parents.push_back(ptr); }
  virtual bool ParReplacePtr(const Node *oldptr, Node *newptr);
};

class Constant : public Node, public NullParNode
{
public:
  Constant(Net *net, Label label, double v) : Node(net, label) {cval = v;}
#ifndef BUILDING_SWIG_INTERFACE
  Constant(Net *net, NetLoader *loader);
#endif

  void NotifyTimeType(int tt, int verbose=0) {}
  bool GetReal(DSSet &val, DFlags req) {
    if (req.mean) {val.mean = cval; req.mean = false;}
    if (req.var) {val.var = 0; req.var = false;}
    if (req.ex) {val.ex = exp(cval); req.ex = false;}
    return req.AllFalse();
  }
  void GradReal(DSSet &val, const Node *ptr) {}
  void Save(NetSaver *saver);
  string GetType() const { return "Constant"; }
private:
  double cval;
};

class ConstantV : public Node, public NullParNode
{
public:
  ConstantV(Net *net, Label label, DV v);
#ifndef BUILDING_SWIG_INTERFACE
  ConstantV(Net *net, NetLoader *loader);
#endif

  void NotifyTimeType(int tt, int verbose=0) {}
  bool GetRealV(DVH &val, DFlags req) {
    val.vec = &myval;
    req.mean = false;
    req.var = false;
    req.ex = false;
    return req.AllFalse();
  }
  void Save(NetSaver *saver);
  string GetType() const { return "ConstantV"; }
private:
  DVSet myval;
};

class Function : public Node
{
public:
  void Outdate(const Node *ptr) 
  {
    uptodate = DFlags(false,false,false); 
    OutdateChild();
  }
  virtual void Save(NetSaver *saver);

protected:
  Function(Net *ptr, Label label, Node *n1 = 0, Node *n2 = 0);
#ifndef BUILDING_SWIG_INTERFACE
  Function(Net *ptr, NetLoader *loader);
#endif

  DFlags uptodate;
};

class Prod : public Function, public BiParNode
{
public:
  Prod(Net *ptr, Label label, Node *n1, Node *n2) :
    Function(ptr, label, n1, n2), BiParNode(n1, n2) {mean = 0.0; var = 0.0;}
#ifndef BUILDING_SWIG_INTERFACE
  Prod(Net *ptr, NetLoader *loader);
#endif

  bool GetReal(DSSet &val, DFlags req);
  void GradReal(DSSet &val, const Node *ptr);
  void Save(NetSaver *saver);
  string GetType() const { return "Prod"; }
private:
  double mean, var;
};

class Sum2 : public Function, public BiParNode
{
public:
  Sum2(Net *ptr, Label label, Node *n1, Node *n2) : 
    Function(ptr, label, n1, n2), BiParNode(n1, n2) {
    persist = 4 | 8; // Sum2 needs at least one child and cuts off if
                     // there is only one parent
  }
#ifndef BUILDING_SWIG_INTERFACE
  Sum2(Net *ptr, NetLoader *loader);
#endif

  bool GetReal(DSSet &val, DFlags req);
  void GradReal(DSSet &val, const Node *ptr);
  void Save(NetSaver *saver);
  string GetType() const { return "Sum2"; }
private:
  DSSet myval;
};

class SumN : public Function, public NParNode
{
public:
  SumN(Net *net, Label label) : 
    Function(net, label)
  {
    persist = 4 | 8; // SumN needs at least one child and cuts off if
                     // there is only one parent
    keepupdated = false;
  }
#ifndef BUILDING_SWIG_INTERFACE
  SumN(Net *net, NetLoader *loader);
#endif

  bool AddParent(Node *n);
  bool GetReal(DSSet &val, DFlags req);
  void GradReal(DSSet &val, const Node *ptr);
  void Save(NetSaver *saver);
  string GetType() const { return "SumN"; }
  void Outdate(const Node *ptr);
  void SetKeepUpdated(const bool _keepupdated);
private:
  DSSet myval;
  vector<DSSet> parentval;
  bool keepupdated;
};

class Relay : public Function, public UniParNode
{
public:
  Relay(Net *ptr, Label label, Node *n) :
    Function(ptr, label, n), UniParNode(n) {}
#ifndef BUILDING_SWIG_INTERFACE
  Relay(Net *ptr, NetLoader *loader);
#endif

  bool GetReal(DSSet &val, DFlags req) {return ParReal(0, val, req);}
  void GradReal(DSSet &val, const Node *ptr) {ChildGradReal(val);}
  void Save(NetSaver *saver);
  string GetType() const { return "Relay"; }
};

class Variable : public Node
{
public:
  virtual double Cost() = 0;
  virtual void Update() {
    if (!clamped) { 
      MyUpdate(); 
      OutdateChild();
    }
  }
  virtual void PartialUpdate(IntV *indices) {
    if (!clamped) {
      MyPartialUpdate(indices);
      OutdateChild();
    }
  }
  void Clamp(double val)
  {
    if (!MyClamp(val)) {
      ostringstream msg;
      msg << GetIdent() << ": Double clamp not allowed";
      throw TypeException(msg.str());
    }
    clamped = true; costuptodate = false;
    OutdateChild();
  }
  void Clamp(double mean, double var)
  {
    if (!MyClamp(mean, var)) {
      ostringstream msg;
      msg << GetIdent() << ": Double double clamp not allowed";
      throw TypeException(msg.str());
    }
    clamped = true; costuptodate = false;
    OutdateChild();
  }
  void Clamp(const DV &val)
  {
    if (!MyClamp(val)) {
      ostringstream msg;
      msg << GetIdent() << ": DV clamp not allowed";
      throw TypeException(msg.str());
    }
    clamped = true; costuptodate = false;
    OutdateChild();
  }
  void Clamp(const DV &mean, const DV &var)
  {
    if (!MyClamp(mean, var)) {
      ostringstream msg;
      msg << GetIdent() << ": Double DV clamp not allowed";
      throw TypeException(msg.str());
    }
    clamped = true; costuptodate = false;
    OutdateChild();
  }
  void Clamp(const DD &val) {
    if (!MyClamp(val)) {
      ostringstream msg;
      msg << GetIdent() << ": DD clamp not allowed";
      throw TypeException(msg.str());
    }
    clamped = true; costuptodate = false;
    OutdateChild();
  }
  void Clamp(int val) {
    if (!MyClamp(val)) {
      ostringstream msg;
      msg << GetIdent() << ": Int clamp not allowed";
      throw TypeException(msg.str());
    }
    clamped = true; costuptodate = false;
    OutdateChild();
  }
  void Clamp(const VDD &val) {
    if (!MyClamp(val)) {
      ostringstream msg;
      msg << GetIdent() << ": VDD clamp not allowed";
      throw TypeException(msg.str());
    }
    clamped = true; costuptodate = false;
    OutdateChild();
  }
  void Unclamp() {if (clamped) {clamped = false; MyUpdate(); OutdateChild();}}
  void SaveState();
  void SaveStep();
  void RepeatStep(double alpha);
  void SaveRepeatedState(double alpha);
  void ClearStateAndStep();
  virtual void Outdate(const Node *ptr) {costuptodate = false;}
  virtual void Save(NetSaver *saver);
  int GetHookeFlags() { return hookeflags; }
  void SetHookeFlags(int h) { hookeflags = h; }
  bool IsClamped() { return clamped; }

  // These two methods are ment for copying things from one network
  // to another similar one

  // The allocation of the DV instance is left to user
  // so it can be done in the jurisdiction of Python's GC.
  // The DV is resized, so initially it can be of size zero, for example.
  virtual void GetState(DV *state, size_t t = 0) {
    ostringstream msg;
    msg << "GetState not supported by " << GetType();
    throw TypeException(msg.str());
  }

  virtual void SetState(DV *state, size_t t = 0) {
    ostringstream msg;
    msg << "SetState not supported by " << GetType();
    throw TypeException(msg.str());
  }

protected:
  Variable(Net *ptr, Label label, Node *n1 = 0, Node *n2 = 0);
#ifndef BUILDING_SWIG_INTERFACE
  Variable(Net *ptr, NetLoader *loader);
#endif

  virtual bool MyClamp(double val) {return false;}
  virtual bool MyClamp(double mean, double var) {return false;}
  virtual bool MyClamp(const DV &val) {return false;}
  virtual bool MyClamp(const DV &mean, const DV &var) {return false;}
  virtual bool MyClamp(const DD &val) {return false;}
  virtual bool MyClamp(int val) {return false;}
  virtual bool MyClamp(const VDD &val) {return false;}
  virtual void MyUpdate() = 0;
  virtual bool MySaveState() {return false;}
  virtual bool MySaveStep() {return false;}
  virtual bool MySaveRepeatedState(double alpha) {return false;}
  virtual void MyRepeatStep(double alpha) {}
  virtual bool MyClearStateAndStep() {return false; }

  virtual void MyPartialUpdate(IntV *indices) {
    ostringstream msg;
    msg << "Partial updates not supported by " << GetType();
    throw StructureException(msg.str());
  }
    
  bool clamped, costuptodate;
  int hookeflags;
};

class Gaussian : public Variable, public BiParNode
{
public:
  Gaussian(Net *net, Label label, Node *m, Node *v);
#ifndef BUILDING_SWIG_INTERFACE
  Gaussian(Net *net, NetLoader *loader);
#endif
  ~Gaussian() {
    if (sstate) delete sstate;
    if (sstep) delete sstep;
  }

  double Cost();
  bool GetReal(DSSet &val, DFlags req);
  void GradReal(DSSet &val, const Node *ptr);
  void Save(NetSaver *saver);
  string GetType() const { return "Gaussian"; }

  void GetState(DV *state, size_t t);
  void SetState(DV *state, size_t t);

protected:
  virtual bool MyClamp(double m);
  virtual bool MyClamp(double m, double v);
  virtual void MyUpdate();
  bool MySaveState();
  bool MySaveStep();
  bool MySaveRepeatedState(double alpha);
  void MyRepeatStep(double alpha);
  bool MyClearStateAndStep();

  void MyPartialUpdate(IntV *indices);

  DSSet myval;
  DSSet *sstate, *sstep;
  double cost;
  bool exuptodate;
};


class RectifiedGaussian : public Variable, public BiParNode
{
public:
  RectifiedGaussian(Net *net, Label label, Node *m, Node *v);
#ifndef BUILDING_SWIG_INTERFACE
  RectifiedGaussian(Net *net, NetLoader *loader);
#endif

  double Cost();
  bool GetReal(DSSet &val, DFlags req);

  /* Returns the actual posterior parameters. */
  bool GetMyval(DSSet &val);

  void GradReal(DSSet &val, const Node *ptr);
  string GetType() const;
  void Save(NetSaver *saver);

  void GetState(DV *state, size_t t);
  void SetState(DV *state, size_t t);

protected:
  virtual void MyUpdate();

  void MyPartialUpdate(IntV *indices);

  void UpdateExpectations();

  /* Parameters of the rectified Gaussian posterior approximation. 
     For debug purposes. */
  DSSet myval;

  /* Expectations (stored to gain speed).
     Note that the posterior mean- or variance parameter is not
     the same as the mean or variance because the posterior is
     approximated with a rectified Gaussian. */
  DSSet expectations;

  double cost;
};

class RectifiedGaussianV : public Variable, public BiParNode
{
public:
  RectifiedGaussianV(Net *net, Label label, Node *m, Node *v);
#ifndef BUILDING_SWIG_INTERFACE
  RectifiedGaussianV(Net *net, NetLoader *loader);
#endif

  double Cost();
  bool GetRealV(DVH &val, DFlags req);
  bool GetMyvalV(DVH &val);

  void GradReal(DSSet &val, const Node *ptr);
  void GradRealV(DVSet &val, const Node *ptr);

  string GetType() const;
  void Save(NetSaver *saver);

protected:
  void MyUpdate();

  void UpdateExpectations();

  DVSet myval;
  DVSet expectations;

  double cost;
};

class GaussRect : public Variable, public BiParNode
{
public:
  GaussRect(Net *net, Label label, Node *m, Node *v);
#ifndef BUILDING_SWIG_INTERFACE
  GaussRect(Net *net, NetLoader *loader);
#endif

  double Cost();
  bool GetReal(DSSet &val, DFlags req);
  bool GetRectReal(DSSet &val, DFlags req);

  void GradReal(DSSet &val, const Node *ptr);

  string GetType() const;
  void Save(NetSaver *saver);

protected:
  void MyUpdate();

  void UpdateMoments();
  void UpdateExpectations();

  void ChildGradients(DSSet &norm, DSSet &rect);

  DSSet posval;
  DSSet negval;

  double posweight;
  double negweight;

  vector<double> posmoments;
  vector<double> negmoments;

  DSSet expts;
  DSSet rectexpts;

  double cost;
};

class GaussRectV : public Variable, public BiParNode
{
public:
  friend class GaussRectVState;

  GaussRectV(Net *net, Label label, Node *m, Node *v);
#ifndef BUILDING_SWIG_INTERFACE
  GaussRectV(Net *net, NetLoader *loader);
#endif

  double Cost();
  bool GetRealV(DVH &val, DFlags req);
  bool GetRectRealV(DVH &val, DFlags req);

  void GradReal(DSSet &val, const Node *ptr);
  void GradRealV(DVSet &val, const Node *ptr);

  string GetType() const;
  void Save(NetSaver *saver);

  void GetState(DV *state, size_t t);
  void SetState(DV *state, size_t t);

protected:
  void MyUpdate();

  void MyPartialUpdate(IntV *indices);

  void UpdateMoments();
  void UpdateExpectations();

  void ChildGradients(DVSet &norm, DVSet &rect);

  DVSet posval;
  DVSet negval;

  DV posweights;
  DV negweights;

  vector<DV> posmoments;
  vector<DV> negmoments;

  DVSet expts;
  DVSet rectexpts;

  double cost;
};

// Making the internals of GaussRectV public does not
// seem temptating but writing unittests without them
// is impossible. Hence, GaussRectVState (a friend of GaussRectV)
// provides access to the internals of GaussRectV without
// cluttering the interface of GaussRectV.

class GaussRectVState
{
public:
  GaussRectVState(GaussRectV *n);
  DVSet &GetPosVal();
  DVSet &GetNegVal();
  DV &GetPosWeights();
  DV &GetNegWeights();
  DV &GetPosMoment(int i);
  DV &GetNegMoment(int i);

private:
  GaussRectV *node;
};



class MoG : public Variable, public NParNode
{
public:
  MoG(Net *net, Label label, Node *d);
#ifndef BUILDING_SWIG_INTERFACE
  MoG(Net *net, NetLoader *loader);
#endif
  
  double Cost();
  bool GetReal(DSSet &val, DFlags req);
  void GradReal(DSSet &val, const Node *ptr);
  void GradDiscrete(DD &val, const Node *ptr);

  string GetType() const;
  void Save(NetSaver *saver);

  void AddComponent(Node *m, Node *v);
  size_t NumComponents();

protected:
  void MyUpdate();

  vector<DSSet*> myval;

  vector<Node*> means;
  vector<Node*> vars;

private:
  bool IsMeanParent(const Node *ptr);
  bool IsVarParent(const Node *ptr);
  int WhichMeanParent(const Node *ptr);
  int WhichVarParent(const Node *ptr);
  int WhichParent(const Node *ptr, const vector<Node*> &parents);

  void ComputeExpectations();

  DSSet expts;

  size_t numComponents;

  double cost;
};


class MoGV : public Variable, public NParNode
{
public:
  MoGV(Net *net, Label label, Node *d);
#ifndef BUILDING_SWIG_INTERFACE
  MoGV(Net *net, NetLoader *loader);
#endif

  double Cost();
  bool GetRealV(DVH &val, DFlags req);
  void GetMyvalV(DVH &val, int k);

  void GradReal(DSSet &val, const Node *ptr);
  void GradRealV(DVSet &val, const Node *ptr);
  void GradDiscreteV(VDD &val, const Node *ptr);

  string GetType() const;
  void Save(NetSaver *saver);
  
  /* Parents MUST be added with this method. */
  void AddComponent(Node *m, Node *v);
  size_t NumComponents();

protected:
  void MyUpdate();
  bool MyClamp(const DV &m);

  /* Posterior parameters (weights are got from Categorical) */
  vector<DVSet*> myval;

  vector<Node*> means;
  vector<Node*> vars;

private:
  bool IsMeanParent(const Node *ptr);
  bool IsVarParent(const Node *ptr);
  int WhichMeanParent(const Node *ptr);
  int WhichVarParent(const Node *ptr);
  int WhichParent(const Node *ptr, const vector<Node*> &parents);

  /* Updates expts. */
  void ComputeExpectations();

  /* Expectations calculated from the posterior. */
  DVSet expts;

  /* Number of mixture components. */
  size_t numComponents;

  double cost;
};


class Dirichlet : public Variable, public NParNode
{
public:
  Dirichlet(Net *net, Label label, ConstantV *n);
#ifndef BUILDING_SWIG_INTERFACE
  Dirichlet(Net *net, NetLoader *loader);
#endif
  
  double Cost();
  /* Returns expectations of different components. 
   <log c_i> is in ex field, naturally. */
  bool GetRealV(DVH &val, DFlags req);
  
  string GetType() const;
  void Save(NetSaver *saver);

protected:
  void MyUpdate();

private:
  /* Updates expts. */
  void ComputeExpectations();

  /* Posterior parameters. */
  DV myval;

  /* Expectations calculated from the posterior. */
  DVSet expts;

  /* Number of components. */
  size_t numComponents;

  double cost;
};

class DiscreteDirichlet : public Variable, public NParNode
{
public:
  DiscreteDirichlet(Net *net, Label label, Dirichlet *n);
#ifndef BUILDING_SWIG_INTERFACE
  DiscreteDirichlet(Net *net, NetLoader *loader);
#endif

  double Cost();
  bool GetDiscrete(DD *&val);
  void GradRealV(DVSet &val, const Node *ptr);

  string GetType() const;
  void Save(NetSaver *saver);

protected:
  void MyUpdate();
  bool MyClamp(const DD &v);

  DD myval;

  double cost;
};


/* A discrete variable with dirichlet prior for its prior weights */

class DiscreteDirichletV : public Variable, public NParNode
{
public:
  DiscreteDirichletV(Net *net, Label label, Dirichlet *n);
#ifndef BUILDING_SWIG_INTERFACE
  DiscreteDirichletV(Net *net, NetLoader *loader);
#endif

  double Cost();
  bool GetDiscreteV(VDDH &val);
  void GradRealV(DVSet &val, const Node *ptr);

  string GetType() const;
  void Save(NetSaver *saver);

protected:
  void MyUpdate();
  bool MyClamp(const VDD &v);

  VDD myval;

  double cost;
};


class Rectification : public Function, public UniParNode
{
public:
  Rectification(Net *net, Label label, Node *n);
#ifndef BUILDING_SWIG_INTERFACE
  Rectification(Net *net, NetLoader *loader);
#endif
  bool GetReal(DSSet &val, DFlags req);
  void GradReal(DSSet &val, const Node *ptr);
  void Save(NetSaver *saver);
  string GetType() const;
};

class RectificationV : public Function, public UniParNode
{
public:
  RectificationV(Net *net, Label label, Node *n);
#ifndef BUILDING_SWIG_INTERFACE
  RectificationV(Net *net, NetLoader *loader);
#endif
  bool GetRealV(DVH &val, DFlags req);
  void GradRealV(DVSet &val, const Node *ptr);
  void Save(NetSaver *saver);
  string GetType() const;
};

class ProdV : public Function, public BiParNode
{
public:
  ProdV(Net *ptr, Label label, Node *n1, Node *n2) : 
    Function(ptr, label, n1, n2), BiParNode(n1, n2) {}
#ifndef BUILDING_SWIG_INTERFACE
  ProdV(Net *ptr, NetLoader *loader);
#endif

  void GradReal(DSSet &val, const Node *ptr);
  bool GetRealV(DVH &val, DFlags req);
  void GradRealV(DVSet &val, const Node *ptr);
  void Save(NetSaver *saver);
  string GetType() const { return "ProdV"; }
private:
  DVSet myval;
};


class Sum2V : public Function, public BiParNode
{
public:
  Sum2V(Net *ptr, Label label, Node *n1, Node *n2) :
    Function(ptr, label, n1, n2), BiParNode(n1, n2) {
    persist = 4 | 8; // Sum2V needs at least one child and cuts off if
                     // there is only one parent
  }
#ifndef BUILDING_SWIG_INTERFACE
  Sum2V(Net *ptr, NetLoader *loader);
#endif

  void GradReal(DSSet &val, const Node *ptr);
  bool GetRealV(DVH &val, DFlags req);
  void GradRealV(DVSet &val, const Node *ptr);
  void Save(NetSaver *saver);
  string GetType() const { return "Sum2V"; }
private:
  DVSet myval;
};

class SumNV : public Function, public NParNode
{
public:
  SumNV(Net *net, Label label) : 
    Function(net, label)
  {
    persist = 4 | 8; // SumN needs at least one child and cuts off if
                     // there is only one parent
    keepupdated = false;
  }
#ifndef BUILDING_SWIG_INTERFACE
  SumNV(Net *net, NetLoader *loader);
#endif

  bool AddParent(Node *n);
  bool GetRealV(DVH &val, DFlags req);
  void GradReal(DSSet &val, const Node *ptr);
  void GradRealV(DVSet &val, const Node *ptr);
  void Save(NetSaver *saver);
  string GetType() const { return "SumNV"; }
  void Outdate(const Node *ptr);
  void SetKeepUpdated(const bool _keepupdated);
private:
  void UpdateFromScratch(DFlags req);
  DVSet myval;
  vector<DVSet> parentval;
  bool keepupdated;
};

class DelayV : public Function, public BiParNode
{
public:
  DelayV(Net *ptr, Label label, Node *n1, Node *n2) : 
    Function(ptr, label, n1, n2), BiParNode(n1, n2) 
  {
    lendelay = 1;
  }
#ifndef BUILDING_SWIG_INTERFACE
  DelayV(Net *ptr, NetLoader *loader);
#endif

  void GradReal(DSSet &val, const Node *ptr);
  bool GetRealV(DVH &val, DFlags req);
  void GradRealV(DVSet &val, const Node *ptr);
  void Save(NetSaver *saver);
  string GetType() const { return "DelayV"; }

  int GetDelayLength();
  void SetDelayLength(int len);

private:
  DVSet myval;

  int lendelay;
};

class GaussianV : public Variable, public BiParNode
{
public:
  GaussianV(Net *_net, Label label, Node *m, Node *v);
#ifndef BUILDING_SWIG_INTERFACE
  GaussianV(Net *_net, NetLoader *loader);
#endif
  ~GaussianV() {
    if (sstate) delete sstate;
    if (sstep) delete sstep;
  }

  double Cost();

  void GradReal(DSSet &val, const Node *ptr);
  bool GetRealV(DVH &val, DFlags req);
  void GradRealV(DVSet &val, const Node *ptr);
  void Save(NetSaver *saver);
  string GetType() const { return "GaussianV"; }

  void GetState(DV *state, size_t t);
  void SetState(DV *state, size_t t);

protected:
  bool MyClamp(double m);
  bool MyClamp(double m, double v);
  bool MyClamp(const DV &m);
  bool MyClamp(const DV &m, const DV &v);
  void MyUpdate();

  void MyPartialUpdate(IntV *indices);

  bool MySaveState();
  bool MySaveStep();
  bool MySaveRepeatedState(double alpha);
  void MyRepeatStep(double alpha);
  bool MyClearStateAndStep();

  DVSet myval;
  double cost;
  bool exuptodate;
private:
  DVSet *sstate, *sstep;
};

class SparseGaussV : public GaussianV
{
public:
  SparseGaussV(Net *_net, Label label, Node *m, Node *v) : 
    GaussianV(_net, label, m, v) {}
#ifndef BUILDING_SWIG_INTERFACE
  SparseGaussV(Net *_net, NetLoader *loader);
#endif

  double Cost();

  void Update();
  void GradReal(DSSet &val, const Node *ptr);
  void GradRealV(DVSet &val, const Node *ptr);
  void Save(NetSaver *saver);
  string GetType() const { return "SparseGaussV"; }
  void SparseClampDV(const DV &mean, const IntV &mis);
  IntV& GetMissing() { return missing; }
  void SetMissing(IntV &mis);

private:
  IntV missing;
};

class DelayGaussV : public Variable, public NParNode
{
public:
  DelayGaussV(Net *_net, Label label, Node *m, Node *v, Node *a,
	      Node *m0, Node *v0);
#ifndef BUILDING_SWIG_INTERFACE
  DelayGaussV(Net *_net, NetLoader *loader);
#endif

  double Cost();

  void GradReal(DSSet &val, const Node *ptr);
  bool GetRealV(DVH &val, DFlags req);
  void GradRealV(DVSet &val, const Node *ptr);
  void Save(NetSaver *saver);
  string GetType() const { return "DelayGaussV"; }
protected:
  bool MyClamp(double m)
  {
    fill(myval.mean.begin(), myval.mean.end(), m);
    fill(myval.var.begin(), myval.var.end(), 0);
    exuptodate = false;
    return true;
  }
  bool MyClamp(const DV &m)
  {
    if (m.size() == myval.mean.size())
      copy(m.begin(), m.end(), myval.mean.begin());
    else {
      ostringstream msg;
      msg << "DelayGaussV::MyClamp: wrong vector size " << m.size() << " != "
	  << myval.mean.size();
      throw TypeException(msg.str());
    }
    fill(myval.var.begin(), myval.var.end(), 0);
    return true;
  }
  bool MyClamp(const DV &m, const DV &v)
  {
    if (m.size() == myval.mean.size() && v.size() == myval.var.size()) {
      copy(m.begin(), m.end(), myval.mean.begin());
      copy(v.begin(), v.end(), myval.var.begin());
    } else {
      ostringstream msg;
      msg << "DelayGaussV::MyClamp: wrong vector size " << m.size() << " != "
	  << myval.mean.size();
      throw TypeException(msg.str());
    }
    return true;
  }
  void MyUpdate();
  bool MySaveState();
  bool MySaveStep();
  bool MySaveRepeatedState(double alpha);
  void MyRepeatStep(double alpha);
  bool MyClearStateAndStep();

private:
  DVSet myval;
  DVSet *sstate, *sstep;
  double cost;
  bool exuptodate;
};

class GaussNonlin : public Variable, public BiParNode
// nonlinearity: myval2 = exp(-myval1*myval1)
{
public:
  GaussNonlin(Net *_net, Label label, Node *m, Node *v) : 
    Variable(_net, label, m, v), BiParNode(m, v)
  {
    sstate = 0; sstep = 0;
    cost = 0;

    CheckParent(0, REAL_MV);
    CheckParent(1, REAL_ME);
    MyUpdate();
  }
#ifndef BUILDING_SWIG_INTERFACE
  GaussNonlin(Net *_net, NetLoader *loader);
#endif
  ~GaussNonlin() {
    if (sstate) delete sstate;
    if (sstep) delete sstep;
  }

  double Cost();
  bool GetReal(DSSet &val, DFlags req);
  void GradReal(DSSet &val, const Node *ptr);
  void Save(NetSaver *saver);
  string GetType() const { return "GaussNonlin"; }

protected:
  bool MyClamp(double m) {
    myval1.mean = m;
    myval1.var = 0;
    meanuptodate = false;
    varuptodate = false;
    return true;
  }

  void MyUpdate();
  void UpdateMean();
  void UpdateVar();
  bool MySaveState();
  bool MySaveStep();
  bool MySaveRepeatedState(double alpha);
  void MyRepeatStep(double alpha);
  bool MyClearStateAndStep();

private:
  DSSet myval1, myval2;  // 1 before nonlinearity, 2 after
  DSSet *sstate, *sstep;
  double cost;
  bool meanuptodate, varuptodate;  // refer to myval2
};

class GaussNonlinV : public Variable, public BiParNode
// nonlinearity: myval2 = exp(-myval1*myval1)
{
public:
  GaussNonlinV(Net *_net, Label label, Node *m, Node *v);
#ifndef BUILDING_SWIG_INTERFACE
  GaussNonlinV(Net *_net, NetLoader *loader);
#endif
  ~GaussNonlinV() {
    if (sstate) delete sstate;
    if (sstep) delete sstep;
  }

  double Cost();
  void GradReal(DSSet &val, const Node *ptr);
  bool GetRealV(DVH &val, DFlags req);
  void GradRealV(DVSet &val, const Node *ptr);
  void Save(NetSaver *saver);
  string GetType() const { return "GaussNonlinV"; }

protected:
  bool MyClamp(double m)
  {
    fill(myval1.mean.begin(), myval1.mean.end(), m);
    fill(myval1.var.begin(), myval1.var.end(), 0);
    meanuptodate = false;
    varuptodate = false;
    return true;
  }
  bool MyClamp(const DV &m)
  {
    if (m.size() == myval1.mean.size())
      copy(m.begin(), m.end(), myval1.mean.begin());
    else {
      ostringstream msg;
      msg << "GaussianV::MyClamp: wrong vector size " << m.size() << " != "
	  << myval1.mean.size();
      throw TypeException(msg.str());
    }
    fill(myval1.var.begin(), myval1.var.end(), 0);
    meanuptodate = false;
    varuptodate = false;
    return true;
  }
  void MyUpdate();
  void UpdateMean();
  void UpdateVar();
  bool MySaveState();
  bool MySaveStep();
  bool MySaveRepeatedState(double alpha);
  void MyRepeatStep(double alpha);
  bool MyClearStateAndStep();

private:
  DVSet myval1, myval2;  // 1 before nonlinearity, 2 after
  DVSet *sstate, *sstep;
  double cost;
  bool meanuptodate, varuptodate;  // refer to myval2
};


class Discrete : public Variable, public NParNode
{
public:
  Discrete(Net *_net, Label label, Node *n=0) : 
    Variable(_net, label, n), NParNode(n)
  {
    cost = 0; exsum = 0;
    if (n) {
      CheckParent(0, REAL_ME);
      exuptodate = false;
      MyUpdate();
    }
  }
#ifndef BUILDING_SWIG_INTERFACE
  Discrete(Net *_net, NetLoader *loader);
#endif

  bool AddParent(Node *n) {
    Node::AddParent(n);
    CheckParent(NumParents()-1, REAL_ME);
    MyUpdate();
    return true;
  }
  double Cost();
  void GradReal(DSSet &val, const Node *ptr);
#ifdef BUILDING_SWIG_INTERFACE
  BOOLASOBJ GetDiscrete(DD *&val);
#else
  bool GetDiscrete(DD *&val);
#endif
  void Save(NetSaver *saver);
  string GetType() const { return "Discrete"; }
protected:
  bool MyClamp(double m) { return false; }
  bool MyClamp(const DD &m) { myval = m; return true; }
  bool MyClamp(int n) { 
    if (n >= (int)NumParents()) {
      throw TypeException("Too large value for clamping a Discrete");
    }
    myval.Resize(NumParents());
    for (size_t j=NumParents(); j>0; j--) {
      myval[j-1] = 0;
    }
    myval[n] = 1;
    return true;
  }
  void MyUpdate();
  void UpdateExpSum();
private:
  DD myval;
  double cost, exsum;
  bool exuptodate;
};


class DiscreteV : public Variable, public NParNode
{
public:
  DiscreteV(Net *_net, Label label, Node *n=0);
#ifndef BUILDING_SWIG_INTERFACE
  DiscreteV(Net *_net, NetLoader *loader);
#endif

  bool AddParent(Node *n) {
    DVH tmp;
    if (! n->GetRealV(tmp, DFlags(true, false, true))) {
      ostringstream msg;
      msg << "Wrong type of parents in " << GetType() << " Node "
          << label << std::endl;
      msg << " Parent " << n->GetLabel() << ":" << n->GetType();
      throw StructureException(msg.str());
    }
    Node::AddParent(n);
    MyUpdate();
    return true;
  }
  double Cost();
  void GradReal(DSSet &val, const Node *ptr);
  void GradRealV(DVSet &val, const Node *ptr);
  bool GetDiscreteV(VDDH &val);
  void Save(NetSaver *saver);
  string GetType() const { return "DiscreteV"; }
protected:
  bool MyClamp(double m) { return false; }
  bool MyClamp(const VDD &m) { myval = m; return true; }
  void MyUpdate();
  void UpdateExpSum();
private:
  VDD myval;
  double cost;
  DV exsum;
  bool exuptodate;
};


class Memory : public Variable, public UniParNode
{
public:
  Memory(Net *_net, Label label, Node *n) : 
    Variable(_net, label, n), UniParNode(n)
  {
    if (n->TimeType()) {
      ostringstream msg;
      msg << GetIdent() << ": parent must be independent of time";
      throw StructureException(msg.str());
    }
    timetype = 2;
    oldcost = 0; cost = 0;
  }
#ifndef BUILDING_SWIG_INTERFACE
  Memory(Net * net, NetLoader *loader);
#endif

  void NotifyTimeType(int tt, int verbose=0)
  {
    if (GetParent(0)->TimeType()) {
      ostringstream msg;
      msg << GetIdent() << ": parent must be independent of time";
      throw StructureException(msg.str());
    }
  }
  double Cost();
  void MyUpdate();
  bool GetReal(DSSet &val, DFlags req) {return ParReal(0, val, req);}
  void GradReal(DSSet &grad, const Node *ptr);
  void Save(NetSaver *saver);
  string GetType() const { return "Memory"; }
  void Outdate(const Node *ptr) { costuptodate = false; OutdateChild(); }

  DSSet oldval;
  double oldcost;
  double cost;
};

class OnLineDelay : public Node
{
public:
  virtual void Save(NetSaver *saver);
  virtual void StepTime() = 0;
  virtual void ResetTime() = 0;

protected:
  OnLineDelay(Net *ptr, Label label, Node *n1 = 0, Node *n2 = 0);
#ifndef BUILDING_SWIG_INTERFACE
  OnLineDelay(Net *ptr, NetLoader *loader);
#endif
};

class OLDelayS : public OnLineDelay, public BiParNode
{
public:
  OLDelayS(Net *ptr, Label label, Node *n1 = 0, Node *n2 = 0) :
    OnLineDelay(ptr, label, n1, n2), BiParNode(n1, n2)
  {
    CheckParent(0, REAL_M);
    CheckParent(1, REAL_M);

    DSSet tmp;
    ParReal(0, tmp, DFlags(true));
    oldmean = tmp.mean;
    exuptodate = false;
  }
#ifndef BUILDING_SWIG_INTERFACE
  OLDelayS(Net *ptr, NetLoader *loader);
#endif
  virtual void Save(NetSaver *saver);
  virtual void StepTime();
  virtual void ResetTime();
  virtual bool GetReal(DSSet &val, DFlags req);
  string GetType() const { return "OLDelayS"; }

private:
  double oldmean;
  double oldexp;
  bool exuptodate;
};

class OLDelayD : public OnLineDelay, public BiParNode
{
public:
  OLDelayD(Net *ptr, Label label, Node *n1 = 0, Node *n2 = 0) :
    OnLineDelay(ptr, label, n1, n2), BiParNode(n1, n2)
  {
    CheckParent(0, DISCRETE);
    CheckParent(1, DISCRETE);

    DD *tmp;
    ParDiscrete(0, tmp);
    oldval = *tmp;
  }
#ifndef BUILDING_SWIG_INTERFACE
  OLDelayD(Net *ptr, NetLoader *loader);
#endif
  virtual void Save(NetSaver *saver);
  virtual void StepTime();
  virtual void ResetTime();
#ifdef BUILDING_SWIG_INTERFACE
  virtual BOOLASOBJ GetDiscrete(DD *&val);
#else
  virtual bool GetDiscrete(DD *&val);
#endif
  string GetType() const { return "OLDelayD"; }

private:
  DD oldval;
};

class Proxy : public Node, public UniParNode
{
public:
  Proxy(Net *ptr, Label label, Label rlabel);
#ifndef BUILDING_SWIG_INTERFACE
  Proxy(Net *ptr, NetLoader *loader);
#endif
  void Save(NetSaver *saver);
  string GetType() const { return "Proxy"; }

  bool GetReal(DSSet &val, DFlags req);
  bool GetRealV(DVH &val, DFlags req);
#ifdef BUILDING_SWIG_INTERFACE
  BOOLASOBJ GetDiscrete(DD *&val);
#else
  bool GetDiscrete(DD *&val);
#endif
  bool GetDiscreteV(VDDH &val);
  void GradReal(DSSet &val, const Node *ptr) { ChildGradReal(val); }
  void GradRealV(DVSet &val, const Node *ptr) { ChildGradRealV(val); }
  void GradDiscrete(DD &val, const Node *ptr) { ChildGradDiscrete(val); }
  void GradDiscreteV(VDD &val, const Node *ptr) { ChildGradDiscreteV(val); }
  bool CheckRef();

private:
  string reflabel;
  bool req_discrete, req_discretev;
  DFlags real_flags, realv_flags;
};

class Evidence : public Variable, public Decayer, public UniParNode
{
public:
  Evidence(Net *ptr, Label label, Node *p) :
    Variable(ptr, label, p), Decayer(ptr), UniParNode(p)
  {
    alpha = 1e-10;
    decay = 0;
    myval = 0;
    cost = 0;
  }
#ifndef BUILDING_SWIG_INTERFACE
  Evidence(Net *ptr, NetLoader *loader);
#endif
  void Save(NetSaver *saver);
  string GetType() const { return "Evidence"; }

  void GradReal(DSSet &val, const Node *ptr);
  double Cost();
  void SetDecayTime(double iters) { decay = alpha / iters; }
  virtual bool DoDecay(string hook);

private:
  void MyUpdate() {}
  bool MyClamp(double mean, double var);

  double cost;
  double myval;
  double alpha;
  double decay;
};

class EvidenceV : public Variable, public Decayer, public UniParNode
{
public:
  EvidenceV(Net *ptr, Label label, Node *p);
#ifndef BUILDING_SWIG_INTERFACE
  EvidenceV(Net *ptr, NetLoader *loader);
#endif
  void Save(NetSaver *saver);
  string GetType() const { return "EvidenceV"; }

  void GradRealV(DVSet &val, const Node *ptr);
  double Cost();
  void SetDecayTime(const DV &iters);
  virtual bool DoDecay(string hook);

private:
  void MyUpdate() {}
  bool MyClamp(double mean, double var);
  bool MyClamp(const DV &mean, const DV &var);

  double cost;
  DV myval;
  DV alpha;
  DV decay;
};



#endif // NODE_H
