# -*- coding: iso-8859-1 -*-

#
# This file is a part of the Bayes Blocks library
#
# Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
# Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License (included in file License.txt in the
# program package) for more details.
#
# $Id: DecayCounters.py 5 2006-10-26 09:44:54Z ah $
#

import Net
import math


class LinearDecayCounter:
    """Decay counter with linearly increasing memory length."""
    def __init__(self, net, ratio = 0.5):
        self.samples = 1.0
        self.ratio = ratio
        self.cumsum = 0.0
        self.decay = 0.0
        self.pythoncounter = Net.PythonDecayCounter(self.DoDecay)
        net.SetDecayCounter(self.pythoncounter)

    def Reinstall(self, net):
        self.pythoncounter = Net.PythonDecayCounter(self.DoDecay)
        net.SetDecayCounter(self.pythoncounter)

    def DoDecay(self):
        self.samples += 1
        self.cumsum += 1
        x = (self.samples * self.ratio - 1) / self.cumsum
        if (x < 0.99):
            x = math.exp(x-1)
        self.cumsum *= x
        self.decay = x
        return float(x)

class SaturatingDecayCounter:
    """Decay counter with constant decay."""
    def __init__(self, net, limit = 100):
        assert limit > 1
        self.samples = 1.0
        self.cumsum = 0.0
        self.decay = 0.0
        self.limit = limit
        self.pythoncounter = Net.PythonDecayCounter(self.DoDecay)
        net.SetDecayCounter(self.pythoncounter)

    def Reinstall(self, net):
        self.pythoncounter = Net.PythonDecayCounter(self.DoDecay)
        net.SetDecayCounter(self.pythoncounter)

    def DoDecay(self):
        self.samples += 1
        self.cumsum += 1
        x = (self.limit - 2) / (self.limit - 1)
        self.cumsum *= x
        self.decay = x
        return float(x)

class FunctionalDecayCounter:
    def __init__(self, net, memfunc, memparams = ()):
        self.samples = 1
        self.cumsum = 0.0
        self.decay = 0.0
        self.memfunc = memfunc
        if type(memparams) != type(()):
            memparams = (memparams, )
        self.memparams = memparams
        self.oldmem = 0.0
        self.pythoncounter = Net.PythonDecayCounter(self.DoDecay)
        net.SetDecayCounter(self.pythoncounter)

    def Reinstall(self, net):
        self.pythoncounter = Net.PythonDecayCounter(self.DoDecay)
        net.SetDecayCounter(self.pythoncounter)

    def DoDecay(self):
        self.samples += 1
        self.cumsum += 1
        memlen = self.memfunc(self.samples - 1, *self.memparams)
        #if memlen > (self.oldmem + 1):
        #    raise ValueError("DecayCounter memory length increase > 1")
        x = memlen / self.cumsum
        if x > 1:
            print "DecayCounter evaluated decay > 1"
            #raise ValueError("DecayCounter evaluated decay > 1")
        self.oldmem = memlen
        self.cumsum *= x
        self.decay = x
        return float(x)
