# -*- coding: iso-8859-1 -*-

#
# This file is a part of the Bayes Blocks library
#
# Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
# Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License (included in file License.txt in the
# program package) for more details.
#
# $Id: Browser.py 5 2006-10-26 09:44:54Z ah $
#

import cmd
import re
import plot
import pprint
import Numeric
import Label

class FieldReader:
    def __init__(self, *fields):
        self.fields = fields
    
    def __call__(self, arg):
        val = arg
        for k in self.fields:
            val = val[k]
        return val

class NetBrowser(cmd.Cmd):
    def __init__(self, net, debug=0):
        cmd.Cmd.__init__(self)
        try:
            import readline
        except ImportError:
            pass

        self.net = net
        self.cost = self.net.Cost()
        self.py_net = net.SaveToPyObject(debug)
        self.prompt = '(NetBrowser) '
        self.pl = plot.Plotter()
        self.pp = pprint.PrettyPrinter()

        self.nodes = {}
        for k in range(1, self.py_net[0]['node_num']+1):
            self.nodes[self.py_net[k]['label']] = self.py_net[k]

        self.variables = {}
        for k in self.py_net[0]['variables']:
            self.variables[k] = self.nodes[k]

        self.mydict = self.variables
        self.ownvars = {}
        self.curnode = None
        self.quit = 0

    def Browse(self):
        self.quit = 0
        self.cmdloop()

    def match_nodes(self, line):
        if len(line) == 0:
            return [self.curnode]
        else:
            try:
                r = re.compile(line)
                keys = self.mydict.keys()
                nodes = filter(r.match, keys)
            except:
                nodes = []
            if len(nodes) == 0 and self.mydict.has_key(line):
                nodes = [ line ]
            nodes.sort()
            return nodes

    def make_matrix(self, nodes, operation):
        labels = map(Label.Unlabel, nodes)
        matrs = {}
        for k in labels:
            if matrs.has_key(k[0]):
                for i in range(len(k[1])):
                    matrs[k[0]][i] = max(matrs[k[0]][i], k[1][i])
            else:
                matrs[k[0]] = list(k[1])
        m = {}
        for l in matrs.keys():
            m[l] = Numeric.zeros(Numeric.array(matrs[l]) + 1).astype('d')
        for k in labels:
            m[k[0]][k[1]] = operation(self.nodes[Label.Label(*k)])
        if len(m.keys()) == 1:
            return m[m.keys()[0]]
        else:
            return m

    def default(self, line):
        if self.nodes.has_key(line):
            self.curnode = line
            self.pp.pprint(self.nodes[line])

    def do_ls(self, line):
        keys = self.mydict.keys()
        keys.sort()
        if len(line) == 0:
            self.pp.pprint(keys)
        else:
            self.pp.pprint(self.match_nodes(line))

    def multiplot_nodes(self, line, op):
        nodes = self.match_nodes(line)
        try:
            vals = map(op, nodes)
            if len(nodes) > 0 and len(vals[0]) > 0:
                self.pl.multiplot(vals)
            else:
                print "Nothing to plot!"
        except:
            print "Error in plotting!  (Please note that only vectors can be plotted.)"

    def do_plot(self, line):
        self.multiplot_nodes(line, lambda x: self.mydict[x]['myval']['mean'])

    def do_plotvar(self, line):
        self.multiplot_nodes(line, lambda x: self.mydict[x]['myval']['var'])

    def do_plotexp(self, line):
        self.multiplot_nodes(line, lambda x: self.mydict[x]['myval']['ex'])

    def do_plotmatrix(self, line):
        nodes = self.match_nodes(line)
        matrix = self.make_matrix(nodes, FieldReader('myval', 'mean'))
        if type(matrix) == type({}):
            print "No single matrix to plot, restrict your node specification"
        else:
            self.pl.hinton_diagram(matrix)

    def do_mean(self, line):
        nodes = self.match_nodes(line)
        matrix = self.make_matrix(nodes, FieldReader('myval', 'mean'))
        self.pp.pprint(matrix)

    def do_var(self, line):
        nodes = self.match_nodes(line)
        matrix = self.make_matrix(nodes, FieldReader('myval', 'var'))
        self.pp.pprint(matrix)

    def do_exp(self, line):
        nodes = self.match_nodes(line)
        matrix = self.make_matrix(nodes, FieldReader('myval', 'ex'))
        self.pp.pprint(matrix)

    def do_set(self, line):
        tok = line.split()
        self.ownvars[tok[0]] = self.match_nodes(tok[1])

    def do_use(self, line):
        if line == "nodes":
            self.mydict = self.nodes
        elif line == "variables":
            self.mydict = self.variables

    def do_parents(self, line):
        nodes = self.match_nodes(line)
        for n in nodes:
            self.pp.pprint({n: self.mydict[n]['parents']})

    def do_children(self, line):
        nodes = self.match_nodes(line)
        for n in nodes:
            self.pp.pprint({n: self.mydict[n]['children']})

    def do_children2(self, line):
        if len(line) == 0:
            self.pp.pprint(self.curnode['children'])
        else:
            nodes = self.match_nodes(line)
            for n in nodes:
                print n+":", self.mydict[n]['children']

    def do_matsave(self, line):
        names = line.split()
        if len(names) != 2:
            print "Usage: matsave file netname"
        else:
            self.net.SaveToMatFile(names[0], names[1])

    def help_ls(self):
        print "List all nodes/variables, optionally those with labels matching given regexp"
    
    def do_quit(self, line):
        self.quit = 1
        return 1

    do_q = do_quit
    do_EOF = do_quit


if __name__ == '__main__':
    import sys
    # Load a given pickled network
    if len(sys.argv) > 1:
        import PickleHelpers
        net = PickleHelpers.LoadWithPickle(sys.argv[1])
    else:
        print "Usage: python %s pickled_net.pickle" % sys.argv[0]
        sys.exit(0)

    try:
        b = NetBrowser(net.net)
    except AttributeError:
        b = NetBrowser(net)
    b.Browse()
