# -*- coding: iso-8859-1 -*-

#
# This file is a part of the Bayes Blocks library
#
# Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
# Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License (included in file License.txt in the
# program package) for more details.
#
# $Id: PyNet.py 7 2006-10-26 10:26:41Z ah $
#

from Net import *
from Label import *
import Helpers
import re
import cPickle
import math
import time
import warnings
try:
    import numpy.oldnumeric as Numeric
except:
    import Numeric
import PickleHelpers
from PickleHelpers import LoadWithPickle

def _issequence(x):
    "Is x a sequence? We say it is if it has a __getitem__ method."
    return hasattr(x, '__getitem__') or type(x) == type(Numeric.array(0))

class PyNet(Net, PickleHelpers.Pickleable):
    __safe_for_unpickling__=1

    def __getstate__(self):
        return {'mystate': self.SaveToPyObject(),
                'hj_maxalpha': self.hj_maxalpha,
                'priorlist': self.priorlist}

    def __setstate__(self, dict):
        self.this = CreateNetFromPyObject(dict['mystate'])
        self.hj_maxalpha = dict['hj_maxalpha']
        if dict.has_key('priorlist'):
            self.priorlist = dict['priorlist']
        else:
            self.priorlist = {}

    def __repr__(self):
        return "<PyNet instance wrapping C Net instance at %s>" % (self.this,)

    def __init__(self, *args):
        apply(Net.__init__, (self, ) + args)
        self.hj_maxalpha = 100.0
        self.priorlist = {}

    def GetConst0(self):
        warnings.warn("method deprecated", DeprecationWarning, stacklevel=2)

        const0 = self.GetNode('const0')
        if not const0:
            const0 = Constant(self, 'const0', 0)
        return const0

    def UpdateAllDebug(self):
        epsilon = 1e-8;
        oldcost = self.CostDebug()
        for i in range(self.VariableCount()-1, -1, -1):
            mynode = self.GetVariableByIndex(i)
            mynode.Update()
            newcost = self.CostDebug()
            if oldcost < newcost - epsilon:
                print "Cost function value grew while updating", \
                      mynode.GetType(), "node", mynode.GetLabel()
                print "Values:", oldcost, "->", newcost
            oldcost = newcost
        self.ProcessDecayHook("UpdateAll")

    def CostDebug(self):
        """Doing the cost calculation in python side makes
        it possible to trap bugs using pdb."""
        c = 0.0
        for i in range(self.VariableCount()):
            n = self.GetVariableByIndex(i)
            c += n.Cost()
        return c

    def UpdateTimeIndDebug(self):
        epsilon = 1e-8;
        oldcost = self.Cost()
        for i in range(self.VariableCount()-1, -1, -1):
            mynode = self.GetVariableByIndex(i)
            if mynode.TimeType() == 0:
                mynode.Update()
                newcost = self.Cost()
                if oldcost < newcost - epsilon:
                    print "Cost function value grew while updating", \
                          mynode.GetType(), "node", mynode.GetLabel()
                    print "Values:", oldcost, "->", newcost
                oldcost = newcost
        self.ProcessDecayHook("UpdateTimeInd")

    def __FindOptimalStepGoldsect(self, a, b, epsilon=1e-1):
        alpha = 0.6180339887498949
        l = a + (1-alpha) * (b-a)
        m = a + alpha     * (b-a)

        self.RepeatAllSteps(l)
        fl = self.Cost()
        self.RepeatAllSteps(m)
        fm = self.Cost()
        
        while (b-a > epsilon):
            if fl > fm:
                a = l
                l = m
                m = a + alpha * (b-a)
                fl = fm
                self.RepeatAllSteps(m)
                fm = self.Cost()
            else:
                b = m
                m = l
                l = a + (1-alpha) * (b-a)
                fm = fl
                self.RepeatAllSteps(l)
                fl = self.Cost()
        return (a+b)/2

    def __FindOptimalStepGoldsect2(self, a, b, epsilon=1e-1):
        alpha = 0.6180339887498949

        a = 1.0
        b = 1.0
        self.RepeatAllSteps(b)
        y1 = self.Cost()
        b *= 1+alpha
        self.RepeatAllSteps(b)
        y2 = self.Cost()
        b *= 1+alpha
        self.RepeatAllSteps(b)
        y3 = self.Cost()

        if self.GetDebugLevel() > 5:
            print "FindOptimalStepGoldsect2:", a, b
            print " costs:", y1, y2, y3
            time.sleep(1)

        while (y1 > y2) and (y2 > y3):
            y1 = y2
            y2 = y3
            a *= 1+alpha
            b *= 1+alpha
            self.RepeatAllSteps(b)
            y3 = self.Cost()
            if self.GetDebugLevel() > 5:
                print "Lengthening:", a, b
                print " costs:", y1, y2, y3
                time.sleep(1)

        l = a * (1+alpha)
        m = a + alpha     * (b-a)

        fl = y2
        self.RepeatAllSteps(m)
        fm = self.Cost()
        
        while (b-a > epsilon):
            if fl > fm:
                a = l
                l = m
                m = a + alpha * (b-a)
                fl = fm
                self.RepeatAllSteps(m)
                fm = self.Cost()
            else:
                b = m
                m = l
                l = a + (1-alpha) * (b-a)
                fm = fl
                self.RepeatAllSteps(l)
                fl = self.Cost()
            if self.GetDebugLevel() > 5:
                print "New iteration:", a, b
                time.sleep(1)
        return (a+b)/2

    def __FindOptimalStepQuadratic(self, a, b, epsilon=1e-1):
        epsilon2 = 1e-8
        maxiters = 30
        iters = 0
        x1 = float(a)
        x2 = (a+b)/2.0
        x3 = float(b)
        
        self.RepeatAllSteps(x1)
        y1 = self.Cost()
        self.RepeatAllSteps(x2)
        y2 = self.Cost()
        self.RepeatAllSteps(x3)
        y3 = self.Cost()

        if self.GetDebugLevel() > 5:
            print "FindOptimalStepQuadratic:"
            print " x:", x1, x2, x3
            print " y:", y1, y2, y3
            time.sleep(1)

        while (y1 > y2) and (y2 > y3):
            x2 = x3
            y2 = y3
            x3 = x1 + 2.0 * (x2 - x1)
            self.RepeatAllSteps(x3)
            y3 = self.Cost()
            if self.GetDebugLevel() > 5:
                print "Lengthened:"
                print " x:", x1, x2, x3
                print " y:", y1, y2, y3
                time.sleep(1)
        
        while (((y1 < y2) and (y2 < y3)) or
               (y2 > 1e300 or Helpers.IsNaN(y2) or
                y3 > 1e300 or Helpers.IsNaN(y3))):
            if (x3-x1 < epsilon):
                return x1
            x3 = x2
            y3 = y2
            x2 = x1 + .5 * (x3 - x1)
            self.RepeatAllSteps(x2)
            y2 = self.Cost()
            if self.GetDebugLevel() > 5:
                print "Shortened:"
                print " x:", x1, x2, x3
                print " y:", y1, y2, y3
                time.sleep(1)

        if (y1 < y2) and (y2 > y3):
            if self.GetDebugLevel() > 5:
                print "Non-convex point configuration, reverting back to Goldsect..."
            return self.__FindOptimalStepGoldsect2(x1, x3)

        while (x3-x1 > epsilon) and \
                  (math.fabs(y2-y3) + math.fabs(y3-y1)
                   + math.fabs(y1-y2) > epsilon2):
            z = 2.0*(x1 * (y2-y3) + x2 * (y3-y1) + x3 * (y1-y2))
            if z == 0:
                if (x3-x2) > (x2-x1):
                    xnew = (x2+x3) / 2
                else:
                    xnew = (x1+x2) / 2
            else:
                xnew = (x1**2 * (y2-y3) + x2**2 * (y3-y1) +
                        x3**2 * (y1-y2)) / z
            self.RepeatAllSteps(xnew)
            ynew = float(self.Cost())
            if self.GetDebugLevel() > 5:
                print "Iteration:", iters
                print " x:", x1, x2, x3, xnew
                print " y:", y1, y2, y3, ynew
                time.sleep(1)

            if xnew == x2:     # xnew = x2  =>  trouble...
                if (x3 - x2) > (x2 - x1):
                    xnew = x2 + .5 * (x3 - x2)
                else:
                    xnew = x2 - .5 * (x2 - x1)
                self.RepeatAllSteps(xnew)
                ynew = float(self.Cost())

            if xnew > x2:
                if ynew > y2:
                    x3 = xnew
                    y3 = ynew
                else:
                    x1 = x2
                    y1 = y2
                    x2 = xnew
                    y2 = ynew
            else:
                if ynew > y2:
                    x1 = xnew
                    y1 = ynew
                else:
                    x3 = x2
                    y3 = y2
                    x2 = xnew
                    y2 = ynew

            iters += 1
            if iters > maxiters:
                return self.__FindOptimalStepGoldsect2(x1, x3)


        z = 2.0*(x1 * (y2-y3) + x2 * (y3-y1) + x3 * (y1-y2))
        if z == 0:
            xnew = x2
        else:
            xnew = (x1**2 * (y2-y3) + x2**2 * (y3-y1) +
                    x3**2 * (y1-y2)) / z
        return xnew

    def __SearchStepLens(self, alpha=1.618):
        a = 1.0
        self.RepeatAllSteps(a)
        c_beg = self.Cost()
        c = c_beg - 1
        while c < c_beg:
            a *= alpha
            self.RepeatAllSteps(a)
            c = self.Cost()
        return a

    def UpdateAllHookeJeeves(self, epsilon=1e-1, exploresteps=1, returncost=0):
        costs = []
        self.SaveAllStates()
        for i in range(exploresteps):
            self.UpdateAll()
            if returncost:
                costs.append(self.Cost())
        self.SaveAllSteps()
        alpha = self.__FindOptimalStepQuadratic(1.0, self.hj_maxalpha, epsilon)
        if alpha < 100:
            self.hj_maxalpha = 1.5*alpha
        else:
            self.hj_maxalpha = 150
        if self.GetDebugLevel() > 2:
            print "Length multiplier:", alpha
        self.RepeatAllSteps(alpha)
        self.ClearAllStatesAndSteps()
        if returncost:
            return alpha, costs
        else:
            return alpha
    
    def BuildSum2VTree(self, nodes, labelbase):
        warnings.warn("method deprecated", DeprecationWarning, stacklevel=2)

        if len(nodes) == 1:
            return nodes[0]
        j = 0
        if type(labelbase) in (type(()), type([])):
            leaflabel = list(labelbase)
        else:
            leaflabel = [labelbase]
            labelbase = (labelbase, )
        leaflabel[0] += "_leaf"
        leaflabel = tuple(leaflabel)
        while len(nodes) > 2:
            tmp = Sum2V(self, apply(Label, leaflabel + (j, )),
                        nodes.pop(0), nodes.pop(0))
            j += 1
            nodes.append(tmp)
        return Sum2V(self, apply(Label, labelbase), nodes.pop(0), nodes.pop(0))

    def BuildSum2Tree(self, nodes, labelbase):
        warnings.warn("method deprecated", DeprecationWarning, stacklevel=2)

        if len(nodes) == 1:
            return nodes[0]
        j = 0
        if type(labelbase) in (type(()), type([])):
            leaflabel = list(labelbase)
        else:
            leaflabel = [labelbase]
            labelbase = (labelbase, )
        leaflabel[0] += "_leaf"
        leaflabel = tuple(leaflabel)
        while len(nodes) > 2:
            tmp = Sum2(self, apply(Label, leaflabel + (j, )),
                       nodes.pop(0), nodes.pop(0))
            j += 1
            nodes.append(tmp)
        return Sum2(self, apply(Label, labelbase), nodes.pop(0), nodes.pop(0))

    def TryPruning(self, node, verbose = 0, addcost = 0):
        if type(node) == type([]):
            s = 0
            for n in node:
                s += self.TryPruning(n, verbose, addcost)
            return s
        if type(node) == type(""):
            node = self.GetVariable(node)
            if node is None:
                return 1
        if (Helpers.CostDifference(node) < addcost):
            if (verbose):
                print "Pruning node",node.GetLabel()
            node.Die(verbose)
            self.CleanUp()
            return 1
        else:
            return 0

    def GetVariables(self, regexp):
        res = []
        com = re.compile(regexp)
        for i in range(self.VariableCount()):
            node = self.GetVariableByIndex(i)
            if com.match(node.GetLabel()):
                res.append(node)
        return res

    def GetNodes(self, regexp):
        res = []
        com = re.compile(regexp)
        for i in range(self.NodeCount()):
            node = self.GetNodeByIndex(i)
            if com.match(node.GetLabel()):
                res.append(node)
        return res

    def GetNodeArray(self, regexp, getter="GetNodes"):
        nodes = getattr(self, getter)(regexp)
        labels = map(lambda n: Unlabel(n.GetLabel()), nodes)
        maxinds = {}
        for l in labels:
            if not maxinds.has_key(l[0]):
                maxinds[l[0]] = list(l[1])
            else:
                maxinds[l[0]] = map(max, maxinds[l[0]], list(l[1]))
        if len(maxinds.keys()) > 1:
            raise ValueError("Regexp matches to more than one matrix")
        inds = map(lambda x: x+1, maxinds[maxinds.keys()[0]])
        arraysize = reduce(lambda x, y: x*y, inds, 1)
        val = Numeric.reshape(Numeric.array(
            [None]*arraysize, Numeric.PyObject), inds)
        for k in range(len(nodes)):
            val[labels[k][1]] = nodes[k]
        return val

    def GetVariableArray(self, regexp):
        return self.GetNodeArray(regexp, getter="GetVariables")

    def ShowVariables(self, regexp = '', showtype = 0):
        if showtype == 0:
            scafunc = Helpers.GetMean
            vecfunc = Helpers.GetMeanV
        elif showtype == 1:
            scafunc = Helpers.GetVar
            vecfunc = Helpers.GetVarV
        else:
            scafunc = Helpers.GetExp
            vecfunc = Helpers.GetExpV
        nodes = self.GetVariables(regexp)
        for node in nodes:
            print node.GetLabel(),
            if node.GetType()[-1] == 'V':
                print min(Numeric.array(vecfunc(node))),
                print max(Numeric.array(vecfunc(node)))
            else:
                print scafunc(node)

    def ShowNodes(self, regexp = '', showtype = 0):
        if showtype == 0:
            scafunc = Helpers.GetMean
            vecfunc = Helpers.GetMeanV
        elif showtype == 1:
            scafunc = Helpers.GetVar
            vecfunc = Helpers.GetVarV
        else:
            scafunc = Helpers.GetExp
            vecfunc = Helpers.GetExpV
        nodes = self.GetNodes(regexp)
        for node in nodes:
            print node.GetLabel(),
            if node.GetType()[-1] == 'V':
                print min(Numeric.array(vecfunc(node))),
                print max(Numeric.array(vecfunc(node)))
            else:
                print scafunc(node)


    def MakeNodes(self, nodetype, labelbase, num, *parents):
        warnings.warn("method deprecated", DeprecationWarning, stacklevel=2)

        c = []
        if type(labelbase) != type(()):
            tmp = Unlabel(labelbase)
            labelbase = (tmp[0],) + tmp[1]
        if type([]) not in map(type, parents):
            for i in range(num):
                c.append(apply(nodetype,
                               (self, apply(Label, labelbase + (i, ))) +
                               parents))
        else:
            if len(parents) > 1:
                parents=list(parents)
                for j in range(len(parents)):
                    if type(parents[j]) is not type([]):
                        parents[j] = [parents[j]]*num
                pars = map(tuple, list(Numeric.transpose(Numeric.array(parents))))
            else:
                pars = map(lambda x: (x, ), parents[0])
            for i in range(num):
                c.append(apply(nodetype,
                               (self, apply(Label, labelbase + (i, ))) +
                               pars[i]))
        return c
    
    def MakeGaussians(self, labelbase, mpar, vpar, num):
        """Deprecated, use MakeNodes instead"""
        return self.MakeNodes(Gaussian, labelbase, num, mpar, vpar)

    def MakeGaussianVs(self, labelbase, mpar, vpar, num):
        """Deprecated, use MakeNodes instead"""
        return self.MakeNodes(GaussianV, labelbase, num, mpar, vpar)

    def SortByNet(self,nodes):
        """ Returns nodes sorted in topological order. First calls net to
        sort itself topologically and then sorts the nodes in same order
        as they are returned by net.GetNodeByIndex(i).
        """
        self.SortNodes()
        nodedict = {}
        sorted = []
        for i in nodes:
            nodedict[i.GetLabel()] = 1
        for i in range(self.NodeCount()):
            l = self.GetNodeByIndex(i).GetLabel()
            if nodedict.has_key(l):
                nodedict[l] = 0
                sorted.append(self.GetNode(l))
        assert(len(nodes)==len(sorted))
        return sorted

    def GetVectorNodes(self):
        res = []
        self.SortNodes()
        for i in range(self.NodeCount()):
            node = self.GetNodeByIndex(i)
            if node.GetType()[-1] == 'V' and node.GetType() != 'ConstantV':
                res.append(node)
        return res
            
    def GenerateTestNodes(self, vectors=1, typemap = {}):
        warnings.warn("method deprecated", DeprecationWarning, stacklevel=2)
        nodes=self.GetVectorNodes()
        node_map = {}

        for v in nodes:
            if v.GetType() in ("ConstantV",):
                if vectors == 0:
                    raise RuntimeError, "If vectors is false ConstantV:s can't be copyed"
                else:
                    node_map[v.GetLabel()] = v
            elif v.GetType() in ("SwitchV", "DelayV", "DelayGaussV",
                               "EvidenceV","TestProdV","DiscreteV"):
                raise RuntimeError, "Copying of " + v.GetType() +\
                      " nodes is not implemented."
            elif v.GetType() in ("ProdV", "Sum2V", "GaussianV",
                               "SparseGaussV", "GaussNonlinV"):
                par = []
                i = 0
                while v.GetParent(i) is not None:
                    par.append(node_map.get(v.GetParent(i).GetLabel(),
                                            v.GetParent(i)))
                    assert(not (par[i].GetType()[-1] == 'V' and
                                par[i].GetLabel()[0:5] != "Copy_"))
                    i += 1
            else:
                raise RuntimeError, "Unknown vector node type: " + v.GetType()
            if vectors:
                copy = apply(typemap.get(v.GetLabel(), eval(v.GetType())),
                             (self, "Copy_" + v.GetLabel()) + tuple(par))
            else:
                type = v.GetType()[:-1]
                if type == "SparseGauss":
                    type = "Gaussian"
                copy = apply(typemap.get(v.GetLabel(), eval(type)),
                             (self, "Copy_" + v.GetLabel()) + tuple(par))
            node_map[v.GetLabel()] = copy
        return node_map

    def EvidenceNode(self, node, label=None, mean=0, var=0.01,
                     decay=4.0, hook="UpdateAll"):
        warnings.warn("method deprecated", DeprecationWarning, stacklevel=2)

        if label is None:
            label =  "evidence_" + node.GetLabel()
        tmp = Evidence(self, label, node)
        tmp.Clamp(mean, var)
        tmp.SetDecayTime(decay)
        self.RegisterDecay(tmp, hook)
        node.Update()

    def EvidenceVNode(self, node, label=None, mean=0, var=0.01,
                     decay=4.0, hook="UpdateAll"):
        warnings.warn("method deprecated", DeprecationWarning, stacklevel=2)

        if label is None:
            label =  "evidence_" + node.GetLabel()
        tmp = EvidenceV(self, label, node)
	time = self.Time()
	if not _issequence(mean):
	    mean = [mean] * time
	if not _issequence(var):
	    var = [var] * time
        tmp.Clamp(mean, var)
	if not _issequence(decay):
	    decay = [decay] * time
        tmp.SetDecayTime(decay)
        self.RegisterDecay(tmp, hook)
        node.Update()

    def GetGaussianNode(self, name, recursion=0, noevidence=0):
        if recursion > 5:
            print name
            raise RuntimeError, "maximum recursion depth for GetGaussianNode exceeded"
        node = self.GetNode(name)
        if node is None:
            if self.priorlist.has_key(name):
                (mlabel, vlabel) = self.priorlist[name]
            else:
                tmp = Unlabel(name);
                label = Label(tmp[0], tmp[1][:-1])
                mlabel = 'm' + label
                vlabel = 'v' + label
            m = self.GetGaussianNode(mlabel,recursion=recursion+1)
            v = self.GetGaussianNode(vlabel,recursion=recursion+1)
            if m is None or v is None:
                raise RuntimeError,"Parents not found in " + \
                      "GetNode() " + mlabel + " " + vlabel
            node = Gaussian(self, name, m, v)
            node.SetPersist(5)
            if not noevidence:
                self.EvidenceNode(node, mean=0, var=1, decay=1)
        return node

    def GetCostDict(self):
        d = {}
        for i in range(self.VariableCount()):
            n = self.GetVariableByIndex(i)
            d[n.GetLabel()] = n.Cost()
        return d

    def CheckStructureP(self):
        self.SortNodes()
        variables = self.GetVariables('')
        delays = filter(lambda x: x.GetType() == 'DelayV',
                        self.GetNodes(''))
        vardelay = {}
        for x in variables + delays:
            vardelay[x.GetLabel()] = None
        for v in variables + delays:
            check = [v]
            checkednodes = {v.GetLabel():None}
            while len(check):
                node = check.pop()
                for c in Helpers.ChildList(node):
                    l = c.GetLabel()
                    if checkednodes.has_key(l):
                        raise ValueError, \
                              "Invalid structure of the network." +\
                              " More than one route from " + v.GetLabel() +\
                              " to " + l + "."
                        return 1
                    checkednodes[l]=None
                    if not vardelay.has_key(l):
                        check.append(c)
        return


class PyNodeFactory(NodeFactory):
    """All node creation utilities should be moved here."""

    def __getstate__(self):
        return {'pynet': self.pynet}

    def __setstate__(self, dict):
        self.__init__(dict['pynet'])

    def __init__(self, net):
        # some type safety enforced
        if not isinstance(net, PyNet):
            raise ValueError, "Net is not an instance of PyNet"
        NodeFactory.__init__(self, net)
        self.pynet = net

    def GetPyNet(self):
        return self.pynet

    def __repr__(self):
        msg = "<PyNodeFactory instance wrapping C NodeFactory instance"
        msg += "at %s>" % (self.this,)
        return msg

    def MakeNode(self, nodetype, label, parents):
        """Creates a node of type nodetype with given parents."""
        factmethod = getattr(self, "Get" + nodetype)
        return apply(factmethod, (label,) + parents)

    def MakeNodes(self, nodetype, labelbase, num, *parents):
        """Creates num instances of nodetype (which is the string returned
        by GetType i.e. the classname)."""
        if type(labelbase) == type(()):
            stem = labelbase[0]
            indices = labelbase[1:]
        else:
            stem, indices = Unlabel(labelbase)

        if len(parents) > 0:
            par = []

            for i in range(len(parents)):
                if type(parents[i]) == type([]):
                    assert len(parents[i]) == num
                    par.append(parents[i])
                else:
                    par.append([parents[i]] * num)

            rap = apply(zip, par)
        else:
            rap = [()] * num

        return [self.MakeNode(nodetype, Label(stem, indices + (i,)), rap[i])
                for i in range(num)]

    def EvidenceNode(self, node, label=None, mean=0, var=0.01,
                     decay=4.0, hook="UpdateAll"):
        if label is None:
            label = "evidence_" + node.GetLabel()
        tmp = self.GetEvidence(label, node)
        tmp.Clamp(mean, var)
        tmp.SetDecayTime(decay)
        self.GetPyNet().RegisterDecay(tmp, hook)
        node.Update()

    def EvidenceVNode(self, node, label=None, mean=0, var=0.01,
                     decay=4.0, hook="UpdateAll"):
        if label is None:
            label =  "evidence_" + node.GetLabel()
        tmp = self.GetEvidenceV(label, node)
	time = self.GetPyNet().Time()
	if not _issequence(mean):
	    mean = [mean] * time
	if not _issequence(var):
	    var = [var] * time
        tmp.Clamp(mean, var)
	if not _issequence(decay):
	    decay = [decay] * time
        tmp.SetDecayTime(decay)
        self.GetPyNet().RegisterDecay(tmp, hook)
        node.Update()

    def BuildBalancedTree(self, nodetype, nodes, labelbase):
        nodes = nodes[:]

        if len(nodes) == 1:
            return nodes[0]
        j = 0
        if type(labelbase) in (type(()), type([])):
            leaflabel = list(labelbase)
        else:
            leaflabel = [labelbase]
            labelbase = (labelbase, )
        leaflabel[0] += "_leaf"
        leaflabel = tuple(leaflabel)
        while len(nodes) > 2:
            tmp = self.MakeNode(nodetype, apply(Label, leaflabel + (j, )),
                                (nodes.pop(0), nodes.pop(0)))
            j += 1
            nodes.append(tmp)
        return self.MakeNode(nodetype, apply(Label, labelbase),
                             (nodes.pop(0), nodes.pop(0)))

    def BuildSum2VTree(self, nodes, labelbase):
        return self.BuildBalancedTree("Sum2V", nodes, labelbase)

    def BuildSum2Tree(self, nodes, labelbase):
        return self.BuildBalancedTree("Sum2", nodes, labelbase)

    def GetConst0(self):
        const0 = self.GetPyNet().GetNode('const0')
        if const0 is None:
            const0 = self.GetConstant('const0', 0.0)
        return const0

    def GenerateTestNodes(self, vectors=1, typemap = {}):
        nodes=self.pynet.GetVectorNodes()
        node_map = {}

        for v in nodes:
            if v.GetType() in ("ConstantV",):
                if vectors == 0:
                    raise RuntimeError, "If vectors is false ConstantV:s can't be copyed"
                else:
                    node_map[v.GetLabel()] = v
            elif v.GetType() in ("SwitchV", "DelayV", "DelayGaussV",
                               "EvidenceV","TestProdV","DiscreteV"):
                raise RuntimeError, "Copying of " + v.GetType() +\
                      " nodes is not implemented."
            elif v.GetType() in ("ProdV", "Sum2V", "GaussianV",
                               "SparseGaussV", "GaussNonlinV"):
                par = []
                i = 0
                while v.GetParent(i) is not None:
                    par.append(node_map.get(v.GetParent(i).GetLabel(),
                                            v.GetParent(i)))
                    assert(not (par[i].GetType()[-1] == 'V' and
                                par[i].GetLabel()[0:5] != "Copy_"))
                    i += 1
            else:
                raise RuntimeError, "Unknown vector node type: " + v.GetType()
            if vectors:
                copy = self.MakeNode(typemap.get(v.GetLabel(), v.GetType()),
                                     "Copy_" + v.GetLabel(), tuple(par))
            else:
                type = v.GetType()[:-1]
                if type == "SparseGauss":
                    type = "Gaussian"
                copy = self.MakeNode(typemap.get(v.GetLabel(), type),
                                     "Copy_" + v.GetLabel() + tuple(par))
            node_map[v.GetLabel()] = copy
        return node_map

