# -*- coding: iso-8859-1 -*-

#
# Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
# Ilin, Tapani Raiko, Harri Valpola and Tomas stman.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License (included in file License.txt in the
# program package) for more details.
#

import bblocks.PickleHelpers as PickleHelpers
import bblocks.Helpers as Helpers
try:
    import numpy.oldnumeric as Numeric
    from numpy.linalg import inv
except:
    import Numeric
    from LinearAlgebra import inverse as inv
import os
import sys

datadir=os.environ.get('DATADIR', 'data')

def load(filename):
    if os.path.isfile(filename) or os.path.isfile(filename+'.gz'):
        return PickleHelpers.load_compat(filename)
    else:
        return None

def loadsig(datafile):
    sig = load(datafile)
    if sig == None:
        sig = load(os.path.join(datadir, datafile))
    if sig == None:
        raise RuntimeError, 'File not found ' + datafile
    return Numeric.transpose(sig)
    
m = Numeric.matrixmultiply
T = Numeric.transpose
su = Numeric.sum
def corrs(net, sig, nodes = None):
    if not nodes:
        nodes = net.net.GetVariables('s.2')
    s2=Numeric.array(Helpers.GetMeanV(nodes), Numeric.Float)
    (a,b) = s2.shape
    (b2,a2) = sig.shape
    for i in range(a):
        s2[i,:] -= su(s2[i,:]) / b
        s2[i,:] /= Numeric.sqrt(su(s2[i,:]*s2[i,:]) / b)
    c = m(s2,sig) / b
    best = []
    for i in range(a2):
        best.append(max(c[:,i]*c[:,i]))
    snr = su(best)/a2
    snr = Numeric.log(snr/(1-snr))/Numeric.log(10)*10
    snr2 = Numeric.array(best)
    snr2 = Numeric.log(snr2/(1-snr2))/Numeric.log(10)*10
    psig = m(T(s2), m(inv(m(s2,T(s2))), m(s2, sig)))
    snr3 = su(psig*sig)/b
    snr4 = su(snr3)/a2
    snr3 = Numeric.log(snr3/(1-snr3))/Numeric.log(10)*10
    snr4 = Numeric.log(snr4/(1-snr4))/Numeric.log(10)*10
    return c, best, snr, snr2, snr3, snr4

def corrs1(net, sig):
    return corrs(net, sig=sig, nodes=net.net.GetVariables('s.1'))

def corrs12(net, sig):
    return corrs(net, sig=sig, nodes=net.net.GetVariables('s.1')+net.net.GetVariables('s.2'))
