# -*- coding: iso-8859-1 -*-

#
# This file is a part of the Bayes Blocks library
#
# Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
# Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License (included in file License.txt in the
# program package) for more details.
#
# $Id: Tests.py 5 2006-10-26 09:44:54Z ah $
#

import unittest

import Numeric, RandomArray
import MLab
import random, time
import sys

import PyNet, Helpers, PickleHelpers
from Label import Label
from Helpers import GetMean, GetMeanV, GetDiscrete, GetDiscreteV, GetVar, GetVarV

def path(a, b):
    node = a
    rootlabel = b.GetLabel()
    seq = []
    while node.GetLabel() != rootlabel:
        seq.append(node.GetLabel())
        c = node.GetChild(0)
        if c is None:
            return None
        node = c
    return seq

class PyNodeFactoryTestCase(unittest.TestCase):
    def setUp(self):
        self.net = PyNet.PyNet(1)
        self.nf = PyNet.PyNodeFactory(self.net)

    def tearDown(self):
        pass

    def test1(self):
        clst = [self.nf.GetConstant(Label("c", i), i) for i in range(10)]
        x = self.nf.MakeNodes("Gaussian", "x", 10, clst, clst[0])
        for i in range(10):
            self.assertEqual(x[i].GetType(), "Gaussian")
            self.assertEqual(x[i].GetLabel(), Label("x", i))
            self.assertEqual(x[i].GetParent(0).GetLabel(), clst[i].GetLabel())
            self.assertEqual(x[i].GetParent(1).GetLabel(), clst[0].GetLabel())

    def test2(self):
        clst = [self.nf.GetConstant(Label("c", i), i) for i in range(10)]
        x = self.nf.MakeNodes("Gaussian", ("x", 1, 2), 10, clst, clst)
        for i in range(10):
            self.assertEqual(x[i].GetType(), "Gaussian")
            self.assertEqual(x[i].GetLabel(), Label("x", 1, 2, i))
            self.assertEqual(x[i].GetParent(0).GetLabel(), clst[i].GetLabel())
            self.assertEqual(x[i].GetParent(1).GetLabel(), clst[i].GetLabel())

    def test3(self):
        c0 = self.nf.GetConst0()
        summands = self.nf.MakeNodes("Gaussian", "s", 16, c0, c0)
        sum = self.nf.BuildSum2Tree(summands, "sumtree")
        for term in summands:
            self.assertEqual(len(path(term, sum)), 4)

    def test4(self):
        c0 = self.nf.GetConst0()
        s = self.nf.GetGaussian("s", c0, c0)
        x = self.nf.GetGaussian("x", s, c0)
        x.Clamp(1.0)
        self.nf.EvidenceNode(s, mean=0.5)
        

class VariableGroupTestCase(unittest.TestCase):
    def setUp(self):
        self.net = PyNet.PyNet(1)
        self.nf = PyNet.PyNodeFactory(self.net)

    def tearDown(self):
        pass

    def test1(self):
        c0 = self.nf.GetConstant("c0", 0.0)

        a1 = self.nf.GetGaussian("ab1", c0, c0)
        a2 = self.nf.GetGaussian("ab2", c0, c0)
        
        b1 = self.nf.GetGaussian("ba1", c0, c0)
        b2 = self.nf.GetGaussian("ba2", c0, c0)

        self.net.DefineVariableGroup("a");
        self.net.DefineVariableGroup("b");

        groupa = [self.net.GetGroupVariable("a", i).GetLabel()
                  for i in range(self.net.NumGroupVariables("a"))]
        groupa.sort()
        self.assertEqual(groupa, ["ab1", "ab2"])
        
        groupb = [self.net.GetGroupVariable("b", i).GetLabel()
                  for i in range(self.net.NumGroupVariables("b"))]
        groupb.sort()
        self.assertEqual(groupb, ["ba1", "ba2"])

    def test2(self):
        c0 = self.nf.GetConstant("c0", 0.0)

        a1 = self.nf.GetGaussian("ab1", c0, c0)
        a2 = self.nf.GetGaussian("ab2", c0, c0)
        
        b1 = self.nf.GetGaussian("ba1", c0, c0)
        b2 = self.nf.GetGaussian("ba2", c0, c0)

        self.net.DefineVariableGroup("a");
        self.net.DefineVariableGroup("b");

        b1.Die()
        self.net.CleanUp()

        a3 = self.nf.GetGaussian("ab3", c0, c0)

        groupa = [self.net.GetGroupVariable("a", i).GetLabel()
                  for i in range(self.net.NumGroupVariables("a"))]
        groupa.sort()
        self.assertEqual(groupa, ["ab1", "ab2", "ab3"])
        
        groupb = [self.net.GetGroupVariable("b", i).GetLabel()
                  for i in range(self.net.NumGroupVariables("b"))]
        groupb.sort()
        self.assertEqual(groupb, ["ba2"])


    def test3(self):
        c0 = self.nf.GetConstant("c0", 0.0)

        foo1 = self.nf.GetGaussian("foo1", c0, c0)
        foo2 = self.nf.GetGaussian("foo2", c0, c0)
        
        foobar1 = self.nf.GetGaussian("foobar1", c0, c0)
        foobar2 = self.nf.GetGaussian("foobar2", c0, c0)

        self.net.DefineVariableGroup("foo");
        self.net.DefineVariableGroup("foobar");

        groupa = [self.net.GetGroupVariable("foo", i).GetLabel()
                  for i in range(self.net.NumGroupVariables("foo"))]
        groupa.sort()
        self.assertEqual(groupa, ["foo1", "foo2", "foobar1", "foobar2"])
        
        groupb = [self.net.GetGroupVariable("foobar", i).GetLabel()
                  for i in range(self.net.NumGroupVariables("foobar"))]
        groupb.sort()
        self.assertEqual(groupb, ["foobar1", "foobar2"])

    def setup4(self):
        data1 = RandomArray.normal(-1, Numeric.exp(-0.5*2), 100)
        data2 = RandomArray.normal(1, Numeric.exp(-0.5*2), 100)

        c0 = self.nf.GetConstant("c0", 0.0)

        m1 = self.nf.GetGaussian("am", c0, c0)
        m1.Clamp(0.0)
        m1.Unclamp()

        v1 = self.nf.GetGaussian("av", c0, c0)
        v1.Clamp(0.0)
        v1.Unclamp()
        
        m2 = self.nf.GetGaussian("bm", c0, c0)
        m2.Clamp(0.0)
        m2.Unclamp()

        v2 = self.nf.GetGaussian("bv", c0, c0)
        v2.Clamp(0.0)
        v2.Unclamp()

        x1 = []
        for i in range(data1.shape[0]):
            x1.append(self.nf.GetGaussian(Label("x1", i), m1, v1))
            x1[i].Clamp(data1[i])
        
        x2 = []
        for i in range(data1.shape[0]):
            x2.append(self.nf.GetGaussian(Label("x2", i), m2, v2))
            x2[i].Clamp(data2[i])

        self.m1 = m1
        self.m2 = m2
        self.v1 = v1
        self.v2 = v2

    def test4a(self):
        self.setup4()

        self.net.DefineVariableGroup("a")
        self.net.DefineVariableGroup("b")

        for i in range(5):
            self.net.UpdateAll()
            #print self.net.Cost()

        self.assert_(abs(GetMean(self.m1) + 1.0) < 0.5)
        self.assert_(abs(GetMean(self.m2) - 1.0) < 0.5)

        self.assert_(abs(GetMean(self.v1) - 2.0) < 0.5)
        self.assert_(abs(GetMean(self.v2) - 2.0) < 0.5)

    def test4b(self):
        self.setup4()

        self.net.DefineVariableGroup("a")

        for i in range(5):
            self.net.UpdateGroup("a")

        self.assert_(abs(GetMean(self.m1) + 1.0) < 0.5)
        self.assert_(abs(GetMean(self.v1) - 2.0) < 0.5)

class PartialUpdateTestCase(unittest.TestCase):
    def setUp(self):
        pass
    
    def test1(self):
        tdim = 100
        net = PyNet.PyNet(tdim)
        nf = PyNet.PyNodeFactory(net)

        indices = RandomArray.permutation(tdim)[:10].tolist()
        net.DefineTimeIndexGroup("mygroup", Helpers.Array2IntV(indices))

        c0 = nf.GetConstant("c0", 0.0)
        c_5 = nf.GetConstant("c_5", -5.0)
        c5 = nf.GetConstant("c5", 5.0)

        s = nf.GetGaussianV("s", c0, c0)
        x = nf.GetGaussianV("x", s, c5)

        data = RandomArray.standard_normal(tdim)
        var = Numeric.ones(tdim) * 1e-8

        x.Clamp(Helpers.Array2DV(data), Helpers.Array2DV(var))

        net.UpdateAll()
        ms = GetMeanV(s)

        for t in range(tdim):
            self.assert_((data[t] - ms[t])**2 < 1e-2,
                         "%f vs. %f" % (data[t], ms[t]))

        newdata = RandomArray.standard_normal(tdim)
        x.Clamp(Helpers.Array2DV(newdata), Helpers.Array2DV(var))
        net.EnableTimeIndexGroup("mygroup")

        net.UpdateAll()
        ms = GetMeanV(s)
        
        for t in range(tdim):
            if t in indices:
                xt = newdata[t]
            else:
                xt = data[t]
            self.assert_((xt - ms[t])**2 < 1e-2,
                         "%f vs. %f" % (xt, ms[t]))

        net.DisableTimeIndexGroup("mygroup")

        net.UpdateAll()
        ms = GetMeanV(s)
        
        for t in range(tdim):
            self.assert_((newdata[t] - ms[t])**2 < 1e-2,
                         "%f vs. %f" % (newdata[t], ms[t]))

class MoGTestCase(unittest.TestCase):
    def setUp(self):
        # N(-1, 0.5^2)
        data1 = Numeric.array(
            [-1.48372132, -0.74105301, -1.46452126, -0.74465525,
             -0.74137861, -0.81603923, -1.22201809, -1.40728462,
             -0.25976342, -1.07168696, -0.80576369, -1.28271085,
             -0.97642539, -1.36678675, -1.10264056, -0.53583777,
             -1.32091546, -0.74083391, -1.34777945, -0.87874509],
            Numeric.Float)

        # N(1, 0.1^2)
        data2 = Numeric.array(
            [ 0.8612323 ,  1.05193536,  1.10445087,  0.89876609,
              0.96364794,  1.14201061,  0.96772116,  0.8137229 ,
              0.9958501 ,  0.99648816], Numeric.Float)

        self.data = Numeric.concatenate((data1, data2))
    
    def buildnet1(self):
        net = PyNet.PyNet(self.data.shape[0])
        nf = PyNet.PyNodeFactory(net)

        numComp = 2
        
        prior = Helpers.Array2DV(Numeric.ones(numComp))
        const = nf.GetConstantV("c_prior", prior)
        c = nf.GetDirichlet(Label("c"), const)
        d = nf.GetDiscreteDirichletV("d", c)

        s = nf.GetMoGV("s", d)

        c0 = nf.GetConstant("c0", 0.0)
        c_5 = nf.GetConstant("c_5", -5)

        m1 = nf.GetGaussian("m1", c0, c_5)
        v1 = nf.GetGaussian("v1", c0, c_5)

        m2 = nf.GetGaussian("m2", c0, c_5)
        v2 = nf.GetGaussian("v2", c0, c_5)

        s.AddComponent(m1, v1)
        s.AddComponent(m2, v2)

        vx = nf.GetConstant("vx", 9.21)
        x = nf.GetGaussianV("x", s, vx)

        x.Clamp(Helpers.Array2DV(self.data))

        nf.EvidenceNode(m1, mean=-1.0)
        nf.EvidenceNode(m2, mean=1.0)

        return net
        
    def buildnet2(self):
        net = PyNet.PyNet(1)
        nf = PyNet.PyNodeFactory(net)

        numComp = 2
        dim = self.data.shape[0]
        
        prior = Helpers.Array2DV(Numeric.ones(numComp))
        const = nf.GetConstantV("c_prior", prior)
        c = nf.GetDirichlet(Label("c"), const)

        c0 = nf.GetConstant("c0", 0.0)
        c_5 = nf.GetConstant("c_5", -5)

        m1 = nf.GetGaussian("m1", c0, c_5)
        v1 = nf.GetGaussian("v1", c0, c_5)

        m2 = nf.GetGaussian("m2", c0, c_5)
        v2 = nf.GetGaussian("v2", c0, c_5)

        nf.EvidenceNode(m1, mean=-1.0)
        nf.EvidenceNode(m2, mean=1.0)

        vx = nf.GetConstant("vx", 9.21)

        d = []
        s = []
        x = []
        for i in range(dim):
            d.append(nf.GetDiscreteDirichlet(Label("d", i), c))
            s.append(nf.GetMoG(Label("s", i), d[i]))
            s[i].AddComponent(m1, v1)
            s[i].AddComponent(m2, v2)
            x.append(nf.GetGaussian(Label("x", i), s[i], vx))
            x[i].Clamp(self.data[i])

        return net

    def costdump(self, vector, scalar):
        variables = ["c", "m1", "v1", "m2", "v2"]

        for l in variables:
            n1 = vector.GetVariable(l)
            n2 = scalar.GetVariable(l)
            print "%s : %f : %f" % (l, n1.Cost(), n2.Cost())

        variables = ["d", "s", "x"]

        for l in variables:
            c1 = vector.GetVariable(l).Cost()
            c2 = MLab.sum([n.Cost() for n in scalar.GetVariableArray(l)])
            print "%s : %f : %f" % (l, c1, c2)

    def comparenets(self, vector, scalar):
        dim = self.data.shape[0]

        #self.costdump(vector, scalar)

        # this alone is quite reliable indicator
        costdiff = abs(vector.Cost() - scalar.Cost())
        self.assert_(costdiff < dim * 1e-8, "costdiff = %f" % costdiff)

        s1 = GetVarV(vector.GetVariable("s"))
        s2 = GetVar(scalar.GetVariableArray("s"))
##
##        import plot
##        pl = plot.Plotter()
##        pl.multiplot(Numeric.log10([s1,s2]))

        self.assert_(MLab.max(Numeric.absolute(s1 - s2)) < 1e-8)

        for l in ["m1", "v1", "m2", "v2"]:
            n1 = GetMean(vector.GetVariable(l))
            n2 = GetMean(scalar.GetVariable(l))
            #print "%s : %f : %f" % (l, n1, n2)
            self.assert_(abs(n1 - n2) < 1e-8)

        d1 = GetDiscreteV(vector.GetVariable("d"))
        d2 = Numeric.transpose(GetDiscrete(scalar.GetVariableArray("d")))

        #print d1
        #print d2

        e = MLab.max(Numeric.reshape(
            Numeric.absolute(d1 - d2), (d1.shape[0]*d1.shape[1],)))

        self.assert_(e < 1e-8, "max e = %f" % e)

        c1 = vector.GetVariable("c")
        c2 = scalar.GetVariable("c")

        #print GetMeanV(c1)

        self.assert_(
            MLab.mean(Numeric.absolute(GetMeanV(c1) - GetMeanV(c2))) < 1e-8)
        self.assert_(
            MLab.mean(Numeric.absolute(GetVarV(c1) - GetVarV(c2))) < 1e-8)
        
    def test1(self):
        """Sameness of scalar and vector implementation of MoGs."""
        net1 = self.buildnet1()
        net2 = self.buildnet2()
        
        for i in range(50):
            net1.UpdateAll()
            net2.UpdateAll()

            #print "%d : %f : %f" % (i, net1.Cost(), net2.Cost())

        self.comparenets(net1, net2)

    def test2(self):
        """Load/save interface."""
        net1 = self.buildnet1()
        net2 = self.buildnet2()

        for i in range(50):
            net1.UpdateAll()
            net2.UpdateAll()

        self.comparenets(net1, net2)

        from PickleHelpers import load, save

        save(net2, "/tmp/bblocks_tmpnet")
        net2alt = load("/tmp/bblocks_tmpnet")

        self.comparenets(net1, net2alt)

class GaussRectTestCase(unittest.TestCase):
    def setUp(self):
        t = Numeric.arange(20) / float(20)
        s = Numeric.sin(4*Numeric.pi*t) + 0.5

        # N(0, 0.01^2)
        noise = Numeric.array(
            [-0.00020618, -0.00214195, -0.00860223,  0.0083709 ,
             0.00493691, -0.00303564,  0.00467232,  0.00875158,
             0.01462178,  0.00804469,  0.00791851,  0.00444479,
             -0.00398483, -0.00677404, -0.02249393,  0.00172111,
            -0.00420562,  0.00287512,  0.02004725,  0.00479894],
            Numeric.Float)

        self.data = s * (s > 0) + noise

    def buildstatnet1(self):
        net = PyNet.PyNet(self.data.shape[0])
        nf = PyNet.PyNodeFactory(net)

        c0 = nf.GetConstant("c0", 0.0)
        c_5 = nf.GetConstant("c_5", -5.0)

        ms = nf.GetGaussian("ms", c0, c_5)
        vs = nf.GetGaussian("vs", c0, c_5)

        nf.EvidenceNode(ms, mean=1.0)
        nf.EvidenceNode(vs, mean=0.0)

        s = nf.GetGaussRectV("s", ms, vs)
        r = nf.GetRectificationV("r", s)

        vx = nf.GetConstant("vx", -2*Numeric.log(0.01))
        #vx = nf.GetGaussian("vx", c0, c_5)
        x = nf.GetGaussianV("x", r, vx)

        x.Clamp(Helpers.Array2DV(self.data))

        return net

    def buildstatnet2(self):
        net = PyNet.PyNet(1)
        nf = PyNet.PyNodeFactory(net)

        c0 = nf.GetConstant("c0", 0.0)
        c_5 = nf.GetConstant("c_5", -5.0)

        ms = nf.GetGaussian("ms", c0, c_5)
        vs = nf.GetGaussian("vs", c0, c_5)

        nf.EvidenceNode(ms, mean=1.0)
        nf.EvidenceNode(vs, mean=0.0)

        vx = nf.GetConstant("vx", -2*Numeric.log(0.01))

        for i in range(self.data.shape[0]):
            s = nf.GetGaussRect(Label("s", i), ms, vs)
            r = nf.GetRectification(Label("r", i), s)
            x = nf.GetGaussian(Label("x", i), r, vx)
            x.Clamp(self.data[i])

        return net

    def comparenets(self, vector, scalar):
        dim = self.data.shape[0]

        costdiff = abs(vector.Cost() - scalar.Cost())
        self.assert_(costdiff < dim * 1e-8, "costdiff = %f" % costdiff)

        s1 = GetMeanV(vector.GetVariable("s"))
        s2 = GetMean(scalar.GetVariableArray("s"))

        e = MLab.max(Numeric.absolute(s1 - s2))
        self.assert_(e < 1e-8)
        
        r1 = GetMeanV(vector.GetNode("r"))
        r2 = GetMean(scalar.GetNodeArray("r"))
        
        e = MLab.max(Numeric.absolute(r1 - r2))
        self.assert_(e < 1e-8)

        if vector.GetVariable("ms") is not None:
            self.assert_(abs(GetMean(vector.GetVariable("ms"))
                             - GetMean(scalar.GetVariable("ms"))) < 1e-8)
        else:
            self.assert_(abs(GetMean(vector.GetVariable("m0"))
                             - GetMean(scalar.GetVariable("m0"))) < 1e-8)

        self.assert_(abs(GetMean(vector.GetVariable("vs"))
                         - GetMean(scalar.GetVariable("vs"))) < 1e-8)
        
    def test1(self):
        """Vector/scalar comparison with static model."""

        net1 = self.buildstatnet1()
        net2 = self.buildstatnet2()
        
        for i in range(50):
            net1.UpdateAll()
            net2.UpdateAll()
            
            #print "%d : %f : %f" % (i, net1.Cost(), net2.Cost())

        self.comparenets(net1, net2)

    def builddynnet1(self):
        net = PyNet.PyNet(self.data.shape[0])
        nf = PyNet.PyNodeFactory(net)

        c0 = nf.GetConstant("c0", 0.0)
        c_5 = nf.GetConstant("c_5", -5.0)

        m0 = nf.GetGaussian("ms", c0, c_5)
        vs = nf.GetGaussian("vs", c0, c_5)

        proxy = nf.GetProxy("proxy", "s")
        delay = nf.GetDelayV("delay", m0, proxy)

        s = nf.GetGaussRectV("s", delay, vs)
        r = nf.GetRectificationV("r", s)

        vx = nf.GetConstant("vx", -2*Numeric.log(0.01))
        #vx = nf.GetGaussian("vx", c0, c_5)
        x = nf.GetGaussianV("x", r, vx)

        x.Clamp(Helpers.Array2DV(self.data))

        net.ConnectProxies()

        return net

    def builddynnet2(self):
        net = PyNet.PyNet(1)
        nf = PyNet.PyNodeFactory(net)

        c0 = nf.GetConstant("c0", 0.0)
        c_5 = nf.GetConstant("c_5", -5.0)

        m0 = nf.GetGaussian("ms", c0, c_5)
        vs = nf.GetGaussian("vs", c0, c_5)

        vx = nf.GetConstant("vx", -2*Numeric.log(0.01))

        prev = m0

        for i in range(self.data.shape[0]):
            s = nf.GetGaussRect(Label("s", i), prev, vs)
            prev = s
            r = nf.GetRectification(Label("r", i), s)
            x = nf.GetGaussian(Label("x", i), r, vx)
            x.Clamp(self.data[i])

        return net

    def test2(self):
        """Vector/scalar comparison with dynamic model."""

        net1 = self.builddynnet1()
        net2 = self.builddynnet2()
        
        for i in range(100):
            net1.UpdateAll()
            net2.UpdateAll()
            
            #print "%d : %f : %f" % (i, net1.Cost(), net2.Cost())

        self.comparenets(net1, net2)

    def test3(self):
        """Load/save vector."""
        from PickleHelpers import load, save
        
        ref = self.buildstatnet2()
        net = self.buildstatnet1()
        for i in range(100):
            ref.UpdateAll()
            net.UpdateAll()
        self.comparenets(net, ref)

        save(net, "/tmp/bblocks_tmpnet")
        alt = load("/tmp/bblocks_tmpnet")
        self.comparenets(alt, ref)

        ref = self.builddynnet2()
        net = self.builddynnet1()
        for i in range(100):
            ref.UpdateAll()
            net.UpdateAll()
        self.comparenets(net, ref)

        save(net, "/tmp/bblocks_tmpnet")
        alt = load("/tmp/bblocks_tmpnet")
        self.comparenets(alt, ref)
        
    def test4(self):
        """Load/save scalar."""
        from PickleHelpers import load, save

        ref = self.buildstatnet1()
        net = self.buildstatnet2()
        for i in range(100):
            ref.UpdateAll()
            net.UpdateAll()
        self.comparenets(ref, net)

        save(net, "/tmp/bblocks_tmpnet")
        alt = load("/tmp/bblocks_tmpnet")
        self.comparenets(ref, alt)

        ref = self.builddynnet1()
        net = self.builddynnet2()
        for i in range(100):
            ref.UpdateAll()
            net.UpdateAll()
        self.comparenets(ref, net)

        save(net, "/tmp/bblocks_tmpnet")
        alt = load("/tmp/bblocks_tmpnet")
        self.comparenets(ref, alt)

if __name__ == '__main__':
    cases = [s for s in dir() if s.endswith("TestCase")]
    suites = []

    if len(sys.argv) > 1:
        for i in range(1, len(sys.argv)):
            if sys.argv[i] in cases:
                suites.append(unittest.makeSuite(eval(sys.argv[i]), "test"))
            else:
                print sys.argv[i] + "?"
    else:
        for case in cases:
            suites.append(unittest.makeSuite(eval(case), "test"))

    bigsuite = unittest.TestSuite(suites)
    runner = unittest.TextTestRunner()

    runner.run(bigsuite)

##    import pdb
##    pdb.run("runner.run(bigsuite)")
##    
