# -*- coding: iso-8859-1 -*-

#
# This file is a part of the Bayes Blocks library
#
# Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
# Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License (included in file License.txt in the
# program package) for more details.
#
# $Id: plot.py 6 2006-10-26 09:54:58Z ah $
#

"""Frontend to plotting with biggles.

Example usage:

pl = plot.Plotter()

t = arange(0, 2*pi, 0.1)
pl.plot(sin(t))
pl.plot(t, sin(t))
pl.sethold(1)
pl.plot(t, sin(2*t), "red")
pl.sethold(0)

S = array([sin(t), sin(2*t)])
pl.plot(S)
pl.multiplot(S)

pl.save("art.eps")
"""

try:
    import biggles
    foundbiggles = 1
except ImportError:
    foundbiggles = 0

import os
import Numeric
from Numeric import *

def setpersistent(state):
    if state:
        biggles.configure("screen", "persistent", "yes")
    else:
        biggles.configure("screen", "persistent", "no")

if foundbiggles:
    setpersistent(True)

class Plotter:

    def __init__(self):
        if not foundbiggles:
            raise EnvironmentError(
                "Biggles not found, cannot initialize Plotter!")
        if os.getenv("DISPLAY") is None:
            raise EnvironmentError(
                "No DISPLAY set, cannot initialize Plotter!")
        self.p = biggles.FramedPlot()
        self.hold = 0
        self.visible = 1

    def sethold(self, state):
        self.hold = state

    def setvisible(self, state):
        self.visible = state

    def plot(self, x, y=None, color="black"):
        self.clear()
        x, y = self.__convertxy(x, y)
        if len(y.shape) == 2:
            for i in range(y.shape[0]):
                self.__plot(x, y[i,:])
        else:
            self.__plot(x, y, color=color)
        self.refresh()

    def histogram(self, data, numbins=10, scale=1):
        data = self.__force_array(data)

        if len(data.shape) == 1:
            self.clear()
            self.__histogram(data, numbins, scale=scale)
        else:
            assert len(data.shape) == 2
            rows = data.shape[0]
            t = biggles.Table(rows, 1)
            for i in range(rows):
                self.p = biggles.FramedPlot()
                self.__histogram(data[i,:], numbins, scale=scale)
                t[i,0] = self.p
            self.p = t

        self.refresh()
        
    def __histogram(self, data, numbins, scale=1):
        sorted = sort(data)
        xrange = (sorted[-1] - sorted[0])
        binsize = xrange / numbins
        count = zeros(numbins, Float)
        i = 0
        for n in range(numbins):
            limit = sorted[0] + (n+1) * binsize
            while i < len(sorted) and sorted[i] <= limit:
                count[n] += 1
                i += 1
        if scale == 1:
            count /= (binsize * len(sorted))
        self.p.add(biggles.Histogram(count, sorted[0], binsize))
        
    def __force_array(self, seq):
        if type(seq) == type(array([])):
            return seq.astype(Float)
        else:
            return array(seq, Float)

    def __plot(self, x, y=None,
               color="black", linetype="solid", linewidth=1):
        x, y = self.__convertxy(x, y)
        self.p.add(biggles.Curve(x, y, color=color, linetype=linetype,
                                 linewidth=linewidth))

    def plotpoints(self, data, type="dot"):
        self.clear()

        arr = self.__force_array(data)

        if len(arr.shape) == 2:
            x = arr[0,:]
            y = arr[1,:]
        else:
            x = arange(len(arr))
            y = arr

        self.p.add(biggles.Points(x, y, type=type))
        self.refresh()

    def __max_variation(self, ary):
        v = Numeric.maximum.reduce(ary, -1) - Numeric.minimum.reduce(ary, -1)
        return Numeric.maximum.reduce(v)

    def __equal_scale(self, ary):
        maxvar = self.__max_variation(ary)
        for i in range(ary.shape[0]):
            b = Numeric.maximum.reduce(ary[i])
            a = Numeric.minimum.reduce(ary[i])
            m = 0.5 * (a + b)
            bhat = m + 0.55 * maxvar
            ahat = m - 0.55 * maxvar
            self.p[i,0].yrange = (ahat, bhat)

    def __convertxy(self, x, y):
        x = self.__force_array(x)
        if y is None:
            y = x
            x = Numeric.arange(y.shape[-1])
        y = self.__force_array(y)
        return x, y

    def multiplot(self, x, y=None, scale="", title=None):
        x, y = self.__convertxy(x, y)
        rows = y.shape[0]
        if title is None:
            a = biggles.FramedArray(rows, 1)
        else:
            a = biggles.FramedArray(rows, 1, title=title)
        for i in range(rows):
            self.p = a[i,0]
            self.__plot(x, y[i,:])
        self.p = a
        if scale == "equal":
            self.__equal_scale(y)
        self.refresh()

    def multiplotwithvar(self, means, vars, x=None, labels=None):
        means = self.__force_array(means)
        stds = sqrt(self.__force_array(vars))
        if labels == None:
            labels = map(str, range(means.shape[0]))
        table = biggles.Table(means.shape[0], 1)
        for i in range(means.shape[0]):
            self.p = biggles.FramedPlot()
            if x is None:
                self.__plot(means[i] + stds[i], linetype="dotdashed")
                self.__plot(means[i] - stds[i], linetype="dotdashed")
                self.__plot(means[i], linewidth=4)
            else:
                self.__plot(x, means[i] + stds[i], linetype="dotdashed")
                self.__plot(x, means[i] - stds[i], linetype="dotdashed")
                self.__plot(x, means[i], linewidth=4)
            self.p.ylabel = labels[i]
            table[i,0] = self.p
        self.p = table
        self.refresh()

    def contours(self, z, x = None, y = None):
        self.clear()
        z = self.__force_array(z)
        if x is None:
            x = arange(0, z.shape[0])
        else:
            x = self.__force_array(x)
        if y is None:
            y = arange(0, z.shape[1])
        else:
            y = self.__force_array(y)
        self.p.add(biggles.Contours(z, x, y))
        self.refresh()

    def hinton_diagram(self, matrix):
        matrix = self.__force_array(matrix)
        maxval = max(abs(reshape(matrix, (product(matrix.shape), ))))
        if len(matrix.shape) == 1:
            matrix = reshape(matrix, (matrix.shape[0], 1))
        self.p = biggles.Table(matrix.shape[0], matrix.shape[1])
        for i in range(matrix.shape[0]):
            for j in range(matrix.shape[1]):
                self.p[i, j] = biggles.Plot()
                self.p[i, j].xrange = (0, 1)
                self.p[i, j].yrange = (0, 1)
                val = sqrt(abs(matrix[i, j]) / maxval)
                self.p[i, j].add(biggles.Box((.5*(1-val), .5*(1-val)),
                                             (.5*(1+val), .5*(1+val))))
                if matrix[i, j] < 0:
                    self.p[i, j].add(biggles.FillBetween(
                        (.5*(1-val), .5*(1+val)),
                        (.5*(1-val), .5*(1-val)),
                        (.5*(1-val), .5*(1+val)),
                        (.5*(1+val), .5*(1+val)), color="black"))
        self.refresh()

    def save(self, filename):
        self.p.write_eps(filename)

    def refresh(self):
        if self.visible:
            self.p.show()

    def clear(self):
        if not self.hold:
            self.p = biggles.FramedPlot()
