#!/usr/bin/env python
# -*- coding: iso-8859-1 -*-

#
# This file is a part of the Bayes Blocks library
#
# Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
# Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License (included in file License.txt in the
# program package) for more details.
#
# $Id: main.py 6 2006-10-26 09:54:58Z ah $
#

from bblocks import PyNet
from bblocks.Label import Label
from bblocks import Helpers

try:
    import numpy.oldnumeric as Numeric
except:
    import Numeric
import math
import random

random.seed(8)

try:
    data = Helpers.LoadMatlabArray("testdata.mat", "data")
    xdim = len(data)
    tdim = len(data[0])
except:
    try:
        data = Helpers.LoadAsciiData("testdata.asc")
        xdim = len(data)
        tdim = len(data[0])
    except:
        print "Data files testdata.mat and testdata.asc not found, generating new data"
        xdim = 50
        tdim = 100
        data = []
        for ix in range(xdim):
            tmp = PyNet.DV(tdim)
            for it in range(tdim):
                tmp[it] = math.sin(it+ix) + random.gauss(0, 0.1)
            data.append(tmp)

sdim = 15

mynet = PyNet.PyNet(tdim)
myfact = PyNet.PyNodeFactory(mynet)

const0 = myfact.GetConstant("const0", 0)
const1 = myfact.GetConstant("const1", -2)

vva = myfact.GetGaussian("vva", const0, const1)

mvs = myfact.GetGaussian("mvs", const0, const1)
vvs = myfact.GetGaussian("vvs", const0, const1)

mmx = myfact.GetGaussian("mmx", const0, const1)
vmx = myfact.GetGaussian("vmx", const0, const1)
mvx = myfact.GetGaussian("mvx", const0, const1)
vvx = myfact.GetGaussian("vvx", const0, const1)

va = []
for i in range(xdim):
    va.append(myfact.GetGaussian(Label("va", i), const0, vva))

vs = []
for i in range(sdim):
    vs.append(myfact.GetGaussian(Label("vs", i), mvs, vvs))

mx = []
for ix in range(xdim):
    mx.append(myfact.GetGaussian(Label("mx", ix), mmx, vmx))

vx = []
for ix in range(xdim):
    vx.append(myfact.GetGaussian(Label("vx", ix), mvx, vvx))

a = []
for i in range(sdim):
    a.append([])
    for ix in range(xdim):
        a[i].append(myfact.GetGaussian(Label("a", i, ix), const0, va[ix]))
        myfact.EvidenceNode(a[i][ix], mean=random.gauss(0, 0.31))

s = []
for i in range(sdim):
    s.append(myfact.GetGaussianV(Label("s", i), const0, vs[i]))

x = []
for ix in range(xdim):
    tmp = [mx[ix]]
    for i in range(sdim):
        tmp.append(myfact.GetProdV(Label("pr", i, ix), a[i][ix], s[i]))
    su = myfact.BuildSum2VTree(tmp, "sumtree")
    x.append(myfact.GetGaussianV(Label("x", ix), su, vx[ix]))
    x[ix].Clamp(data[ix])

conns = Numeric.sum(map(len,a))
print "Initially", conns, "connections"
iter = 0
print iter, ":", mynet.Cost()
while iter < 20:
    mynet.UpdateAll()
    iter += 1
    print iter, ":", mynet.Cost()

print "Starting pruning"

for i in s:
    i.SetPersist(5)

while iter < 1001:
    if iter % 10 == 0:
        mynet.TryPruning(mynet.GetVariables(r"a\(.*\)"), 0, 1)
    elif iter % 2 == 1:
        for i in range(sdim):
	    if mynet.GetVariable(Label("s", i)):
                ix = random.randrange(xdim)
                l = Label("a", i, ix)
                if not mynet.GetVariable(l):
                    Helpers.AddWeight(mynet, s[i], x[ix], l,
                                      const0, va[ix], -1)
    nc = len(mynet.GetVariables(r"a\(.*\)"))
    if nc != conns:
        print "Connections change:", conns, "->", nc
        conns = nc
    if iter % 10 == 5:
        mynet.UpdateAllHookeJeeves()
    else:
        mynet.UpdateAll()
    iter += 1
    print iter, ":", mynet.Cost()

print "Finally", conns, "connections"

try:
    mynet.SaveToMatFile("pythonnet.mat","pythonnet")
except:
    print "Saving in Pickle format instead"
    mynet.SaveWithPickle("pythonnet.pickle")
