# -*- coding: iso-8859-1 -*-
#
# Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
# Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License (included in file License.txt in the
# program package) for more details.
#
# $Id: osvar.py,v 1.1.1.1 2006/11/23 09:42:07 mha Exp $
#

from bblocks.Label import *
from bblocks.Helpers import *
from bblocks.PyNet import PyNet, PyNodeFactory
from bblocks import Learner

from Numeric import *
import RandomArray
import MLab

import util
import blocks

class Model:
    """More generic model encapsulating both source- and observation
    variance-models.
    """
    def __init__(self, config):
        self.config = config
        self.net = self.__buildnet()

    def addvarsources(self):
        if self.config["rblock"] is None:
            return

        print "Adding variance sources..."

        net = self.net
        mur = blocks.Priors(net, self.config["rdim"], "mur")
        vur = blocks.Priors(net, self.config["rdim"], "vur")

        rangedim = 0
        if self.config["rtos"]: rangedim += self.config["sdim"]
        if self.config["rtous"]: rangedim += self.config["sdim"]
        if self.config["rtox"]: rangedim += self.config["xdim"]
        if self.config["rtoux"]: rangedim += self.config["xdim"]
        assert rangedim, "Insane parameters!"
            
        rblk = self.config["rblock"](net, self.config["rdim"], rangedim,
                                     mur.outputs, vur.outputs, "rblk",
                                     linmap = self.config["rmap"])

        initdata = ()
        offset = 0

        if self.config["rtos"]:
            for i in range(self.config["sdim"]):
                util.replace_parent(self.s[i], self.c.c0,
                                    rblk.outputs[i])
            offset += self.config["sdim"]
            initdata += (util.getmean(self.s),)

        if self.config["rtous"]:
            for i in range(self.config["sdim"]):
                util.replace_parent(self.us[i], self.mus[i],
                                    rblk.outputs[offset + i])
            offset += self.config["sdim"]
            initdata += (util.getmean(self.us),)

        if self.config["rtox"]:
            for i in range(self.config["xdim"]):
                meanparent = self.x[i].GetParent(0) 
                if self.config["stox"]:
                    replacement = Sum2V(net, Label("rtox", i),
                                        rblk.outputs[offset + i], meanparent)
                else:
                    replacement = rblk.outputs[offset + i]
                # no child deserves a mean parent
                util.replace_parent(self.x[i], meanparent, replacement)
            offset += self.config["xdim"]
            initdata += (util.getmean(self.x),)

        if self.config["rtoux"]:
            for i in range(self.config["xdim"]):
                meanparent = self.ux[i].GetParent(0) 
                if self.config["stoux"]:
                    replacement = Sum2V(net, Label("rtoux", i),
                                        rblk.outputs[offset + i], meanparent)
                else:
                    replacement = rblk.outputs[offset + i]
                util.replace_parent(self.ux[i], meanparent, replacement)
            initdata += (util.getmean(self.ux),)

        rblk.initsources(concatenate(initdata), self.config["rdecay"])

        self.r = rblk.s
        self.ur = rblk.u
        self.rblk = rblk

    def addsources(self):
        if self.config["sblock"] is None:
            return

        print "Adding sources..."

        net = self.net
        mus = blocks.Priors(net, self.config["sdim"], "mus").outputs
        vus = blocks.Priors(net, self.config["sdim"], "vus").outputs

        if self.config["stoux"] and self.config["stox"]:
            rangedim = 2 * self.config["xdim"]
        else:
            rangedim = self.config["xdim"]

        sblk = self.config["sblock"](net, self.config["sdim"], rangedim,
                                     mus, vus, "sblk",
                                     linmap = self.config["smap"])

        initdata = ()
        offset = 0

        if self.config["stox"]:
            for i in range(self.config["xdim"]):
                util.replace_parent(self.x[i], self.c.c0, sblk.outputs[i])
            offset += self.config["xdim"]
            initdata += (util.getmean(self.x),)

        if self.config["stoux"]:
            for i in range(self.config["xdim"]):
                util.replace_parent(self.ux[i], self.mux[i],
                                    sblk.outputs[offset + i])
            initdata += (util.getmean(self.ux),)

        sblk.initsources(concatenate(initdata), self.config["sdecay"])

        self.sblk = sblk
        self.s = sblk.s
        self.us = sblk.u
        self.mus = mus

    def __buildnet(self):
        """For convenience only two linear maps are used as follows:
        [x' ux']'= A s and [s' us' x' ux']' = B r
        """
        net = PyNet(self.config["tdim"])
        nf = PyNodeFactory(net)
        c = blocks.Constants(net)

        if self.config["useux"]:
            mux = blocks.Priors(net, self.config["xdim"], "mux").outputs
            vux = blocks.Priors(net, self.config["xdim"], "vux").outputs
            ux = nf.MakeNodes("GaussianV", "ux", self.config["xdim"], mux, vux)
            map(lambda n: n.SetPersist(5), ux)
            varx = ux
            self.mux = mux
            self.ux = ux
        else:
            vx = blocks.Priors(net, self.config["xdim"], "vx").outputs
            varx = vx
            
        x = nf.MakeNodes("GaussianV", "x", self.config["xdim"], c.c0, varx)

        self.c = c
        self.x = x

        return net

    def clamp(self, data):
        assert data.shape == (self.config["xdim"], self.config["tdim"])
        var = self.config["clampwithvar"]

        for i in range(self.config["xdim"]):
            if var is None:
                self.x[i].Clamp(data[i,:])
            else:
                self.x[i].ClampWithVarV(data[i,:], repeat([var], self.tdim))

