## Demo pipeline script for running group factor analysis (GFA) with sparsity and on the TG-GATEs data set
# Author: Tommi Suvitaival, tommi.suvitaival@aalto.fi
# Based on group factor analysis code from Seppo Virtanen, seppo.j.virtanen@aalto.fi

## License

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

## Set working directory.

setwd() # Set the path to the package here.

## Load packages.

library(abind)
library(compiler)
library(GOstats)
library(Matrix)
library(mutoss)

## Load source code.

source(file.path("source", "sGFA.R"))

##
## DATA
##

## Load the collapsed TGP data.

load(file.path("data", "data_camda_collapsed-intersecting_drug-dose-time_as_samples-intersecting_genes_with_nonzero_variance_as_variables-categorical_findings-130427-covariates_included.RData"))

Y = data.camda.collapsed # four data views: pathological findigs, human in vitro gene expression (GE), rat in vitro GE, rat in vivo GE
Y$category.samples = NULL

D <- sapply(Y,ncol)
Ds <- c(1,cumsum(D)+1); De <- cumsum(D)
gr <- vector("list"); for(m in 1:length(De)){gr[[m]] <- Ds[m]:De[m]}
opts$groups <- gr
rm(gr)
# Concatenate data list Y, where each element contains Y[[m]] N \times D_m data matrix, 
# into N \times \sum_m D_m matrix.
Y <- abind(Y,along=2)
Y <- scale(Y)

##
## MODEL
##

## Model settings

opts <- getDefaultOptsGFA()

K = 100 # number of factors

opts$iter.max <- 1000 # Total number of iterations
opts$iter.burnin <- 500 # Number of burn-in iterations out of all iterations.
opts$iter.saved.max <- 0 # Maximum number of iterations after burn-in that are saved.

# Sparsity parameters -- values from 1 to 1e-7 are ok.

opts$prior.alphaZ <- 1e-7
opts$prior.alpha_0 <- 1e-7
opts$prior.beta_0 <- 1e-7
a <- 1e-2; b <- 0.5
opts$prior.betaW1 <- a*b ; opts$prior.betaW2 <- a*(1-b)
a <- 1e-2; b <- 0.5
opts$prior.betaX1 <- b*a ; opts$prior.betaX2 <- (1-b)*a

## Inference

model <- sGFA(Y,K,opts)
model$groups = opts$groups


##
## ANALYSIS
##

## Analysis settings

opts$fdr.level = 0.05 # the level at which the false discovery rate is controlled with the Benjamini-Hochberg correction
opts$n.top.genes.GO = 0 # the number of top-ranking genes used for the GO enrichment analysis in each factor; 0: all genes active in the factor
opts$p.cutoff.GO = 1 # cut-off level for the non-corrected p-values of the GO terms; 1: all included
opts$max.GO.terms.print = 10 # the maximum number of enriched GO terms printed per factor and gene expression view

## Load the gene annotations

load(file.path("data","MappingFromEnsemblTo.RData"))
HUGOENTREZ <- merge(HUGO,ENTREZ,by.x="ensembl_gene_id",by.y="ensembl_gene_id")
gene.names.Entrez <- HUGOENTREZ$entrezgene[match(x=rownames(model$W)[model$groups[[2]]], table=HUGOENTREZ$external_gene_id)]

## Find factors that are active in all the views.

factors.pathology.association.shared.GE = which(apply(X=(model$W[model$groups[[1]],]!=0),MAR=2,FUN=any) & apply(X=(model$W[model$groups[[2]],]!=0),MAR=2,FUN=any) & apply(X=(model$W[model$groups[[3]],]!=0),MAR=2,FUN=any) & apply(X=(model$W[model$groups[[4]],]!=0),MAR=2,FUN=any))
print(paste(length(factors.pathology.association.shared.GE), "factors shared across all the views."))

## Construct ranked list of the genes active in the shared factors.
# A list with dimensions GE views by sgared factors

ranked.genes.shared.factors = vector(mode="list", length=length(model$groups))
for (mi in 2:length(model$groups)) { # Go through all gene expression views.
	ranked.genes.shared.factors[[mi]] = vector(mode="list", length=length(factors.pathology.association.shared.GE))
	names(ranked.genes.shared.factors[[mi]]) = rownames(model$W[model$groups[[1]]]) # names of the target variables
	# Construct a separate list for each target variable.
	for (ci in 1:length(factors.pathology.association.shared.GE)) { # Go through all shared factors.
		ranked.genes.mi.ci = sort(x=abs(model$W[model$groups[[mi]],factors.pathology.association.shared.GE[ci]]), decreasing=TRUE, index.return=TRUE)$ix
		ranked.genes.shared.factors[[mi]][[ci]] = gene.names.Entrez[ranked.genes.mi.ci[which(model$W[model$groups[[mi]][ranked.genes.mi.ci], factors.pathology.association.shared.GE[ci]] !=0)]] # the ranked list of active genes for factor 'ci' in view 'mi'.
	}
}
rm(ranked.genes.mi.ci)

## GO enrichment test: for each shared factor and GE view

result.GO.factors = vector(mode="list", length=length(model$groups))
names(result.GO.factors) = c(NA, "Human in vitro", "Rat in vitro", "Rat in vivo") # names of the GE views
print("Computing gene set enrichments (gene expression view : factor)")
for (mi in 2:length(model$groups)) { # Go through all GE views.
	result.GO.factors[[mi]] = vector(mode="list", length=length(factors.pathology.association.shared.GE))
	names(result.GO.factors[[mi]]) = paste("Factor", factors.pathology.association.shared.GE)
	for (ci in 1:length(factors.pathology.association.shared.GE)) { # Go through all shared factors.
		cat(mi, ":", ci, "... ")
		if (length(ranked.genes.shared.factors[[mi]][[ci]])>=opts$n.top.genes.GO & opts$n.top.genes.GO>0) {
			params.GO.mi.ci = new("GOHyperGParams", geneIds=ranked.genes.shared.factors[[mi]][[ci]][which(!is.na(ranked.genes.shared.factors[[mi]][[ci]]))][1:opts$n.top.genes.GO], universeGeneIds=gene.names.Entrez[which(!is.na(gene.names.Entrez))], pvalueCutoff=opts$p.cutoff.GO, annotation="hgu133a", ontology="BP", conditional=TRUE, testDirection="over")
		} else {
			params.GO.mi.ci = new("GOHyperGParams", geneIds=ranked.genes.shared.factors[[mi]][[ci]][which(!is.na(ranked.genes.shared.factors[[mi]][[ci]]))], universeGeneIds=gene.names.Entrez[which(!is.na(gene.names.Entrez))], pvalueCutoff=opts$p.cutoff.GO, annotation="hgu133a", ontology="BP", conditional=TRUE, testDirection="over")
		}
		result.GO.factors[[mi]][[ci]] <- summary(hyperGTest(params.GO.mi.ci))
	}
}
rm(params.GO.mi.ci)

## FDR correction for the GO enrichment test

p.values.GO.vec = list()
p.values.GO.vec$p.original = vector(mode="numeric")
p.values.GO.vec$view = vector(mode="numeric")
p.values.GO.vec$factor = vector(mode="numeric")
for (mi in 2:length(model$groups)) { # Go through all GE views.
	for (ci in 1:length(factors.pathology.association.shared.GE)) { # Go through all shared factors.
		p.values.GO.vec$p.original = c(p.values.GO.vec$p.original, result.GO.factors[[mi]][[ci]]$Pvalue)
		p.values.GO.vec$view = c(p.values.GO.vec$view, rep(x=mi, times=length(result.GO.factors[[mi]][[ci]]$Pvalue)))
		p.values.GO.vec$factor = c(p.values.GO.vec$factor, rep(x=ci, times=length(result.GO.factors[[mi]][[ci]]$Pvalue)))
	}
}
p.values.GO.vec$p.corrected = BH(pValues=p.values.GO.vec$p.original, alpha=opts$fdr.level, silent=TRUE)$adjPValues # Compute Benjamini-Hochberg correction for the p-values over all the tests.
for (mi in 2:length(model$groups)) { # Go through all views.
	for (ci in 1:length(factors.pathology.association.shared.GE)) { # Go through all shared factors.
		result.GO.factors[[mi]][[ci]]$Pvalue.corrected = p.values.GO.vec$p.corrected[which(p.values.GO.vec$view==mi & p.values.GO.vec$factor==ci)]
	}
}
rm(p.values.GO.vec)

## Pick GO terms that are significantly enriched with the threshold.

result.GO.factors.corrected.significant = vector(mode="list", length=length(model$groups))
names(result.GO.factors.corrected.significant) = names(result.GO.factors)
for (mi in 2:length(model$groups)) { # Go through all GE views.
	result.GO.factors.corrected.significant[[mi]] = vector(mode="list", length=length(factors.pathology.association.shared.GE))
	for (ci in 1:length(factors.pathology.association.shared.GE)) { # Go through all shared factors.
		tmp = (result.GO.factors[[mi]][[ci]]$Pvalue.corrected<opts$fdr.level)
		if (length(which(tmp))>0) {
			result.GO.factors.corrected.significant[[mi]][[ci]] = lapply(X=result.GO.factors[[mi]][[ci]], FUN=subset, subset=tmp)
		}
	}
}

## Find pathology variables that are active in the shared factors.

W.shared.pathology.abs.normalized = abs(model$W[model$groups[[1]], factors.pathology.association.shared.GE])/max(abs(range(model$W[model$groups[[1]], factors.pathology.association.shared.GE])))

## Print the factor-wise pathological findings and associated enriched GO terms in each gene expression view.

for (ci in 1:length(factors.pathology.association.shared.GE)) { # Go through all shared factors.
	print(paste("SHARED FACTOR ", ci, "/", length(result.GO.factors[[mi]]), sep=""))
	print("Pathological findings:")
	print(rownames(W.shared.pathology.abs.normalized)[which(W.shared.pathology.abs.normalized[,ci]>0)])
	for (mi in 2:length(model$groups)) { # Go through all GE views.
		if (length(result.GO.factors.corrected.significant[[mi]][[ci]]$Term)>0) {
			print(paste(names(result.GO.factors.corrected.significant)[[mi]], " - ", length(ranked.genes.shared.factors[[mi]][[ci]]), " genes, ", length(result.GO.factors.corrected.significant[[mi]][[ci]]$Term), " significantly enriched GO terms:", sep=""))
			if (length(result.GO.factors.corrected.significant[[mi]][[ci]]$Term)>opts$max.GO.terms.print) {
				tmp = sort(x=result.GO.factors.corrected.significant[[mi]][[ci]]$Pvalue.corrected, index.return=TRUE)$ix[1:opts$max.GO.terms.print]
				print(result.GO.factors.corrected.significant[[mi]][[ci]]$Term[tmp])
				print(paste("and", length(result.GO.factors.corrected.significant[[mi]][[ci]]$Term)-opts$max.GO.terms.print, "more"))
			} else if (length(result.GO.factors.corrected.significant[[mi]][[ci]]$Term)>0) {
				print(result.GO.factors.corrected.significant[[mi]][[ci]]$Term)
			}
		} else {
			print(paste(names(result.GO.factors.corrected.significant)[[mi]], " - ", length(ranked.genes.shared.factors[[mi]][[ci]]), " genes, no significantly enriched GO terms", sep=""))
		}
	}
	print("")
}
print(paste("(False discovery rate of GO term enrichment controlled at the level ", opts$fdr.level, ")", sep=""))
