# -*- coding: iso-8859-1 -*-

#
# Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
# Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License (included in file License.txt in the
# program package) for more details.
#

from __future__ import nested_scopes #needed for python2.1
from bblocks.Label import Label, Unlabel
import bblocks.Net as Net
import bblocks.Helpers as Helpers
import bblocks.PyNet as PyNet
import bblocks.Learner as Learner
import bblocks.PickleHelpers as PickleHelpers
import math
import random
try:
    import numpy.oldnumeric as Numeric
    import numpy as MLab
except:
    import Numeric
    import MLab


def prunefunc(x):
    return apply(PyNet.PyNet.TryPruning,
                 (x, Helpers.GetLabel(apply(PyNet.PyNet.GetVariables,
                                            (x, 'A'))),0, 0.01))

class HnfaNet(PickleHelpers.Pickleable):
    def __init__(self, data, dim, seed=None, directconnect = -1,
                 sourcetype="fa"):
        """ Generates an net.
    
        Parameters:
        TODO...
        """
        self.data = Numeric.array(data)
        if type(dim) != type(()):
            self.dim = (len(self.data), 0, dim)
        elif len(dim) == 2:
            self.dim = (len(self.data),) + dim
        else:
            raise ValueError, "dim must be an integer or a tuple of length 2"

        if self.dim[0] < self.dim[2]:
            raise ValueError, "Input dimension can't be larger than output dimension:" + `dim`

        self.seed = seed
        self.random = random.Random(self.seed)
        self.directconnect = directconnect
        self.sourcetype = sourcetype
    
        self.net=PyNet.PyNet(len(self.data[0]))
	self.fact = PyNet.PyNodeFactory(self.net)

        const0 = self.fact.GetConst0()
        self.fact.GetConstant("const_1_2", -1)
        self.fact.GetConstant("c0", -30)
        self.fact.GetConstant("c1", -7)
        self.fact.GetConstant("c2", -4)
        self.fact.GetConstant("c3", -3)
        # Are those c[0-3] the correct ones?
        # Should the const_1_2 be kept?
        self.net.priorlist = {
            'mvs(0)': ('const0', 'c1'),
            'vvs(0)': ('const0', 'c2'),
            'mvs(1)': ('const0', 'c1'),
            'vvs(1)': ('const_1_2', 'c2'),
            'mvA_in(1)': ('const0', 'c1'),
            'vvA_in(1)': ('const0', 'c2'),
            'mvA_out(1)': ('const0', 'c1'),
            'vvA_out(1)': ('const0', 'c2'),
            'mvA_in(0)': ('const0', 'c1'),
            'vvA_in(0)': ('const0', 'c2')}

        # Using PCA get the most importat direction(s) to use as a priori data.
        pcomp = Helpers.DoPCA(self.data, self.dim[2])
        pcompDV = Helpers.Array2DV(pcomp)

        s2 = self.MakeSources()
        for i in range(len(s2)):
            self.fact.EvidenceVNode(s2[i], mean=pcompDV[i],
				    var=0.001, decay=40.0)
        
        dataDV = Helpers.Array2DV(self.data)

        for i in range(self.dim[0]):
            prod = []
            for j in range(self.dim[2]):
                prod.append(self.AddOneWeight((2, 0, j ,i), lazy=1))
            ms = self.fact.GetGaussian(Label("ms", 0, i), const0, const0)
            sum = self.fact.BuildSum2VTree(prod + [ms], ("sums", 0, i))
            s = self.fact.GetGaussianV(Label("s", 0, i), sum,
                              self.net.GetGaussianNode(Label("vs", 0, i)))
            s.Clamp(dataDV[i])

        self.nexthidden = 0

        self.learner = Learner.Learner(self.net, prunefunc=prunefunc)
        self.learner.HistoryAdd("Begin")
        self.AddHidden(self.dim[1])
        self.net.SortNodes()

    def MakeSources(self):
        if self.sourcetype not in ('difa', 'ifa', 'fa', 'dfa'):
            raise RuntimeError('Unknown sourcetype: ' + self.sourcetype)
        ms2 = self.fact.GetConst0()

        if self.sourcetype in ('fa', 'dfa'):
            vs2 = self.fact.GetConst0()
        else:
            mmvs2 = self.fact.GetConst0()
            vmvs2 = self.fact.GetConst0()
            mvs2 = self.fact.MakeNodes('Gaussian', 'mvs(2)', self.dim[2],
				       mmvs2, vmvs2)
            for i in range(len(mvs2)):
                self.fact.EvidenceNode(mvs2[i], mean=0, var=1, decay=1)
            mvvs2 = self.fact.GetConst0()
            vvvs2 = self.fact.GetConst0()
            vvs2 = self.fact.MakeNodes('Gaussian', 'vvs(2)', self.dim[2],
				       mvvs2, vvvs2)
            for i in range(len(vvs2)):
                self.fact.EvidenceNode(vvs2[i], mean=0, var=1, decay=1)
            vs2 = self.fact.MakeNodes('GaussianV', 'vs(2)', self.dim[2],
				      mvs2, vvs2)
            for i in range(len(vs2)):
                self.fact.EvidenceVNode(vs2[i], mean=0, var=1, decay=1)
        if self.sourcetype in ('ifa', 'fa'):
            s2 = self.fact.MakeNodes('GaussianV', 's(2)', self.dim[2],
				     ms2, vs2)
        else:
            mas2 = self.fact.GetConst0()
            vas2 = self.fact.GetConst0()
            as2 = self.fact.MakeNodes('Gaussian', 'as(2)', self.dim[2],
				      mas2, vas2)
            for i in range(len(as2)):
                self.fact.EvidenceNode(as2[i], mean=0, var=1, decay=1)

            m0s2 = self.fact.GetConst0()

            mv0s2 = self.fact.GetConst0()
            vv0s2 = self.fact.GetConst0()
            v0s2 = self.fact.GetGaussian('v0s(2)', mv0s2, vv0s2)
            self.fact.EvidenceNode(v0s2, mean=0, var=1, decay=1)

            s2= self.fact.MakeNodes('DelayGaussV', 's(2)', self.dim[2],
				    ms2, vs2, as2, m0s2, v0s2)
        return s2

    def SaveMeanV(self, outfile, nodes='s\(2'):
        if type(nodes) == type(''):
            nodes = self.net.GetNodes(nodes)
        f=open(outfile,'w')
        for x in Numeric.transpose(Numeric.array(Helpers.GetMeanV(nodes))):
            f.write(reduce(lambda x,y:x+" "+`y`,x,""))
            f.write('\n')
        f.close()

    def HistoryAdd(self, record, data = ()):
        self.learner.HistoryAdd(record, data)

    def AddHidden(self, num=1, parentprob=1.0, childprob=1.0,
                  scalemean = 0.5, scalestd = 0.5, vsdecay = 1000):
        if num == 0:
            return
        weights=[]
        childind = []
        bias = []
        for j in range(num):
            scale = math.exp(self.random.gauss(scalemean, scalestd))
            parentind = []
            for i in range(self.dim[2]):
                if self.random.random() < parentprob:
                    parentind.append(1)
                else:
                    parentind.append(0)
            childind.append([])
            for i in range(self.dim[0]):
                if self.random.random() < childprob:
                    childind[j].append(i)
            weights.append([])
            for i in range(len(parentind)):
                if parentind[1]:
                    weights[j].append(self.random.gauss(
                        0, scale/math.sqrt(len(parentind))))
                else:
                    weights[j].append(None)
            bias.append(self.random.gauss(0, scale))
        self.AddHidden2(weights, bias, childind, vsdecay)

    def AddHidden2(self, weights, bias, childind=None, vsdecay = 1000):
        if len(weights) == 0:
            return
        if childind is None:
            childind = [range(self.dim[0])]*len(weights)
        sumcache = []
        Alist = []
        for i in range(self.dim[0]):
            sumcache.append([])
        for j in range(len(weights)):
            vA_in = self.net.GetGaussianNode(Label("vA_in",
                                                   1, self.nexthidden))
            vA_out = self.net.GetGaussianNode(Label("vA_out",
                                                    1, self.nexthidden))
            prod = []
            for i in range(len(weights[j])):
                if weights[j][i] is None:
                    continue
                A = self.fact.GetGaussian(
		    Label("A", 2, 1, i, self.nexthidden),
		    self.fact.GetConst0(), vA_in)
                self.fact.EvidenceNode(
                    A, mean=weights[j][i],
                    var=0.01, decay=40.0)
                prod.append(self.fact.GetProdV(
			Label("prA", 2, 1, i, self.nexthidden),
			A, self.net.GetNode(Label("s", 2, i))))
                A.SetPersist(5)
            ms = (self.fact.GetGaussian(Label("ms", 1, self.nexthidden),
                               self.fact.GetConst0(),
                               self.fact.GetConst0()))
            ms.SetPersist(5)
            self.fact.EvidenceNode(
                ms, mean=bias[j],
                var=0.01, decay=40.0)
            sum = self.fact.BuildSum2VTree(prod + [ms], ("sums",
                                                        1, self.nexthidden))
            vs = self.net.GetGaussianNode(Label("vs", 1, self.nexthidden))
            self.fact.EvidenceNode(vs, mean=9, var=1, decay=vsdecay)
            s = self.fact.GetGaussNonlinV(Label("s", 1, self.nexthidden),
					  sum, vs)
            s.Update()
	    s.SetPersist(5)
            for i in childind[j]:
                vA = self.fact.GetSum2(Label("vA", 1, 0, self.nexthidden, i),
                          vA_out, self.net.GetGaussianNode(Label("vA_in", 0, i)))
                A = self.fact.GetGaussian(Label("A", 1, 0, self.nexthidden, i),
					  self.fact.GetConst0(), vA)
                prA = self.fact.GetProdV(
		    Label("prA", 1, 0, self.nexthidden, i),
		    A, s)
                sumcache[i].append(prA)
                Alist.append(A)
            self.nexthidden += 1
        for i in range(self.dim[0]):
            if len(sumcache[i]) == 0:
                continue
            sum = self.fact.BuildSum2VTree(sumcache[i], ("sums", 0, i))
            Helpers.AddSum2V(self.net,
                             self.net.GetNode(Label("s", 0, i)),
                             0, sum, Label("sums", 0, i))
        map(Net.Variable.Update,Alist)
        self.net.SortNodes()

    def AddHiddenBest(self, num=5, numtest=1000,
                      scalemean = 0.5, scalestd = 0.5, vsdecay = 1000):
        if num == numtest:
            return self.AddHidden(num=num,scalemean=scalemean,
                                  scalestd=scalestd,vsdecay=vsdecay)
        (weights, bias, costchange)=self.HiddenCost(
            numtest=numtest, scalemean=scalemean, scalestd=scalestd)
        cs = map(Numeric.sum, costchange)
        s=list(Numeric.argsort(cs))
        ind = s[-5:]
        ind.reverse()
        self.AddHidden2(weights=Numeric.take(weights,ind),
                        bias=Numeric.take(bias,ind),
                        vsdecay=vsdecay)

    def HiddenCost(self, numtest=1, scalemean = 0.5, scalestd = 0.5):
        sources = Numeric.array(Helpers.GetMeanV(
            map(lambda i:self.net.GetVariable(Label('s', 2, i)),
                range(self.dim[2]))),'d')
        gradmean = Numeric.zeros((self.dim[0], self.net.Time()),'d')
        gradvar = Numeric.zeros((self.dim[0], self.net.Time()),'d')
        exp = []
        for i in range(self.dim[0]):
            s0 = self.net.GetVariable(Label('s', 0, i))
            dsset = Net.DVSet()
            s0.GradRealV(dsset, s0.GetParent(0))
            gradmean[i] = dsset.mean
            gradvar[i] = dsset.var
            exp.append(Helpers.GetExp(s0.GetParent(1)))

        weights=[]
        bias=[]
        costchange=[]
        for j in range(numtest):
            scale = math.exp(self.random.gauss(scalemean, scalestd))
            weights.append([])
            for i in range(self.dim[2]):
                weights[j].append(self.random.gauss(
                    0,scale/math.sqrt(self.dim[2])))
            bias.append(self.random.gauss(0, scale))
            try:
                mean1=Numeric.dot(Numeric.array([weights[j]]), sources)+bias[j]
                #This clip should take care of overflows
                mean1=Numeric.clip(mean1,-18,18)
                mean2=Numeric.exp(-(mean1**2))
                mean2=mean2-MLab.mean(mean2,1)
                #mean2=Numeric.exp(-(mean1**2)/(1+2*mvs1))/math.sqrt(1+2*mvs1)
                #Want to minimise cost2 = sum(a*(x-b*m))^2
                #where a=a[i], x=x[i], m=m[i], b is the parameter
                #gradmean=a*x gradvar=a/2 mean2=m
                #differentiate-> sum(-m*a*(x-b*m))==0
                #-> b=sum(m*a*x)/sum(m^2*a) = c/d
                #we wan the difference sum(a*x^2-a*(x-b*m)^2)=
                #sum(2*a*x*b*m-a*b^2*m^2)=
                #2*b*sum(a*x*m)-b^2*sum(a*m^2)=2*b*c-b^2*d=2*c^2/d-c^2/d=c^2/d
                c=Numeric.dot(mean2,Numeric.transpose(gradmean))
                _2d=Numeric.dot(mean2**2, Numeric.transpose(gradvar))
                costchange.append(((c**2)/_2d)[0])
            except OverflowError:
                costchange.append(None)
        return weights, bias, costchange

    def PruneHidden(self, hidden, cdiff = None):
        node = self.net.GetNode(hidden)
        if cdiff is None:
            cdiff = Helpers.CostDifferenceV(node)
        self.learner.HistoryAdd("PruneHidden", (hidden, cdiff))
        node.Die()
        self.net.CleanUp()
        self.net.SortNodes()

    def TryPruneHidden(self, cdiff = 200, num=1, lazy=0):
        self.learner.HistoryAdd("TryPruneHidden", (cdiff, num))
        nodes = filter(lambda x:Unlabel(x.GetLabel())[1][1] <
                       self.nexthidden-lazy,
                       self.net.GetVariables('s\(1,'))
        hiddens = map(lambda x:(Helpers.CostDifferenceV(x), x.GetLabel()),
                      nodes)
        hiddens.sort()
        for i in range(num):
            if len(hiddens) > i and hiddens[i][0] < cdiff:
                self.PruneHidden(hiddens[i][1], hiddens[i][0])
            else:
                return i
        return num

    def AddWeight(self):
        raise RuntimeError, "AddWeight() not ready for use"
        hiddenlist=self.net.GetVariables('s\(1,')
        hidden=self.random.choise(slist)
        hiddennum = Unlabel(s.GetLabel())[1][1]
        parentind=self.random.randrange(self.dim[2])
        childind=self.random.randrange(self.dim[0])
        vA_in = self.net.GetGaussianNode(Label("vA_in", 1, hiddennum))
        vA_out = self.net.GetGaussianNode(Label("vA_out", 1, hiddennum))

    def AddAllWeights(self, prob = 1.0):
        adden = []
        for i in range(self.dim[0]):
            for j in range(self.dim[2]):
                if self.random.random() < prob:
                    if self.AddOneWeight((2, 0, j ,i)):
                        adden.append(Label('A', 2, 0, j ,i))
        hiddenlist=self.net.GetVariables('s\(1,')
        for h in hiddenlist:
            k = Unlabel(h.GetLabel())[1][1]
            for j in range(self.dim[2]):
                if self.random.random() < prob:
                    if self.AddOneWeight((2, 1, j ,k)):
                        adden.append(Label('A', 2, 1, j ,k))
            for i in range(self.dim[0]):
                if self.random.random() < prob:
                    if self.AddOneWeight((1, 0, k ,i)):
                        adden.append(Label('A', 1, 0, k ,i))
        self.net.SortNodes()
        self.learner.HistoryAdd("AddAllWeights", (prob, tuple(adden)))
            

    def AddOneWeight(self, index, lazy = 0):
        if self.net.GetNode(Label("A", index)) is None:
            if index[0] == 1:
                vA_out = self.net.GetGaussianNode(Label("vA_out", 1, index[2]))
            elif index[0] == 2:
                vA_out = None
            else:
                raise IndexError, "index[0] out of range"
            if index[1] not in (0, 1):
                raise IndexError, "index[1] out of range"
            vA_in = self.net.GetGaussianNode(Label("vA_in", index[1], index[3]))
            if vA_out is None:
                vA = vA_in
            else:
                vA = self.fact.GetSum2(Label("vA", index), vA_out, vA_in)
            A = self.fact.GetGaussian(Label("A", index),
				      self.fact.GetConst0(), vA)
            A.SetPersist(5)
            prod = self.fact.GetProdV(Label("prA", index), A,
                             self.net.GetNode(Label("s", index[0], index[2])))
            if lazy:
                return prod
            Helpers.AddSum2V(self.net,
                             self.net.GetNode(Label("s",
                                                    index[1], index[3])),
                             0, prod,
                             Label("sums", index[1], index[3]))
            A.Update()
            self.net.SortNodes()
            return 1

    def TryPruning(self):
        return self.learner.TryPruning()

    def HookeJeeves(self):
        return self.learner.HookeJeeves()
        
    def Iteration(self):
        return self.learner.Iteration()

    def LearnNet(self, *args, **kws):
        return self.learner.LearnNet(*args, **kws)
