//
// This file is a part of the Bayes Blocks library
//
// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
// Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2, or (at your option)
// any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License (included in file License.txt in the
// program package) for more details.
//
// $Id: main.cc 7 2006-10-26 10:26:41Z ah $

#include <sstream>
#include "Net.h"
#include "NodeFactory.h"
//#include "Node.h"

double c_randn()
{
  double d = -6;

  for (int i = 0; i < 12; i++)
    d += drand48();

  return d;
}

int main()
try {
  int xdim = 50, tdim = 100, sdim = 2, ix, it, is, iter = 0, i;

  Net *net = new Net(tdim);
  NodeFactory *fact = new NodeFactory(net);

  Constant *const0 = fact->GetConstant("const0", 0);

  Gaussian *mvs = fact->GetGaussian("mvs", const0, const0);
  Gaussian *vvs = fact->GetGaussian("vvs", const0, const0);
  Gaussian *mvx = fact->GetGaussian("mvx", const0, const0);
  Gaussian *vvx = fact->GetGaussian("vvx", const0, const0);

  vector<Gaussian *> vs;
  vector<Gaussian *> vs0;
  vector<Gaussian *> vx;
  vector<Gaussian *> mx;

  vector<Gaussian *> b0;
  vector<Gaussian *> b1;
  vector<ProdV *> prb1;
  vector<DelayV *> del1;
  vector<Proxy *> prx1;

  vector<Gaussian *> a;
  vector<DelayGaussV *> s;

  vector<ProdV *> pr;
  vector<Sum2V *> su;

  vector<GaussianV *> x;

  for (is = 0; is < sdim; is++)
    vs.push_back(fact->GetGaussian("vs", mvs, vvs));

  for (is = 0; is < sdim; is++)
    vs0.push_back(fact->GetGaussian("vs0", const0, const0));

  for (ix = 0; ix < xdim; ix++)
    vx.push_back(fact->GetGaussian("vx", mvx, vvx));

  for (ix = 0; ix < xdim; ix++)
    mx.push_back(fact->GetGaussian("mx", const0, const0));

  for (is = 0; is < sdim; is++)
    b0.push_back(fact->GetGaussian("b0", const0, const0));

  for (is = 0; is < sdim; is++)
    b1.push_back(fact->GetGaussian("b1", const0, const0));

  for (is = 0; is < sdim; is++) {
    ostringstream ss;
    ss << "prb1(" << ((is + sdim - 1) % sdim) << ")";
    prx1.push_back(fact->GetProxy("prx1", ss.str()));
    del1.push_back(fact->GetDelayV("del1", const0, prx1[is]));
  }

  for (is = 0; is < sdim; is++)
    for (ix = 0; ix < xdim; ix++)
      a.push_back(fact->GetGaussian("a", const0, const0));

  for (is = 0; is < sdim; is++) {
    ostringstream ss;
    ss << "prb1(" << is << ")";
    s.push_back(fact->GetDelayGaussV("s", del1[is], vs[is], b0[is], const0, vs0[is]));
    prb1.push_back(fact->GetProdV(ss.str(), s[is], b1[is]));
  }

  net->ConnectProxies();

  for (ix = 0; ix < xdim; ix++)
    for (is = 0; is < sdim; is++)
      pr.push_back(fact->GetProdV("pr", a[is * xdim + ix], s[is]));

  for (ix = 0; ix < xdim; ix++) {
    vector<Node *> tmp;
    for (is = 0; is < sdim; is++)
      tmp.push_back(pr[ix * sdim + is]);
    tmp.push_back(mx[ix]);
    while (tmp.size() > 1) {
      Sum2V *tmp2 = fact->GetSum2V("su", tmp[0], tmp[1]);
      tmp.erase(tmp.begin());
      tmp.erase(tmp.begin());
      tmp.push_back(tmp2);
      su.push_back(tmp2);
    }
    x.push_back(fact->GetGaussianV("x", tmp[0], vx[ix]));
  }

  for (ix = 0; ix < xdim; ix++) {
    DV tmp(tdim);
    for (it = 0; it < tdim; it++)
      tmp[it] = sin((double)it+ix)+c_randn()*0.1;
    x[ix]->Clamp(tmp);
  }

  for (i = 0; i < sdim * xdim; i++)
    a[i]->Clamp(drand48() - 0.5);
  cout << iter << ": " << net->Cost() << '\n';
  while (iter < 3) {
    net->UpdateAll();
    cout << ++iter << ": " << net->Cost() << '\n';
  }
  for (i = 0; i < sdim * xdim; i++)
    a[i]->Unclamp();

  try {
    net->SaveToMatFile("mynet.mat", "mynet");
  }
  catch(MatlabException e) {
    cout << "Error in saving: " << e.what();
  }

  while (iter < 20) {
    net->UpdateAll();
    cout << ++iter << ": " << net->Cost() << '\n';
    cout.flush();
  }

  /*
  cout << "\nVX:\n";
  for (ix = 0; ix < xdim; ix++)
    cout << 1/vx[ix]->GetExp() << ' ';
  cout << "\n\nVS:\n";
  for (is = 0; is < sdim; is++)
    cout << 1/vs[is]->GetExp() << ' ';
  cout << "\n\nA: average var\n";
  for (is = 0; is < sdim; is++) {
    double d = 0;
    for (ix = 0; ix < xdim; ix++)
      d += a[is*xdim + ix]->GetVar();
    cout << d / xdim << ' ';
  }
  cout << "\n\nA:\n";
  for (is = 0; is < sdim; is++) {
    for (ix = 0; ix < xdim; ix++)
      cout << a[is*xdim + ix]->GetMean() << ' ';
    cout << '\n';
  }
  */

  delete fact;
  delete net;

  return 0;
}
 catch (std::runtime_error &e) {
   cerr << "Runtime error: " << e.what() << endl;
 }
 catch (std::logic_error &e) {
   cerr << "Logic error: " << e.what() << endl;
 }
 catch (std::exception &e) {
   cerr << "Exception: " << e.what() << endl;
 }
 catch (...) {
   cerr << "Unknown exception." << endl;
 }
