//
// This file is a part of the Bayes Blocks library
//
// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
// Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2, or (at your option)
// any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License (included in file License.txt in the
// program package) for more details.
//
// $Id: Node.cc 7 2006-10-26 10:26:41Z ah $

#include "config.h"
#ifdef WITH_PYTHON
#include <Python.h>
#define __PYTHON_H_INCLUDED__
#endif

#include "Templates.h"
#include "Net.h"
#include "Node.h"
#include "Saver.h"
#include <algorithm>
//#include <assert.h>
#include "SpecFun.h"

const double MINSTEP = 1e-4;
const double MAXSTEP = 4;
const double EPSILON = 1.5e-8;  // ~ sqrt(eps) = 1.4901e-08
const double NL_EPSILON = 1e-10;
const double PI = 3.14159265358979323846;
const double _5LOG2PI = 0.5 * log(2*PI);

const double RECTLIMIT = -30;
#define RECTIFIED_BETTER_APPROX 0

const double CATEGLIMIT = 1e-40;
const double GAUSSRECTLIMIT = 1e-40;

inline double sign(double d)
{
  if (d < 0)
    return -1;
  if (d > 0)
    return 1;
  else
    return 0;
}


// abstract class Node

Node::Node(Net *ptr, Label mylabel)
{
  persist = 0;
  net = ptr;
  if (net->GetNode(mylabel)) {
    label = net->GetNextLabel(mylabel);
  }
  else
    label = mylabel;
  dying = false;
  net->AddNode(this, label);
  timetype = 0;
}

void Node::Die(int verbose)
{
  size_t i;

  if (verbose)
    cout << "Node " << GetLabel() << " of type " << GetType() <<
      " is dying" << endl;

  dying = true;

  net->NotifyDeath(this);

  for (i = 0; i < NumParents(); i++)
    GetParent(i)->NotifyDeath(this, verbose);
  for (i = 0; i < children.size(); i++)
    children[i]->NotifyDeath(this, verbose);
}

void Node::NotifyDeath(Node *ptr, int verbose)
{
  int par, child;
  NodeIterator it;

  if (dying) return;

  par = RemoveParent(ptr);

  it = remove(children.begin(), children.end(), ptr);
  child = children.end() - it;
  children.erase(it, children.end());

  if (par || child)
    if ((1 & persist) && par)
      Die(verbose);
    else if ((2 & persist) && NumParents() == 0)
      Die(verbose);
    else if ((4 & persist) && children.empty())
      Die(verbose);
    else if ((8 & persist) && NumParents() == 1) {
      while (children.size()) {
	children[0]->ReplacePtr(this, GetParent(0));
	GetParent(0)->AddChild(children[0]);
	children.erase(children.begin());
      }
      Die(verbose);
    }

  if (par && !dying)
    Outdate(ptr);
}

void Node::ReplacePtr(Node *oldptr, Node *newptr)
{
  size_t i;

  if (ParReplacePtr(oldptr, newptr))
    Outdate(newptr);

  for (i = 0; i < children.size(); i++)
    if (children[i] == oldptr)
      children[i] = newptr;
}

void Node::NotifyTimeType(int tt, int verbose)
{
  size_t i;
  if (timetype || !tt)
    return;
  timetype = tt;
  if (verbose)
    cout << "Node " << GetLabel() << " of type " << GetType() <<
      ": time type = " << tt << endl;
  for (i = 0; i < NumParents(); i++)
    GetParent(i)->NotifyTimeType(tt, verbose);
  for (i = 0; i < children.size(); i++)
    children[i]->NotifyTimeType(tt, verbose);
}

void Node::AddParent(Node *ptr, bool really) {
  if (ptr->GetDying()) {
    ostringstream msg;
    msg << "Parent " << ptr->GetLabel() << " for Node "
	<< GetLabel() << " is dying";
    throw StructureException(msg.str());
  }
  if (really)
    ReallyAddParent(ptr);
  ptr->AddChild(this);
  if (ptr->TimeType() && !timetype)
    NotifyTimeType(1);
  if (!ptr->TimeType() && timetype) {
    if (net->GetDebugLevel() > -1) {
      cerr << "Warning: changing parent " << ptr->GetLabel() << " time type due to"
	   << endl;
      cerr << "addition of a new child " << label << endl;
    }
    ptr->NotifyTimeType(1);
  }
}

int UniParNode::RemoveParent(const Node *ptr)
{
  if (parent == ptr) {
    parent = 0;
    return 1;
  }
  else {
    return 0;
  }
}

bool BiParNode::ParReplacePtr(const Node *oldptr, Node *newptr)
{
  bool ret = false;
  size_t i;
  for (i = 0; i < NumParents(); i++)
    if (parents[i] == oldptr) {
      parents[i] = newptr;
      ret = true;
    }
  return ret;
}

int BiParNode::ParIdentity(const Node *ptr)
{
  if (parents[0] == ptr)
    return 0;
  if (parents[1] == ptr)
    return 1;

  return -1;
}

void BiParNode::ReallyAddParent(Node *ptr)
{
  parents[0] == 0 ? parents[0] = ptr : parents[1] = ptr;
}

int BiParNode::RemoveParent(const Node *ptr)
{
  int par = 0;

  if (parents[1] == ptr) {
    parents[1] = 0;
    par++;
  }
  if (parents[0] == ptr) {
    parents[0] = parents[1];
    parents[1] = 0;
    par++;
  }

  return par;
}

bool NParNode::ParReplacePtr(const Node *oldptr, Node *newptr)
{
  bool ret = false;
  size_t i;
  for (i = 0; i < NumParents(); i++)
    if (parents[i] == oldptr) {
      parents[i] = newptr;
      parent_inds.erase(oldptr);
      parent_inds[newptr] = i;
      ret = true;
    }
  return ret;
}

int NParNode::ParIdentity(const Node *ptr)
{
  map<const Node *, int>::iterator p = parent_inds.find(ptr);

  // The node was found
  if (p != parent_inds.end())
    return p->second;
  else {
    return -1;
  }
}

int NParNode::RemoveParent(const Node *ptr)
{
  int par;
  NodeIterator it;

  it = remove(parents.begin(), parents.end(), ptr);
  par = parents.end() - it;
  parents.erase(it, parents.end());

  parent_inds.clear();
  for (size_t i=0; i < parents.size(); i++)
    parent_inds[parents[i]] = i;

  return par;
}

void Node::CheckParent(size_t parnum, partype_e partype)
{
  DVH tmp_dvh;
  DSSet tmp_dss;
  DD *tmp_dd;
  VDDH tmp_vddh;
  Node *ptr = GetParent(parnum);
  ostringstream msg;

  switch(partype) {
  case REAL_MV:
    if (!ParReal(parnum, tmp_dss, DFlags(true, true))) {
      msg << "Wrong type of parents in " << GetType() << " Node "
	  << label << endl;
      msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":"
	  << ptr->GetType() << endl;
      msg << " (Expected a scalar parent giving mean and variance)";
      throw StructureException(msg.str());
    }
    break;
  case REAL_ME:
    if (!ParReal(parnum, tmp_dss, DFlags(true, false, true))) {
      msg << "Wrong type of parents in " << GetType() << " Node "
	  << label << endl;
      msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":"
	  << ptr->GetType() << endl;
      msg << " (Expected a scalar parent giving mean and exp)";
      throw StructureException(msg.str());
    }
    break;
  case REAL_M:
    if (!ParReal(parnum, tmp_dss, DFlags(true))) {
      msg << "Wrong type of parents in " << GetType() << " Node "
	  << label << endl;
      msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":"
	  << ptr->GetType() << endl;
      msg << " (Expected a scalar parent giving mean)";
      throw StructureException(msg.str());
    }
    break;
  case REALV_MV:
    if (!ParRealV(parnum, tmp_dvh, DFlags(true, true))) {
      msg << "Wrong type of parents in " << GetType() << " Node "
	  << label << endl;
      msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":"
	  << ptr->GetType() << endl;
      msg << " (Expected a vector parent giving mean and variance)";
      throw StructureException(msg.str());
    }
    break;
  case REALV_ME:
    if (!ParRealV(parnum, tmp_dvh, DFlags(true, false, true))) {
      msg << "Wrong type of parents in " << GetType() << " Node "
	  << label << endl;
      msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":"
	  << ptr->GetType() << endl;
      msg << " (Expected a vector parent giving mean and exp)";
      throw StructureException(msg.str());
    }
    break;
  case REALV_M:
    if (!ParRealV(parnum, tmp_dvh, DFlags(true))) {
      msg << "Wrong type of parents in " << GetType() << " Node "
	  << label << endl;
      msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":"
	  << ptr->GetType() << endl;
      msg << " (Expected a vector parent giving mean)";
      throw StructureException(msg.str());
    }
    break;
  case DISCRETE:
    if (!ParDiscrete(parnum, tmp_dd)) {
      msg << "Wrong type of parents in " << GetType() << " Node "
	  << label << endl;
      msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":"
	  << ptr->GetType() << endl;
      msg << " (Expected a discrete parent)";
      throw StructureException(msg.str());
    }
    break;
  case DISCRETEV:
    if (!ParDiscreteV(parnum, tmp_vddh)) {
      msg << "Wrong type of parents in " << GetType() << " Node "
	  << label << endl;
      msg << " Parent #" << parnum << " " << ptr->GetLabel() << ":"
	  << ptr->GetType() << endl;
      msg << " (Expected a discrete vector parent)";
      throw StructureException(msg.str());
    }
    break;
  }
}

void Node::ChildGradReal(DSSet &val)
{
  for (size_t i = 0; i < children.size(); i++)
    children[i]->GradReal(val, this);
}

void Node::ChildGradRealV(DVSet &val)
{
  for (size_t i = 0; i < children.size(); i++)
    children[i]->GradRealV(val, this);
}

void Node::ChildGradDiscrete(DD &val)
{
  for (size_t i = 0; i < children.size(); i++)
    children[i]->GradDiscrete(val, this);
}

void Node::ChildGradDiscreteV(VDD &val)
{
  for (size_t i = 0; i < children.size(); i++)
    children[i]->GradDiscreteV(val, this);
}

void Node::OutdateChild()
{
  for (size_t i = 0; i < children.size(); i++)
    children[i]->Outdate(this);
}

void Node::Save(NetSaver *saver)
{
  size_t i;

  saver->SetNamedLabel("label", label);

  saver->StartEnumCont(NumParents(), "parents");
  for (i=0; i<NumParents(); i++)
    saver->SetLabel(GetParent(i)->GetLabel());
  saver->FinishEnumCont("parents");

  saver->StartEnumCont(children.size(), "children");
  for (i=0; i<children.size(); i++)
    saver->SetLabel(children[i]->GetLabel());
  saver->FinishEnumCont("children");

  saver->SetNamedInt("persist", persist);
  saver->SetNamedInt("timetype", timetype);
  saver->SetNamedBool("dying", dying);
}

// class Constant : public Node

void Constant::Save(NetSaver *saver)
{
  saver->SetNamedDouble("cval", cval);
  Node::Save(saver);
}

// class ConstantV : public Node

ConstantV::ConstantV(Net *net, Label label, DV v) : Node(net, label)
{
  myval.mean = v;
  myval.var.resize(v.size());
  myval.ex.resize(v.size());
  for (size_t i = 0; i < v.size(); i++) {
    myval.ex[i] = exp(v[i]);
  }
}


void ConstantV::Save(NetSaver *saver)
{
  saver->SetNamedDVSet("myval", myval);
  Node::Save(saver);
}

// abstract class Function : public Node

Function::Function(Net *ptr, Label label, Node *n1, Node *n2) :
  Node(ptr, label)
{
  if (n1)
    AddParent(n1, false);
  if (n2)
    AddParent(n2, false);
  uptodate = DFlags(false,false,false);
  persist = 1 | 4; // Functions usually need all parents and at least one child
}

void Function::Save(NetSaver *saver)
{
  if (saver->GetSaveFunctionValue()) {
    saver->SetNamedDFlags("uptodate", uptodate);
  }
  Node::Save(saver);
}

// class Prod : public Function : public Node

bool Prod::GetReal(DSSet &val, DFlags req)
{
  bool needm = req.mean && !uptodate.mean;
  bool needv = req.var && !uptodate.var;
  if (needm || needv) {
    DSSet p0, p1;
    if (!ParReal(0, p0, DFlags(true, needv)) ||
	!ParReal(1, p1, DFlags(true, needv)))
      return false;
    if (needm) {
      mean = p0.mean * p1.mean;
      uptodate.mean = true;
    }
    if (needv) {
      var = (Sqr(p0.mean) + p0.var) * p1.var + p0.var * Sqr(p1.mean);
      uptodate.var = true;
    }
  }
  if (req.mean) {val.mean = mean; req.mean = false;}
  if (req.var) {val.var = var; req.var = false;}
  return req.AllFalse();
}

void Prod::GradReal(DSSet &val, const Node *ptr)
{
  int ide = ParIdentity(ptr);
  DSSet grad, p0, p1;

  ChildGradReal(grad);
  ParReal(ide, p0, DFlags(true));
  ParReal(1-ide, p1, DFlags(true, true));

  val.mean += grad.mean * p1.mean + 2 * grad.var * p0.mean * p1.var;
  val.var += grad.var * (Sqr(p1.mean) + p1.var);
}

void Prod::Save(NetSaver *saver)
{
  if (saver->GetSaveFunctionValue()) {
    saver->SetNamedDouble("mean", mean);
    saver->SetNamedDouble("var", var);
  }
  Function::Save(saver);
}


// class Sum2 : public Function : public Node

bool Sum2::GetReal(DSSet &val, DFlags req)
{
  bool needm = req.mean && !uptodate.mean;
  bool needv = req.var && !uptodate.var;
  bool neede = req.ex && !uptodate.ex;
  if (needm || needv || neede) {
    DSSet p0, p1;
    if (!ParReal(0, p0, DFlags(needm, needv, neede)) ||
	!ParReal(1, p1, DFlags(needm, needv, neede)))
      return false;
    if (needm) {
      myval.mean = p0.mean + p1.mean;
      uptodate.mean = true;
    }
    if (needv) {
      myval.var = p0.var + p1.var;
      uptodate.var = true;
    }
    if (neede) {
      myval.ex = p0.ex * p1.ex;
      uptodate.ex = true;
    }
  }
  if (req.mean) {val.mean = myval.mean; req.mean = false;}
  if (req.var) {val.var = myval.var; req.var = false;}
  if (req.ex) {val.ex = myval.ex; req.ex = false;}
  return req.AllFalse();
}

void Sum2::GradReal(DSSet &val, const Node *ptr)
{
  int ide = ParIdentity(ptr);
  DSSet grad;

  ChildGradReal(grad);
  val.mean += grad.mean;
  val.var += grad.var;
  if (grad.ex) {
    DSSet p1;
    ParReal(1-ide, p1, DFlags(false, false, true));
    val.ex += grad.ex * p1.ex;
  }
}

void Sum2::Save(NetSaver *saver)
{
  if (saver->GetSaveFunctionValue()) {
    saver->SetNamedDSSet("myval", myval);
  }
  Function::Save(saver);
}

// class SumN : public Function : public Node

bool SumN::GetReal(DSSet &val, DFlags req)
{
  bool needm = req.mean && !uptodate.mean;
  bool needv = req.var && !uptodate.var;
  bool neede = false; //req.ex && !uptodate.ex;

  if (needm || needv || neede) {
    if (needm) myval.mean = 0;
    if (needv) myval.var = 0;
    //    if (neede) myval.ex = 1;
    for (size_t i = 0; i<NumParents(); i++) {
      DSSet p;
      if (!ParReal(i, p, DFlags(needm, needv, neede)))
	return false;
      if (needm) {
	myval.mean += p.mean;
      }
      if (needv) {
	myval.var += p.var;
      }
      //      if (neede) {
      //	myval.ex *= p.ex;
      //      }
      if (keepupdated) {
	parentval[i] = p;
      }
    }
    if (needm) uptodate.mean = true;
    if (needv) uptodate.var = true;
    //    if (neede) uptodate.ex = true;
  }
  if (req.mean) {val.mean = myval.mean; req.mean = false;}
  if (req.var) {val.var = myval.var; req.var = false;}
  //  if (req.ex) {val.ex = myval.ex; req.ex = false;}
  return req.AllFalse();
}

void SumN::GradReal(DSSet &val, const Node *ptr)
{
  //int ide = ParIdentity(ptr);
  DSSet grad;

  ChildGradReal(grad);
  val.mean += grad.mean;
  val.var += grad.var;
  //  if (grad.ex) {
  //    DSSet p, tempself;
  //    if (!uptodate.ex) {
  //      GetReal(tempself, DFlags(false,false,true));
  //    }
  //    ParReal(ide, p, DFlags(false, false, true));
  //    val.ex += grad.ex * myval.ex / p.ex;
  //  }
}


bool SumN::AddParent(Node *n)
{
  // The parents are checked in GetReal
  Node::AddParent(n, true);
  int ide = ParIdentity(n);
  if (keepupdated) {
    parentval.resize(ide+1,DSSet(0.0,0.0,1.0));
    parentval[ide] = DSSet(0.0,0.0,1.0); // is this done already above?
  }
  uptodate = DFlags(false,false,false);
  OutdateChild();
  return true;
}

void SumN::SetKeepUpdated(const bool _keepupdated)
{
  keepupdated = _keepupdated;
  if (keepupdated)
  {
    // update now
    myval.mean = 0.0;
    myval.var = 0.0;
    //    myval.ex = 1.0;
    for (size_t i = 0; i<NumParents(); i++) {
      DSSet p;
      ParReal(i, p, DFlags(true,true,false)); //.ex
      parentval[i] = p;
      myval.mean += p.mean;
      myval.var += p.var;
      //      myval.ex *= p.ex;
    }
    uptodate.mean = true;
    uptodate.var = true;
    //    uptodate.ex = true;
  }
}

void SumN::Outdate(const Node *ptr) 
{
  if (uptodate.mean || uptodate.var) { //.ex
    if (keepupdated) {
      int ide = ParIdentity(ptr);
      DSSet p;
      ParReal(ide, p, DFlags(true, true, false)); //.ex
      myval.mean += p.mean - parentval[ide].mean;
      myval.var += p.var - parentval[ide].var;
      if (myval.var<0) {
	uptodate.var = false;
      }
      //      myval.ex *= p.ex / parentval[ide].ex;
      parentval[ide] = p;
    } else {
      uptodate = DFlags(false,false,false); 
    }
  }
  OutdateChild();
}


void SumN::Save(NetSaver *saver)
{
  if (saver->GetSaveFunctionValue()) {
    saver->SetNamedDSSet("myval", myval);
  }
  Function::Save(saver);
}

// class Relay : public Function : public Node

void Relay::Save(NetSaver *saver)
{
  Function::Save(saver);
}


// abstract class Variable : public Node

Variable::Variable(Net *ptr, Label label, Node *n1, Node *n2) : 
  Node(ptr, label)
{
  persist = 1; // Usually variables need all parents to survive
  hookeflags = 0;
  net->AddVariable(this, label);
  clamped = false;
  costuptodate = false;
  if (n1)
    AddParent(n1, false);
  if (n2)
    AddParent(n2, false);
}

void Variable::Save(NetSaver *saver)
{
  saver->SetNamedBool("clamped", clamped);
  saver->SetNamedBool("costuptodate", costuptodate);
  saver->SetNamedInt("hookeflags", hookeflags);
  Node::Save(saver);
}

void Variable::SaveState() {
  if (!clamped && !MySaveState() && (net->GetDebugLevel() > 5))
    cerr << "SaveState not supported" << endl;
}

void Variable::SaveStep() {
  if (!clamped && !MySaveStep() && (net->GetDebugLevel() > 5))
    cerr << "SaveStep not supported" << endl;
}

void Variable::RepeatStep(double alpha) {
  if (!clamped) { MyRepeatStep(alpha);  OutdateChild(); }
}

void Variable::SaveRepeatedState(double alpha) {
  if (!clamped && !MySaveRepeatedState(alpha) && (net->GetDebugLevel() > 5))
    cerr << label << ": SaveRepeatedState not supported" << endl;
}

void Variable::ClearStateAndStep() {
  if (!clamped && !MyClearStateAndStep() && (net->GetDebugLevel() > 5))
    cerr << "ClearStateAndStep not supported" << endl;
}

// class Gaussian : public Variable : public Node

Gaussian::Gaussian(Net *net, Label label, Node *m, Node *v) : 
  Variable(net, label, m, v), BiParNode(m, v)
{
  sstate = 0; sstep = 0;
  cost = 0;
    
  CheckParent(0, REAL_MV);
  CheckParent(1, REAL_ME);

  DSSet p0, p1;
  ParReal(0, p0, DFlags(true));
  ParReal(1, p1, DFlags(false, false, true));

//  myval.mean = p0.mean;
//  myval.var = 1/p1.ex;

  myval.mean = 0.0;
  myval.var = 1.0;

  exuptodate = false;
  costuptodate = false;
}

void Gaussian::GetState(DV *state, size_t t = 0)
{
  BBASSERT2(t == 0);

  state->resize(2);
  (*state)[0] = myval.mean;
  (*state)[1] = myval.var;
}

void Gaussian::SetState(DV *state, size_t t = 0)
{
  BBASSERT2(t == 0);
  BBASSERT2(state->size() == 2);

  myval.mean = (*state)[0];
  myval.var = (*state)[1];

  costuptodate = false;
  exuptodate = false;
  
  OutdateChild();
}

double Gaussian::Cost()
{
  if (!clamped && children.empty())
    return 0;
  if (!costuptodate) {
    DSSet p0, p1;
    ParReal(0, p0, DFlags(true, true));
    ParReal(1, p1, DFlags(true, false, true));
    if (clamped) {
      //assert(myval.var == 0);
      cost = ((Sqr(myval.mean - p0.mean) + p0.var + myval.var)
	      * p1.ex - p1.mean) / 2 + _5LOG2PI;
    }
    else
      cost = ((Sqr(myval.mean - p0.mean) + p0.var + myval.var) *
	      p1.ex - p1.mean - log(myval.var) - 1) / 2;
    costuptodate = true;
  }
  return cost;
}

bool Gaussian::GetReal(DSSet &val, DFlags req)
{
  if (req.ex && !exuptodate) {
    myval.ex = exp(myval.mean+myval.var/2);
    exuptodate = true;
  }
  if (req.mean) {val.mean = myval.mean; req.mean = false;}
  if (req.var) {val.var = myval.var; req.var = false;}
  if (req.ex) {val.ex = myval.ex; req.ex = false;}
  return req.AllFalse();
}

void Gaussian::GradReal(DSSet &val, const Node *ptr)
{
  if (!clamped && children.empty())
    return;

  if (ParIdentity(ptr)) { // ParMean(1)
    DSSet p0;
    ParReal(0, p0, DFlags(true, true));
    val.mean -= 0.5;
    if (clamped) {
      //assert(myval.var == 0);
      val.ex += (Sqr(myval.mean - p0.mean) + p0.var + myval.var) / 2;
    }
    else // !clamped
      val.ex += (Sqr(myval.mean - p0.mean) + p0.var + myval.var) / 2;
  }
  else {                  // ParMean(0)
    DSSet p0, p1;
    ParReal(0, p0, DFlags(true));
    ParReal(1, p1, DFlags(false, false, true));
    val.mean += (p0.mean - myval.mean) * p1.ex;
    val.var += p1.ex / 2;
  }
}

void Gaussian::Save(NetSaver *saver)
{
  saver->SetNamedDSSet("myval", myval);
  saver->SetNamedDouble("cost", cost);
  saver->SetNamedBool("exuptodate", exuptodate);
  if (sstate)
    saver->SetNamedDSSet("sstate", *sstate);
  if (sstep)
    saver->SetNamedDSSet("sstep", *sstep);
  Variable::Save(saver);
}


void VarNewton(double &mean, double &var, double gme, double gva,
	       double gex, Label label)
{
  if (!gex) {
    var = 0.5 / gva;
    mean -= var * gme;
  }
  else {
    // Solve the minimum of gex * exp(mean + var/2) + gme * mean +
    // gva * [(mean - current_mean)^2 + var] - 0.5 * log(var)
    double oldm = mean, oldv = var, mstep, vstep, coef;
    double newc, oldc = gex * exp(mean + var/2) + gme * mean +
      gva * var - 0.5 * log(var);

    int i = 0;
    do {
      mstep = -(gme + 2 * gva * (mean - oldm) + gex * exp(mean+var/2)) /
	(2 * gva + gex * exp(mean+var/2));
      mean += (mstep > MAXSTEP) ? MAXSTEP : mstep;
      vstep = 1 / (2 * gva + gex * exp(mean + var/2)) - var;
      coef = var * (0.5 - gva * var);
      if (coef > 0)
	vstep /= 1 + coef;
      var += (vstep > MAXSTEP) ? MAXSTEP : vstep;
      if (++i >= 100) {
	cerr << label << " VarNewton: M=" << oldm << "; V=" << oldv << "; GEX="
	     << gex << "; GVA=" << gva << "; GME=" << gme
	     << ": mstep = " << mstep << ", vstep = " << vstep << '\n';
	mstep = vstep = 0;
      }
    }
    while (fabs(mstep) > MINSTEP || fabs(vstep) > MINSTEP);
    newc = gex * exp(mean + var/2) + gme * mean +
      gva * (Sqr(mean - oldm) + var) - 0.5 * log(var);
    if (newc > oldc + EPSILON)
      cerr << label << " VarNewton: M=" << oldm << "; V=" << oldv << "; GEX="
	   << gex << "; GVA=" << gva << "; GME=" << gme
	   << ": diff = " << newc - oldc << '\n';
  }
}

void Gaussian::MyPartialUpdate(IntV *indices)
{
  MyUpdate();
}

void Gaussian::MyUpdate()
{
  if (NumChildren() == 0) {
    return;
  }

  DSSet grad, p0, p1;
  ChildGradReal(grad);
  ParReal(0, p0, DFlags(true));
  ParReal(1, p1, DFlags(false, false, true));
  VarNewton(myval.mean, myval.var, (myval.mean - p0.mean) * p1.ex + grad.mean,
	    p1.ex/2 + grad.var, grad.ex, label);
  exuptodate = false; costuptodate = false;
}

bool Gaussian::MyClamp(double m)
{
  myval.mean = m; myval.var = 0;
  exuptodate = false;
  return true;
}

bool Gaussian::MyClamp(double m, double v)
{
  myval.mean = m; myval.var = v;
  exuptodate = false;
  return true;
}

bool Gaussian::MySaveState()
{
  if (!sstate)
    sstate = new DSSet;

  sstate->mean = myval.mean;
  sstate->var  = log(myval.var);

  return true;
}

bool Gaussian::MySaveStep()
{
  if (!sstate) return false;
  if (!sstep)
    sstep = new DSSet;

  switch (hookeflags) {
  case 0:
  case 1:
    sstep->mean = myval.mean - sstate->mean;
    sstep->var  = log(myval.var) - sstate->var;
    break;
  case 2:
  case 3:
    if (sstate->mean == 0)
      sstep->mean = -1.0;
    else
      sstep->mean = myval.mean / sstate->mean;
    sstep->var  = log(myval.var) - sstate->var;
    break;
  }

  return true;
}

bool Gaussian::MySaveRepeatedState(double alpha)
{
  if (!sstate || !sstep) {
    cerr << label << ": No saved state when trying to repeat step!" << endl;
    return false;
  }
  switch (hookeflags) {
  case 0:
  case 2:
    sstate->mean = sstate->mean + alpha * sstep->mean;
    sstate->var = sstate->var + alpha * sstep->var;
    break;
  case 1:
  case 3:
    if (sstep->mean > 0)
      sstate->mean = sstate->mean * exp(alpha * log(sstep->mean));
    break;
  }
  return true;
}

void Gaussian::MyRepeatStep(double alpha)
{
  if (!sstate || !sstep) {
    cerr << label << ": No saved state when trying to repeat step!" << endl;
    return;
  }
  switch (hookeflags) {
  case 0:
    myval.mean = sstate->mean + alpha * sstep->mean;
    myval.var = exp(sstate->var + alpha * sstep->var);
    break;
  case 1:
    myval.mean = sstate->mean + alpha * sstep->mean;
    break;
  case 2:
    if (sstep->mean > 0)
      myval.mean = sstate->mean * exp(alpha * log(sstep->mean));
    myval.var = exp(sstate->var + alpha * sstep->var);
    break;
  case 3:
    if (sstep->mean > 0)
      myval.mean = sstate->mean * exp(alpha * log(sstep->mean));
    break;
  }
  exuptodate = false; costuptodate = false;
}

bool Gaussian::MyClearStateAndStep()
{
  if (sstate) {
    delete sstate;
    sstate = 0;
  }
  if (sstep) {
    delete sstep;
    sstep = 0;
  }
  return true;
}

/* class RectifiedGaussian */

RectifiedGaussian::RectifiedGaussian(Net *net, Label label, Node *m, Node *v) :
  Variable(net, label, m, v), BiParNode(m, v)
{
  cost = 0;
  CheckParent(0, REAL_MV);
  CheckParent(0, REAL_ME);

  myval.mean = 0;
  myval.var = 1;

  UpdateExpectations();

  MyUpdate();
}

void RectifiedGaussian::GetState(DV *state, size_t t = 0)
{
  BBASSERT2(t == 0);
  
  state->resize(2);
  (*state)[0] = myval.mean;
  (*state)[1] = myval.var;
}


void RectifiedGaussian::SetState(DV *state, size_t t = 0)
{
  BBASSERT2(t == 0 && state->size() == 2);

  myval.mean = (*state)[0];
  myval.var = (*state)[1];

  UpdateExpectations();
  costuptodate = false;
  OutdateChild();
}

void RectifiedGaussian::Save(NetSaver *saver)
{
  saver->SetNamedDSSet("myval", myval);
  saver->SetNamedDSSet("expectations", expectations);
  saver->SetNamedDouble("cost", cost);
  Variable::Save(saver);
}

string RectifiedGaussian::GetType() const 
{ 
  return "RectifiedGaussian"; 
}

double RectifiedGaussian::Cost()
{
  if (children.empty()) {
    return 0;
  }

  if (!costuptodate) {
    DSSet mpar, vpar;
    ParReal(0, mpar, DFlags(true, true, false));
    ParReal(1, vpar, DFlags(true, false, true));

    /* C_p */
    cost = 0.5 * (vpar.ex * (Sqr(expectations.mean - mpar.mean)
			     + expectations.var + mpar.var)
		  - vpar.mean + log(PI/2)); // + log(2) if m != const0
#if RECTIFIED_BETTER_APPROX
    cost += log(Erfc(-1/sqrt(2.0) * mpar.mean * 
		     exp(vpar.mean / 2 + vpar.var / 8)));
#endif
    /* C_q */
    cost += -1/(2*myval.var) * (expectations.var 
				+ Sqr(expectations.mean - myval.mean))
      + 0.5 * log(2/(PI*myval.var)) 
      - log(Erfc(-myval.mean / sqrt(2*myval.var)));

    costuptodate = true;
  }

  return cost;
}

bool RectifiedGaussian::GetMyval(DSSet &val)
{
  val.mean = myval.mean;
  val.var = myval.var;
  return true;
}

bool RectifiedGaussian::GetReal(DSSet &val, DFlags req)
{
  if (req.mean) {
    val.mean = expectations.mean;
  }
  if (req.var) { 
    val.var = expectations.var;
  }
  /* Clients hoping for ex are out of luck. */
  if (req.ex) {
    return false;
  } else {
    return true;
  }
}

void RectifiedGaussian::GradReal(DSSet &val, const Node *ptr)
{
  if (!clamped && children.empty())
    return;

  if (ParIdentity(ptr) == 1) { // Variance parent
    DSSet p0;
    ParReal(0, p0, DFlags(true, true));
    val.mean -= 0.5;
    val.ex += 0.5 * (Sqr(expectations.mean - p0.mean) 
		     + expectations.var + p0.var);
  } else if (ParIdentity(ptr) == 0) { // Mean parent
    DSSet p0, p1;
    ParReal(0, p0, DFlags(true));
    ParReal(1, p1, DFlags(false, false, true));
    val.mean += p1.ex * (expectations.mean - p0.mean);
    val.var += 0.5 * p1.ex;
#if RECTIFIED_BETTER_APPROX
    val.mean += sqrt(2 / PI)
      * exp(-0.5 * exp(p1.mean + p1.var / 4) * Sqr(p0.mean) 
	    + p1.mean / 2 + p1.var / 8)
      / Erfc(-1 / sqrt(2.0) * exp(p1.mean / 2 + p1.var / 8) * p0.mean);
#endif
  } else { 
    BBASSERT2(false);
  }
}

void RectifiedGaussian::MyPartialUpdate(IntV *indices)
{
  MyUpdate();
}

void RectifiedGaussian::MyUpdate()
{
  DSSet grad, mpar, vpar;;
  /* double linvvar, lmean; */

  /* From the gradient we can deduce the likelihood and since the
     likelihood is known to be Gaussian and the prior is Rectified
     Gaussian the Rectified Gaussian posterior approximation matches
     exactly to the correct posterior and the parameters can be set
     according to the correct posterior. */

  ChildGradReal(grad);
  ParReal(0, mpar, DFlags(true));
  ParReal(1, vpar, DFlags(false, false, true));

  /* This is what happens below.
  linvvar = 2 * grad.var;
  lmean = expectations.mean - grad.mean / linvvar;

  myval.var = 1 / (linvvar + vpar.ex); 
  myval.mean = myval.var * (linvvar * lmean + vpar.ex * mpar.mean); 
  */

  myval.var = 1 / (2*grad.var + vpar.ex);
  myval.mean = myval.var * (2*grad.var * expectations.mean - grad.mean
			    + vpar.ex * mpar.mean);

  /* myvals have changed, expectations needs to be updated */
  UpdateExpectations();

  costuptodate = false;
}

void RectifiedGaussian::UpdateExpectations()
{
  if ((myval.mean / sqrt(myval.var)) > RECTLIMIT) {
    double scale;
    scale = sqrt(2*myval.var/PI) / Erfcx(-myval.mean / sqrt(2*myval.var));
    expectations.mean = myval.mean + scale;
    expectations.var = Sqr(myval.mean) + myval.var + scale * myval.mean
      - Sqr(expectations.mean);
  } else { /* use exponential approximation */
    expectations.mean = -myval.var / myval.mean;
    expectations.var = Sqr(expectations.mean);
  }
}

/* class RectifiedGaussianV */

RectifiedGaussianV::RectifiedGaussianV(Net *net, Label label, 
				       Node *m, Node *v) :
  Variable(net, label, m, v), BiParNode(m, v)
{
  cost = 0;

  myval.mean.resize(net->Time()); 
  myval.var.resize(net->Time());

  expectations.mean.resize(net->Time());
  expectations.var.resize(net->Time());

  CheckParent(0, REALV_MV);
  CheckParent(1, REALV_ME);

  for (size_t i = 0; i < net->Time(); i++) {
    myval.mean[i] = 0.0;
    myval.var[i] = 1.0;
  }

  UpdateExpectations();

  MyUpdate();
}

void RectifiedGaussianV::Save(NetSaver *saver)
{
  saver->SetNamedDVSet("myval", myval);
  saver->SetNamedDVSet("expectations", expectations);
  saver->SetNamedDouble("cost", cost);
  Variable::Save(saver);
}

string RectifiedGaussianV::GetType() const 
{ 
  return "RectifiedGaussianV"; 
}

double RectifiedGaussianV::Cost()
{
  if (children.empty()) {
    return 0;
  }

  if (!costuptodate) {
    DVH mpar, vpar;
    ParRealV(0, mpar, DFlags(true, true));
    ParRealV(1, vpar, DFlags(true, false, true));

    double c = 0;

    for (size_t i = 0; i < net->Time(); i++) {
      /* C_p */
      c += 0.5 * (vpar.Exp(i) * (Sqr(expectations.mean[i] - mpar.Mean(i))
				 + expectations.var[i] + mpar.Var(i))
		  - vpar.Mean(i) + log(PI/2)); // + log(2)
#if RECTIFIED_BETTER_APPROX
      c += log(Erfc(-1/sqrt(2.0) * mpar.Mean(i) * 
		    exp(vpar.Mean(i) / 2 + vpar.Var(i) / 8)));
#endif
      /* C_q */
      c += -1/(2*myval.var[i]) *
	(expectations.var[i] + Sqr(expectations.mean[i] - myval.mean[i])) +
	0.5 * log(2/(PI*myval.var[i])) - 
	log(Erfc(-myval.mean[i] / sqrt(2*myval.var[i])));
    }

    cost = c;
    costuptodate = true;
  }

  return cost;
}

bool RectifiedGaussianV::GetRealV(DVH &val, DFlags req)
{
  val.vec = &expectations;

  if (req.ex) {
    return false;
  } else {
    return true;
  }
}

bool RectifiedGaussianV::GetMyvalV(DVH &val)
{
  val.vec = &myval;
  return true;
}

void RectifiedGaussianV::GradRealV(DVSet &val, const Node *ptr)
{
  if (!clamped && children.empty())
    return;

  if (ParIdentity(ptr) == 1) { // Variance parent
    DVH p0;
    ParRealV(0, p0, DFlags(true, true));
    val.mean.resize(net->Time());
    val.ex.resize(net->Time());

    for (size_t i = 0; i < net->Time(); i++) {
      val.mean[i] -= 0.5;
      val.ex[i] += (Sqr(expectations.mean[i] - p0.Mean(i)) + p0.Var(i) +
		    expectations.var[i]) / 2;
    }
  } else if (ParIdentity(ptr) == 0) { // Mean parent
    DVH p0, p1;
    ParRealV(0, p0, DFlags(true));
    ParRealV(1, p1, DFlags(false, false, true));
    val.mean.resize(net->Time());
    val.var.resize(net->Time());

    for (size_t i = 0; i < net->Time(); i++) {
      val.mean[i] += (p0.Mean(i) - expectations.mean[i]) * p1.Exp(i);
      val.var[i] += p1.Exp(i) / 2;
#if RECTIFIED_BETTER_APPROX
      val.mean[i] += sqrt(2 / PI)
	* exp(-0.5 * exp(p1.Mean(i) + p1.Var(i) / 4) * Sqr(p0.Mean(i)) 
	      + p1.Mean(i) / 2 + p1.Var(i) / 8)
	/ Erfc(-1/sqrt(2.0) * exp(p1.Mean(i)/2 + p1.Var(i)/8) * p0.Mean(i));
#endif
    }
  } else {
    BBASSERT2(0);
  }
}

void RectifiedGaussianV::GradReal(DSSet &val, const Node *ptr)
{
  if (!clamped && children.empty())
    return;

  if (ParIdentity(ptr) == 1) { // Variance parent
    DVH p0;
    ParRealV(0, p0, DFlags(true, true));

    for (size_t i = 0; i < net->Time(); i++) {
      val.ex += (Sqr(expectations.mean[i] - p0.Mean(i)) + p0.Var(i) +
		 expectations.var[i]) / 2;
    }

    val.mean -= 0.5 * net->Time();
  } else if (ParIdentity(ptr) == 0) {  // Mean parent
    DVH p0, p1;
    ParRealV(0, p0, DFlags(true));
    ParRealV(1, p1, DFlags(false, false, true));

    for (size_t i = 0; i < net->Time(); i++) {
      val.mean += (p0.Mean(i) - expectations.mean[i]) * p1.Exp(i);
      val.var += p1.Exp(i) / 2;
#if RECTIFIED_BETTER_APPROX
      val.mean += sqrt(2 / PI)
	* exp(-0.5 * exp(p1.Mean(i) + p1.Var(i) / 4) * Sqr(p0.Mean(i)) 
	      + p1.Mean(i) / 2 + p1.Var(i) / 8)
	/ Erfc(-1/sqrt(2.0) * exp(p1.Mean(i)/2 + p1.Var(i)/8) * p0.Mean(i));
#endif
    }
  } else {
    BBASSERT2(0);
  }
}

void RectifiedGaussianV::MyUpdate()
{
  if (NumChildren() < 1) {
    return;
  }

  DVSet grad;
  ChildGradRealV(grad);

  DVH mpar, vpar;
  ParRealV(0, mpar, DFlags(true));
  ParRealV(1, vpar, DFlags(false, false, true));

  double gm = 0;
  double gv = 0;
  bool hasgradient = grad.mean.size() > 0;

  for (size_t i = 0; i < net->Time(); i++) {
    if (hasgradient) {
      gm = grad.mean[i];
      gv = grad.var[i];
    }

    myval.var[i] = 1 / (2*gv + vpar.Exp(i));
    myval.mean[i] = myval.var[i] * (2*gv * expectations.mean[i] - gm
				    + vpar.Exp(i) * mpar.Mean(i));
  }

  UpdateExpectations();

  costuptodate = false;
}

void RectifiedGaussianV::UpdateExpectations()
{
  double s;

  for (size_t i = 0; i < net->Time(); i++) {
    if ((myval.mean[i] / sqrt(myval.var[i])) > RECTLIMIT) {
      s = sqrt(2 * myval.var[i] / PI) 
	/ Erfcx(-myval.mean[i] / sqrt(2 * myval.var[i]));
      expectations.mean[i] = myval.mean[i] + s;
      expectations.var[i] = Sqr(myval.mean[i]) + myval.var[i] 
	+ s * myval.mean[i] - Sqr(expectations.mean[i]);
    } else {
      expectations.mean[i] = - myval.var[i] / myval.mean[i];
      expectations.var[i] = Sqr(expectations.mean[i]);
    }
  }
}


/* Class GaussRectV */

GaussRectV::GaussRectV(Net *net, Label label, Node *m, Node *v) :
  Variable(net, label, m, v), BiParNode(m, v)
{
  cost = 0.0;

  posval.mean.resize(net->Time()); 
  posval.var.resize(net->Time());

  negval.mean.resize(net->Time()); 
  negval.var.resize(net->Time());

  posweights.resize(net->Time());
  negweights.resize(net->Time());

  for (size_t i = 0; i < net->Time(); i++) {
    posval.mean[i] = 0.1;
    posval.var[i] = 1.0;
    negval.mean[i] = 0.1;
    negval.var[i] = 1.0;
    posweights[i] = 1.0;
    negweights[i] = 1.0;
  }

  posmoments.resize(3);
  negmoments.resize(3);

  for (int i = 0; i < 3; i++) {
    posmoments[i].resize(net->Time());
    negmoments[i].resize(net->Time());
  }

  expts.mean.resize(net->Time());
  expts.var.resize(net->Time());

  rectexpts.mean.resize(net->Time());
  rectexpts.var.resize(net->Time());

  CheckParent(0, REALV_MV);
  CheckParent(1, REALV_ME);

//  MyUpdate();
  UpdateMoments();
  UpdateExpectations();
}

void GaussRectV::GetState(DV *state, size_t t)
{
  BBASSERT2(t < net->Time());
  
  state->resize(6);

  (*state)[0] = posval.mean[t];
  (*state)[1] = posval.var[t];
  (*state)[2] = negval.mean[t];
  (*state)[3] = negval.var[t];
  (*state)[4] = posweights[t];
  (*state)[5] = negweights[t];
}

void GaussRectV::SetState(DV *state, size_t t)
{
  BBASSERT2(t < net->Time());
  BBASSERT2(state->size() == 6);

  posval.mean[t] = (*state)[0];
  posval.var[t] = (*state)[1];
  negval.mean[t] = (*state)[2];
  negval.var[t] = (*state)[3];
  posweights[t] = (*state)[4];
  negweights[t] = (*state)[5];

  UpdateMoments(); // change one, update all =)
  UpdateExpectations();
  costuptodate = false;

  OutdateChild();
}

void GaussRectV::Save(NetSaver *saver)
{
  saver->SetNamedDVSet("posval", posval);
  saver->SetNamedDVSet("negval", negval);

  saver->SetNamedDV("posweights", posweights);
  saver->SetNamedDV("negweights", negweights);

  saver->SetNamedDouble("cost", cost);

  Variable::Save(saver);
}

string GaussRectV::GetType() const 
{ 
  return "GaussRectV";
}

void GaussRectV::UpdateMoments()
{
  for (size_t t = 0; t < net->Time(); t++) {
    double sp = Erfc(-posval.mean[t] / sqrt(2*posval.var[t]));
    double a = 0.5 * posweights[t];
    double b = sqrt(2*posval.var[t] / PI) 
      / exp(Sqr(posval.mean[t]) / (2 * posval.var[t]));

    BBASSERT2(finite(sp));
    BBASSERT2(finite(a));
    BBASSERT2(finite(b));

    posmoments[0][t] = a * sp;
    posmoments[1][t] = a * (sp * posval.mean[t] + b);
    posmoments[2][t] = a * (sp * (Sqr(posval.mean[t]) + posval.var[t])
			    + b * posval.mean[t]);


    BBASSERT2(finite(posmoments[0][t]));
    BBASSERT2(finite(posmoments[1][t]));
    BBASSERT2(finite(posmoments[2][t]));


    double sn = Erfc(negval.mean[t] / sqrt(2 * negval.var[t]));
    a = 0.5 * negweights[t];
    b = sqrt(2*negval.var[t] / PI) 
      / exp(Sqr(negval.mean[t]) / (2 * negval.var[t]));

    negmoments[0][t] = a * sn;
    negmoments[1][t] = a * (sn * negval.mean[t] - b);
    negmoments[2][t] = a * (sn * (Sqr(negval.mean[t]) + negval.var[t])
			    - b * negval.mean[t]);

    BBASSERT2(finite(negmoments[0][t]));
    BBASSERT2(finite(negmoments[1][t]));
    BBASSERT2(finite(negmoments[2][t]));


  }
}

void GaussRectV::UpdateExpectations()
{
  for (size_t t = 0; t < net->Time(); t++) {
    expts.mean[t] = posmoments[1][t] + negmoments[1][t];
    expts.var[t] = posmoments[2][t] + negmoments[2][t] - Sqr(expts.mean[t]);

    rectexpts.mean[t] = posmoments[1][t];
    rectexpts.var[t] = posmoments[2][t] - Sqr(rectexpts.mean[t]);
  }
}

double GaussRectV::Cost()
{
  if (children.empty()) {
    return 0;
  }

  if (!costuptodate) {
    DVH mpar, vpar;
    ParRealV(0, mpar, DFlags(true, true));
    ParRealV(1, vpar, DFlags(true, false, true));

    double c = 0;
    double S;

    for (size_t i = 0; i < net->Time(); i++) {
      /* C_p */
      c += 0.5 * (vpar.Exp(i) * (Sqr(expts.mean[i] - mpar.Mean(i)) 
				 + expts.var[i] + mpar.Var(i))
		  - vpar.Mean(i) + log(2*PI));
//       if (!finite(c)) {
// 	cout << "expts: mean = " << expts.mean[i] << ", "
// 	     << "var = " << expts.var[i] << endl;
// 	cout << "mpar: mean = " << mpar.Mean(i) << ", "
// 	     << "var = " << mpar.Var(i) << endl;
// 	cout << "vpar: mean = " << vpar.Mean(i) << ", "
// 	     << "exp = " << vpar.Exp(i) << endl;
//       }
      BBASSERT2(finite(c));

      /* C_q^+ */
      S = Erfc(-posval.mean[i] / sqrt(2 * posval.var[i]));
      if (posweights[i] > EPSILON) {
	c += S/2 * (posweights[i] * log(posweights[i])
		    - posweights[i]/2 * log(2*PI*posval.var[i]));
	BBASSERT2(finite(c));
      } else {
	c -= S/4 * posweights[i] * log(2*PI*posval.var[i]);
	BBASSERT2(finite(c));
      }
      c += - Sqr(posval.mean[i]) / (2 * posval.var[i]) * posmoments[0][i]
	+ posval.mean[i] / posval.var[i] * posmoments[1][i]
	- posmoments[2][i] / (2*posval.var[i]);
      BBASSERT2(finite(c));

      /* C_q^- */
      S = Erfc(negval.mean[i] / sqrt(2 * negval.var[i]));
      if (negweights[i] > EPSILON) {
	c += S/2 * (negweights[i] * log(negweights[i])
		    - negweights[i]/2 * log(2*PI*negval.var[i]));
	BBASSERT2(finite(c));
      } else {
	c -= S/4 * negweights[i] * log(2*PI*negval.var[i]);
	BBASSERT2(finite(c));
      }
      c += - Sqr(negval.mean[i]) / (2 * negval.var[i]) * negmoments[0][i]
	+ negval.mean[i] / negval.var[i] * negmoments[1][i]
	- negmoments[2][i] / (2*negval.var[i]);
      BBASSERT2(finite(c));
    }

    cost = c;
    costuptodate = true;
  }

  BBASSERT2(finite(cost));

  return cost;
}

bool GaussRectV::GetRealV(DVH &val, DFlags req)
{
  val.vec = &expts;
  return !req.ex;
}

bool GaussRectV::GetRectRealV(DVH &val, DFlags req)
{
  val.vec = &rectexpts;
  return !req.ex;
}

void GaussRectV::GradRealV(DVSet &val, const Node *ptr)
{
  if (!clamped && children.empty()) {
    return;
  }

  if (ParIdentity(ptr) == 1) { // Variance parent
    DVH p0;
    ParRealV(0, p0, DFlags(true, true));
    val.mean.resize(net->Time());
    val.ex.resize(net->Time());

    for (size_t i = 0; i < net->Time(); i++) {
      val.mean[i] -= 0.5;
      val.ex[i] += (Sqr(expts.mean[i] - p0.Mean(i)) + p0.Var(i) 
		    + expts.var[i]) / 2;
    }
  } else if (ParIdentity(ptr) == 0) { // Mean parent
    DVH p0, p1;
    ParRealV(0, p0, DFlags(true));
    ParRealV(1, p1, DFlags(false, false, true));
    val.mean.resize(net->Time());
    val.var.resize(net->Time());

    for (size_t i = 0; i < net->Time(); i++) {
      val.mean[i] += (p0.Mean(i) - expts.mean[i]) * p1.Exp(i);
      val.var[i] += p1.Exp(i) / 2;
    }
  } else {
    BBASSERT2(false);
  }
}

void GaussRectV::GradReal(DSSet &val, const Node *ptr)
{
  if (!clamped && children.empty())
    return;

  if (ParIdentity(ptr) == 1) { // Variance parent
    DVH p0;
    ParRealV(0, p0, DFlags(true, true));

    for (size_t i = 0; i < net->Time(); i++) {
      val.ex += (Sqr(expts.mean[i] - p0.Mean(i)) + p0.Var(i) 
		 + expts.var[i]) / 2;
    }

    val.mean -= 0.5 * net->Time();
  } else if (ParIdentity(ptr) == 0) {  // Mean parent
    DVH p0, p1;
    ParRealV(0, p0, DFlags(true));
    ParRealV(1, p1, DFlags(false, false, true));

    for (size_t i = 0; i < net->Time(); i++) {
      val.mean += (p0.Mean(i) - expts.mean[i]) * p1.Exp(i);
      val.var += p1.Exp(i) / 2;
    }
  } else {
    BBASSERT2(false);
  }
}

void GaussRectV::MyPartialUpdate(IntV *indices)
{
  if (NumChildren() < 1) {
    return;
  }

  DVSet ng; // gradient from direct children
  DVSet rg; // gradient from children below the rectification
  ChildGradients(ng, rg);

  if (rg.mean.size() == 0) {
    return;
  }

  DVH mpar, vpar;
  ParRealV(0, mpar, DFlags(true));
  ParRealV(1, vpar, DFlags(false, false, true));

  double x, vx, ivx, z;
  bool limitexceded = false;

  for (size_t j = 0; j < indices->size(); j++) {
    int i = (*indices)[j];

    if (ng.mean.size() == 0) {
      negval.var[i] = 1 / vpar.Exp(i);
      negval.mean[i] = mpar.Mean(i);
    } else {
      negval.var[i] = 1 / (vpar.Exp(i) + 2 * ng.var[i]);
      negval.mean[i] = negval.var[i] * 
	(vpar.Exp(i) * mpar.Mean(i)
	 + 2 * ng.var[i] * expts.mean[i] - ng.mean[i]);
    }

    ivx = 2 * rg.var[i];
    vx = 1 / ivx;
    x = rectexpts.mean[i] - vx * rg.mean[i];

    posval.var[i] = 1 / (ivx + 1 / negval.var[i]);
    posval.mean[i] = posval.var[i] * 
      (ivx * x + negval.mean[i] / negval.var[i]);

    posweights[i] = NormPdf(x, negval.mean[i], vx + negval.var[i]);
    negweights[i] = NormPdf(x, 0, vx);

    z = 0.5 * posweights[i] * Erfc(-posval.mean[i] / sqrt(2 * posval.var[i]))
      + 0.5 * negweights[i] * Erfc(negval.mean[i] / sqrt(2 * negval.var[i]));

    if (z > GAUSSRECTLIMIT) {
      z = 1/z;
    } else {
      limitexceded = true;
      z = 1 / GAUSSRECTLIMIT;
    }
    
//     if (!finite(z)) {
//       cout << "posval: mean = " << posval.mean[i] << ", "
// 	   << "var = " << posval.var[i] << endl;
//       cout << "negval: mean = " << negval.mean[i] << ", "
// 	   << "var = " << negval.var[i] << endl;
//     }


    BBASSERT2(finite(z));

    posweights[i] *= z;
    negweights[i] *= z;

    BBASSERT2(finite(posweights[i]));
    BBASSERT2(finite(negweights[i]));
  }

  if (limitexceded) {
    cout << "Warning: Limit exceded in " << GetLabel() << endl;
  }

  UpdateMoments();
  UpdateExpectations();

  costuptodate = false;
}

void GaussRectV::MyUpdate()
{
  if (NumChildren() < 1) {
    return;
  }

  DVSet ng; // gradient from direct children
  DVSet rg; // gradient from children below the rectification
  ChildGradients(ng, rg);

  if (rg.mean.size() == 0) {
    return;
  }

  DVH mpar, vpar;
  ParRealV(0, mpar, DFlags(true));
  ParRealV(1, vpar, DFlags(false, false, true));

  double x, vx, ivx, z;
  bool limitexceded = false;

  for (size_t i = 0; i < net->Time(); i++) {
    if (ng.mean.size() == 0) {
      negval.var[i] = 1 / vpar.Exp(i);
      negval.mean[i] = mpar.Mean(i);
    } else {
      negval.var[i] = 1 / (vpar.Exp(i) + 2 * ng.var[i]);
      negval.mean[i] = negval.var[i] * 
	(vpar.Exp(i) * mpar.Mean(i)
	 + 2 * ng.var[i] * expts.mean[i] - ng.mean[i]);
    }

    ivx = 2 * rg.var[i];
    vx = 1 / ivx;
    x = rectexpts.mean[i] - vx * rg.mean[i];

    posval.var[i] = 1 / (ivx + 1 / negval.var[i]);
    posval.mean[i] = posval.var[i] * 
      (ivx * x + negval.mean[i] / negval.var[i]);

    posweights[i] = NormPdf(x, negval.mean[i], vx + negval.var[i]);
    negweights[i] = NormPdf(x, 0, vx);

    z = 0.5 * posweights[i] * Erfc(-posval.mean[i] / sqrt(2 * posval.var[i]))
      + 0.5 * negweights[i] * Erfc(negval.mean[i] / sqrt(2 * negval.var[i]));

    if (z > GAUSSRECTLIMIT) {
      z = 1/z;
    } else {
      limitexceded = true;
      z = 1 / GAUSSRECTLIMIT;
    }
    
//     if (!finite(z)) {
//       cout << "posval: mean = " << posval.mean[i] << ", "
// 	   << "var = " << posval.var[i] << endl;
//       cout << "negval: mean = " << negval.mean[i] << ", "
// 	   << "var = " << negval.var[i] << endl;
//     }


    BBASSERT2(finite(z));

    posweights[i] *= z;
    negweights[i] *= z;

    BBASSERT2(finite(posweights[i]));
    BBASSERT2(finite(negweights[i]));
  }

  if (limitexceded) {
    cout << "Warning: Limit exceded in " << GetLabel() << endl;
  }

  UpdateMoments();
  UpdateExpectations();

  costuptodate = false;
}

void GaussRectV::ChildGradients(DVSet &norm, DVSet &rect)
{
  for (size_t i = 0; i < children.size(); i++) {
    RectificationV *rnode = dynamic_cast<RectificationV *>(children[i]);
    if (rnode == 0) { // not the rectification node
      children[i]->GradRealV(norm, this);
    } else { // yes, this is the rectification node
      rnode->GradRealV(rect, this);
    }
  }
}


/* Class GaussRect */

GaussRect::GaussRect(Net *net, Label label, Node *m, Node *v) :
  Variable(net, label, m, v), BiParNode(m, v)
{
  cost = 0.0;

  posval.mean = 0.1;
  posval.var = 1.0;
  negval.mean = 0.1;
  negval.var = 1.0;
  posweight = 1.0;
  negweight = 1.0;

  posmoments.resize(3);
  negmoments.resize(3);

  CheckParent(0, REAL_MV);
  CheckParent(1, REAL_ME);

  UpdateMoments();
  UpdateExpectations();
}

void GaussRect::Save(NetSaver *saver)
{
  saver->SetNamedDSSet("posval", posval);
  saver->SetNamedDSSet("negval", negval);

  saver->SetNamedDouble("posweight", posweight);
  saver->SetNamedDouble("negweight", negweight);

  saver->SetNamedDouble("cost", cost);

  Variable::Save(saver);
}

string GaussRect::GetType() const 
{ 
  return "GaussRect";
}

void GaussRect::UpdateMoments()
{
  double sp = Erfc(-posval.mean / sqrt(2*posval.var));
  double a = 0.5 * posweight;
  double b = sqrt(2*posval.var / PI)
    / exp(Sqr(posval.mean) / (2 * posval.var));

  posmoments[0] = a * sp;
  posmoments[1] = a * (sp * posval.mean + b);
  posmoments[2] = a * (sp * (Sqr(posval.mean) + posval.var)
		       + b * posval.mean);

  double sn = Erfc(negval.mean / sqrt(2 * negval.var));
  a = 0.5 * negweight;
  b = sqrt(2*negval.var / PI) 
    / exp(Sqr(negval.mean) / (2 * negval.var));

  negmoments[0] = a * sn;
  negmoments[1] = a * (sn * negval.mean - b);
  negmoments[2] = a * (sn * (Sqr(negval.mean) + negval.var)
		       - b * negval.mean);
}

void GaussRect::UpdateExpectations()
{
  expts.mean = posmoments[1] + negmoments[1];
  expts.var = posmoments[2] + negmoments[2] - Sqr(expts.mean);

  rectexpts.mean = posmoments[1];
  rectexpts.var = posmoments[2] - Sqr(rectexpts.mean);
}

double GaussRect::Cost()
{
  if (children.empty()) {
    return 0;
  }

  if (!costuptodate) {
    DSSet mpar, vpar;
    ParReal(0, mpar, DFlags(true, true));
    ParReal(1, vpar, DFlags(true, false, true));

    double c = 0;
    double S;

    /* C_p */
    c += 0.5 * (vpar.ex * (Sqr(expts.mean - mpar.mean) 
			   + expts.var + mpar.var)
		- vpar.mean + log(2*PI));

    /* C_q^+ */
    S = Erfc(-posval.mean / sqrt(2 * posval.var));
    if (posweight > EPSILON) {
      c += S/2 * (posweight * log(posweight)
		  - posweight/2 * log(2*PI*posval.var));
    } else {
      c -= S/4 * posweight * log(2*PI*posval.var);
    }
    c += - Sqr(posval.mean) / (2 * posval.var) * posmoments[0]
      + posval.mean / posval.var * posmoments[1]
      - posmoments[2] / (2*posval.var);
    
    /* C_q^- */
    S = Erfc(negval.mean / sqrt(2 * negval.var));
    if (negweight > EPSILON) {
      c += S/2 * (negweight * log(negweight)
		  - negweight/2 * log(2*PI*negval.var));
    } else {
      c -= S/4 * negweight * log(2*PI*negval.var);
    }
    c += - Sqr(negval.mean) / (2 * negval.var) * negmoments[0]
      + negval.mean / negval.var * negmoments[1]
      - negmoments[2] / (2*negval.var);

    cost = c;
    costuptodate = true;

    BBASSERT2(finite(cost));
  }

  return cost;
}

bool GaussRect::GetReal(DSSet &val, DFlags req)
{
  val = expts;
  return !req.ex;
}

bool GaussRect::GetRectReal(DSSet &val, DFlags req)
{
  val = rectexpts;
  return !req.ex;
}

void GaussRect::GradReal(DSSet &val, const Node *ptr)
{
  if (!clamped && children.empty())
    return;

  if (ParIdentity(ptr) == 1) { // Variance parent
    DSSet p0;
    ParReal(0, p0, DFlags(true, true));

    val.ex += (Sqr(expts.mean - p0.mean) + p0.var + expts.var) / 2;
    val.mean -= 0.5;
  } else if (ParIdentity(ptr) == 0) {  // Mean parent
    DSSet p0, p1;
    ParReal(0, p0, DFlags(true));
    ParReal(1, p1, DFlags(false, false, true));

    val.mean += (p0.mean - expts.mean) * p1.ex;
    val.var += p1.ex / 2;
  } else {
    BBASSERT2(false);
  }
}

void GaussRect::MyUpdate()
{
  if (NumChildren() < 1) {
    return;
  }

  DSSet ng; // gradient from direct children
  DSSet rg; // gradient from children below the rectification
  ChildGradients(ng, rg);

  DSSet mpar, vpar;
  ParReal(0, mpar, DFlags(true));
  ParReal(1, vpar, DFlags(false, false, true));

  double x, vx, ivx, z;
  bool limitexceded = false;

  negval.var = 1 / (vpar.ex + 2 * ng.var);
  negval.mean = negval.var * 
    (vpar.ex * mpar.mean + 2 * ng.var * expts.mean - ng.mean);

  ivx = 2 * rg.var;
  vx = 1 / ivx;
  x = rectexpts.mean - vx * rg.mean;

  posval.var = 1 / (ivx + 1 / negval.var);
  posval.mean = posval.var * (ivx * x + negval.mean / negval.var);

  posweight = NormPdf(x, negval.mean, vx + negval.var);
  negweight = NormPdf(x, 0, vx);

  z = 0.5 * posweight * Erfc(-posval.mean / sqrt(2 * posval.var))
    + 0.5 * negweight * Erfc(negval.mean / sqrt(2 * negval.var));

  if (z > GAUSSRECTLIMIT) {
    z = 1/z;
  } else {
    limitexceded = true;
    z = 1 / GAUSSRECTLIMIT;
  }
  
  BBASSERT2(finite(z));

  posweight *= z;
  negweight *= z;

  BBASSERT2(finite(posweight));
  BBASSERT2(finite(negweight));

  if (limitexceded) {
    cout << "Warning: Limit exceded in " << GetLabel() << endl;
  }

  UpdateMoments();
  UpdateExpectations();

  costuptodate = false;
}

void GaussRect::ChildGradients(DSSet &norm, DSSet &rect)
{
  for (size_t i = 0; i < children.size(); i++) {
    Rectification *rnode = dynamic_cast<Rectification *>(children[i]);
    if (rnode == 0) { // not the rectification node
      children[i]->GradReal(norm, this);
    } else { // yes, this is the rectification node
      rnode->GradReal(rect, this);
    }
  }
}


/* class GaussRectVState */

GaussRectVState::GaussRectVState(GaussRectV *n)
{
  node = n;
}

DVSet &GaussRectVState::GetPosVal()
{
  return node->posval;
}

DVSet &GaussRectVState::GetNegVal()
{
  return node->negval;
}

DV &GaussRectVState::GetPosWeights()
{
  return node->posweights;
}

DV &GaussRectVState::GetNegWeights()
{
  return node->negweights;
}

DV &GaussRectVState::GetPosMoment(int i)
{
  return node->posmoments[i];
}

DV &GaussRectVState::GetNegMoment(int i)
{
  return node->negmoments[i];
}


  
inline void ScaleWeights(DV *w)
{
  //cout << "ScaleWeights begin" << endl;

  double sum = 0.0;
  for (size_t i = 0; i < w->size(); i++) {
    sum += (*w)[i];
  }
  BBASSERT(sum != 0.0);
  for (size_t i = 0; i < w->size(); i++) {
    double val = (*w)[i] / sum;
    (*w)[i] = val > CATEGLIMIT ? val : 0.0;
  }

  //cout << "ScaleWeights end" << endl;
}


/* class MoGV */

MoGV::MoGV(Net *net, Label label, Node *d)
  : Variable(net, label, d), NParNode(d)
{
  cost = 0;
  CheckParent(0, DISCRETEV);
  numComponents = NumComponents();

  expts.mean.resize(net->Time());
  expts.var.resize(net->Time());
}

double MoGV::Cost()
{
  if (!clamped && children.empty())
    return 0;

  if (!costuptodate) {
    DVH mpar, vpar;
    DFlags mflags(true, true);
    DFlags vflags(true, false, true);
    VDDH dpar;
    ParDiscreteV(0, dpar);

    double c = 0;
    cost = 0;
    double log2pi = log(2*PI);
    
    for (size_t k = 0; k < numComponents; k++) {
      means[k]->GetRealV(mpar, mflags);
      vars[k]->GetRealV(vpar, vflags);

      for (size_t i = 0; i < net->Time(); i++) {
	/* C_p */
	c = vpar.Exp(i) * (Sqr(myval[k]->mean[i] - mpar.Mean(i)) 
			   + mpar.Var(i) + myval[k]->var[i])
	  - vpar.Mean(i) + log2pi;
	if (!clamped) {
	  /* C_q */
	  c -= log(myval[k]->var[i]) + log2pi + 1;
	}
	cost += 0.5 * dpar[i][k] * c;
      }
    }

    costuptodate = true;
  }

  return cost;
}


bool MoGV::GetRealV(DVH &val, DFlags req)
{
  val.vec = &expts;
  return !req.ex;
}

void MoGV::GetMyvalV(DVH &val, int k)
{
  val.vec = myval[k];
}

bool MoGV::IsMeanParent(const Node *ptr)
{
  return WhichMeanParent(ptr) != -1;
}

bool MoGV::IsVarParent(const Node *ptr)
{
  return WhichVarParent(ptr) != -1;
}

int MoGV::WhichParent(const Node *ptr, const vector<Node*> &parents)
{
  for (size_t i = 0; i < parents.size(); i++) {
    if (ptr == parents[i]) {
      return i;
    }
  }

  return -1;
}

int MoGV::WhichMeanParent(const Node *ptr)
{
  return WhichParent(ptr, means);
}

int MoGV::WhichVarParent(const Node *ptr)
{
  return WhichParent(ptr, vars);
}

void MoGV::GradReal(DSSet &val, const Node *ptr)
{
  //cout << "MoGV::GradReal() begin" << endl;

  if (!clamped && children.empty()) {
    return;
  }

  if (IsMeanParent(ptr)) {
    //cout << "mean" << endl;
    int k = WhichMeanParent(ptr);
    DVH mpar, vpar;
    VDDH dpar;
    means[k]->GetRealV(mpar, DFlags(true));
    vars[k]->GetRealV(vpar, DFlags(false, false, true));
    ParDiscreteV(0, dpar);
    
    for (size_t i = 0; i < net->Time(); i++) {
      val.mean += dpar[i][k] * 
	(mpar.Mean(i) - myval[k]->mean[i]) * vpar.Exp(i);
      val.var += dpar[i][k] * vpar.Exp(i) / 2;
    }
  } else if (IsVarParent(ptr)) {
    //cout << "var" << endl;
    int k = WhichVarParent(ptr);
    //cout << "k = " << k << endl;
    DVH mpar;
    VDDH dpar;
    //cout << "GetRealV" << endl;
    means[k]->GetRealV(mpar, DFlags(true, true));
    //cout << "ParDiscreteV" << endl;
    ParDiscreteV(0, dpar);

    for (size_t i = 0; i < net->Time(); i++) {
      val.mean -= dpar[i][k] * 0.5;
      val.ex += dpar[i][k] *
	(Sqr(myval[k]->mean[i] - mpar.Mean(i)) + mpar.Var(i) +
	 myval[k]->var[i]) / 2;
    }
  } else {
    BBASSERT2(false);
  }
  //cout << "MoGV::GradReal() end" << endl;
}

void MoGV::GradRealV(DVSet &val, const Node *ptr)
{
  //cout << "MoGV::GradRealV()" << endl;

  if (!clamped && children.empty()) {
    return;
  }

  if (IsMeanParent(ptr)) {
    int k = WhichMeanParent(ptr);
    DVH mpar, vpar;
    means[k]->GetRealV(mpar, DFlags(true));
    vars[k]->GetRealV(vpar, DFlags(false, false, true));
    VDDH dpar;
    ParDiscreteV(0, dpar);

    val.mean.resize(net->Time());
    val.var.resize(net->Time());
    
    for (size_t i = 0; i < net->Time(); i++) {
      val.mean[i] += dpar[i][k] * 
	(mpar.Mean(i) - myval[k]->mean[i]) * vpar.Exp(i);
      val.var[i] += dpar[i][k] * vpar.Exp(i) / 2;
    }
  } else if (IsVarParent(ptr)) {
    int k = WhichVarParent(ptr);
    DVH mpar;
    means[k]->GetRealV(mpar, DFlags(true, true));
    VDDH dpar;
    ParDiscreteV(0, dpar);

    val.mean.resize(net->Time());
    val.ex.resize(net->Time());

    for (size_t i = 0; i < net->Time(); i++) {
      val.mean[i] -= dpar[i][k] * 0.5;
      val.ex[i] += dpar[i][k] *
	(Sqr(myval[k]->mean[i] - mpar.Mean(i)) + mpar.Var(i) +
	 myval[k]->var[i]) / 2;
    }
  } else {
    BBASSERT2(false);
  }
}

void MoGV::GradDiscreteV(VDD &val, const Node *ptr)
{
  //cout << "MoGV::GradDiscreteV begin" << endl;

  if (!clamped && children.empty()) {
    return;
  }

  BBASSERT2(ptr == GetParent(0));

  val.Resize(net->Time());
  val.ResizeDD(numComponents);

  DVSet grad;
  ChildGradRealV(grad);

  for (size_t k = 0; k < numComponents; k++) {
    DVH mpar, vpar;
    means[k]->GetRealV(mpar, DFlags(true, true));
    vars[k]->GetRealV(vpar, DFlags(true, false, true));
    DVSet &mv = *(myval[k]);

    for (size_t i = 0; i < net->Time(); i++) {
      // constants w.r.t. k are dropped because they only affect the scaling

      // C_ps 
      val[i][k] += 0.5 * (vpar.Exp(i) * (Sqr(mv.mean[i] - mpar.Mean(i)) 
					 + mpar.Var(i) + mv.var[i])
			  - vpar.Mean(i));
      
      // C_qs
      val[i][k] -= 0.5 * log(mv.var[i]);
      
      // C_px

      double evx = 2 * grad.var[i];
      double x = expts.mean[i] - grad.mean[i] / evx;
      
      val[i][k] += 0.5 * (evx * (Sqr(x - mv.mean[i]) + mv.var[i]));
    }
  }

  //cout << "MoGV::GradDiscreteV end" << endl;
}

string MoGV::GetType() const
{
  return "MoGV";
}

void MoGV::Save(NetSaver *saver)
{
  saver->SetNamedDVSet("expts", expts);
  saver->SetNamedDouble("cost", cost);

  BBASSERT2(means.size() == numComponents);
  BBASSERT2(vars.size() == numComponents);
  BBASSERT2(myval.size() == numComponents);

  saver->StartEnumCont(numComponents, "means");
  for (size_t i = 0; i < numComponents; i++) {
    saver->SetLabel(means[i]->GetLabel());
  }
  saver->FinishEnumCont("means");

  saver->StartEnumCont(numComponents, "vars");
  for (size_t i = 0; i < numComponents; i++) {
    saver->SetLabel(vars[i]->GetLabel());
  }
  saver->FinishEnumCont("vars");

  saver->StartEnumCont(numComponents, "myval");
  for (size_t i = 0; i < numComponents; i++) {
    saver->SetDVSet(*myval[i]);
  }
  saver->FinishEnumCont("myval");

  Variable::Save(saver);
}

void MoGV::AddComponent(Node *m, Node *v)
{
  if (means.size() + 1 > NumComponents())
    throw StructureException("MoGV::AddComponent: too many components");

  BBASSERT2(means.size() == vars.size());
  BBASSERT2(NumParents() == 2 * means.size() + 1);

  means.push_back(m);
  vars.push_back(v);

  DVSet* val = new DVSet();
  val->mean.resize(net->Time());
  val->var.resize(net->Time());
  myval.push_back(val);

  BBASSERT2(myval.size() == means.size());

  AddParent(m, true);
  AddParent(v, true);
}    

size_t MoGV::NumComponents()
{
  VDDH dhandle;
  Node* d = GetParent(0);
  BBASSERT2(d != NULL);
  d->GetDiscreteV(dhandle);
  BBASSERT2(dhandle.vec != NULL);
  numComponents = dhandle.vec->DDsize();
  return numComponents;
}

void MoGV::MyUpdate()
{
  //cout << "MoGV::MyUpdate() begin" << endl;

  if (NumChildren() == 0) {
    return;
  }

  size_t K = NumComponents();
  size_t N = net->Time();
    
  BBASSERT2(means.size() == K && vars.size() == K);

  DVH mpar, vpar;
  DFlags mflags(true);
  DFlags vflags(false, false, true);

  DVSet grad;
  ChildGradRealV(grad);
  
  BBASSERT2(grad.mean.size() > 0);

  for (size_t i = 0; i < N; i++) {
    for (size_t k = 0; k < K; k++) {
      means[k]->GetRealV(mpar, mflags);
      vars[k]->GetRealV(vpar, vflags);

      myval[k]->var[i] = 1 / (2 * grad.var[i] + vpar.Exp(i));
      myval[k]->mean[i] = myval[k]->var[i] * 
	(2 * grad.var[i] * expts.mean[i] - grad.mean[i]
	 + vpar.Exp(i) * mpar.Mean(i));
    }
  }

  ComputeExpectations();

  costuptodate = false;

  //cout << "MoGV::MyUpdate() end" << endl;
}

void MoGV::ComputeExpectations()
{
  VDDH dpar;
  ParDiscreteV(0, dpar);
  
  size_t N = net->Time();
  size_t K = NumComponents();

  for (size_t i = 0; i < N; i++) {
    double mean = 0.0;
    double var = 0.0;
    for (size_t k = 0; k < K; k++) {
      mean += dpar[i][k] * myval[k]->mean[i];
      var += dpar[i][k] * (myval[k]->var[i] + Sqr(myval[k]->mean[i]));
    }
    var -= Sqr(mean);

    expts.mean[i] = mean;
    expts.var[i] = var;
  }
}

bool MoGV::MyClamp(const DV &m)
{
  if (m.size() == expts.mean.size()) {
    copy(m.begin(), m.end(), expts.mean.begin());
  } else {
    ostringstream msg;
    msg << "MoGV::MyClamp: wrong vector size " << m.size() << " != "
	<< expts.mean.size();
    throw TypeException(msg.str());
  }
  fill(expts.var.begin(), expts.var.end(), 0.0);
  return true;
}


/* MoG */

MoG::MoG(Net *net, Label label, Node *d)
  : Variable(net, label, d), NParNode(d)
{
  cost = 0;
  CheckParent(0, DISCRETE);
  numComponents = NumComponents();
}

double MoG::Cost()
{
  if (!clamped && children.empty())
    return 0;

  if (!costuptodate) {
    DSSet mpar, vpar;
    DFlags mflags(true, true);
    DFlags vflags(true, false, true);
    DD *dpar;
    ParDiscrete(0, dpar);

    double c = 0;
    cost = 0;
    double log2pi = log(2*PI);
    
    for (size_t k = 0; k < numComponents; k++) {
      means[k]->GetReal(mpar, mflags);
      vars[k]->GetReal(vpar, vflags);
      /* C_p */
      c = vpar.ex * (Sqr(myval[k]->mean - mpar.mean) 
		     + mpar.var + myval[k]->var)
	- vpar.mean + log2pi;
      if (!clamped) {
	/* C_q */
	c -= log(myval[k]->var) + log2pi + 1;
      }
      cost += 0.5 * dpar->Get(k) * c;
    }
  }

  costuptodate = true;
  return cost;
}

bool MoG::GetReal(DSSet &val, DFlags req)
{
  val = expts;
  return !req.ex;
}

bool MoG::IsMeanParent(const Node *ptr)
{
  return WhichMeanParent(ptr) != -1;
}

bool MoG::IsVarParent(const Node *ptr)
{
  return WhichVarParent(ptr) != -1;
}

int MoG::WhichParent(const Node *ptr, const vector<Node*> &parents)
{
  for (size_t i = 0; i < parents.size(); i++) {
    if (ptr == parents[i]) {
      return i;
    }
  }

  return -1;
}

int MoG::WhichMeanParent(const Node *ptr)
{
  return WhichParent(ptr, means);
}

int MoG::WhichVarParent(const Node *ptr)
{
  return WhichParent(ptr, vars);
}

void MoG::GradReal(DSSet &val, const Node *ptr)
{
  //cout << "MoG::GradReal() begin" << endl;

  if (!clamped && children.empty()) {
    return;
  }

  if (IsMeanParent(ptr)) {
    //cout << "mean" << endl;
    int k = WhichMeanParent(ptr);
    DSSet mpar, vpar;
    DD *dpar;
    means[k]->GetReal(mpar, DFlags(true));
    vars[k]->GetReal(vpar, DFlags(false, false, true));
    ParDiscrete(0, dpar);
    
    val.mean += dpar->Get(k) *  (mpar.mean - myval[k]->mean) * vpar.ex;
    val.var += dpar->Get(k) * vpar.ex / 2;
  } else if (IsVarParent(ptr)) {
    //cout << "var" << endl;
    int k = WhichVarParent(ptr);
    //cout << "k = " << k << endl;
    DSSet mpar;
    DD *dpar;
    //cout << "GetRealV" << endl;
    means[k]->GetReal(mpar, DFlags(true, true));
    //cout << "ParDiscreteV" << endl;
    ParDiscrete(0, dpar);

    val.mean -= dpar->Get(k) * 0.5;
    val.ex += dpar->Get(k) *
      (Sqr(myval[k]->mean - mpar.mean) + mpar.var + myval[k]->var) / 2;
  } else {
    BBASSERT2(false);
  }
  //cout << "MoG::GradReal() end" << endl;
}

void MoG::GradDiscrete(DD &val, const Node *ptr)
{
  //cout << "MoG::GradDiscreteV begin" << endl;

  if (!clamped && children.empty()) {
    return;
  }

  val.Resize(numComponents);

  BBASSERT2(ptr == GetParent(0));

  DSSet grad;
  ChildGradReal(grad);

  for (size_t k = 0; k < numComponents; k++) {
    DSSet mpar, vpar;
    means[k]->GetReal(mpar, DFlags(true, true));
    vars[k]->GetReal(vpar, DFlags(true, false, true));
    DSSet &mv = *(myval[k]);

    double tmp;

    // C_ps 
    tmp = 0.5 * (vpar.ex * (Sqr(mv.mean - mpar.mean) + mpar.var + mv.var)
		 - vpar.mean);
    
    // C_qs
    tmp -= 0.5 * log(mv.var);
    
    // C_px

    double evx = 2 * grad.var;
    double x = expts.mean - grad.mean / evx;
      
    tmp += 0.5 * (evx * (Sqr(x - mv.mean) + mv.var));
    
    val.Set(k, tmp);
  }

  //cout << "MoG::GradDiscreteV end" << endl;
}

string MoG::GetType() const
{
  return "MoG";
}

void MoG::Save(NetSaver *saver)
{
  saver->SetNamedDSSet("expts", expts);
  saver->SetNamedDouble("cost", cost);

  BBASSERT2(means.size() == numComponents);
  BBASSERT2(vars.size() == numComponents);
  BBASSERT2(myval.size() == numComponents);

  saver->StartEnumCont(numComponents, "means");
  for (size_t i = 0; i < numComponents; i++) {
    saver->SetLabel(means[i]->GetLabel());
  }
  saver->FinishEnumCont("means");

  saver->StartEnumCont(numComponents, "vars");
  for (size_t i = 0; i < numComponents; i++) {
    saver->SetLabel(vars[i]->GetLabel());
  }
  saver->FinishEnumCont("vars");

  saver->StartEnumCont(numComponents, "myval");
  for (size_t i = 0; i < numComponents; i++) {
    saver->SetDSSet(*myval[i]);
  }
  saver->FinishEnumCont("myval");

  Variable::Save(saver);
}

void MoG::AddComponent(Node *m, Node *v)
{
  if (means.size() + 1 > NumComponents()) {
    throw StructureException("MoG::AddComponent: too many components");
  }

  BBASSERT2(means.size() == vars.size());
  BBASSERT2(NumParents() == 2 * means.size() + 1);

  means.push_back(m);
  vars.push_back(v);

  DSSet* val = new DSSet();
  myval.push_back(val);

  BBASSERT2(myval.size() == means.size());

  AddParent(m, true);
  AddParent(v, true);
}    

size_t MoG::NumComponents()
{
  DD *dpar;
  Node* d = GetParent(0);
  BBASSERT2(d != NULL);
  d->GetDiscrete(dpar);
  numComponents = dpar->size();
  return numComponents;
}

void MoG::MyUpdate()
{
  //cout << "MoG::MyUpdate() begin" << endl;

  if (NumChildren() == 0) {
    return;
  }

  size_t K = NumComponents();
    
  BBASSERT2(means.size() == K && vars.size() == K);

  DSSet mpar, vpar;
  DFlags mflags(true);
  DFlags vflags(false, false, true);

  DSSet grad;
  ChildGradReal(grad);
  
  for (size_t k = 0; k < K; k++) {
    means[k]->GetReal(mpar, mflags);
    vars[k]->GetReal(vpar, vflags);

    myval[k]->var = 1 / (2 * grad.var + vpar.ex);
    myval[k]->mean = myval[k]->var * 
      (2 * grad.var * expts.mean - grad.mean + vpar.ex * mpar.mean);
    }

  ComputeExpectations();

  costuptodate = false;

  //cout << "MoG::MyUpdate() end" << endl;
}

void MoG::ComputeExpectations()
{
  DD *dpar;
  ParDiscrete(0, dpar);
  
  size_t K = NumComponents();

  double mean = 0.0;
  double var = 0.0;
  for (size_t k = 0; k < K; k++) {
    mean += dpar->Get(k) * myval[k]->mean;
    var += dpar->Get(k) * (myval[k]->var + Sqr(myval[k]->mean));
  }
  var -= Sqr(mean);

  expts.mean = mean;
  expts.var = var;
}


/* class DiscreteDirichletV */

DiscreteDirichletV::DiscreteDirichletV(Net *net, Label label, Dirichlet *n)
  : Variable(net, label, n), NParNode(n)
{
  DVH tmp;
  n->GetRealV(tmp, DFlags(true));
  BBASSERT2(tmp.vec != 0);
  size_t numComp = tmp.vec->mean.size();
  BBASSERT2(numComp > 0);

  myval.Resize(net->Time());
  myval.ResizeDD(numComp);

  double prior = 1.0 / numComp;

  for (size_t i = 0; i < myval.size(); i++) {
    for (size_t k = 0; k < myval.DDsize(); k++) {
      myval[i][k] = prior;
    }
  }

  cost = 0.0;
}

bool DiscreteDirichletV::MyClamp(const VDD &v)
{
  if (v.size() == myval.size() && v.DDsize() == myval.DDsize()) {
    myval = v;
  } else {
    ostringstream msg;
    msg << "DiscreteDirichletV::MyClamp: wrong VDD size " 
	<< "(" << v.size() << ", " << v.DDsize() << ")"
	<< " != "
	<< "(" << myval.size() << ", " << myval.DDsize() << ")";
    throw StructureException(msg.str());
  }
  return true;
}

double DiscreteDirichletV::Cost()
{
  if (!costuptodate) {
    double c = 0;

    DVH p;
    ParRealV(0, p, DFlags(false, false, true));
    BBASSERT2(p.vec != 0);
    BBASSERT2(p.vec->ex.size() == myval.DDsize());

    for (size_t i = 0; i < myval.size(); i++) {
      for (size_t k = 0; k < myval.DDsize(); k++) {
	double x = myval[i][k];
	if (x > EPSILON) {
	  c += x * (log(x) - p.Exp(k));
	}
      }
    }

    cost = c;
    costuptodate = true;
  }

  return cost;
}

bool DiscreteDirichletV::GetDiscreteV(VDDH &val)
{
  val.vec = &myval;
  return true;
}

void DiscreteDirichletV::GradRealV(DVSet &val, const Node *ptr)
{
  val.mean.resize(myval.DDsize());
  for (size_t k = 0; k < myval.DDsize(); k++) {
    for (size_t i = 0; i < net->Time(); i++) {
      val.mean[k] += myval[i][k];
    }
  }
}

void DiscreteDirichletV::Save(NetSaver *saver)
{
  saver->SetNamedVDD("myval", myval);
  saver->SetNamedDouble("cost", cost);
  Variable::Save(saver);
}

string DiscreteDirichletV::GetType() const
{
  return "DiscreteDirichletV";
}

void DiscreteDirichletV::MyUpdate()
{
  VDD grad;
  ChildGradDiscreteV(grad);

  BBASSERT2(grad.size() == myval.size());
  BBASSERT2(grad.DDsize() == myval.DDsize());

  DVH p;
  ParRealV(0, p, DFlags(false, false, true));
  BBASSERT2(p.vec != 0);
  BBASSERT2(p.vec->ex.size() == myval.DDsize());
  
  for (size_t i = 0; i < myval.size(); i++) {
    double mingrad = grad[i].Minimum();
    for (size_t k = 0; k < myval.DDsize(); k++) {
      myval[i][k] = exp(mingrad - grad[i][k] + p.Exp(k));
    }

    DV *w = myval.GetDDp(i)->GetDV();
    ScaleWeights(w);
  }

  costuptodate = false;
}

/* class DiscreteDirichlet */

DiscreteDirichlet::DiscreteDirichlet(Net *net, Label label, Dirichlet *n)
  : Variable(net, label, n), NParNode(n)
{
  DVH tmp;
  n->GetRealV(tmp, DFlags(true));
  BBASSERT2(tmp.vec != 0);
  size_t numComp = tmp.vec->mean.size();
  BBASSERT2(numComp > 0);

  myval.Resize(numComp);

  double prior = 1.0 / numComp;

  for (size_t k = 0; k < myval.size(); k++) {
    myval.Set(k, prior);
  }

  cost = 0.0;
}

bool DiscreteDirichlet::MyClamp(const DD &v)
{
  if (v.size() == myval.size()) {
    myval = v;
  } else {
    ostringstream msg;
    msg << "DiscreteDirichlet::MyClamp: wrong DD size " 
	<< v.size() << " != " << myval.size();
    throw StructureException(msg.str());
  }
  return true;
}

double DiscreteDirichlet::Cost()
{
  if (!costuptodate) {
    double c = 0;

    DVH p;
    ParRealV(0, p, DFlags(false, false, true));
    BBASSERT2(p.vec != 0);
    BBASSERT2(p.vec->ex.size() == myval.size());

    for (size_t k = 0; k < myval.size(); k++) {
      double x = myval[k];
      if (x > EPSILON) {
	c += x * (log(x) - p.Exp(k));
      }
    }

    cost = c;
    costuptodate = true;
  }

  return cost;
}

bool DiscreteDirichlet::GetDiscrete(DD *&val)
{
  val = &myval;
  return true;
}

void DiscreteDirichlet::GradRealV(DVSet &val, const Node *ptr)
{
  val.mean.resize(myval.size());
  for (size_t k = 0; k < myval.size(); k++) {
    val.mean[k] += myval[k];
  }
}

void DiscreteDirichlet::Save(NetSaver *saver)
{
  saver->SetNamedDD("myval", myval);
  saver->SetNamedDouble("cost", cost);
  Variable::Save(saver);
}

string DiscreteDirichlet::GetType() const
{
  return "DiscreteDirichlet";
}

void DiscreteDirichlet::MyUpdate()
{
  DD grad;
  ChildGradDiscrete(grad);

  BBASSERT2(grad.size() == myval.size());

  DVH p;
  ParRealV(0, p, DFlags(false, false, true));
  BBASSERT2(p.vec != 0);
  BBASSERT2(p.vec->ex.size() == myval.size());
  
  double mingrad = grad.Minimum();
  for (size_t k = 0; k < myval.size(); k++) {
    myval.Set(k, exp(mingrad - grad[k] + p.Exp(k)));
  }

  DV *w = myval.GetDV();
  ScaleWeights(w);

  costuptodate = false;
}



/* class Dirichlet */

Dirichlet::Dirichlet(Net *net, Label label, ConstantV *n)
  : Variable(net, label, n), NParNode(n)
{
  DVH tmp;
  n->GetRealV(tmp, DFlags(true));
  BBASSERT2(tmp.vec != 0);
  numComponents = tmp.vec->mean.size();
  BBASSERT2(numComponents > 0);

  myval.resize(numComponents);
  expts.mean.resize(numComponents);
  expts.var.resize(numComponents);
  expts.ex.resize(numComponents);

  for (size_t i = 0; i < numComponents; i++) {
    myval[i] = tmp.vec->mean[i];
  }

  ComputeExpectations();
  cost = 0.0;
}

double Dirichlet::Cost()
{
  if (!costuptodate) {
    DVH p;
    ParRealV(0, p, DFlags(true));
  
    BBASSERT2(p.vec != 0);
    BBASSERT2(p.vec->mean.size() == numComponents);

    double logZu = 0.0;
    double logZv = 0.0;
    double u0 = 0.0;
    double v0 = 0.0;

    for (size_t i = 0; i < numComponents; i++) {
      logZu += LogGamma(p.Mean(i));
      logZv += LogGamma(myval[i]);
      u0 += p.Mean(i);
      v0 += myval[i];
    }

    logZu -= LogGamma(u0);
    logZv -= LogGamma(v0);

    cost = logZu - logZv;

    double a = DiGamma(v0);
    
    for (size_t i = 0; i < numComponents; i++) {
      cost += (DiGamma(myval[i]) - a) * (myval[i] - p.Mean(i));
    }

    costuptodate = true;
  }
    
  return cost;
}

bool Dirichlet::GetRealV(DVH &val, DFlags req)
{
  val.vec = &expts;
  return true;
}
  
string Dirichlet::GetType() const
{
  return "Dirichlet";
}

void Dirichlet::Save(NetSaver *saver)
{
  saver->SetNamedDV("myval", myval);
  saver->SetNamedDVSet("expts", expts);
  saver->SetNamedDouble("cost", cost);
  Variable::Save(saver);
}

void Dirichlet::MyUpdate()
{
  DVSet grad;
  ChildGradRealV(grad);

  DVH p;
  ParRealV(0, p, DFlags(true));
  
  BBASSERT2(p.vec != 0);
  BBASSERT2(p.vec->mean.size() == numComponents);

  for (size_t i = 0; i < numComponents; i++) {
    myval[i] = p.Mean(i) + grad.mean[i];
  }
  
  ComputeExpectations();

  costuptodate = false;
}

void Dirichlet::ComputeExpectations()
{
  double u0 = 0.0;

  for (size_t i = 0; i < numComponents; i++) {
    u0 += myval[i];
  }

  double a = (Sqr(u0) * (u0 + 1));
  double b = DiGamma(u0);

  for (size_t i = 0; i < numComponents; i++) {
    expts.mean[i] = myval[i] / u0;
    expts.var[i] = (myval[i] * (u0 - myval[i])) / a;
    expts.ex[i] = DiGamma(myval[i]) - b;
  }
}


/* class RectificationV */

RectificationV::RectificationV(Net *net, Label label, Node *n)
  : Function(net, label, n), UniParNode(n)
{
  ;
}

bool RectificationV::GetRealV(DVH &val, DFlags req)
{
  GaussRectV *node = dynamic_cast<GaussRectV *>(GetParent(0));
  if (node == 0) {
    throw StructureException("Invalid parent for RectificationV");
  }
  return node->GetRectRealV(val, req);
}

void RectificationV::GradRealV(DVSet &val, const Node *ptr)
{
  ChildGradRealV(val);
}

void RectificationV::Save(NetSaver *saver)
{
  Function::Save(saver);
}

string RectificationV::GetType() const
{
  return "RectificationV";
}

/* class Rectification */

Rectification::Rectification(Net *net, Label label, Node *n)
  : Function(net, label, n), UniParNode(n)
{
  ;
}

bool Rectification::GetReal(DSSet &val, DFlags req)
{
  GaussRect *node = dynamic_cast<GaussRect *>(GetParent(0));
  if (node == 0) {
    throw StructureException("Invalid parent for Rectification");
  }
  return node->GetRectReal(val, req);
}

void Rectification::GradReal(DSSet &val, const Node *ptr)
{
  ChildGradReal(val);
}

void Rectification::Save(NetSaver *saver)
{
  Function::Save(saver);
}

string Rectification::GetType() const
{
  return "Rectification";
}

// class ProdV : public Function : public Node

bool ProdV::GetRealV(DVH &val, DFlags req)
{
  bool needm = req.mean && !uptodate.mean;
  bool needv = req.var && !uptodate.var;
  if (needm || needv) {
    DVH p0, p1;
    if (!ParRealV(0, p0, DFlags(true, needv)) ||
	!ParRealV(1, p1, DFlags(true, needv)))
      return false;
    if (needm) {
      myval.mean.resize(net->Time());
      for (size_t i = 0; i < net->Time(); i++)
	myval.mean[i] = p0.Mean(i) * p1.Mean(i);
      uptodate.mean = true;
    }
    if (needv) {
      myval.var.resize(net->Time());
      for (size_t i = 0; i < net->Time(); i++)
	myval.var[i] = (Sqr(p0.Mean(i)) + p0.Var(i)) * p1.Var(i) +
	  p0.Var(i) * Sqr(p1.Mean(i));
      uptodate.var = true;
    }
  }
  req.mean = false;
  req.var = false;
  val.vec = &myval;
  return req.AllFalse();
}

void ProdV::GradRealV(DVSet &val, const Node *ptr)
{
  int ide = ParIdentity(ptr);
  DVSet grad;
  DVH p0, p1;

  ChildGradRealV(grad);

  if (grad.mean.size() && grad.var.size()) {
    ParRealV(ide, p0, DFlags(true));
    ParRealV(1-ide, p1, DFlags(true, true));

    val.mean.resize(net->Time());
    val.var.resize(net->Time());

    for (size_t i = 0; i < net->Time(); i++) {
      val.mean[i] += grad.mean[i] * p1.Mean(i) +
	2 * grad.var[i] * p0.Mean(i) * p1.Var(i);
      val.var[i] += grad.var[i] * (Sqr(p1.Mean(i)) + p1.Var(i));

    }
  }
}

void ProdV::GradReal(DSSet &val, const Node *ptr)
{
  int ide = ParIdentity(ptr);
  DVSet grad;
  DVH p1;
  DSSet p0;

  ChildGradRealV(grad);

  if (grad.mean.size() && grad.var.size()) {
    ParReal(ide, p0, DFlags(true));
    ParRealV(1-ide, p1, DFlags(true, true));
    for (size_t i = 0; i < net->Time(); i++) {
      val.mean += grad.mean[i] * p1.Mean(i) +
	2 * grad.var[i] * p0.mean * p1.Var(i);
      val.var += grad.var[i] * (Sqr(p1.Mean(i)) + p1.Var(i));
    }
  }
}

void ProdV::Save(NetSaver *saver)
{
  if (saver->GetSaveFunctionValue()) {
    saver->SetNamedDVSet("myval", myval);
  }
  Function::Save(saver);
}


// class Sum2V : public Function : public Node

bool Sum2V::GetRealV(DVH &val, DFlags req)
{
  bool needm = req.mean && !uptodate.mean;
  bool needv = req.var && !uptodate.var;
  bool neede = req.ex && !uptodate.ex;
  if (needm || needv || neede) {
    DVH p0, p1;
    if (!ParRealV(0, p0, DFlags(needm, needv, neede)) ||
	!ParRealV(1, p1, DFlags(needm, needv, neede)))
      return false;
    if (needm) {
      myval.mean.resize(net->Time());
      for (size_t i = 0; i < net->Time(); i++)
	myval.mean[i] = p0.Mean(i) + p1.Mean(i);
      uptodate.mean = true;
    }
    if (needv) {
      myval.var.resize(net->Time());
      for (size_t i = 0; i < net->Time(); i++)
	myval.var[i] = p0.Var(i) + p1.Var(i);
      uptodate.var = true;
    }
    if (neede) {
      myval.ex.resize(net->Time());
      for (size_t i = 0; i < net->Time(); i++)
	myval.ex[i] = p0.Exp(i) * p1.Exp(i);
      uptodate.ex = true;
    }
  }
  req.mean = false;
  req.var = false;
  req.ex = false;
  val.vec = &myval;
  return req.AllFalse();
}

void Sum2V::GradRealV(DVSet &val, const Node *ptr)
{
  int ide = ParIdentity(ptr);
  if (myval.ex.empty())
    ChildGradRealV(val);
  else {
    DVSet grad;
    DVH p1;

    ChildGradRealV(grad);

    if (grad.mean.size()) {
      val.mean.resize(net->Time());
      for (size_t i = 0; i < net->Time(); i++)
	val.mean[i] += grad.mean[i];
    }
    if (grad.var.size()) {
      val.var.resize(net->Time());
      for (size_t i = 0; i < net->Time(); i++)
	val.var[i] += grad.var[i];
    }
    if (grad.ex.size()) {
      ParRealV(1-ide, p1, DFlags(false, false, true));
      val.ex.resize(net->Time());
      for (size_t i = 0; i < net->Time(); i++)
	val.ex[i] += grad.ex[i] * p1.Exp(i);
    }
  }
}

void Sum2V::GradReal(DSSet &val, const Node *ptr)
{
  size_t i;
  int ide = ParIdentity(ptr);
  if (myval.ex.empty())
    ChildGradReal(val);
  else {
    DVSet grad;
    DVH p1;

    ChildGradRealV(grad);

    if (grad.mean.size())
      for (i = 0; i < net->Time(); i++)
	val.mean += grad.mean[i];
    if (grad.var.size())
      for (i = 0; i < net->Time(); i++)
	val.var += grad.var[i];
    if (grad.ex.size()) {
      ParRealV(1-ide, p1, DFlags(false, false, true));
      for (i = 0; i < net->Time(); i++)
	val.ex += grad.ex[i] * p1.Exp(i);
    }
  }
}

void Sum2V::Save(NetSaver *saver)
{
  if (saver->GetSaveFunctionValue()) {
    saver->SetNamedDVSet("myval", myval);
  }
  Function::Save(saver);
}

// class SumNV : public Function : public Node

bool SumNV::GetRealV(DVH &val, DFlags req)
{
  bool needm = req.mean && !uptodate.mean;
  bool needv = req.var && !uptodate.var;

  if (needm || needv) {
    if (needm) {
      myval.mean.resize(net->Time());
      for (size_t i = 0; i < net->Time(); i++)
	myval.mean[i] = 0.0;
      uptodate.mean = true;
    }
    if (needv) {
      myval.var.resize(net->Time());
      for (size_t i = 0; i < net->Time(); i++)
	myval.var[i] = 0.0;
      uptodate.var = true;
    }
    for (size_t j = 0; j<NumParents(); j++) {
      DVH p;
      if (!ParRealV(j, p, DFlags(needm, needv, false))) {
	uptodate = DFlags(false, false, false);
	return false;
      }
      if (needm) {
	for (size_t i = 0; i < net->Time(); i++)
	  myval.mean[i] += p.Mean(i);
      }
      if (needv) {
	for (size_t i = 0; i < net->Time(); i++)
	  myval.var[i] += p.Var(i);
      }
      if (keepupdated) {
	BBASSERT2(parentval.size() == NumParents());
	if (needm) {
	  parentval[j].mean.resize(net->Time());
	  for (size_t i = 0; i < net->Time(); i++) {
	    parentval[j].mean[i] = p.Mean(i);
	  }
	}
	if (needv) {
	  parentval[j].var.resize(net->Time());
	  for (size_t i = 0; i < net->Time(); i++) {
	    parentval[j].var[i] = p.Var(i);
	  }
	}
      }
    }
  }
  req.mean = false;
  req.var = false;
  val.vec = &myval;
  return req.AllFalse();
}

void SumNV::GradRealV(DVSet &val, const Node *ptr)
{
  ChildGradRealV(val);
}

void SumNV::GradReal(DSSet &val, const Node *ptr)
{
  ChildGradReal(val);
}

bool SumNV::AddParent(Node *n)
{
  // The parents are checked in GetRealV
  Node::AddParent(n, true);
  int ide = ParIdentity(n);
  if (keepupdated) {
    parentval.resize(ide+1,DVSet());
    parentval[ide] = DVSet();
  }
  uptodate = DFlags(false,false,false);
  OutdateChild();
  return true;
}

void SumNV::SetKeepUpdated(const bool _keepupdated)
{
  keepupdated = _keepupdated;
  if (keepupdated)
    UpdateFromScratch(DFlags(true, true, false));
}

void SumNV::UpdateFromScratch(DFlags req)
{
  uptodate = DFlags(false, false, false);
  parentval.resize(NumParents());
  if (req.mean) {
    myval.mean.resize(net->Time());
    for (size_t j = 0; j < net->Time(); j++) {
      myval.mean[j] = 0.0;
    }
  }
  if (req.var) {
    myval.var.resize(net->Time());
    for (size_t j = 0; j < net->Time(); j++) {
      myval.mean[j] = 0.0;
      myval.var[j] = 0.0;
    }
  }
  for (size_t i = 0; i<NumParents(); i++) {
    DVH p;
    if(!ParRealV(i, p, req))
      return;
    if (req.mean) {
      parentval[i].mean.resize(net->Time());
      for (size_t j = 0; j < net->Time(); j++) {
	parentval[i].mean[j] = p.Mean(j);
	myval.mean[j] += p.Mean(j);
      }
    }
    if (req.var) {
      parentval[i].var.resize(net->Time());
      for (size_t j = 0; j < net->Time(); j++) {
	parentval[i].var[j] = p.Var(j);
	myval.var[j] += p.Var(j);
      }
    }
  }
  uptodate = req;
}

void SumNV::Outdate(const Node *ptr) 
{
  if (uptodate.mean || uptodate.var) {
    if (keepupdated) {
      if (NumParents() != parentval.size()) {
	UpdateFromScratch(uptodate);
	OutdateChild();
	return;
      }
      int ide = ParIdentity(ptr);
      BBASSERT2(0 <= ide && ide < (int)NumParents());
      DVH p;
      if(!ParRealV(ide, p, DFlags(uptodate.mean, uptodate.var, false))) {
	uptodate = DFlags(false,false,false); 
      } else {
	if (uptodate.mean) {
	  for (size_t i = 0; i < net->Time(); i++) {
	    myval.mean[i] += p.Mean(i) - parentval[ide].mean[i];
	    parentval[ide].mean[i] = p.Mean(i);
	  }
	}
	if (uptodate.var) {
	  for (size_t i = 0; i < net->Time(); i++) {
	    myval.var[i] += p.Var(i) - parentval[ide].var[i];
	    parentval[ide].var[i] = p.Var(i);
	    if (myval.var[i]<0) {
	      uptodate.var = false;
	    }
	  }
	}
      }
    } else {
      uptodate = DFlags(false,false,false); 
    }
  }
  OutdateChild();
}


void SumNV::Save(NetSaver *saver)
{
  if (saver->GetSaveFunctionValue()) {
    saver->SetNamedDVSet("myval", myval);
  }
  Function::Save(saver);
}


// class DelayV : public Function : public Node

bool DelayV::GetRealV(DVH &val, DFlags req)
{
  bool needm = req.mean && !uptodate.mean;
  bool needv = req.var && !uptodate.var;
  bool neede = req.ex && !uptodate.ex;
  if (needm || needv || neede) {
    DSSet p0;
    DVH p1;
    if (!ParReal(0, p0, DFlags(needm, needv, neede)) ||
	!ParRealV(1, p1, DFlags(needm, needv, neede))
	// || !p1.vec
	)
      return false;
    BBASSERT2((int)-net->Time() <= lendelay && lendelay <= (int)net->Time());
    if (needm) {
      myval.mean.resize(net->Time());
      if (lendelay >= 0) {
	for (size_t i = 0; i < (size_t)lendelay; i++) {
	  myval.mean[i] = p0.mean;
	}
	for (size_t i = lendelay; i < net->Time(); i++) {
	  myval.mean[i] = p1.Mean(i-lendelay);
	}
      } else {
	for (size_t i = 0; i < net->Time() + lendelay; i++) {
	  myval.mean[i] = p1.Mean(i-lendelay);
	}
	for (size_t i = net->Time() + lendelay; i < net->Time(); i++) {
	  myval.mean[i] = p0.mean;
	}
      }
      uptodate.mean = true;
    }
    if (needv) {
      myval.var.resize(net->Time());
      if (lendelay >= 0) {
	for (size_t i = 0; i < (size_t)lendelay; i++) {
	  myval.var[i] = p0.var;
	}
	for (size_t i = lendelay; i < net->Time(); i++) {
	  myval.var[i] = p1.Var(i-lendelay);
	}
      } else {
	for (size_t i = 0; i < net->Time() + lendelay; i++) {
	  myval.var[i] = p1.Var(i-lendelay);
	}
	for (size_t i = net->Time() + lendelay; i < net->Time(); i++) {
	  myval.var[i] = p0.var;
	}
      }
      uptodate.var = true;
    }
    if (neede) {
      myval.ex.resize(net->Time());
      if (lendelay >= 0) {
	for (size_t i = 0; i < (size_t)lendelay; i++) {
	  myval.ex[i] = p0.ex;
	}
	for (size_t i = lendelay; i < net->Time(); i++) {
	  myval.ex[i] = p1.Exp(i-lendelay);
	}
      } else {
	for (size_t i = 0; i < net->Time() + lendelay; i++) {
	  myval.ex[i] = p1.Exp(i-lendelay);
	}
	for (size_t i = net->Time() + lendelay; i < net->Time(); i++) {
	  myval.ex[i] = p0.ex;
	}
      }
      uptodate.ex = true;
    }
  }
  req.mean = false;
  req.var = false;
  req.ex = false;
  val.vec = &myval;
  return req.AllFalse();
}

void DelayV::GradRealV(DVSet &val, const Node *ptr)
{
  size_t i;
  DVSet grad;

  ChildGradRealV(grad);

  if (grad.mean.size()) {
    val.mean.resize(net->Time());
    if (lendelay >= 0) {
      for (i = lendelay; i < net->Time(); i++) {
	val.mean[i-lendelay] += grad.mean[i];
      }
    } else {
      for (i = 0; i < net->Time() + lendelay; i++) {
	val.mean[i-lendelay] += grad.mean[i];
      }
    }
  }
  if (grad.var.size()) {
    val.var.resize(net->Time());
    if (lendelay >= 0) {
      for (i = lendelay; i < net->Time(); i++) {
	val.var[i-lendelay] += grad.var[i];
      }
    } else {
      for (i = 0; i < net->Time() + lendelay; i++) {
	val.var[i-lendelay] += grad.var[i];
      }
    }
  }
  if (grad.ex.size()) {
    val.ex.resize(net->Time());
    if (lendelay >= 0) {
      for (i = lendelay; i < net->Time(); i++) {
	val.ex[i-lendelay] += grad.ex[i];
      }
    } else {
      for (i = 0; i < net->Time() + lendelay; i++) {
	val.ex[i-lendelay] += grad.ex[i];
      }
    }
  }
}

void DelayV::GradReal(DSSet &val, const Node *ptr)
{
  DVSet grad;

  ChildGradRealV(grad);
  if (grad.mean.size()) {
    if (lendelay >= 0) {
      for (size_t i = 0; i < (size_t)lendelay; i++) {
	val.mean += grad.mean[i];
      }
    } else {
      for (size_t i = net->Time() + lendelay; i < net->Time(); i++) {
	val.mean += grad.mean[i];
      }
    } 
  }
  if (grad.var.size()) {
    if (lendelay >= 0) {
      for (size_t i = 0; i < (size_t)lendelay; i++) {
	val.var += grad.var[i];
      }
    } else {
      for (size_t i = net->Time() + lendelay; i < net->Time(); i++) {
	val.var += grad.var[i];
      }
    }      
  }
  if (grad.ex.size()) {
    if (lendelay >= 0) {
      for (size_t i = 0; i < (size_t)lendelay; i++) {
	val.ex += grad.ex[i];
      }
    } else {
      for (size_t i = net->Time() + lendelay; i < net->Time(); i++) {
	val.ex += grad.ex[i];
      }
    }    
  }
}

void DelayV::Save(NetSaver *saver)
{
  if (saver->GetSaveFunctionValue()) {
    saver->SetNamedDVSet("myval", myval);
  }
  saver->SetNamedInt("lendelay", lendelay);
  Function::Save(saver);
}

int DelayV::GetDelayLength()
{
  return lendelay;
}

void DelayV::SetDelayLength(int len)
{
  lendelay = len;
}

// class GaussianV : public Variable : public Node

GaussianV::GaussianV(Net *_net, Label label, Node *m, Node *v) : 
  Variable(_net, label, m, v), BiParNode(m, v)
{
  sstate = 0;
  sstep = 0;
  cost = 0;
  myval.mean.resize(net->Time());
  myval.var.resize(net->Time());
  exuptodate = false;
  costuptodate = false;

  CheckParent(0, REALV_MV);
  CheckParent(1, REALV_ME);

  DVH p0, p1;
  ParRealV(0, p0, DFlags(true));
  ParRealV(1, p1, DFlags(false, false, true));
  for (size_t i=0; i<net->Time(); i++) {
    myval.mean[i] = p0.Mean(i);
    myval.var[i] = 1/p1.Exp(i);
  }
}

void GaussianV::GetState(DV *state, size_t t)
{
  BBASSERT2(t < net->Time());
  
  state->resize(2);
  (*state)[0] = myval.mean[t];
  (*state)[1] = myval.var[t];
}

void GaussianV::SetState(DV *state, size_t t)
{
  BBASSERT2(t < net->Time());
  BBASSERT2(state->size() == 2);

  myval.mean[t] = (*state)[0];
  myval.var[t] = (*state)[1];

  costuptodate = false;
  exuptodate = false;
  
  OutdateChild();
}

double GaussianV::Cost()
{
  if (!clamped && children.empty())
    return 0;
  if (!costuptodate) {
    DVH p0, p1;
    ParRealV(0, p0, DFlags(true, true));
    ParRealV(1, p1, DFlags(true, false, true));

    double c = 0;
    if (clamped)
      for (size_t i = 0; i < net->Time(); i++) {
	// assert(myval.var[i] == 0);
	c += (Sqr(myval.mean[i] - p0.Mean(i)) + p0.Var(i) + myval.var[i]) *
	  p1.Exp(i) - p1.Mean(i);
      }
    else
      for (size_t i = 0; i < net->Time(); i++)
	c += (Sqr(myval.mean[i] - p0.Mean(i)) + p0.Var(i) + myval.var[i]) *
	  p1.Exp(i) - p1.Mean(i) - log(myval.var[i]);
    cost = clamped ? c/2 + net->Time() * _5LOG2PI : (c - net->Time()) / 2;
    costuptodate = true;
  }
  return cost;
}

bool GaussianV::GetRealV(DVH &val, DFlags req)
{
  if (req.ex && !exuptodate) {
    myval.ex.resize(net->Time());
    for (size_t i = 0; i < net->Time(); i++) {
      myval.ex[i] = exp(myval.mean[i] + myval.var[i] / 2);
    }
    exuptodate = true;
  }
  req.mean = false;
  req.var = false;
  req.ex = false;
  val.vec = &myval;
  return req.AllFalse();
}

void GaussianV::GradRealV(DVSet &val, const Node *ptr)
{
  if (!clamped && children.empty())
    return;

  if (ParIdentity(ptr)) { // ParMean(1)
    DVH p0;
    ParRealV(0, p0, DFlags(true, true));
    val.mean.resize(net->Time());
    val.ex.resize(net->Time());
    size_t i;
    for (i = 0; i < net->Time(); i++) {
      val.mean[i] -= 0.5;
      //if (clamped)
      //assert(myval.var[i] == 0);
      val.ex[i] += (Sqr(myval.mean[i] - p0.Mean(i)) + p0.Var(i) +
		    myval.var[i]) / 2;
    }
  }
  else {                  // ParMean(0)
    DVH p0, p1;
    ParRealV(0, p0, DFlags(true));
    ParRealV(1, p1, DFlags(false, false, true));
    val.mean.resize(net->Time());
    val.var.resize(net->Time());
    size_t i;
    for (i = 0; i < net->Time(); i++) {
      val.mean[i] += (p0.Mean(i) - myval.mean[i]) * p1.Exp(i);
      val.var[i] += p1.Exp(i) / 2;
    }
  }
}

void GaussianV::GradReal(DSSet &val, const Node *ptr)
{
  if (!clamped && children.empty())
    return;

  if (ParIdentity(ptr)) { // ParMean(1)
    DVH p0;
    ParRealV(0, p0, DFlags(true, true));
    size_t i;
    for (i = 0; i < net->Time(); i++) {
      //if (clamped)
      //assert(myval.var[i] == 0);
      val.ex += (Sqr(myval.mean[i] - p0.Mean(i)) + p0.Var(i) +
		 myval.var[i]) / 2;
    }
    val.mean -= 0.5 * net->Time();
  }
  else {                  // ParMean(0)
    DVH p0, p1;
    ParRealV(0, p0, DFlags(true));
    ParRealV(1, p1, DFlags(false, false, true));
    size_t i;
    for (i = 0; i < net->Time(); i++) {
      val.mean += (p0.Mean(i) - myval.mean[i]) * p1.Exp(i);
      val.var += p1.Exp(i) / 2;
    }
  }
}

void GaussianV::MyPartialUpdate(IntV *indices)
{
  if (NumChildren() == 0) {
    return;
  }

  DVSet grad;
  ChildGradRealV(grad);

  DVH p0, p1;
  ParRealV(0, p0, DFlags(true));
  ParRealV(1, p1, DFlags(false, false, true));

  double gm = 0, gv = 0, ge = 0;
  for (size_t j = 0; j < indices->size(); j++) {
    size_t i = (*indices)[j];

    if (grad.mean.size()) gm = grad.mean[i];
    if (grad.var.size()) gv = grad.var[i];
    if (grad.ex.size()) ge = grad.ex[i];

    VarNewton(myval.mean[i], myval.var[i], (myval.mean[i] - p0.Mean(i)) *
	      p1.Exp(i) + gm, p1.Exp(i)/2 + gv, ge, label);

  }
  exuptodate = false; costuptodate = false;
}

void GaussianV::MyUpdate()
{
  if (NumChildren() == 0) {
    return;
  }

  DVSet grad;
  ChildGradRealV(grad);

  DVH p0, p1;
  ParRealV(0, p0, DFlags(true));
  ParRealV(1, p1, DFlags(false, false, true));

  double gm = 0, gv = 0, ge = 0;
  for (size_t i = 0; i < net->Time(); i++) {
    if (grad.mean.size()) gm = grad.mean[i];
    if (grad.var.size()) gv = grad.var[i];
    if (grad.ex.size()) ge = grad.ex[i];

    VarNewton(myval.mean[i], myval.var[i], (myval.mean[i] - p0.Mean(i)) *
	      p1.Exp(i) + gm, p1.Exp(i)/2 + gv, ge, label);

  }
  exuptodate = false; costuptodate = false;
}

void GaussianV::Save(NetSaver *saver)
{
  saver->SetNamedDVSet("myval", myval);
  saver->SetNamedDouble("cost", cost);
  saver->SetNamedBool("exuptodate", exuptodate);
  if (sstate)
    saver->SetNamedDVSet("sstate", *sstate);
  if (sstep)
    saver->SetNamedDVSet("sstep", *sstep);
  Variable::Save(saver);
}

bool GaussianV::MyClamp(double m)
{
  fill(myval.mean.begin(), myval.mean.end(), m);
  fill(myval.var.begin(), myval.var.end(), 0);
  exuptodate = false;
  return true;
}

bool GaussianV::MyClamp(double m, double v)
{
  fill(myval.mean.begin(), myval.mean.end(), m);
  fill(myval.var.begin(), myval.var.end(), v);
  exuptodate = false;
  return true;
}

bool GaussianV::MyClamp(const DV &m)
{
  if (m.size() == myval.mean.size())
    copy(m.begin(), m.end(), myval.mean.begin());
  else {
    ostringstream msg;
    msg << "GaussianV::MyClamp: wrong vector size " << m.size() << " != "
	<< myval.mean.size();
    throw TypeException(msg.str());
  }
  fill(myval.var.begin(), myval.var.end(), 0);
  exuptodate = false;
  return true;
}

bool GaussianV::MyClamp(const DV &m, const DV &v)
{
  if (m.size() == myval.mean.size())
    copy(m.begin(), m.end(), myval.mean.begin());
  else {
    ostringstream msg;
    msg << "GaussianV::MyClamp: wrong vector size for mean " << m.size()
	<< " != " << myval.mean.size();
    throw TypeException(msg.str());
  }

  if (v.size() == myval.var.size())
    copy(v.begin(), v.end(), myval.var.begin());
  else {
    ostringstream msg;
    msg << "GaussianV::MyClamp: wrong vector size for var " << v.size()
	<< " != " << myval.var.size();
    throw TypeException(msg.str());
  }
  exuptodate = false;
  return true;
}

bool GaussianV::MySaveState()
{
  if (!sstate)
    sstate = new DVSet;

  sstate->mean.resize(net->Time()); sstate->var.resize(net->Time());
  for (size_t i=0; i<net->Time(); i++) {
    sstate->mean[i] = myval.mean[i];
    sstate->var[i]  = log(myval.var[i]);
  }

  return true;
}

bool GaussianV::MySaveStep()
{
  if (!sstate) return false;
  if (!sstep)
    sstep = new DVSet;

  sstep->mean.resize(net->Time()); sstep->var.resize(net->Time());
  switch (hookeflags) {
  case 0:
  case 1:
    for (size_t i=0; i<net->Time(); i++) {
      sstep->mean[i] = myval.mean[i] - sstate->mean[i];
      sstep->var[i]  = log(myval.var[i]) - sstate->var[i];
    }
    break;
  case 2:
  case 3:
    for (size_t i=0; i<net->Time(); i++) {
      if (sstep->mean[i] == 0)
	sstep->mean[i] = -1.0;
      else
	sstep->mean[i] = myval.mean[i] / sstate->mean[i];
      sstep->var[i]  = log(myval.var[i]) - sstate->var[i];
    }
    break;
  }
  return true;
}

bool GaussianV::MySaveRepeatedState(double alpha)
{
  if (!sstate || !sstep) {
    cerr << label << ": No saved state when trying to repeat step!" << endl;
    return false;
  }
  switch (hookeflags) {
  case 0:
  case 2:
    for (size_t i=0; i<net->Time(); i++) {
      sstate->mean[i] = sstate->mean[i] + alpha * sstep->mean[i];
      sstate->var[i] = sstate->var[i] + alpha * sstep->var[i];
    }
    break;
  case 1:
  case 3:
    for (size_t i=0; i<net->Time(); i++) {
      if (sstep->mean[i] > 0)
	sstate->mean[i] = sstate->mean[i] * exp(alpha * log(sstep->mean[i]));
    }
    break;
  }
  return true;
}

void GaussianV::MyRepeatStep(double alpha)
{
  if (!sstate || !sstep) {
    cerr << label << ": No saved state when trying to repeat step!" << endl;
    return;
  }
  switch (hookeflags) {
  case 0:
    for (size_t i=0; i<net->Time(); i++) {
      myval.mean[i] = sstate->mean[i] + alpha * sstep->mean[i];
      myval.var[i] = exp(sstate->var[i] + alpha * sstep->var[i]);
    }
    break;
  case 1:
    for (size_t i=0; i<net->Time(); i++) {
      myval.mean[i] = sstate->mean[i] + alpha * sstep->mean[i];
    }
    break;
  case 2:
    for (size_t i=0; i<net->Time(); i++) {
      if (sstep->mean[i] > 0)
	myval.mean[i] = sstate->mean[i] + exp(alpha * log(sstep->mean[i]));
      myval.var[i] = exp(sstate->var[i] + alpha * sstep->var[i]);
    }
    break;
  case 3:
    for (size_t i=0; i<net->Time(); i++) {
      if (sstep->mean[i] > 0)
	myval.mean[i] = sstate->mean[i] + exp(alpha * log(sstep->mean[i]));
    }
    break;
  }
  exuptodate = false; costuptodate = false;
}

bool GaussianV::MyClearStateAndStep()
{
  if (sstate) {
    delete sstate;
    sstate = 0;
  }
  if (sstep) {
    delete sstep;
    sstep = 0;
  }
  return true;
}

// class SparseGaussV : public GaussianV : public Variable : public Node

double SparseGaussV::Cost()
{
  if (!clamped) {
    return GaussianV::Cost();
  }
  if (!costuptodate) {
    DVH p0, p1;
    ParRealV(0, p0, DFlags(true, true));
    ParRealV(1, p1, DFlags(true, false, true));

    double c = 0;
    size_t j = 0;
    for (size_t i = 0; i < net->Time(); i++) {
      if (j < missing.size() && missing[j] == (int)i) {
	j++;
	if (!children.empty()) {
	  c += (Sqr(myval.mean[i] - p0.Mean(i)) + p0.Var(i) + myval.var[i]) *
	    p1.Exp(i) - p1.Mean(i) - log(myval.var[i]);
	}
      } else {
	c += (Sqr(myval.mean[i] - p0.Mean(i)) + p0.Var(i) + myval.var[i]) *
	  p1.Exp(i) - p1.Mean(i);
      }
    }
    cost = (c-(children.empty()?0:missing.size()))/2 + \
      (net->Time() - missing.size()) * _5LOG2PI;
    costuptodate = true;
  }
  return cost;
}

void SparseGaussV::Update()
{
  if (!clamped) {
    Variable::Update();
    return;
  }
  if (missing.size()) {
    DVSet grad;
    ChildGradRealV(grad);

    DVH p0, p1;
    ParRealV(0, p0, DFlags(true));
    ParRealV(1, p1, DFlags(false, false, true));
  
    double gm = 0, gv = 0, ge = 0;
    for (size_t j = 0; j < missing.size(); j++) {
      size_t i = missing[j];
      if (grad.mean.size()) gm = grad.mean[i];
      if (grad.var.size()) gv = grad.var[i];
      if (grad.ex.size()) ge = grad.ex[i];
      VarNewton(myval.mean[i], myval.var[i], (myval.mean[i] - p0.Mean(i)) *
		p1.Exp(i) + gm, p1.Exp(i)/2 + gv, ge, label);
    }
    exuptodate = false; costuptodate = false;
  }
}

void SparseGaussV::GradRealV(DVSet &val, const Node *ptr)
{
  // If mean parent asks for the gradient and we have children the calculation
  // doesn't need information about missing values => GaussianV::GradRealV
  // can do the calculations
  // if !clamped then the calculations can also be done by GaussianV::GradRealV
  if (!clamped || ((ParIdentity(ptr) == 0) && !children.empty())) {
    GaussianV::GradRealV(val, ptr);
    return;
  }
  if (ParIdentity(ptr)) { // ParMean(1)
    DVH p0;
    ParRealV(0, p0, DFlags(true, true));
    val.mean.resize(net->Time());
    val.ex.resize(net->Time());
    size_t j = 0;
    for (size_t i = 0; i < net->Time(); i++) {
      if (j < missing.size() && missing[j] == (int)i) {
	j++;
	if (!children.empty()) {
	  val.mean[i] -= 0.5;
	  val.ex[i] += (Sqr(myval.mean[i] - p0.Mean(i)) + p0.Var(i) +
			myval.var[i]) / 2;
	}
      } else {
	val.mean[i] -= 0.5;
	val.ex[i] += (Sqr(myval.mean[i] - p0.Mean(i)) + p0.Var(i)
		      + myval.var[i]) / 2;
      }
    }
  } else {                  // ParMean(0)
    BBASSERT2(children.empty()); //checked at beginning of function
    DVH p0, p1;
    ParRealV(0, p0, DFlags(true));
    ParRealV(1, p1, DFlags(false, false, true));
    val.mean.resize(net->Time());
    val.var.resize(net->Time());
    size_t j = 0;
    for (size_t i = 0; i < net->Time(); i++) {
      if (j < missing.size() && missing[j] == (int)i) {
	j++;
      }
      else {
	val.mean[i] += (p0.Mean(i) - myval.mean[i]) * p1.Exp(i);
	val.var[i] += p1.Exp(i) / 2;
      }
    }
  }
}

void SparseGaussV::GradReal(DSSet &val, const Node *ptr)
{
  // Same reason for if(...) that in SparseGaussV::GradRealV
  if (!clamped || ((ParIdentity(ptr) == 0) && !children.empty())) {
    GaussianV::GradReal(val, ptr);
    return;
  }

  if (ParIdentity(ptr)) { // ParMean(1)
    DVH p0;
    ParRealV(0, p0, DFlags(true, true));
    size_t j = 0;
    for (size_t i = 0; i < net->Time(); i++) {
      if (j < missing.size() && missing[j] == (int)i) {
	j++;
	if (!children.empty()) {
	  val.ex += (Sqr(myval.mean[i] - p0.Mean(i)) + p0.Var(i) +
		     myval.var[i]) / 2;
	}
      }
      else {
	val.ex += (Sqr(myval.mean[i] - p0.Mean(i)) + p0.Var(i)
		   + myval.var[i]) / 2;
      }
    }
    val.mean -= 0.5 * (net->Time() - (children.empty() ? missing.size() : 0));
  }
  else {                  // ParMean(0)
    BBASSERT2(children.empty()); //checked at beginning of function
    DVH p0, p1;
    ParRealV(0, p0, DFlags(true));
    ParRealV(1, p1, DFlags(false, false, true));
    size_t j = 0;
    for (size_t i = 0; i < net->Time(); i++) {
      if (j < missing.size() && missing[j] == (int)i) {
	j++;
      }
      else {
	val.mean += (p0.Mean(i) - myval.mean[i]) * p1.Exp(i);
	val.var += p1.Exp(i) / 2;
      }
    }
  }
}

void SparseGaussV::SparseClampDV(const DV &m, const IntV &mis) {
  ostringstream msg;
  Clamp(m);

  bool err = false;
  if (!err && mis.size() && mis.front() < 0) {
    msg << "SparseGaussV::SparseClamp: first item in vector too small "
	<< mis.front() << " < 0 " << endl;
    err = true;
  }
  for(size_t j = 0; !err && j + 1 < mis.size(); j++) {
    if (mis[j] >= mis[j+1]) {
      msg << "SparseGaussV::SparseClamp: vector not sorted " << mis[j]
	  << " >= " << mis[j+1] << endl;
      err = true;
    }
  }
  if (!err && mis.size() && mis.back() >= (int)net->Time()) {
    msg << "SparseGaussV::SparseClamp: last item in vector too big "
	<< mis.back() << " >= " << net->Time() << endl;
    err = true;
  }
  if (err) {
    missing.resize(0);
    throw TypeException(msg.str());
  }
  else {
    missing.resize(mis.size());
    copy(mis.begin(), mis.end(), missing.begin());
  }
  Update();
}

void SparseGaussV::SetMissing(IntV &mis)
{
  // first check that the vector mis is sorted
  for (size_t i = 0; i + 1 < mis.size(); i++) {
    if (mis[i] >= mis[i+1]) {
      throw StructureException("missing vector not sorted");
    }
  }

  missing.resize(mis.size());
  copy(mis.begin(), mis.end(), missing.begin());
}

void SparseGaussV::Save(NetSaver *saver)
{
  saver->SetNamedIntV("missing", missing);
  GaussianV::Save(saver);
}


// class DelayGaussV : public Variable : public Node

DelayGaussV::DelayGaussV(Net *_net, Label label, Node *m,
			 Node *v, Node *a,
			 Node *m0, Node *v0) :
  Variable(_net, label, m, v), NParNode(m, v, a, m0, v0)
{
  AddParent(a, false); AddParent(m0, false); AddParent(v0, false);
  sstate = 0; sstep = 0;
  cost = 0;
  myval.mean.resize(net->Time()); 
  myval.var.resize(net->Time());
  myval.ex.resize(net->Time());

  CheckParent(0, REALV_MV);
  CheckParent(1, REALV_ME);
  CheckParent(2, REALV_MV);
  CheckParent(3, REAL_MV);
  CheckParent(4, REAL_ME);
  MyUpdate();
}

double DelayGaussV::Cost()
{
  if (!clamped && children.empty())
    return 0;
  if (!costuptodate) {
    DVH p0, p1, p2;
    ParRealV(0, p0, DFlags(true, true));
    ParRealV(1, p1, DFlags(true, false, true));
    ParRealV(2, p2, DFlags(true, true));
    DSSet p3, p4;
    ParReal(3, p3, DFlags(true, true));
    ParReal(4, p4, DFlags(true, false, true));

    double c, m, v;
    if (clamped)
      c = (Sqr(myval.mean[0] - p3.mean) + p3.var) * p4.ex - p4.mean;
    else
      c = (Sqr(myval.mean[0] - p3.mean) + p3.var + myval.var[0]) * p4.ex -
        p4.mean - log(myval.var[0]);
    for (size_t i = 1; i < net->Time(); i++) {
      m = myval.mean[i-1] * p2.Mean(i) + p0.Mean(i);
      if (clamped) {
        v = Sqr(myval.mean[i-1]) * p2.Var(i) + p0.Var(i) + myval.var[i];
	c += (Sqr(myval.mean[i] - m) + v) * p1.Exp(i) - p1.Mean(i);
      }
      else {
        v = (Sqr(myval.mean[i-1]) + myval.var[i-1]) * p2.Var(i) +
          myval.var[i-1] * Sqr(p2.Mean(i)) + p0.Var(i) + myval.var[i];
	c += (Sqr(myval.mean[i] - m) + v) * p1.Exp(i) - p1.Mean(i) - log(myval.var[i]);
      }
    }
    cost = clamped ? c/2 + net->Time() * _5LOG2PI : (c - net->Time()) / 2;
    costuptodate = true;
  }
  return cost;
}

bool DelayGaussV::GetRealV(DVH &val, DFlags req)
{
  if (req.ex && !exuptodate) {
    myval.ex.resize(net->Time());
    for (size_t i = 0; i < net->Time(); i++)
      myval.ex[i] = exp(myval.mean[i] + myval.var[i] / 2);
    exuptodate = true;
  }
  req.mean = false;
  req.var = false;
  req.ex = false;
  val.vec = &myval;
  return req.AllFalse();
}

void DelayGaussV::GradRealV(DVSet &val, const Node *ptr)
{
  if (!clamped && children.empty())
    return;

  switch (ParIdentity(ptr)) {
  case 0:
    {
      DVH p0, p1, p2;
      ParRealV(0, p0, DFlags(true));
      ParRealV(1, p1, DFlags(false, false, true));
      ParRealV(2, p2, DFlags(true));
      val.mean.resize(net->Time());
      val.var.resize(net->Time());
      size_t i;
      for (i = 1; i < net->Time(); i++) {
	val.mean[i] += (p0.Mean(i) + myval.mean[i-1] * p2.Mean(i) -
			myval.mean[i]) * p1.Exp(i);
	val.var[i] += p1.Exp(i) / 2;
      }
    }
    break;
  case 1:
    {
      DVH p0, p2;
      ParRealV(0, p0, DFlags(true, true));
      ParRealV(2, p2, DFlags(true, true));
      val.mean.resize(net->Time());
      val.ex.resize(net->Time());
      size_t i;
      for (i = 1; i < net->Time(); i++) {
	val.mean[i] -= 0.5;
	val.ex[i] += (Sqr(myval.mean[i] - myval.mean[i-1] * p2.Mean(i) - p0.Mean(i)) +
		      p0.Var(i) + Sqr(myval.mean[i-1]) * p2.Var(i) +
		      (clamped ? 0 : myval.var[i-1] * (Sqr(p2.Mean(i)) + p2.Var(i)) +
		       myval.var[i])) / 2;
      }
    }
    break;
  case 2:
    {
      DVH p0, p1, p2;
      ParRealV(0, p0, DFlags(true));
      ParRealV(1, p1, DFlags(false, false, true));
      ParRealV(2, p2, DFlags(true));
      val.mean.resize(net->Time());
      val.var.resize(net->Time());
      size_t i;
      for (i = 1; i < net->Time(); i++) {
	val.mean[i] += p1.Exp(i) * (myval.var[i-1] * p2.Mean(i) + myval.mean[i-1] *
				    (p0.Mean(i) + myval.mean[i-1] * p2.Mean(i) -
				     myval.mean[i]));
	val.var[i] += (Sqr(myval.mean[i-1]) + myval.var[i-1]) * p1.Exp(i) / 2;
      }
    }
  }
}

void DelayGaussV::GradReal(DSSet &val, const Node *ptr)
{
  if (!clamped && children.empty())
    return;

  switch (ParIdentity(ptr)) {
  case 0:
    {
      DVH p0, p1, p2;
      ParRealV(0, p0, DFlags(true));
      ParRealV(1, p1, DFlags(false, false, true));
      ParRealV(2, p2, DFlags(true));
      size_t i;
      for (i = 1; i < net->Time(); i++) {
	val.mean += (p0.Mean(i) + myval.mean[i-1] * p2.Mean(i) -
		     myval.mean[i]) * p1.Exp(i);
	val.var += p1.Exp(i) / 2;
      }
    }
    break;
  case 1:
    {
      DVH p0, p2;
      ParRealV(0, p0, DFlags(true, true));
      ParRealV(2, p2, DFlags(true, true));
      size_t i;
      for (i = 1; i < net->Time(); i++)
	val.ex += (Sqr(myval.mean[i] - myval.mean[i-1] * p2.Mean(i) - p0.Mean(i)) +
		   p0.Var(i) + Sqr(myval.mean[i-1]) * p2.Var(i) +
		   (clamped ? 0 : myval.var[i-1] * (Sqr(p2.Mean(i)) + p2.Var(i)) +
		    myval.var[i])) / 2;
      if (net->Time() > 0) val.mean -= 0.5 * (net->Time() - 1);
    }
    break;
  case 2:
    {
      DVH p0, p1, p2;
      ParRealV(0, p0, DFlags(true));
      ParRealV(1, p1, DFlags(false, false, true));
      ParRealV(2, p2, DFlags(true));
      size_t i;
      for (i = 1; i < net->Time(); i++) {
	val.mean += p1.Exp(i) * (myval.var[i-1] * p2.Mean(i) + myval.mean[i-1] *
				 (p0.Mean(i) + myval.mean[i-1] * p2.Mean(i) -
				  myval.mean[i]));
	val.var += (Sqr(myval.mean[i-1]) + myval.var[i-1]) * p1.Exp(i) / 2;
      }
    }
    break;
  case 3:
    {
      DSSet p3, p4;
      ParReal(3, p3, DFlags(true));
      ParReal(4, p4, DFlags(false, false, true));
      val.mean += (p3.mean - myval.mean[0]) * p4.ex;
      val.var += p4.ex / 2;
    }
    break;
  case 4:
    {
      DSSet p3;
      ParReal(3, p3, DFlags(true, true));
      val.mean -= 0.5;
      if (clamped)
	val.ex += (Sqr(myval.mean[0] - p3.mean) + p3.var) / 2;
      else // !clamped
	val.ex += (Sqr(myval.mean[0] - p3.mean) + p3.var + myval.var[0]) / 2;
    }
  }
}

void DelayGaussV::MyUpdate()
{
  if (NumChildren() == 0) {
    return;
  }

  DVSet grad;
  ChildGradRealV(grad);

  if (grad.mean.size() < 1) {
    cout << GetLabel() << ": grad.mean.size() < 1, children:" << endl;
    for (size_t i = 0; i < NumChildren(); i++) {
      Node *c = GetChild(i);
      cout << c->GetLabel() << endl;
    }
  }

  BBASSERT2(grad.mean.size() > 0);

  DVH p0, p1, p2;
  ParRealV(0, p0, DFlags(true));
  ParRealV(1, p1, DFlags(false, false, true));
  ParRealV(2, p2, DFlags(true, true));
  DSSet p3, p4;
  ParReal(3, p3, DFlags(true));
  ParReal(4, p4, DFlags(false, false, true));

  double gm = 0, gv = 0, ge = 0;
  int i = net->Time() - 1;         // *** Does not work if net->Time() < 2
  if (grad.mean.size()) gm = grad.mean[i];
  if (grad.var.size()) gv = grad.var[i];
  if (grad.ex.size()) ge = grad.ex[i];


  VarNewton(myval.mean[i], myval.var[i], (myval.mean[i] - myval.mean[i-1] * p2.Mean(i) - p0.Mean(i)) *
            p1.Exp(i) + gm, p1.Exp(i)/2 + gv, ge, label);

  for (i--; i > 0; i--) {

    gm = (myval.mean[i] - myval.mean[i-1] * p2.Mean(i) - p0.Mean(i)) * p1.Exp(i) +
      (p2.Var(i+1) * myval.mean[i] + p2.Mean(i+1) * (p0.Mean(i+1) + myval.mean[i] * p2.Mean(i+1) -
						     myval.mean[i+1])) * p1.Exp(i+1);
    if (grad.mean.size()) gm += grad.mean[i];
    if (grad.var.size()) gv = grad.var[i];
    if (grad.ex.size()) ge = grad.ex[i];


    VarNewton(myval.mean[i], myval.var[i], gm,
	      p1.Exp(i)/2 + (Sqr(p2.Mean(i+1)) + p2.Var(i+1)) * p1.Exp(i+1) / 2 + gv, ge, label);


  }
  if (grad.mean.size()) gm = grad.mean[i];
  if (grad.var.size()) gv = grad.var[i];
  if (grad.ex.size()) ge = grad.ex[i];
  VarNewton(myval.mean[i], myval.var[i], (myval.mean[i] - p3.mean) * p4.ex +
            p2.Mean(i+1) * (p0.Mean(i+1) + myval.mean[i] * p2.Mean(i+1) - myval.mean[i+1]) * p1.Exp(i+1) + gm,
            p4.ex/2 + (Sqr(p2.Mean(i+1)) + p2.Var(i+1)) * p1.Exp(i+1) / 2 + gv, ge, label);
  exuptodate = false; costuptodate = false;
}

void DelayGaussV::Save(NetSaver *saver)
{
  saver->SetNamedDVSet("myval", myval);
  saver->SetNamedDouble("cost", cost);
  saver->SetNamedBool("exuptodate", exuptodate);
  if (sstate)
    saver->SetNamedDVSet("sstate", *sstate);
  if (sstep)
    saver->SetNamedDVSet("sstep", *sstep);
  Variable::Save(saver);
}

bool DelayGaussV::MySaveState()
{
  if (!sstate)
    sstate = new DVSet;

  sstate->mean.resize(net->Time()); sstate->var.resize(net->Time());
  for (size_t i=0; i<net->Time(); i++) {
    sstate->mean[i] = myval.mean[i];
    sstate->var[i]  = log(myval.var[i]);
  }

  return true;
}

bool DelayGaussV::MySaveStep()
{
  if (!sstate) return false;
  if (!sstep)
    sstep = new DVSet;

  sstep->mean.resize(net->Time()); sstep->var.resize(net->Time());
  switch (hookeflags) {
  case 0:
  case 1:
    for (size_t i=0; i<net->Time(); i++) {
      sstep->mean[i] = myval.mean[i] - sstate->mean[i];
      sstep->var[i]  = log(myval.var[i]) - sstate->var[i];
    }
    break;
  case 2:
  case 3:
    for (size_t i=0; i<net->Time(); i++) {
      if (sstep->mean[i] == 0)
	sstep->mean[i] = -1.0;
      else
	sstep->mean[i] = myval.mean[i] / sstate->mean[i];
      sstep->var[i]  = log(myval.var[i]) - sstate->var[i];
    }
    break;
  }
  return true;
}

bool DelayGaussV::MySaveRepeatedState(double alpha)
{
  if (!sstate || !sstep) {
    cerr << label << ": No saved state when trying to repeat step!" << endl;
    return false;
  }
  switch (hookeflags) {
  case 0:
  case 2:
    for (size_t i=0; i<net->Time(); i++) {
      sstate->mean[i] = sstate->mean[i] + alpha * sstep->mean[i];
      sstate->var[i] = sstate->var[i] + alpha * sstep->var[i];
    }
    break;
  case 1:
  case 3:
    for (size_t i=0; i<net->Time(); i++) {
      if (sstep->mean[i] > 0)
	sstate->mean[i] = sstate->mean[i] * exp(alpha * log(sstep->mean[i]));
    }
    break;
  }
  return true;
}

void DelayGaussV::MyRepeatStep(double alpha)
{
  if (!sstate || !sstep) {
    cerr << label << ": No saved state when trying to repeat step!" << endl;
    return;
  }
  switch (hookeflags) {
  case 0:
    for (size_t i=0; i<net->Time(); i++) {
      myval.mean[i] = sstate->mean[i] + alpha * sstep->mean[i];
      myval.var[i] = exp(sstate->var[i] + alpha * sstep->var[i]);
    }
    break;
  case 1:
    for (size_t i=0; i<net->Time(); i++) {
      myval.mean[i] = sstate->mean[i] + alpha * sstep->mean[i];
    }
    break;
  case 2:
    for (size_t i=0; i<net->Time(); i++) {
      if (sstep->mean[i] > 0)
	myval.mean[i] = sstate->mean[i] + exp(alpha * log(sstep->mean[i]));
      myval.var[i] = exp(sstate->var[i] + alpha * sstep->var[i]);
    }
    break;
  case 3:
    for (size_t i=0; i<net->Time(); i++) {
      if (sstep->mean[i] > 0)
	myval.mean[i] = sstate->mean[i] + exp(alpha * log(sstep->mean[i]));
    }
    break;
  }
  exuptodate = false; costuptodate = false;
}

bool DelayGaussV::MyClearStateAndStep()
{
  if (sstate) {
    delete sstate;
    sstate = 0;
  }
  if (sstep) {
    delete sstep;
    sstep = 0;
  }
  return true;
}


// class GaussNonlin : public Variable : public Node

void finds_gauss(double *mean1, double *var1, double *mean2, double *var2, double gme, double gva, double sprmean, double sprvar_inv)
  // Find an s with a prior N(sprmean,1/sprvar_inv)
  //    such that exp(-s^2) is close to N(fopt,foptvar)
  // mean2 and var2 have to be up-to-date and they are updated.
{
  int sign = (sprmean > 0) * 2 - 1;
  int iter1 = 0, iter2;
  double sE = fabs(*mean1) * sign;
  double svar = *var1;
  double sE_n, svar_n, d1, d2, C0, C1, fe, f2e, svar12, svar14;
  double foptvar, fopt;

  svar12 = 1 + 2 * svar;
  svar14 = 1 + 4 * svar;

  if (gva == 0) {
    *mean1 = sprmean;
    *var1 = 1 / sprvar_inv;
    fe = exp(-Sqr(sE)/(svar12))/sqrt(svar12);
    f2e = exp(-2*Sqr(sE)/(svar14))/sqrt(svar14);
    *mean2 = fe;
    *var2 = f2e - Sqr(fe);
    return;
  }

  foptvar = 0.5 / gva;
  fopt = *mean2 - foptvar * gme;

  // expected myval2 and expected squared myval2:
  fe = *mean2;
  f2e = *var2 + Sqr(fe);
  //  fe = exp(-Sqr(sE)/(svar12))/sqrt(svar12);
  //  f2e = exp(-2*Sqr(sE)/(svar14))/sqrt(svar14);
  // cost at the beginning
  C0 = (( sprvar_inv*(Sqr(sE - sprmean) + svar) 
          + (Sqr(fopt) - 2*fopt * fe + f2e) / foptvar) - log(svar)) / 2;

  while (iter1 < 100)  // an infinite loop should do!
  {
    // part of partial derivative of cost wrt. var:
    d2 = ( fopt*(-2*Sqr(sE)+svar12)*fe / Sqr(svar12)
	   + (4*Sqr(sE)-svar14)*f2e / Sqr(svar14)
	   ) / foptvar;
    if (d2 < 0)
      d2 = 0;
    d2 += 0.5*sprvar_inv;
    // partial derivative of cost wrt. mean:
    d1 = sprvar_inv*(sE-sprmean)         
         +sE*(sprvar_inv
	      +((2*fopt*fe/(svar12) - 2*f2e/(svar14))
		/foptvar));
    // fixed point iteration:
    svar_n = 1 / (2*d2);
    // approximate newton iteration:
    sE_n = sE - svar_n * d1;
    if (sE_n*sign < 0) 
      sE_n = 0;
    iter2 = 0;
    while (iter2 < 100)
    {
      fe = exp(-Sqr(sE_n)/(1+2*svar_n))/sqrt(1+2*svar_n);
      f2e = exp(-2*Sqr(sE_n)/(1+4*svar_n))/sqrt(1+4*svar_n);
      C1 = (( sprvar_inv*(Sqr(sE_n - sprmean) + svar_n) 
	      + (Sqr(fopt) - 2*fopt * fe + f2e) / foptvar) - log(svar_n)) / 2;
      // if the new cost is better than the old take the step.
      if (C1<=C0)
        break;
      iter2++;
      // too long step, take half back.
      sE_n = 0.5 * (sE + sE_n);
      svar_n = sqrt(svar * svar_n);
    }
    // if changed less than epsilon, exit loop
    if (fabs(sE_n - sE) + fabs(svar_n - svar) < NL_EPSILON)
      break;
    iter1++;
    sE = sE_n;
    svar = svar_n;
    svar12 = 1 + 2 * svar;
    svar14 = 1 + 4 * svar;
    C0 = C1;
  }
  *mean1 = sE_n;
  *var1 = svar_n;
  *mean2 = fe;
  *var2 = f2e - Sqr(fe);
}

double GaussNonlin::Cost()
  // Same as Gaussian::Cost() except myval replaced by myval1
{
  if (!clamped && children.empty())
    return 0;
  if (!costuptodate) {
    DSSet p0, p1;
    ParReal(0, p0, DFlags(true, true));
    ParReal(1, p1, DFlags(true, false, true));
    cost = clamped ? ((Sqr(myval1.mean - p0.mean) + p0.var) * p1.ex - p1.mean) / 2
      + _5LOG2PI : ((Sqr(myval1.mean - p0.mean) + p0.var + myval1.var) *
		    p1.ex - p1.mean - log(myval1.var) - 1) / 2;
    costuptodate = true;
  }
  return cost;
}

bool GaussNonlin::GetReal(DSSet &val, DFlags req)
{
  if (req.mean || req.var)
    UpdateMean();
  // mean has to be updated before var.
  if (req.mean) {
    val.mean = myval2.mean; 
    req.mean = false;
  }
  if (req.var) {
    UpdateVar();
    val.var = myval2.var;
    req.var = false;
  }

  return req.AllFalse();
}

void GaussNonlin::GradReal(DSSet &val, const Node *ptr)
  // Same as Gaussian::GradReal() except myval replaced by myval1
{
  if (!clamped && children.empty())
    return;

  if (ParIdentity(ptr)) { // ParMean(1)
    DSSet p0;
    ParReal(0, p0, DFlags(true, true));
    val.mean -= 0.5;
    if (clamped)
      val.ex += (Sqr(myval1.mean - p0.mean) + p0.var) / 2;
    else // !clamped
      val.ex += (Sqr(myval1.mean - p0.mean) + p0.var + myval1.var) / 2;
  }
  else {                  // ParMean(0)
    DSSet p0, p1;
    ParReal(0, p0, DFlags(true));
    ParReal(1, p1, DFlags(false, false, true));
    val.mean += (p0.mean - myval1.mean) * p1.ex;
    val.var += p1.ex / 2;
  }
}

void GaussNonlin::Save(NetSaver *saver)
{
  saver->SetNamedDSSet("myval1", myval1);
  saver->SetNamedDSSet("myval2", myval2);
  saver->SetNamedDouble("cost", cost);
  saver->SetNamedBool("meanuptodate", meanuptodate);
  saver->SetNamedBool("varuptodate", varuptodate);
  if (sstate)
    saver->SetNamedDSSet("sstate", *sstate);
  if (sstep)
    saver->SetNamedDSSet("sstep", *sstep);
  Variable::Save(saver);
}


void GaussNonlin::MyUpdate()
{
  DSSet grad, p0, p1;
  ChildGradReal(grad);
  ParReal(0, p0, DFlags(true));
  ParReal(1, p1, DFlags(false, false, true));
  UpdateMean();
  UpdateVar();
  finds_gauss(&myval1.mean, &myval1.var, &myval2.mean, &myval2.var, 
	      grad.mean, grad.var, p0.mean, p1.ex);

  meanuptodate = true; varuptodate = true; costuptodate = false;
}

void GaussNonlin::UpdateMean()
{
  if (!meanuptodate) {
    myval2.mean = exp(-Sqr(myval1.mean)/(1+2*myval1.var))/sqrt(1+2*myval1.var);
    meanuptodate = true;
  }
}

void GaussNonlin::UpdateVar()
{
  if (!varuptodate) {
    myval2.var = exp(-2*Sqr(myval1.mean)/(1+4*myval1.var))/sqrt(1+4*myval1.var)
                 - Sqr(myval2.mean);
    varuptodate = true;
  }
}

bool GaussNonlin::MySaveState()
{
  if (!sstate)
    sstate = new DSSet;

  sstate->mean = myval1.mean;
  sstate->var  = log(myval1.var);

  return true;
}

bool GaussNonlin::MySaveStep()
{
  if (!sstate) return false;
  if (!sstep)
    sstep = new DSSet;

  switch (hookeflags) {
  case 0:
  case 1:
    sstep->mean = myval1.mean - sstate->mean;
    sstep->var  = log(myval1.var) - sstate->var;
    break;
  case 2:
  case 3:
    if (sstate->mean == 0)
      sstep->mean = -1.0;
    else
      sstep->mean = myval1.mean / sstate->mean;
    sstep->var  = log(myval1.var) - sstate->var;
    break;
  }

  return true;
}

bool GaussNonlin::MySaveRepeatedState(double alpha)
{
  if (!sstate || !sstep) {
    cerr << label << ": No saved state when trying to repeat step!" << endl;
    return false;
  }
  switch (hookeflags) {
  case 0:
  case 2:
    sstate->mean = sstate->mean + alpha * sstep->mean;
    sstate->var = sstate->var + alpha * sstep->var;
    break;
  case 1:
  case 3:
    if (sstep->mean > 0)
      sstate->mean = sstate->mean * exp(alpha * log(sstep->mean));
    break;
  }
  return true;
}

void GaussNonlin::MyRepeatStep(double alpha)
{
  if (!sstate || !sstep) {
    cerr << label << ": No saved state when trying to repeat step!" << endl;
    return;
  }
  switch (hookeflags) {
  case 0:
    myval1.mean = sstate->mean + alpha * sstep->mean;
    myval1.var = exp(sstate->var + alpha * sstep->var);
    break;
  case 1:
    myval1.mean = sstate->mean + alpha * sstep->mean;
    break;
  case 2:
    if (sstep->mean > 0)
      myval1.mean = sstate->mean * exp(alpha * log(sstep->mean));
    myval1.var = exp(sstate->var + alpha * sstep->var);
    break;
  case 3:
    if (sstep->mean > 0)
      myval1.mean = sstate->mean * exp(alpha * log(sstep->mean));
    break;
  }
  meanuptodate = false; varuptodate = false; costuptodate = false;
}

bool GaussNonlin::MyClearStateAndStep()
{
  if (sstate) {
    delete sstate;
    sstate = 0;
  }
  if (sstep) {
    delete sstep;
    sstep = 0;
  }
  return true;
}


// class GaussNonlinV : public Variable : public Node

GaussNonlinV::GaussNonlinV(Net *_net, Label label, Node *m, Node *v) : 
  Variable(_net, label, m, v), BiParNode(m, v)
{
  sstate = 0; sstep = 0;
  cost = 0;
  myval1.mean.resize(net->Time()); myval1.var.resize(net->Time());
  myval2.mean.resize(net->Time()); myval2.var.resize(net->Time());

  CheckParent(0, REALV_MV);
  CheckParent(1, REALV_ME);
  MyUpdate();
}

double GaussNonlinV::Cost()
  // Same as GaussianV::Cost() except myval replaced by myval1
{
  if (!clamped && children.empty())
    return 0;
  if (!costuptodate) {
    DVH p0, p1;
    ParRealV(0, p0, DFlags(true, true));
    ParRealV(1, p1, DFlags(true, false, true));

    double c = 0;
    if (clamped)
      for (size_t i = 0; i < net->Time(); i++)
	c += (Sqr(myval1.mean[i] - p0.Mean(i)) + p0.Var(i)) *
	  p1.Exp(i) - p1.Mean(i);
    else
      for (size_t i = 0; i < net->Time(); i++)
	c += (Sqr(myval1.mean[i] - p0.Mean(i)) + p0.Var(i) + myval1.var[i]) *
	  p1.Exp(i) - p1.Mean(i) - log(myval1.var[i]);
    cost = clamped ? c/2 + net->Time() * _5LOG2PI : (c - net->Time()) / 2;
    costuptodate = true;
  }
  return cost;
}

void GaussNonlinV::GradReal(DSSet &val, const Node *ptr)
  // Same as GaussianV::GradReal except myval replaced by myval1
{
  if (!clamped && children.empty())
    return;

  if (ParIdentity(ptr)) { // ParMean(1)
    DVH p0;
    ParRealV(0, p0, DFlags(true, true));
    size_t i;
    for (i = 0; i < net->Time(); i++)
      val.ex += (Sqr(myval1.mean[i] - p0.Mean(i)) + p0.Var(i) +
		 (clamped ? 0 : myval1.var[i])) / 2;
    val.mean -= 0.5 * net->Time();
  }
  else {                  // ParMean(0)
    DVH p0, p1;
    ParRealV(0, p0, DFlags(true));
    ParRealV(1, p1, DFlags(false, false, true));
    size_t i;
    for (i = 0; i < net->Time(); i++) {
      val.mean += (p0.Mean(i) - myval1.mean[i]) * p1.Exp(i);
      val.var += p1.Exp(i) / 2;
    }
  }
}

bool GaussNonlinV::GetRealV(DVH &val, DFlags req)
{
  if (req.mean || req.var)
    UpdateMean();
  // mean has to be updated before var.
  if (req.var)
    UpdateVar();
  req.mean = false;
  req.var = false;
  val.vec = &myval2;
  return req.AllFalse();
}

void GaussNonlinV::GradRealV(DVSet &val, const Node *ptr)
  // Same as GaussianV::GradRealV except myval replaced by myval1
{
  if (!clamped && children.empty())
    return;

  if (ParIdentity(ptr)) { // ParMean(1)
    DVH p0;
    ParRealV(0, p0, DFlags(true, true));
    val.mean.resize(net->Time());
    val.ex.resize(net->Time());
    size_t i;
    for (i = 0; i < net->Time(); i++) {
      val.mean[i] -= 0.5;
      val.ex[i] += (Sqr(myval1.mean[i] - p0.Mean(i)) + p0.Var(i) +
		    (clamped ? 0 : myval1.var[i])) / 2;
    }
  }
  else {                  // ParMean(0)
    DVH p0, p1;
    ParRealV(0, p0, DFlags(true));
    ParRealV(1, p1, DFlags(false, false, true));
    val.mean.resize(net->Time());
    val.var.resize(net->Time());
    size_t i;
    for (i = 0; i < net->Time(); i++) {
      val.mean[i] += (p0.Mean(i) - myval1.mean[i]) * p1.Exp(i);
      val.var[i] += p1.Exp(i) / 2;
    }
  }
}

void GaussNonlinV::Save(NetSaver *saver)
{
  saver->SetNamedDVSet("myval1", myval1);
  saver->SetNamedDVSet("myval2", myval2);
  saver->SetNamedDouble("cost", cost);
  saver->SetNamedBool("meanuptodate", meanuptodate);
  saver->SetNamedBool("varuptodate", varuptodate);
  if (sstate)
    saver->SetNamedDVSet("sstate", *sstate);
  if (sstep)
    saver->SetNamedDVSet("sstep", *sstep);
  Variable::Save(saver);
}


void GaussNonlinV::MyUpdate()
{
  DVSet grad;
  ChildGradRealV(grad);

  DVH p0, p1;
  ParRealV(0, p0, DFlags(true));
  ParRealV(1, p1, DFlags(false, false, true));

  UpdateMean();
  UpdateVar();

  double gm = 0, gv = 0;
  for (size_t i = 0; i < net->Time(); i++) {
    if (grad.mean.size()) gm = grad.mean[i];
    if (grad.var.size()) gv = grad.var[i];
    finds_gauss(&myval1.mean[i], &myval1.var[i], &myval2.mean[i],
		&myval2.var[i], gm, gv, p0.Mean(i), p1.Exp(i));
  }

  meanuptodate = true; varuptodate = true; costuptodate = false;
}

void GaussNonlinV::UpdateMean()
{
  if (!meanuptodate) {
    for (size_t i = 0; i < myval1.mean.size(); i++)
      myval2.mean[i] = exp(-Sqr(myval1.mean[i])/(1+2*myval1.var[i])) /
	sqrt(1+2*myval1.var[i]);
    meanuptodate = true;
  }
}

void GaussNonlinV::UpdateVar()
{
  if (!varuptodate) {
    for (size_t i = 0; i < myval1.mean.size(); i++)
      myval2.var[i] = exp(-2*Sqr(myval1.mean[i])/(1+4*myval1.var[i])) /
	sqrt(1+4*myval1.var[i]) - Sqr(myval2.mean[i]);
    varuptodate = true;
  }
}

bool GaussNonlinV::MySaveState()
{
  if (!sstate)
    sstate = new DVSet;

  sstate->mean.resize(net->Time()); sstate->var.resize(net->Time());
  for (size_t i=0; i<net->Time(); i++) {
    sstate->mean[i] = myval1.mean[i];
    sstate->var[i]  = log(myval1.var[i]);
  }

  return true;
}

bool GaussNonlinV::MySaveStep()
{
  if (!sstate) return false;
  if (!sstep)
    sstep = new DVSet;

  sstep->mean.resize(net->Time()); sstep->var.resize(net->Time());
  switch (hookeflags) {
  case 0:
  case 1:
    for (size_t i=0; i<net->Time(); i++) {
      sstep->mean[i] = myval1.mean[i] - sstate->mean[i];
      sstep->var[i]  = log(myval1.var[i]) - sstate->var[i];
    }
    break;
  case 2:
  case 3:
    for (size_t i=0; i<net->Time(); i++) {
      if (sstep->mean[i] == 0)
	sstep->mean[i] = -1.0;
      else
	sstep->mean[i] = myval1.mean[i] / sstate->mean[i];
      sstep->var[i]  = log(myval1.var[i]) - sstate->var[i];
    }
    break;
  }
  return true;
}

bool GaussNonlinV::MySaveRepeatedState(double alpha)
{
  if (!sstate || !sstep) {
    cerr << label << ": No saved state when trying to repeat step!" << endl;
    return false;
  }
  switch (hookeflags) {
  case 0:
  case 2:
    for (size_t i=0; i<net->Time(); i++) {
      sstate->mean[i] = sstate->mean[i] + alpha * sstep->mean[i];
      sstate->var[i] = sstate->var[i] + alpha * sstep->var[i];
    }
    break;
  case 1:
  case 3:
    for (size_t i=0; i<net->Time(); i++) {
      if (sstep->mean[i] > 0)
	sstate->mean[i] = sstate->mean[i] * exp(alpha * log(sstep->mean[i]));
    }
    break;
  }
  return true;
}

void GaussNonlinV::MyRepeatStep(double alpha)
{
  if (!sstate || !sstep) {
    cerr << label << ": No saved state when trying to repeat step!" << endl;
    return;
  }
  switch (hookeflags) {
  case 0:
    for (size_t i=0; i<net->Time(); i++) {
      myval1.mean[i] = sstate->mean[i] + alpha * sstep->mean[i];
      myval1.var[i] = exp(sstate->var[i] + alpha * sstep->var[i]);
    }
    break;
  case 1:
    for (size_t i=0; i<net->Time(); i++) {
      myval1.mean[i] = sstate->mean[i] + alpha * sstep->mean[i];
    }
    break;
  case 2:
    for (size_t i=0; i<net->Time(); i++) {
      if (sstep->mean[i] > 0)
	myval1.mean[i] = sstate->mean[i] + exp(alpha * log(sstep->mean[i]));
      myval1.var[i] = exp(sstate->var[i] + alpha * sstep->var[i]);
    }
    break;
  case 3:
    for (size_t i=0; i<net->Time(); i++) {
      if (sstep->mean[i] > 0)
	myval1.mean[i] = sstate->mean[i] + exp(alpha * log(sstep->mean[i]));
    }
    break;
  }
  meanuptodate = false; varuptodate = false; costuptodate = false;
}

bool GaussNonlinV::MyClearStateAndStep()
{
  if (sstate) {
    delete sstate;
    sstate = 0;
  }
  if (sstep) {
    delete sstep;
    sstep = 0;
  }
  return true;
}




// class Discrete : public Variable : public Node

double Discrete::Cost()
{
  if (!costuptodate) {
    cost = 0;
    exsum = 0;
    DSSet p; DFlags req(true, false, true);
    if (clamped)
      for (size_t i=0; i<NumParents(); i++) {
	ParReal(i, p, req);
	cost -= myval[i] * p.mean;
	exsum += p.ex;
      }
    else
      for (size_t i=0; i<NumParents(); i++) {
	ParReal(i, p, req);
	// 0 * log(0) = 0
	if (myval[i])
	  cost += myval[i] * ( log(myval[i]) - p.mean );
	exsum += p.ex;
      }
    cost += log(exsum);
    costuptodate = true;
    exuptodate = true;
  }
  return cost;
}


bool Discrete::GetDiscrete(DD *&val)
{
  val = &myval;
  return true;
}

void Discrete::GradReal(DSSet &val, const Node *ptr)
{
  if (!clamped && children.empty())
    return;

  UpdateExpSum();
  int i = ParIdentity(ptr);
  val.mean -= myval[i];
  val.ex += 1 / exsum;
}

void Discrete::Save(NetSaver *saver)
{
  saver->SetNamedDD("myval", myval);
  saver->SetNamedDouble("cost", cost);
  saver->SetNamedDouble("exsum", exsum);
  saver->SetNamedBool("exuptodate", exuptodate);
  Variable::Save(saver);
}


void Discrete::MyUpdate()
{
  DD grad(NumParents());
  double maxval = 0;
  double sum = 0;
  myval.Resize(NumParents());

  ChildGradDiscrete(grad);
  DSSet p; DFlags req(true, false, false);
  // myval[i] first contains the value in log scale
  // maximum value is subtracted before exponentation to prevent overflows
  for (size_t i=0; i<NumParents(); i++) {
    ParReal(i, p, req);
    myval[i] = p.mean - grad[i];
    if (!i || maxval < myval[i])
      maxval = myval[i];
  }
  for (size_t i=0; i<NumParents(); i++) {
    myval[i] = exp(myval[i] - maxval);
    sum += myval[i];
  }
  for (size_t i=0; i<NumParents(); i++)
    myval[i] /= sum;
  costuptodate = false;
}

void Discrete::UpdateExpSum()
{
  DFlags req(false, false, true);
  exsum = 0;
  DSSet p;
  for (size_t i = 0; i < NumParents(); i++) {
    ParReal(i, p, req);
    exsum += p.ex;
  }
}



// class DiscreteV : public Variable : public Node

DiscreteV::DiscreteV(Net *_net, Label label, Node *n) : 
  Variable(_net, label, n), NParNode(n)
{
  myval.Resize(net->Time());
  exuptodate = false;
  cost = 0;
  if (n) {
    CheckParent(0, REALV_ME);
    MyUpdate();
  }
}

double DiscreteV::Cost()
{
  if (!costuptodate) {
    cost = 0;
    UpdateExpSum();
    DVH p; DFlags req(true, false, false);
    for (size_t i=0; i<NumParents(); i++) {
      ParRealV(i, p, req);
      if (clamped)
	for (size_t j=0; j<net->Time(); j++)
	  cost -= myval[j][i] * p.Mean(j);
      else
	for (size_t j=0; j<net->Time(); j++)
	  // 0 * log(0) = 0
	  if (myval[j][i])
	    cost += myval[j][i] * ( log(myval[j][i]) - p.Mean(j) );
    }
    for (size_t j=0; j<net->Time(); j++)
      cost += log(exsum[j]);
    costuptodate = true;
  }
  return cost;
}

bool DiscreteV::GetDiscreteV(VDDH &val)
{
  val.vec = &myval;
  return true;
}

void DiscreteV::GradReal(DSSet &val, const Node *ptr)
{
  if (!clamped && children.empty())
    return;

  int i = ParIdentity(ptr);
  UpdateExpSum();
  for (size_t j=0; j<net->Time(); j++) {
    val.mean -= myval[j][i];
    val.ex += 1 / exsum[j];
  }
}

void DiscreteV::GradRealV(DVSet &val, const Node *ptr)
{
  if (!clamped && children.empty())
    return;

  int i = ParIdentity(ptr);
  UpdateExpSum();
  val.mean.resize(net->Time());
  val.ex.resize(net->Time());
  for (size_t j=0; j<net->Time(); j++) {
    val.mean[j] -= myval[j][i];
    val.ex[j] += 1 / exsum[j];
  }
}

void DiscreteV::Save(NetSaver *saver)
{
  saver->SetNamedVDD("myval", myval);
  saver->SetNamedDouble("cost", cost);
  saver->SetNamedDV("exsum", exsum);
  saver->SetNamedBool("exuptodate", exuptodate);
  Variable::Save(saver);
}


void DiscreteV::MyUpdate()
{
  VDD grad(net->Time(), NumParents());
  DV maxval(net->Time());
  double sum;
  
  for (size_t i=0; i<net->Time(); i++) {
    if (myval.GetDDp(i))
      myval[i].Resize(NumParents());
    else
      myval.GetDDp(i) = new DD(NumParents());
  }

  ChildGradDiscreteV(grad);
  DVH p; DFlags req(true, false, false);
  // myval[i] first contains the value in log scale
  // maximum value is subtracted before exponentation to prevent overflows
  for (size_t i=0; i<NumParents(); i++) {
    ParRealV(i, p, req);
    for (size_t j=0; j<net->Time(); j++) {
      myval[j][i] = p.Mean(j) - grad[j][i];
      if (i==0 || maxval[j] < myval[j][i])
	maxval[j] = myval[j][i];
    }
  }
  for (size_t j=0; j<net->Time(); j++) {
    sum = 0;
    for (size_t i=0; i<NumParents(); i++) {
      myval[j][i] = exp(myval[j][i] - maxval[j]);
      sum += myval[j][i];
    }
    for (size_t i=0; i<NumParents(); i++)
      myval[j][i] /= sum;
  }
  costuptodate = false;
}

void DiscreteV::UpdateExpSum()
{
  exsum = DV(net->Time());
  DVH p; DFlags req(false, false, true);
  for (size_t i = 0; i < NumParents(); i++) {
    ParRealV(i, p, req);
    for (size_t j = 0; j < net->Time(); j++)
      exsum[j] += p.Exp(j);
  }
  exuptodate = true;
}


// class Memory : public Variable : public Node

// The potential is stored as oldval.mean x + oldval.var x^2 + oldval.ex exp(x)

void Memory::MyUpdate()
{
  DSSet grad;
  ChildGradReal(grad);

  bool needmean = ((grad.mean != 0) || (grad.var != 0));
  bool needvar = (grad.var != 0);
  bool needexp = (grad.ex != 0);

  DSSet val;
  ParReal(0, val, DFlags(needmean, needvar, needexp));
  if (grad.var) {
    grad.mean -= 2 * val.mean * grad.var;
  }
  oldval.mean += grad.mean;
  oldval.var += grad.var;
  oldval.ex += grad.ex;

  oldcost += grad.mean * val.mean + grad.var * val.mean*val.mean +
    grad.var * val.var + grad.ex * val.ex;

  double d = net->Decay();
  oldval.mean *= d;
  oldval.var *= d;
  oldval.ex *= d;
  oldcost *= d;
  costuptodate = false;
}

// Return the difference gained by the "old children"

double Memory::Cost()
{
  if (!costuptodate) {
    bool needmean = ((oldval.mean != 0) || (oldval.var != 0));
    bool needvar = (oldval.var != 0);
    bool needexp = (oldval.ex != 0);

    DSSet val;
    ParReal(0, val, DFlags(needmean, needvar, needexp));

    cost = oldval.mean * val.mean + oldval.var * val.mean*val.mean +
      oldval.var * val.var + oldval.ex * val.ex - oldcost;
    costuptodate = true;
  }
  return cost;
}

void Memory::GradReal(DSSet &grad, const Node *ptr)
{
  ChildGradReal(grad);
  grad.mean += oldval.mean;
  grad.var += oldval.var;
  grad.ex += oldval.ex;
  if (oldval.var) {
    DSSet val;
    ParReal(0, val, DFlags(true));
    grad.mean += 2 * val.mean * oldval.var;
  }
}

void Memory::Save(NetSaver *saver)
{
  saver->SetNamedDSSet("oldval", oldval);
  saver->SetNamedDouble("oldcost", oldcost);
  saver->SetNamedDouble("cost", cost);

  Variable::Save(saver);
}


// abstract class OnLineDelay : public Node

OnLineDelay::OnLineDelay(Net *ptr, Label label, Node *n1, Node *n2) :
  Node(ptr, label)
{
  if (n1)
    AddParent(n1, false);
  if (n2)
    AddParent(n2, false);
  persist = 1 | 4; // OnLineDelays usually need all parents and at least one child
  timetype = 3;
  net->AddOnLineDelay(this, label);
}

void OnLineDelay::Save(NetSaver *saver)
{
  Node::Save(saver);
}


// class OLDelayS : public OnLineDelay

void OLDelayS::StepTime()
{
  DSSet tmp;
  ParReal(1, tmp, DFlags(true));
  oldmean = tmp.mean;
  exuptodate = false;
  OutdateChild();
}

void OLDelayS::ResetTime()
{
  DSSet tmp;
  ParReal(0, tmp, DFlags(true));
  oldmean = tmp.mean;
  exuptodate = false;
  OutdateChild();
}

void OLDelayS::Save(NetSaver *saver)
{
  saver->SetNamedDouble("oldmean", oldmean);
  saver->SetNamedDouble("oldexp", oldexp);
  saver->SetNamedBool("exuptodate", exuptodate);
  OnLineDelay::Save(saver);
}

bool OLDelayS::GetReal(DSSet &val, DFlags req)
{
  if (req.ex && !exuptodate) {
    oldexp = exp(oldmean);
    exuptodate = true;
  }
  if (req.mean) {val.mean = oldmean; req.mean = false;}
  if (req.var) {val.var = 0; req.var = false;}
  if (req.ex) {val.ex = oldexp; req.ex = false;}
  return req.AllFalse();
}


// class OLDelayD : public OnLineDelay

void OLDelayD::StepTime()
{
  DD *tmp;
  ParDiscrete(1, tmp);
  oldval = *tmp;
  OutdateChild();
}

void OLDelayD::ResetTime()
{
  DD *tmp;
  ParDiscrete(0, tmp);
  oldval = *tmp;
  OutdateChild();
}

void OLDelayD::Save(NetSaver *saver)
{
  saver->SetNamedDD("oldval", oldval);
  OnLineDelay::Save(saver);
}


bool OLDelayD::GetDiscrete(DD *&val)
{
  val = &oldval;
  return true;
}


// class Proxy : public Node

Proxy::Proxy(Net *ptr, Label label, Label rlabel) :
  Node(ptr, label), UniParNode(0)
{
  if (label == rlabel) {
    ostringstream msg;
    msg << "Proxy node " << label << " tries to reference itself.";
    throw StructureException(msg.str());
  }
  reflabel = rlabel;
  net->AddProxy(this, label);
  req_discrete = false; req_discretev = false;
}

void Proxy::Save(NetSaver *saver)
{
  reflabel = GetParent(0)->GetLabel();
  saver->SetNamedLabel("reflabel", reflabel);
  saver->SetNamedBool("req_discrete", req_discrete);
  saver->SetNamedBool("req_discretev", req_discretev);
  saver->SetNamedDFlags("real_flags", real_flags);
  saver->SetNamedDFlags("realv_flags", realv_flags);
  Node::Save(saver);
}


bool Proxy::GetReal(DSSet &val, DFlags req)
{
  if (NumParents() == 0) {
    real_flags.Add(req);
    return true;
  }

  return ParReal(0, val, req);
}

bool Proxy::GetRealV(DVH &val, DFlags req)
{
  if (NumParents() == 0) {
    realv_flags.Add(req);
    return true;
  }

  return ParRealV(0, val, req);
}

bool Proxy::GetDiscrete(DD *&val)
{
  if (NumParents() == 0) {
    req_discrete = true;
    return true;
  }

  return ParDiscrete(0, val);
}

bool Proxy::GetDiscreteV(VDDH &val)
{
  if (NumParents() == 0) {
    req_discretev = true;
    return true;
  }

  return ParDiscreteV(0, val);
}

bool Proxy::CheckRef()
{
  Node *realref;

  if (NumParents() > 0)
    return true;

  realref = net->GetNode(reflabel);
  if (!realref) {
    ostringstream msg;
    msg << "Proxy: real node " << reflabel << " does not exist";
    throw StructureException(msg.str());
  }
  if (realref->GetType() == "Proxy") {
    ostringstream msg;
    msg << "Proxy " << label << ": tries to reference another proxy"
	<< reflabel;
    throw StructureException(msg.str());
  }

  AddParent(realref, true);

  if (!real_flags.AllFalse()) {
    DSSet p;
    if (! realref->GetReal(p, real_flags))
      throw StructureException("Proxy: wrong type of parent");
    real_flags.mean = false; real_flags.var = false; real_flags.ex = false;
  }
  if (!realv_flags.AllFalse()) {
    DVH p;
    if (! realref->GetRealV(p, realv_flags))
      throw StructureException("Proxy: wrong type of parent");
    realv_flags.mean = false; realv_flags.var = false; realv_flags.ex = false;
  }
  if (req_discrete) {
    DD *p;
    if (! realref->GetDiscrete(p))
      throw StructureException("Proxy: wrong type of parent");
    req_discrete = false;
  }
  if (req_discretev) {
    VDDH p;
    if (! realref->GetDiscreteV(p))
      throw StructureException("Proxy: wrong type of parent");
    req_discretev = false;
  }
  return true;
}


// class Evidence : public Variable : public Node

void Evidence::Save(NetSaver *saver)
{
  saver->SetNamedDouble("cost", cost);
  saver->SetNamedDouble("myval", myval);
  saver->SetNamedDouble("alpha", alpha);
  saver->SetNamedDouble("decay", decay);
  Variable::Save(saver);
  Decayer::Save(saver);
}

void Evidence::GradReal(DSSet &val, const Node *ptr)
{
  if (ParIdentity(ptr) == 0) {
    DSSet p;
    ParReal(0, p, DFlags(true, true));
    val.mean += (p.mean - myval) * alpha;
    val.var += alpha / 2;
  }
}

double Evidence::Cost()
{
  if (!costuptodate) {
    DSSet p;
    ParReal(0, p, DFlags(true, true));
    cost = (Sqr(myval - p.mean) + p.var) * alpha / 2;
    costuptodate = true;
  }
  return cost;
}

bool Evidence::DoDecay(string hook)
{
  alpha -= decay;
  if (alpha <= 0) {
    if (net->GetDebugLevel() > 10)
      Die(1);
    else
      Die(0);
    return false;
  }
  
  costuptodate = false;
  return true;
}

bool Evidence::MyClamp(double mean, double var)
{
  myval = mean;
  alpha = 1 / var;

  return true;
}


// class EvidenceV : public Variable : public Node

EvidenceV::EvidenceV(Net *ptr, Label label, Node *p) :
  Variable(ptr, label, p), Decayer(ptr), UniParNode(p)
{
  myval.resize(net->Time());
  alpha.resize(net->Time());
  decay.resize(net->Time());
  cost = 0;
  for (size_t i=0; i<myval.size(); i++)
    myval[i] = 0;
  for (size_t i=0; i<alpha.size(); i++)
    alpha[i] = 1e-10;
  for (size_t i=0; i<decay.size(); i++)
    decay[i] = 0;
}

void EvidenceV::Save(NetSaver *saver)
{
  saver->SetNamedDouble("cost", cost);
  saver->SetNamedDV("myval", myval);
  saver->SetNamedDV("alpha", alpha);
  saver->SetNamedDV("decay", decay);
  Variable::Save(saver);
  Decayer::Save(saver);
}

void EvidenceV::GradRealV(DVSet &val, const Node *ptr)
{
  if (ParIdentity(ptr) == 0) {
    DVH p;
    ParRealV(0, p, DFlags(true, true));
    val.mean.resize(net->Time());
    val.var.resize(net->Time());
    for (size_t i=0; i<net->Time(); i++) {
      val.mean[i] += (p.Mean(i) - myval[i]) * alpha[i];
      val.var[i] += alpha[i] / 2;
    }
  }
}

double EvidenceV::Cost()
{
  if (!costuptodate) {
    cost = 0;
    DVH p;
    ParRealV(0, p, DFlags(true, true));
    for (size_t i=0; i<net->Time(); i++)
      cost += (Sqr(myval[i] - p.Mean(i)) + p.Var(i)) * alpha[i] / 2;
    costuptodate = true;
  }
  return cost;
}

void EvidenceV::SetDecayTime(const DV &iters) {
  if (iters.size() == decay.size())
    for (size_t i=0; i<net->Time(); i++)
      decay[i] = alpha[i] / iters[i];
  else {
    ostringstream msg;
    msg << "EvidenceV::SetDecayTime: wrong vector size " << iters.size() << " != "
	<< decay.size();
    throw TypeException(msg.str());
  }
}

bool EvidenceV::DoDecay(string hook)
{
  bool alive = false;

  for (size_t i=0; i<net->Time(); i++) {
    alpha[i] -= decay[i];
    if (alpha[i] > 0)
      alive = true;
  }
  if (!alive) {
    Die(0);
    return false;
  }

  costuptodate = false;
  return true;
}

bool EvidenceV::MyClamp(double mean, double var)
{
  fill(myval.begin(), myval.end(), mean);
  fill(alpha.begin(), alpha.end(), 1/var);

  return true;
}

bool EvidenceV::MyClamp(const DV &mean, const DV &var)
{
  if (mean.size() == myval.size())
    copy(mean.begin(), mean.end(), myval.begin());
  else {
    ostringstream msg;
    msg << "EvidenceV::MyClamp: wrong vector size " << mean.size() << " != "
	<< myval.size();
    throw TypeException(msg.str());
  }

  if (var.size() == alpha.size())
    for (size_t i=0; i<net->Time(); i++)
      alpha[i] = 1 / var[i];
  else {
    ostringstream msg;
    msg << "EvidenceV::MyClamp: wrong vector size " << var.size() << " != "
	<< alpha.size();
    throw TypeException(msg.str());
  }

  return true;
}

