#! /usr/bin/env python
# -*- coding: iso-8859-1 -*-

#
# Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
# Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License (included in file License.txt in the
# program package) for more details.
#

from bblocks.Label import Label, Unlabel
import bblocks.PickleHelpers as PickleHelpers
import bblocks.Helpers as Helpers
import hnfa
import sys
import os
import getopt
import signal
import re
import math
try:
    import numpy.oldnumeric as Numeric
except:
    import Numeric

datadir=os.environ.get('DATADIR', 'data')

def usage(out=sys.stdout):
    out.write('Usage: python ' + os.path.basename(sys.argv[0]) +
              ' [OPTION]... [FILE]' + '\n')
    out.write('Makes a hnfanet using given data and runs it\n\n')
    out.write('  FILE   File to save generated network in,\n')
    out.write('         defaults to filename generated from options\n\n')
    for option in options:
        out.write("  ");
        oname = reduce(lambda x,y:x+', '+y, option[0])
        if len(oname) > 19 or len(option[4])>57:
            out.write(oname+'\n'+" "*10)
        else:
            out.write(oname.ljust(22))
        out.write(option[4]+'\n')
            

#[0]=(name,), [1]=param, [2]=tofilename, [3]=default_value, [4]=help 
options = [
    (('-t', '--sourcetype',), str, -1, 'fa',
     "type of sources to use, '[d][i]fa' 'fa'(default))"),
    (('--ignorectrlc',),      None,0, 0,
     'if set runhnfa ignores SIGINT'),
    (('-n', '--numsources',), int, 2, None,
     'number of sources to use'),
    (('-s', '--seed',),       int, 3, 0,
     'random seed to use'),
    (('-d','--datapoints',),  int, 4, None,
     'number of data points to use from the data set'),
    (('-l','--numpasses',),   int, 6, 9,
     'number of passes to do 0=linear mapping 9=default'),
    (('-p','--fileprefix',),  str, 0, 'nl1',
     'prefix to use in filenames and in data to load'),
    (('-i','--iterations',),  int, 7, 5000,
     'number of iterations to run'),
    (('-x','--onlyhj',),      int, 8, 1000,
     'number of iterations to use nothing but h-j in at end'),
    (('-r','--randhidden',),  int, 9, 50,
     'How many add randomly at first addition'),
    (('-y','--doprunehidden',), None,10,0,
     'If set, prunes also hidden nodes'),
    (('--hiddentest',),       int, 0, 1000,
     'How many nodes to test at subsecuent addition of hidden nodes'),
    (('--hiddenaccept',),     int, 0, 5,
     'How many nodes to accept at subsecuent addition of hidden nodes'),
    (('--printfilename',),    None,0, 0,
     'prints filename which will be used for results and exit')]

long_opts = ['help']
short_opts = "h"
optionsmap = {}
maxtofilename = 0
values = globals()

for option in options:
    for name in option[0]:
        assert(len(name) >= 2)
        assert(name[0] == '-')
        if name[1] == '-':
            if option[1] is None:
                long_opts.append(name[2:])
            else:
                long_opts.append(name[2:]+'=')
        else:
            assert(len(name) == 2)
            short_opts += name[1]
            if option[1] is not None:
                short_opts += ':'
        optionsmap[name] = option
        maxtofilename = max(maxtofilename, abs(option[2]))
    values[option[0][-1][2:]] = option[3]

filenameopts = [None] * maxtofilename

try:
    opts, args = getopt.getopt(sys.argv[1:], short_opts, long_opts)
    for o, a in opts:
        if o in ('-h', '--help'):
            usage()
            sys.exit()
        else:
            try:
                option = optionsmap[o]
            except KeyError:
                raise getopt.GetoptError('Unknown option: ' + o,'')
            if option[1] is None:
                values[option[0][-1][2:]] = 1
            else:
                try:
                    values[option[0][-1][2:]] = option[1](a)
                except ValueError:
                    raise getopt.GetoptError('In option ' + o + ' parameter ' +
                                             a + 'is not of type:' +
                                             option[1],'') 
            if option[2] > 0:
                filenameopts[option[2]-1] = o[1] + a
            elif option[2] < 0:
                filenameopts[abs(option[2])-1] = a

    if len(args) == 1:
        forcefile = a
    elif len(args) > 1:
        raise getopt.GetoptError('Extra filenames given','')
    else:
        forcefile = None
except getopt.GetoptError:
    sys.stderr.write(sys.argv[0] + ": " + str(sys.exc_info()[1]) + "\n")
    usage(sys.stderr)
    sys.exit(2)

if ignorectrlc:
    oldsiginthandler = signal.signal(signal.SIGINT, signal.SIG_IGN)
    raisekbd = hnfa.Learner.SIGNORE
else:
    raisekbd = hnfa.Learner.SRAISE
    
if forcefile is None:
    filename = os.path.split(fileprefix)[1]
    for x in filenameopts:
        if x is not None:
            filename += '_' + x
    filename += '.pickle.gz'
else:
    filename = forcefile

if printfilename:
    print re.sub('\.pickle(\.gz)?$','',filename)
    sys.exit(0)
        
def load(filename):
    if os.path.isfile(filename) or os.path.isfile(filename+'.gz'):
        return PickleHelpers.load_compat(filename)

data = None
if data is None: data=load(fileprefix + '-data.pickle')
if data is None: data=load(os.path.join(datadir, fileprefix + '-data.pickle'))
if data is None:
    sys.stderr.write(sys.argv[0] + ': Datafile not found\n')
    sys.exit(2)

if numsources is None:
    if os.path.split(fileprefix)[1] == 'helix':
        numsources = 1
    else:
        numsources = min(8, data.shape[0])

if datapoints is None:
    datapoints = data.shape[1]

print sys.argv

def prune():
    pruned = hnfanet.TryPruning()
    numhid = len(hnfanet.net.GetVariables('s\(1, '))
    num={'Pruned': pruned}
    for x in ('s\(1, ','A\(2, 1, ','A\(1, 0, ','A\(2, 0, '):
        num[x.replace('\\','')[:-2]+')'] = len(hnfanet.net.GetVariables(x))
        hnfanet.HistoryAdd("Nodes left", num)
    print num

def prunehidden(cdiff=500, num=1, lazy=0):
    pruned = hnfanet.TryPruneHidden(cdiff=cdiff,num=num,lazy=lazy)
    numhid = len(hnfanet.net.GetVariables('s\(1, '))
    num={'HiddenPruned': pruned}
    for x in ('s\(1, ','A\(2, 1, ','A\(1, 0, ','A\(2, 0, '):
        num[x.replace('\\','')[:-2]+')'] = len(hnfanet.net.GetVariables(x))
        hnfanet.HistoryAdd("Nodes left", num)
    print num

def iter(iters):
    hnfanet.LearnNet(printcost=10, iters=-iters,
                     printhooke=1,hooke=10, raisekbd=raisekbd)

def addw_10_prune_10(n,prob,doprunehidden=0,cdiff=500, num=1, lazy=0):
    if doprunehidden < 0:
        doprunehidden=n
    for j in range(n):
        hnfanet.AddAllWeights(prob)
        hnfanet.HistoryAdd('AddAllWeights', prob)
        prob *= probfactor
        iter(10)
        if doprunehidden and j%doprunehidden == doprunehidden/2:
            prunehidden(cdiff=cdiff,num=num,lazy=lazy)
        else:
            prune()
        iter(10)
    return prob

try:
    hnfanet=hnfa.HnfaNet(data[:,:datapoints], numsources, seed,
                         sourcetype=sourcetype)

    hnfanet.HistoryAdd("params", sys.argv)
    hnfanet.HistoryAdd("times", os.times())
    hnfanet.LearnNet(printcost=10, iters=100, raisekbd=raisekbd)

    prob = .2
    finalprob = .01
    withaddweights=iterations-hnfanet.learner.iter-onlyhj-numpasses*40-20
    if withaddweights > 0:
        probfactor = math.exp((math.log(finalprob)-math.log(prob))/
                              (withaddweights)/20)
    else:
        probfactor = 0
    if numpasses > 0:
        hnfanet.AddHidden(num=randhidden, vsdecay = 500)
        iter(50)
        prune()
        iter(10)
        prob=addw_10_prune_10(7,prob)

    for i in range(1, numpasses):
        hnfanet.AddHiddenBest(num=hiddenaccept, numtest=hiddentest,
                              vsdecay = 500)
        iter(30)
        prune()
        iter(10)
        prob=addw_10_prune_10(3,prob,-doprunehidden,lazy=10,num=3)

    prune()
    iter(10)
    prob=addw_10_prune_10((iterations-onlyhj-hnfanet.learner.iter)/20,
                          prob,doprunehidden*5,cdiff=1000)

    hnfanet.LearnNet(printcost=10, iters=iterations,
                     printhooke=1,hooke=10, raisekbd=raisekbd)
        
    hnfanet.HistoryAdd("times", os.times())
    hnfanet.SaveWithPickle(filename)

except KeyboardInterrupt:
    pass
