#!/usr/bin/env python
# -*- coding: iso-8859-1 -*-

#
# This file is a part of the Bayes Blocks library
#
# Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
# Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License (included in file License.txt in the
# program package) for more details.
#
# $Id: BBVis.py 5 2006-10-26 09:44:54Z ah $
#

"""Tkinter GUI for BBNet visualiser
"""
import os
import tempfile
import Tkinter
from Tkconstants import *
import tkFileDialog
import tkColorChooser
import tkMessageBox
import Pmw
import string
import copy
import DotWriter


# Node count for warning of too large network
LARGENETLIMIT = 100


class BBVis:
    """GUI for Bayes Blocks Visualiser"""
    def __init__(self, net):
        self.writer = DotWriter.PlainWriter(net)
        self.root = Tkinter.Tk()
        self.root.title('BBNet Visualiser')
        Pmw.initialise(self.root, size = 14, fontScheme = 'pmw1')

        # Supported zoomlevels and related font sizes etc.
        self.zoomlevels = ['10 %', '25 %', '50 %', '100 %', '200 %', '400 %']
        self.zoomvals = {'10 %': .1, '25 %': .25, '50 %': .5, '100 %': 1.0,
                         '200 %': 2.0, '400 %': 4.0}
        self.zoomfonts = {.1: 2, .25: 5, .5: 9, 1.0: 14,
                          2.0: 28, 4.0: 56}
        self.textitems = []

        # Included node types and other active modifications
        self.selectedtypes = copy.deepcopy(self.writer.types)
        self.mods = {}
        # Supported modifications and their handlers
        self.modhandlers = {'Combine sum trees': self.CombineSumtrees,
                            'Drop all indices': self.DropIndices}

        # Add all the necessary GUI elements
        self.canvas = Tkinter.Canvas(self.root)
        self.commandframe = Tkinter.Frame(self.root)

        self.typeframe = Tkinter.Frame(
            self.commandframe, borderwidth=2, relief='ridge')
        self.typeframe.pack(side=TOP, padx = 8, pady = 8, fill=X)
        self.typebutton = Tkinter.Button(
            self.typeframe, text='Select types',
            command=self.RunTypeSelector)
        self.typebutton.pack(side=TOP)

        self.treeselect = Pmw.RadioSelect(self.commandframe,
                                          buttontype = 'checkbutton',
                                          orient = 'vertical',
                                          labelpos = 'n',
                                          label_text = 'Modifications',
                                          command = self.ApplyMod,
                                          hull_borderwidth = 2,
                                          hull_relief = 'ridge',
                                          selectmode = 'multiple')
        self.treeselect.pack(side=TOP, padx = 8, pady = 8, fill=X)
        for t in ['Combine sum trees', 'Drop all indices']:
            self.treeselect.add(t)

        self.dropframe = Tkinter.Frame(
            self.commandframe, borderwidth=2, relief='ridge')
        self.dropframe.pack(side=TOP, padx = 8, pady = 8, fill=X)
        self.dropbutton = Tkinter.Button(
            self.dropframe, text='Selectively drop indices',
            command=self.RunIndexSelector)
        self.dropbutton.pack(side=TOP)

        self.propframe = Tkinter.Frame(
            self.commandframe, borderwidth=2, relief='ridge')
        self.propframe.pack(side=TOP, padx = 8, pady = 8, fill=X)
        self.propbutton = Tkinter.Button(
            self.propframe, text='Select node properties',
            command=self.RunPropertySelector)
        self.propbutton.pack(side=TOP)
        self.propmenu =  Pmw.OptionMenu(self.propframe,
                                        labelpos = 'n',
                                        label_text = 'Node type',
                                        items = self.writer.types
                                        )
        self.propmenu.pack(side=TOP, padx = 8, pady = 8, expand = 0)

        self.zoombutton =  Pmw.OptionMenu(self.commandframe,
                                          labelpos = 'n',
                                          label_text = 'Zoom level',
                                          items = self.zoomlevels,
                                          command = self.ZoomToLevel,
                                          hull_borderwidth = 2,
                                          hull_relief = 'ridge',
                                          initialitem='100 %'
                                          )
        self.zoombutton.pack(side=TOP, padx = 8, pady = 8, fill=X)
        self.zoomlevel = 1.0
        self.savebutton = Tkinter.Button(self.commandframe, text='Save graph',
                                         command=self.SaveToFile)
        self.savebutton.pack(side=TOP, padx = 40)
        self.quitbutton = Tkinter.Button(self.commandframe, text='Quit',
                                         command=self.root.quit)
        self.quitbutton.pack(side=TOP, padx = 40)

        if len(self.writer.d_act.keys()) > LARGENETLIMIT:
            tkMessageBox.showwarning("Too large net", "Large net - dropping indices.\nTry to simplify the structure before returning to full view.")
            self.treeselect.setvalue(['Drop all indices'])
            self.mods['Drop all indices'] = 1
            self.PerformOps()
        val = self.GenerateFigure()
        if not val:
            self.PackGUI()


    def _AdjustScrollregion(self):
        """Internal function to adjust the scroll region of the canvas"""
        self.bbox = self.canvas.bbox('everything')
        self.scrollregion = (self.bbox[0] - 50, self.bbox[1] - 50,
                             self.bbox[2] + 50, self.bbox[3] + 50)
        self.canvas.config(scrollregion=self.scrollregion)

    def _TagText(self):
        """Internal function to tag all text items"""
        self.textitems = []
        for i in self.canvas.find_all():
            try:
                font = self.canvas.itemcget(i, 'font')
                self.textitems.append(i)
            except:
                pass

    def ResizeText(self, newsize):
        """Resize all text items to given size"""
        for i in self.textitems:
            font = self.canvas.itemcget(i, 'font')
            components = font.split('-')
            components[7] = `newsize`
            newfont = string.join(components, '-')
            self.canvas.itemconfigure(i, font=newfont)

    def Zoom(self, factor):
        """Zoom the canvas"""
        self.canvas.scale('everything', 1.0, 1.0, float(factor), float(factor))
        self._AdjustScrollregion()

    def ZoomToLevel(self, factor):
        """Handler for GUI zoom operation"""
        newlevel = self.zoomvals[factor]
        if newlevel != self.zoomlevel:
            self.Zoom(newlevel / self.zoomlevel)
            self.ResizeText(self.zoomfonts[newlevel])
            self.zoomlevel = newlevel

    def PerformOps(self):
        """Perform all requested modifications to the graph"""
        self.writer.ResetFilter()
        self.writer.FilterButType(self.selectedtypes)
        self.writer.SelectivelyDropIndices()
        for k in self.mods.keys():
            if self.mods[k]:
                self.modhandlers[k]()

    def GenerateFigure(self):
        """(Re)generate the figure"""
        # Issue a warning if the net seems too large
        if len(self.writer.d_act.keys()) > LARGENETLIMIT and \
               not tkMessageBox.askyesno("Large net", "Visualising this large graph (" + str(len(self.writer.d_act.keys())) + " nodes) may take very long time.  Are you sure you want to proceed?"):
            return 0
        # Generate a temporary filename for the graph and write to the file
        fname = tempfile.mktemp('tk')
        self.writer.WriteToTk(fname)
        # Destroy the old figure
        self.canvas.destroy()
        # Show the new figure and delete the temporary file
        self.ShowFigure(fname)
        os.unlink(fname)
        # Set zoom to correct level
        self.Zoom(self.zoomlevel)
        self.ResizeText(self.zoomfonts[self.zoomlevel])
        return 1

    def ShowFigure(self, file):
        """Load a graph for Tk source file and display it"""
        # Run the Tk source
        self.root.tk.call('source', file)
        # Substitute the generated canvas to existing Python class
        # for easier handling
        self.canvas._name = 'c'
        self.canvas._w = '.c'
        self.canvas.addtag_all('everything')
        self._AdjustScrollregion()
        self._TagText()
        self.canvas.config(width=500, height=500)

        self.PackGUI()

    def PackGUI(self):
        self.xbar = Tkinter.Scrollbar(self.canvas)
        self.xbar.config(orient='horizontal')
        self.ybar = Tkinter.Scrollbar(self.canvas)
        self.xbar.config(command=self.canvas.xview)
        self.ybar.config(command=self.canvas.yview)
        self.canvas.config(xscrollcommand=self.xbar.set)
        self.canvas.config(yscrollcommand=self.ybar.set)
        self.xbar.pack(side=BOTTOM, fill=X)
        self.ybar.pack(side=RIGHT, fill=Y)
        self.canvas.pack(side=RIGHT, expand=YES, fill=BOTH)
        self.canvas.config(width=500, height=500)
        self.commandframe.pack(side=LEFT, fill=Y)

    def RunTypeSelector(self):
        TypeSelector(self.writer.types, self.writer.variables,
                     self.selectedtypes, self)

    def ApplyMod(self, button, val):
        """Apply a new modification to the graph"""
        self.mods[button] = val
        self.PerformOps()
        self.GenerateFigure()

    def CombineSumtrees(self):
        self.writer.CombineSumtrees()

    def DropIndices(self):
        self.writer.DropIndices()

    def SaveToFile(self):
        """Save the dot source of the current net"""
        fname = tkFileDialog.asksaveasfilename(
            filetypes=[('dot source files', '*.dot')])
        if fname:
            self.writer.WriteGraph(fname)

    def RunPropertySelector(self):
        PropertySelector(self,
                         self.writer.properties[self.propmenu.getvalue()])

    def RunIndexSelector(self):
        IndexSelector(self, self.writer.GetAllLabels(), self.writer.indexdrops)

    def ApplyTypeSelections(self, selected=None):
        if selected:
            if `self.selectedtypes` != `selected`:
                self.selectedtypes = selected
                self.PerformOps()
                self.GenerateFigure()

    def ApplyPropertySelections(self):
        self.PerformOps()
        self.GenerateFigure()

    def ApplyIndexSelections(self, sel):
        if `self.writer.indexdrops` != `sel`:
            self.writer.indexdrops = copy.deepcopy(sel)
            self.PerformOps()
            self.GenerateFigure()

class TypeSelector(Tkinter.Toplevel):
    """A popup for selecting node types to display"""
    def __init__(self, types, vartypes, selected, parent):
        self.vartypes = vartypes
        halflen = (len(types)+1)/2
        self.halflen = halflen
        self.parent = parent
        self.origselected = selected
        self.selected = selected
        self.all = types
        Tkinter.Toplevel.__init__(self)
        self.title('Select node types to display')
        self.frame = Tkinter.Frame(self)
        self.frame.pack(side=TOP, fill=X)
        self.select1 = Pmw.RadioSelect(self.frame,
                                      buttontype = 'checkbutton',
                                      orient = 'vertical',
                                      labelpos = None,
                                      selectmode = 'multiple')
        self.select2 = Pmw.RadioSelect(self.frame,
                                      buttontype = 'checkbutton',
                                      orient = 'vertical',
                                      labelpos = None,
                                      selectmode = 'multiple')
        self.select1.pack(side=LEFT, anchor=N, fill = 'x', padx = 10)
        self.select2.pack(side=RIGHT, anchor=N, fill = 'x', padx = 10)
        for t in types[:halflen]:
            self.select1.add(t)
        for t in types[halflen:]:
            self.select2.add(t)
        sel1 = filter(lambda x: x in types[:halflen], selected)
        sel2 = filter(lambda x: x in types[halflen:], selected)
        self.select1.setvalue(sel1)
        self.select2.setvalue(sel2)
        
        self.rbox = Pmw.ButtonBox(self, labelpos = None)
        self.rbox.add('Reset', command = self.reset)
        self.rbox.add('Select all', command = self.selall)
        self.rbox.add('Clear', command = self.clear)
        self.rbox.alignbuttons()
        self.rbox.pack(side=TOP, fill = 'x', expand = 1,
                       padx = 12, pady = 6)

        self.vbox = Pmw.ButtonBox(self, labelpos = None)
        self.vbox.add('Select variables', command = self.selvars)
        self.vbox.alignbuttons()
        self.vbox.pack(side=TOP, fill = 'x', expand = 1,
                       padx = 12, pady = 6)

        self.bbox = Pmw.ButtonBox(self, labelpos = None)
        self.bbox.add('   OK   ', command = self.finish)
        self.bbox.add('  Apply ', command = self.apply)
        self.bbox.add(' Cancel ', command = self.cancel)
        self.bbox.alignbuttons()
        self.bbox.pack(side=TOP, fill = 'x', expand = 1,
                       padx = 12, pady = 6)

    def new_selection(self, selected):
        self.selected = selected
        sel1 = filter(lambda x: x in self.all[:self.halflen], selected)
        sel2 = filter(lambda x: x in self.all[self.halflen:], selected)
        self.select1.setvalue(sel1)
        self.select2.setvalue(sel2)

    def reset(self):
        self.new_selection(self.origselected)

    def selall(self):
        self.new_selection(self.all)

    def selvars(self):
        self.new_selection(self.vartypes)

    def clear(self):
        self.new_selection([])

    def apply(self):
        self.selected = list(self.select1.getcurselection()) + \
                        list(self.select2.getcurselection())
        self.parent.ApplyTypeSelections(self.selected)

    def finish(self):
        self.selected = list(self.select1.getcurselection()) + \
                        list(self.select2.getcurselection())
        self.parent.ApplyTypeSelections(self.selected)
        self.destroy()

    def apply(self):
        self.selected = list(self.select1.getcurselection()) + \
                        list(self.select2.getcurselection())
        self.parent.ApplyTypeSelections(self.selected)

    def cancel(self):
        self.selected = self.origselected
        self.parent.ApplyTypeSelections(self.selected)
        self.destroy()

class PropertySelector(Tkinter.Toplevel):
    """A popup for querying properties of node(s)"""
    def __init__(self, parent, props):
        self.parent = parent
        self.props = props
        self.origprops = props.PropertyDict()
        Tkinter.Toplevel.__init__(self)
        self.title('Select properties for the node(s)')

        self.shapebutton =  Pmw.OptionMenu(
            self, labelpos = 'w', label_text = 'Shape',
            items = DotWriter.all_properties['shape'],
            initialitem = self.origprops['shape'],
            command = props.SetShape
            )
        self.shapebutton.pack(side=TOP, padx = 8, pady = 8, expand = 0)

        self.stylebutton =  Pmw.OptionMenu(
            self, labelpos = 'w', label_text = 'Style',
            items = DotWriter.all_properties['style'],
            initialitem = self.origprops['style'],
            command = props.SetStyle
            )
        self.stylebutton.pack(side=TOP, padx = 8, pady = 8, expand = 0)

        self.peripbutton =  Pmw.OptionMenu(
            self, labelpos = 'w', label_text = 'Peripheries',
            items = DotWriter.all_properties['peripheries'],
            initialitem = self.origprops['peripheries']-1,
            command = props.SetPeripheries
            )
        self.peripbutton.pack(side=TOP, padx = 8, pady = 8, expand = 0)

        self.colorframe = Tkinter.Frame(self)
        self.colorframe.pack(side=TOP, fill=X)
        self.colorbutton = Tkinter.Button(self.colorframe, text='Change color',
                                          command=self.SelectColor)
        self.colorbutton.pack(side=LEFT, padx = 5)
        self.colorlabel = Tkinter.Label(self.colorframe, bg=props.color)
        self.colorlabel.pack(side=RIGHT, fill=X,expand=YES)

        self.fillcframe = Tkinter.Frame(self)
        self.fillcframe.pack(side=TOP, fill=X)
        self.fillcbutton = Tkinter.Button(self.fillcframe, text='Change fillcolor',
                                          command=self.SelectFillcolor)
        self.fillcbutton.pack(side=LEFT, padx = 5)
        self.fillclabel = Tkinter.Label(self.fillcframe, bg=props.fillcolor)
        self.fillclabel.pack(side=RIGHT, fill=X, expand=YES)

        self.bbox = Pmw.ButtonBox(self, labelpos = None)
        self.bbox.add('   OK   ', command = self.finish)
        self.bbox.add('  Apply ', command = self.apply)
        self.bbox.add(' Cancel ', command = self.cancel)
        self.bbox.alignbuttons()
        self.bbox.pack(side=TOP, fill = 'x', expand = 1,
                       padx = 12, pady = 6)

    def SelectFillcolor(self):
        newcol = tkColorChooser.askcolor()
        if newcol[1]:
            self.props.SetFillcolor(newcol[1])
            self.fillclabel.configure(bg=newcol[1])

    def SelectColor(self):
        newcol = tkColorChooser.askcolor()
        if newcol[1]:
            self.props.SetColor(newcol[1])
            self.colorlabel.configure(bg=newcol[1])

    def finish(self):
        self.parent.ApplyPropertySelections()
        self.destroy()

    def apply(self):
        self.parent.ApplyPropertySelections()

    def cancel(self):
        if `self.props.PropertyDict()` != `self.origprops`:
            self.props.SetAll(self.origprops)
            self.parent.ApplyPropertySelections()
        self.destroy()

class IndexSelector(Tkinter.Toplevel):
    """A popup for selecting number of indices to drop from labels"""
    def __init__(self, parent, nodes, cursel):
        self.parent = parent
        self.nodes = nodes
        self.origsel = cursel
        self.sel = copy.deepcopy(cursel)
        Tkinter.Toplevel.__init__(self)
        self.title('Select label indices to drop')

        self.sf = Pmw.ScrolledFrame(
            self,
            usehullsize = 1,
            hull_width = 400,
            hull_height = 500,
        )

        self.frame = self.sf.interior()

        allnodes = nodes.keys()
        allnodes.sort()
        self.counters = []
        self.counterdict = {}
        for n in allnodes:
            if nodes[n] > 0:
                c = Pmw.Counter(self.frame,
                                labelpos = 'w',
                                label_text = '%s, max %d:' % (n, nodes[n]),
                                orient = 'horizontal',
                                entry_width = 2,
                                entryfield_value = cursel.get(n, 0),
                                entryfield_validate = {'validator' : 'integer',
                                                       'min' : 0, 'max' : nodes[n]},
                                entryfield_modifiedcommand = lambda x=n: self.ChangeValue(x)
                                )
                self.counters.append(c)
                self.counterdict[n] = c
        Pmw.alignlabels(self.counters)
        for counter in self.counters:
            counter.pack(side=TOP, fill=X, expand=1, padx=10, pady=5)
        self.sf.pack(padx = 5, pady = 3, fill = 'both', expand = 1)

        self.bbox = Pmw.ButtonBox(self, labelpos = None)
        self.bbox.add('   OK   ', command = self.finish)
        self.bbox.add('  Apply ', command = self.apply)
        self.bbox.add(' Cancel ', command = self.cancel)
        self.bbox.alignbuttons()
        self.bbox.pack(side=TOP, fill = 'x', expand = 1,
                       padx = 12, pady = 6)

    def ChangeValue(self, l):
        self.sel[l] = int(self.counterdict[l].getvalue())

    def finish(self):
        self.parent.ApplyIndexSelections(self.sel)
        self.destroy()

    def apply(self):
        self.parent.ApplyIndexSelections(self.sel)

    def cancel(self):
        if `self.sel` != `self.origsel`:
            self.parent.ApplyIndexSelections(self.origsel)
        self.destroy()

if __name__ == '__main__':
    import sys

    # Load a given pickled network
    if len(sys.argv) > 1:
        import PickleHelpers
        net = PickleHelpers.LoadWithPickle(sys.argv[1])
    else:
        print "Usage: python %s savefile.pickle" % sys.argv[0]

    # Run the GUI
    try:
        gui = BBVis(net.net)
    except AttributeError:
        gui = BBVis(net)
    gui.root.mainloop()
