"""
================================================================================
Pipeline when train and test are sperated.
If only interested in cross validation on train set, use the pipeline of 
shen_ISMB2014.py.
================================================================================
"""

import sys
import commands
import numpy
import multiprocessing
import warnings; warnings.filterwarnings('ignore')
sys.path.append("../../../fingerid") # path to fingerid package         

from fingerid.preprocess.massbankparser import MassBankParser
from fingerid.kernel.twodgaussiankernel import TwoDGaussianKernel
from fingerid.kernel.mskernel import Kernel
from fingerid.model.internalCV_mp import internalCV_mp
from fingerid.model.trainSVM import trainModels
from fingerid.model.predSVM import predModels
from fingerid.preprocess.util import writeIDs
from fingerid.preprocess.util import centerTestKernel
from search import search

if __name__ == "__main__":
    """ Another pipeline when you have train/test instead of cross validation""" 
    # parse data
    train_dir = "APCI-ITFT-HCD_train_data"
    test_dir = "test_data"

    print "parse data\n"
    mbparser = MassBankParser()
    train_ms = mbparser.parse_dir(train_dir)
    test_ms = mbparser.parse_dir(test_dir)
    labels = numpy.loadtxt("util/train_output.txt")
    inds = numpy.loadtxt("util/used_fp_index",dtype=int)
    labels = labels[:,inds]
    n_train, n_fp = labels.shape
    n_test = len(test_ms)

    # compute train and test kernels
    print "computing test kernels"
    sm = 0.00001
    si = 100000
    kernel = TwoDGaussianKernel(sm, si)
    train_km = numpy.load("util/train_km.npy")
    test_km = kernel.compute_test_kernel(test_ms, train_ms)
    
    # make prediction
    prob = False
    preds = predModels(test_km, n_fp, "MODELS",prob=prob)
    numpy.savetxt("fp_predictions.txt", preds, fmt="%d")

    # search kegg
    ppm = 10  # set ppm for mass filtering
    res = search("util/kegg_mass", "util/kegg_fp.dict", test_ms, preds, ppm)

    # write results in fodler
    commands.getoutput("mkdir result")
    for i in xrange(n_test):
        stem = test_ms[i].f_name
        fname = "result/"+ stem[stem.find("test_data")+10:]+".res"
        w = open(fname,"w")
        w.write("\n".join([tup[0]+ "\t" + "%.4f" % tup[1] for tup in res[i]]))
        w.close()
        
