"""
================================================================================
Pipeline when train and test are sperated.
If only interested in cross validation on train set, use the pipeline of 
shen_ISMB2014.py.
================================================================================
"""

import sys
import numpy
import multiprocessing
import warnings; warnings.filterwarnings('ignore')
sys.path.append("../../../fingerid") # path to fingerid package         

from fingerid.preprocess.massbankparser import MassBankParser
from fingerid.kernel.twodgaussiankernel import TwoDGaussianKernel
from fingerid.kernel.mskernel import Kernel
from fingerid.model.internalCV_mp import internalCV_mp
from fingerid.model.trainSVM import trainModels
from fingerid.model.predSVM import predModels
from fingerid.preprocess.util import writeIDs
from fingerid.preprocess.util import centerTestKernel

if __name__ == "__main__":
    # set the training data folder
    train_dir = "LC-ESI-QTOF_train_data"

    # parse data
    print "parse data\n"
    mbparser = MassBankParser()
    train_ms = mbparser.parse_dir(train_dir)
    labels = numpy.loadtxt("util/train_output.txt")
    inds = numpy.loadtxt("util/used_fp_index",dtype=int)
    labels = labels[:,inds]
    n_train, n_fp = labels.shape

    # output the files corresponding to the spectra and fragmentation trees  
    writeIDs("util/train_spectras.txt",train_ms)

    # compute train kernel
    print "computing %s kernels"
    # cross validation selected parameter
    sm = 0.00001
    si = 100000

    kernel = TwoDGaussianKernel(sm, si)
    train_km = kernel.compute_train_kernel(train_ms)
    numpy.save("util/train_km", train_km)

    # train with 4 processe
    print "train models"
    # MODELS is the folder to store trained models.
    trainModels(train_km, labels, "MODELS", select_c=False, n_p=4)
    
    # compute training accuracy
    preds = predModels(train_km, n_fp, "MODELS")
    acc = numpy.sum(preds == labels,0)/float(n_train)
    numpy.savetxt("util/train_acc.txt", acc)

    






