//
// This file is a part of the Bayes Blocks library
//
// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
// Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2, or (at your option)
// any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License (included in file License.txt in the
// program package) for more details.
//
// $Id: NodeLoader.cc 7 2006-10-26 10:26:41Z ah $

#include "config.h"
#ifdef WITH_PYTHON
#include <Python.h>
#define __PYTHON_H_INCLUDED__
#endif

#include "Templates.h"
#include "Net.h"
#include "Node.h"
#include "Loader.h"


Node::Node(Net *ptr, NetLoader *loader, bool isproxy)
{
  Label parlabel;
  Node *par;

  persist = 0;
  dying = false;
  timetype = 0;

  net = ptr;
  loader->GetNamedLabel("label",   label);

  loader->StartEnumCont("parents");
  while( loader->GetLabel( parlabel ) )
  {
    if( !isproxy ) {
      if (!net->GetNode(parlabel)) {
	loader->FinishEnumCont("parents");
	throw StructureException("The saved network is not topologically sorted, cannot load!");
      }
      par = net->GetNode(parlabel);
      AddParent(par, false);
      loader->AddParent(par);
    }
  }
  loader->FinishEnumCont("parents");

  loader->StartEnumCont("children");
  loader->FinishEnumCont("children");

  loader->GetNamedInt ("persist",  persist);
  loader->GetNamedInt ("timetype", timetype);
  loader->GetNamedBool("dying",    dying);

  net->AddNode(this, label);
}


Constant::Constant(Net *net, NetLoader *loader)
    : Node( net, loader )
{
  loader->GetNamedDouble("cval", cval);
}


ConstantV::ConstantV(Net *net, NetLoader *loader)
    : Node( net, loader )
{
  loader->GetNamedDVSet("myval", myval);
}


Function::Function(Net *ptr, NetLoader *loader)
    : Node( ptr, loader )
{
  if (! loader->GetNamedDFlags("uptodate", uptodate)) {
    uptodate.mean = false;
    uptodate.var = false;
    uptodate.ex = false;
  }
}

Rectification::Rectification(Net *net, NetLoader *loader)
  : Function(net, loader), UniParNode(0)
{
  ;
}

Prod::Prod(Net *ptr, NetLoader *loader)
  : Function( ptr, loader ), BiParNode(0, 0)
{
  loader->GetNamedDouble("mean", mean);
  loader->GetNamedDouble("var",  var);
}


Sum2::Sum2(Net *ptr, NetLoader *loader)
    : Function( ptr, loader ), BiParNode(0, 0)
{
  if (! loader->GetNamedDSSet("myval", myval)) {
    uptodate.mean = false;
    uptodate.var = false;
    uptodate.ex = false;
  }
}

SumN::SumN(Net *ptr, NetLoader *loader)
    : Function( ptr, loader )
{
  loader->GetNamedDSSet("myval", myval);
  uptodate.mean = false;
  uptodate.var = false;
  uptodate.ex = false;
  keepupdated = false;
}

Relay::Relay(Net *ptr, NetLoader *loader)
  : Function( ptr, loader ), UniParNode(0)
{ }


Variable::Variable(Net *ptr, NetLoader *loader)
    : Node( ptr, loader )
{
  loader->GetNamedBool("clamped", clamped);
  loader->GetNamedBool("costuptodate", costuptodate);
  if (loader->GetNamedInt("hookeflags", hookeflags))
    hookeflags = 0;
  net->AddVariable(this, label);
}


Gaussian::Gaussian(Net *_net, NetLoader *loader)
    : Variable( _net, loader ), BiParNode(0, 0)
{
  DSSet temp;
  sstate = 0; sstep = 0;

  loader->GetNamedDSSet("myval", myval);
  loader->GetNamedDouble("cost", cost);
  loader->GetNamedBool("exuptodate", exuptodate);
  if (loader->GetNamedDSSet("sstate", temp)) {
    sstate = new DSSet();
    sstate->mean = temp.mean;
    sstate->var = temp.var;
    sstate->ex = temp.ex;
  }
  if (loader->GetNamedDSSet("sstep", temp)) {
    sstep = new DSSet();
    sstep->mean = temp.mean;
    sstep->var = temp.var;
    sstep->ex = temp.ex;
  }
}

RectifiedGaussian::RectifiedGaussian(Net *_net, NetLoader *loader)
  : Variable(_net, loader), BiParNode(0, 0)
{
  loader->GetNamedDSSet("myval", myval);
  loader->GetNamedDSSet("expectations", expectations);
  loader->GetNamedDouble("cost", cost);
}

RectificationV::RectificationV(Net *net, NetLoader *loader)
  : Function(net, loader), UniParNode(0)
{
  ;
}

ProdV::ProdV(Net *ptr, NetLoader *loader)
    : Function( ptr, loader ), BiParNode(0, 0)
{
  loader->GetNamedDVSet("myval", myval);
}


Sum2V::Sum2V(Net *ptr, NetLoader *loader)
    : Function( ptr, loader ), BiParNode(0, 0)
{
  loader->GetNamedDVSet("myval", myval);
}

SumNV::SumNV(Net *ptr, NetLoader *loader)
    : Function( ptr, loader )
{
  loader->GetNamedDVSet("myval", myval);
  uptodate.mean = false;
  uptodate.var = false;
  uptodate.ex = false;
  keepupdated = false;
}

DelayV::DelayV(Net *ptr, NetLoader *loader)
    : Function( ptr, loader ), BiParNode(0, 0)
{
  loader->GetNamedDVSet("myval", myval);
  if (!loader->GetNamedInt("lendelay", lendelay)) {
    lendelay = 1;
  }
}


GaussianV::GaussianV(Net *net, NetLoader *loader)
    : Variable( net, loader ), BiParNode(0, 0)
{
  DVSet temp;
  sstate = 0; sstep = 0;

  loader->GetNamedDVSet("myval", myval);
  loader->GetNamedDouble("cost", cost);
  loader->GetNamedBool("exuptodate", exuptodate);
  if (loader->GetNamedDVSet("sstate", temp)) {
    sstate = new DVSet();
    sstate->mean = temp.mean;
    sstate->var = temp.var;
    sstate->ex = temp.ex;
  }
  if (loader->GetNamedDVSet("sstep", temp)) {
    sstep = new DVSet();
    sstep->mean = temp.mean;
    sstep->var = temp.var;
    sstep->ex = temp.ex;
  }
}

RectifiedGaussianV::RectifiedGaussianV(Net *net, NetLoader *loader)
  : Variable(net, loader), BiParNode(0, 0)
{
  loader->GetNamedDVSet("myval", myval);
  loader->GetNamedDVSet("expectations", expectations);
  loader->GetNamedDouble("cost", cost);
}

GaussRect::GaussRect(Net *net, NetLoader *loader)
  : Variable(net, loader), BiParNode(0, 0)
{
  loader->GetNamedDSSet("posval", posval);
  loader->GetNamedDSSet("negval", negval);
  loader->GetNamedDouble("posweight", posweight);
  loader->GetNamedDouble("negweight", negweight);
  loader->GetNamedDouble("cost", cost);

  posmoments.resize(3);
  negmoments.resize(3);

  UpdateMoments();
  UpdateExpectations();
}

GaussRectV::GaussRectV(Net *net, NetLoader *loader)
  : Variable(net, loader), BiParNode(0, 0)
{
  loader->GetNamedDVSet("posval", posval);
  loader->GetNamedDVSet("negval", negval);
  loader->GetNamedDV("posweights", posweights);
  loader->GetNamedDV("negweights", negweights);
  loader->GetNamedDouble("cost", cost);

  expts.mean.resize(net->Time());
  expts.var.resize(net->Time());

  rectexpts.mean.resize(net->Time());
  rectexpts.var.resize(net->Time());

  posmoments.resize(3);
  negmoments.resize(3);

  for (int i = 0; i < 3; i++) {
    posmoments[i].resize(net->Time());
    negmoments[i].resize(net->Time());
  }

  UpdateMoments();
  UpdateExpectations();
}


MoG::MoG(Net *net, NetLoader *loader)
  : Variable(net, loader)
{
  numComponents = NumComponents();

  loader->GetNamedDouble("cost", cost);
  loader->GetNamedDSSet("expts", expts);

  Label label;

  loader->StartEnumCont("means");
  while(loader->GetLabel(label)) {
    Node *m = net->GetNode(label);
    if (!m) {
      loader->FinishEnumCont("means");
      throw StructureException("MoG::MoG: The saved network is not topologically sorted, cannot load!");
    }
    means.push_back(m);
  }
  loader->FinishEnumCont("means");

  loader->StartEnumCont("vars");
  while(loader->GetLabel(label)) {
    Node *v = net->GetNode(label);
    if (!v) {
      loader->FinishEnumCont("vars");
      throw StructureException("MoG::MoG: The saved network is not topologically sorted, cannot load!");
    }
    vars.push_back(v);
  }
  loader->FinishEnumCont("vars");

  BBASSERT(means.size() == numComponents && vars.size() == numComponents);

  DSSet val;

  loader->StartEnumCont("myval");
  while (loader->GetDSSet(val)) {
    myval.push_back(new DSSet(val));
  }
  loader->FinishEnumCont("myval");
  
  BBASSERT(myval.size() == numComponents);
}

MoGV::MoGV(Net *net, NetLoader *loader)
  : Variable(net, loader)
{
  numComponents = NumComponents();

  loader->GetNamedDouble("cost", cost);
  loader->GetNamedDVSet("expts", expts);

  Label label;

  loader->StartEnumCont("means");
  while(loader->GetLabel(label)) {
    Node *m = net->GetNode(label);
    if (!m) {
      loader->FinishEnumCont("means");
      throw StructureException("MoGV::MoGV: The saved network is not topologically sorted, cannot load!");
    }
    means.push_back(m);
  }
  loader->FinishEnumCont("means");

  loader->StartEnumCont("vars");
  while(loader->GetLabel(label)) {
    Node *v = net->GetNode(label);
    if (!v) {
      loader->FinishEnumCont("vars");
      throw StructureException("MoGV::MoGV: The saved network is not topologically sorted, cannot load!");
    }
    vars.push_back(v);
  }
  loader->FinishEnumCont("vars");

  BBASSERT(means.size() == numComponents && vars.size() == numComponents);

  DVSet val;

  loader->StartEnumCont("myval");
  while (loader->GetDVSet(val)) {
    myval.push_back(new DVSet(val));
  }
  loader->FinishEnumCont("myval");
  
  BBASSERT(myval.size() == numComponents);
}




Dirichlet::Dirichlet(Net *net, NetLoader *loader)
  : Variable(net, loader)
{
  loader->GetNamedDV("myval", myval);
  numComponents = myval.size();
  loader->GetNamedDVSet("expts", expts);
  loader->GetNamedDouble("cost", cost);
}

DiscreteDirichlet::DiscreteDirichlet(Net *net, NetLoader *loader)
  : Variable(net, loader)
{
  loader->GetNamedDD("myval", myval);
  loader->GetNamedDouble("cost", cost);
}

DiscreteDirichletV::DiscreteDirichletV(Net *net, NetLoader *loader)
  : Variable(net, loader)
{
  loader->GetNamedVDD("myval", myval);
  loader->GetNamedDouble("cost", cost);
}



SparseGaussV::SparseGaussV(Net *net, NetLoader *loader)
    : GaussianV( net, loader )
{
  loader->GetNamedIntV("missing", missing);
}

DelayGaussV::DelayGaussV(Net *net, NetLoader *loader)
    : Variable( net, loader )
{
  DVSet temp;
  sstate = 0; sstep = 0;

  loader->GetNamedDVSet("myval", myval);
  loader->GetNamedDouble("cost", cost);
  loader->GetNamedBool("exuptodate", exuptodate);
  if (loader->GetNamedDVSet("sstate", temp)) {
    sstate = new DVSet();
    sstate->mean = temp.mean;
    sstate->var = temp.var;
    sstate->ex = temp.ex;
  }
  if (loader->GetNamedDVSet("sstep", temp)) {
    sstep = new DVSet();
    sstep->mean = temp.mean;
    sstep->var = temp.var;
    sstep->ex = temp.ex;
  }
}


GaussNonlin::GaussNonlin(Net *net, NetLoader *loader)
    : Variable( net, loader ), BiParNode(0, 0)
{
  DSSet temp;
  sstate = 0; sstep = 0;

  loader->GetNamedDSSet("myval1", myval1);
  loader->GetNamedDSSet("myval2", myval2);
  loader->GetNamedDouble("cost", cost);
  loader->GetNamedBool("meanuptodate", meanuptodate);
  loader->GetNamedBool("varuptodate", varuptodate);
  if (loader->GetNamedDSSet("sstate", temp)) {
    sstate = new DSSet();
    sstate->mean = temp.mean;
    sstate->var = temp.var;
    sstate->ex = temp.ex;
  }
  if (loader->GetNamedDSSet("sstep", temp)) {
    sstep = new DSSet();
    sstep->mean = temp.mean;
    sstep->var = temp.var;
    sstep->ex = temp.ex;
  }
}


GaussNonlinV::GaussNonlinV(Net *net, NetLoader *loader)
    : Variable( net, loader ), BiParNode(0, 0)
{
  DVSet temp;
  sstate = 0; sstep = 0;

  loader->GetNamedDVSet("myval1", myval1);
  loader->GetNamedDVSet("myval2", myval2);
  loader->GetNamedDouble("cost", cost);
  loader->GetNamedBool("meanuptodate", meanuptodate);
  loader->GetNamedBool("varuptodate", varuptodate);
  if (loader->GetNamedDVSet("sstate", temp)) {
    sstate = new DVSet();
    sstate->mean = temp.mean;
    sstate->var = temp.var;
    sstate->ex = temp.ex;
  }
  if (loader->GetNamedDVSet("sstep", temp)) {
    sstep = new DVSet();
    sstep->mean = temp.mean;
    sstep->var = temp.var;
    sstep->ex = temp.ex;
  }
}

Discrete::Discrete(Net *net, NetLoader *loader)
    : Variable( net, loader )
{
  loader->GetNamedDD("myval", myval);
  loader->GetNamedDouble("cost", cost);
  loader->GetNamedDouble("exsum", exsum);
  loader->GetNamedBool("exuptodate", exuptodate);
}




DiscreteV::DiscreteV(Net *net, NetLoader *loader)
    : Variable( net, loader )
{
  loader->GetNamedVDD("myval", myval);
  loader->GetNamedDouble("cost", cost);
  loader->GetNamedDV("exsum", exsum);
  loader->GetNamedBool("exuptodate", exuptodate);
}



Memory::Memory(Net *net, NetLoader *loader)
  : Variable( net, loader ), UniParNode(0)
{
  if (GetParent(0)->TimeType()) {
    ostringstream msg;
    msg << GetIdent() << ": parent must be independent of time";
    throw StructureException(msg.str());
  }

  loader->GetNamedDSSet("oldval", oldval);
  loader->GetNamedDouble("oldcost", oldcost);
  loader->GetNamedDouble("cost", cost);
}


OnLineDelay::OnLineDelay(Net *ptr, NetLoader *loader)
    : Node( ptr, loader )
{
  net->AddOnLineDelay(this, label);
}


OLDelayS::OLDelayS(Net *ptr, NetLoader *loader)
    : OnLineDelay( ptr, loader ), BiParNode(0, 0)
{
  loader->GetNamedDouble("oldmean", oldmean);
  loader->GetNamedDouble("oldexp", oldexp);
  loader->GetNamedBool("exuptodate", exuptodate);
}


OLDelayD::OLDelayD(Net *ptr, NetLoader *loader)
    : OnLineDelay( ptr, loader ), BiParNode(0, 0)
{
  loader->GetNamedDD("oldval", oldval);
}


Proxy::Proxy(Net *ptr, NetLoader *loader)
  : Node( ptr, loader, true ), UniParNode(0)
{
  loader->GetNamedLabel("reflabel", reflabel);
  net->AddProxy(this, label);
  loader->GetNamedBool("req_discrete", req_discrete);
  loader->GetNamedBool("req_discretev", req_discretev);
  loader->GetNamedDFlags("real_flags", real_flags);
  loader->GetNamedDFlags("realv_flags", realv_flags);
}


Evidence::Evidence(Net *ptr, NetLoader *loader)
  : Variable(ptr, loader), Decayer(ptr, loader), UniParNode(0)
{
  loader->GetNamedDouble("cost", cost);
  loader->GetNamedDouble("myval", myval);
  loader->GetNamedDouble("alpha", alpha);
  loader->GetNamedDouble("decay", decay);
}

EvidenceV::EvidenceV(Net *ptr, NetLoader *loader)
  : Variable(ptr, loader), Decayer(ptr, loader), UniParNode(0)
{
  loader->GetNamedDouble("cost", cost);
  loader->GetNamedDV("myval", myval);
  loader->GetNamedDV("alpha", alpha);
  loader->GetNamedDV("decay", decay);
}
