#
# Tools for experiments with Tensor GFA
#
# For questions and bug reports please contact
# suleiman.khan@aalto.fi
#
# Copyright 2013 Suleiman Ali Khan. All rights reserved.
# The software is licensed under the FreeBSD license; see LICENSE
# included in the package.

library(tensor)
source("tools_plot.R")
source("tools_data.R")
source("getRandomInits.R")
################################################################################
########################### Computation Functions - Tensor/GibbsGFA ############
################################################################################

#
# Compute Mean Square Error of Projection matrices
# Input: W is list of the orginial W matrices. This MUST be the original one. Distance metric is not symmetric.
# Input: Westim is list of W matrices estimated by GFA variant
#
computeMSE.W <- function(W,Wlearned)
{
	##combine W across M
	mse <- 0
	M <- length(W)
	D <- unlist(lapply(W,nrow))
	Worig <- W[[1]]
	Westim <- Wlearned[[1]]
	if(M>1)
	for(m in 2:M)
	{
		Worig <- rbind(Worig,W[[m]])
		Westim <- rbind(Westim,Wlearned[[m]])
	}

	
	##Remove inactive components
	if(nrow(Worig)== 1)	
	{
		inds <- Worig != 0
		Worig <- matrix(Worig[,inds],nrow=nrow(Worig),ncol=sum(inds))
	}
	else
	{
		inds <- apply(Worig,2,sd)!=0
		Worig <- Worig[,inds]
	}
	##Scale W's to a have max value of 1
	sc <- matrix(rep(apply(abs(Worig),2,max),nrow(Worig)),nrow(Worig),ncol(Worig),byrow=TRUE)
	Worig <- Worig/sc

	##Remove inactive components
	if(nrow(Westim)== 1)
	{
		inds <- Westim != 0
		Westim <- matrix(Westim[,inds],nrow=nrow(Westim),ncol=sum(inds))		
	}
	else
	{
		inds <- apply(Westim,2,sd)!=0
		Westim <- Westim[,inds]
	}
	##Scale W's to a have max value of 1	
	WEcol <- sum(inds)
	if(WEcol==1)
		sc <- rep(max(Westim^2),length(Westim))
	else
		sc <- matrix(rep(apply(abs(Westim),2,max),nrow(Westim)),nrow(Westim),ncol(Westim),byrow=TRUE)
	Westim <- Westim/sc
	#browser()	
	#run so that Westim has one more column than Worig
	if(WEcol <= ncol(Worig))
	for(i in WEcol:ncol(Worig)) 
		Westim <- cbind(Westim,rep(Inf,nrow(Worig)))
		

	##Find optimal match by optimal reordering and sign
	for(k in 1:ncol(Worig))
	{
		Wrep <- matrix(rep(Worig[,k],ncol(Westim)),nrow(Worig),ncol(Westim))
		a <- colMeans(((Wrep - Westim)^2)) #add sqrt ?
		b <- colMeans(((Wrep - ((-1)*Westim))^2))
		if(min(a)<min(b))
		{
			rc <- which.min(a)
			v <- min(a)
		}
		else
		{
			rc <- which.min(b)
			v <- min(b)
		}
		
		if(v!=Inf)
		{
			mse[k] <- v
			if(nrow(Westim)== 1)
				Westim <- matrix(Westim[,-rc],nrow=nrow(Westim),ncol=ncol(Westim)-1)
			else
			Westim <- Westim[,-rc]
		}
		else
			mse[k] <- mean(Worig[,k]^2)
	}
	return(mean(mse))
}#EndFunction

#
# Reshaped Khatri Rao Product
#
khatriRao.Reshaped <- function(Z,U)
{
	if(ncol(U) != ncol(Z))
	{
	print("Columns not same. Khatri Rao Product not possible")
	return()
	}

	N = nrow(Z)
	L = nrow(U)
	K = ncol(U)
	KR <- array(0, dim=c(N,K,L))

	for(l in 1:L)
		KR[,,l] <- Z*matrix(rep(U[l,],N),N,K,byrow=TRUE)

	return(KR)
}

#
# Compute Rank/Active Factors/Cardinality of the model
#
computeRank <- function(model)
{

}#EndFunction


#
# Compute Log Likelihood
#
computeLL <- function(TestData)
{

}#EndFunction

################################################################################
########################## Matrix Normalization ################################
#
# Normalize the Data with Mean=0 
#
norm.mean <- function(mat)
{
  cmat = ncol(mat)
  rmat = nrow(mat)
  cm = apply(mat,2,"mean")
  mat = mat - matrix(cm,rmat,cmat,byrow=TRUE)
  return(mat)
}#EndFunction

#
# Normalize the Data with Mean=0 and Column-wise Variance = 1
#
norm.ztransform <- function(mat)
{
  cmat = ncol(mat)
  rmat = nrow(mat)
  cm = apply(mat,2,"mean")
  mat = mat - matrix(cm,rmat,cmat,byrow=TRUE)
  cv = apply(mat,2,"sd")
  mat = mat/matrix(cv,rmat,cmat,byrow=TRUE)
  cv = mean(apply(mat,2,"var"))
  #print(cv)
  return(mat)
}#EndFunction

#
# Normalize the Data with Mean=0 and Total Variance = 1
#
norm.total.var.mat <- function(mat)
{
  cmat = ncol(mat)
  rmat = nrow(mat)
  cm = apply(mat,2,"mean")
  mat = mat - matrix(cm,rmat,cmat,byrow=TRUE)
  cv = sum(mat^2)/(rmat*cmat) 
  cv = sqrt(cv) #Divide by SD not VAR. As divide and SD are linear operations while var is quadratic
  mat = mat/cv
  cv = sum(mat^2)/(rmat*cmat) #as sd=1, means var=1, cv should be 1.
#  print(cv)
  return(mat)
}#EndFunction

################################################################################
########################## Tensor Normalization ################################
#
# None
#
norm.none <- function(ten)
{
return(list(data=ten,pre=list()))
}

#
# Scale Total RMS
#
norm.total.rss <- function(ten)
{
	sten <- get.total.rss.tensor(ten)
	return(list(data = ten/sten, pre=sten))
}
norm.total.rms <- function(ten)
{
	sten <- get.total.rms.tensor(ten)
	return(list(data = ten/sten, pre=sten))
}
undo.norm.total.rms <- function(CP,sten)
{
	CP$W <- CP$W * sten
	return(CP)
}
norm.total.var <- function(ten)
{
	sten <- sqrt(get.total.var.tensor(ten))
	return(list(data = ten/sten, pre=sten))
}
undo.norm.total.var <- function(CP,sten)
{
	CP$W <- CP$W * sten
	return(CP)
}
#
# Scale Slab
#
norm.slabscaling <- function(ten,o=2)
{
	pten <- array(0,dim=dim(ten),dimnames=dimnames(ten))
	sten <- sqrt(get.slab.var(ten,o)) #get.slab.rms(ten,o)
	if(o == 1)
	for(j in 1:dim(ten)[1])
		pten[j,,] = ten[j,,]/sten[j]

		
	if(o == 2)
	for(j in 1:dim(ten)[2])
		pten[,j,] = ten[,j,]/sten[j]
	
	if(o == 3)
	for(j in 1:dim(ten)[3])
		pten[,,j] = ten[,,j]/sten[j]
	
	return(list(data = pten, pre=sten))
	
	# Slab Scaling
#	if(dim(pten)[3] == 1)
#	{
#		for(j in 1:dim(pten)[2])
#			pten[,j,] = pten[,j,] / sqrt(mean(pten[,j,]^2)) #sd(ten[,j,])
#	}
#	else
#	{
#		for(j in 1:dim(pten)[2])
#		{
#			##pten[,j,] = pten[,j,] / sqrt(mean(apply(ten[,j,],2,var)))
#			#pten[,j,] = pten[,j,] / mean(apply(ten[,j,],2,sd))
#			#print(paste("diff:",sqrt(mean(apply(ten[,j,],2,var))) - mean(apply(ten[,j,],2,sd))))
#			pten[,j,] = pten[,j,] / sqrt(mean(pten[,j,]^2))
#		}			
#	}	
}
undo.norm.slabscaling <- function(CP,sten,o)
{
	if(o == 1)
		CP$X <- CP$X * matrix(rep(sten,ncol(CP$X)),nrow(CP$X),ncol(CP$X),byrow=FALSE)

	if(o == 2)
		CP$W <- CP$W * matrix(rep(sten,ncol(CP$W)),nrow(CP$W),ncol(CP$W),byrow=FALSE)

	if(o == 3)
		CP$U <- CP$U * matrix(rep(sten,ncol(CP$U)),nrow(CP$U),ncol(CP$U),byrow=FALSE)
	
	return(CP)
}
mean.fun <- function(x)
{
	mean(x,na.rm=TRUE)
}

#
# Center Fibre
#
norm.fibrecentering <- function(ten,o)
{
	pten <- array(0,dim=dim(ten),dimnames=dimnames(ten))
	mten = 0
	if(o == 1)
	{
	mten = apply(ten,c(2,3),mean.fun)
	for(i in 1:dim(ten)[1])
		pten[i,,] = ten[i,,] - mten
	}
			
	if(o == 2)
	{
	mten = apply(ten,c(1,3),mean.fun)
	for(i in 1:dim(ten)[2])
		pten[,i,] = ten[,i,] - mten
	}
		
	if(o == 3)
	{
	mten = apply(ten,c(1,2),mean.fun)
	for(i in 1:dim(ten)[3])
		pten[,,i] = ten[,,i] - mten
	}
		
	return(list(data = pten, pre=mten))
}
undo.norm.fibrecentering <- function(CP,mten,o)
{
	##TODO: see page 233 for details
	# can be solved as system of equations - jkr *r = jk. need to find r
	return(CP)
}

#
# Center Slab
#
norm.slabcentring <- function(ten,o=2)
{
	pten <- array(0,dim=dim(ten),dimnames=dimnames(ten))
	mten <- get.slab.mean(ten,o)
	if(o == 1)
	for(j in 1:dim(ten)[1])
		pten[j,,] = ten[j,,] - mten[j]
		
	if(o == 2)
	for(j in 1:dim(ten)[2])
		pten[,j,] = ten[,j,] - mten[j]
	
	if(o == 3)
	for(j in 1:dim(ten)[3])
		pten[,,j] = ten[,,j] - mten[j]
	
	return(list(data = pten, pre=mten))
}
undo.norm.slabcentring <- function(CP,mten,o)
{
	return(CP)
}

norm.slab <- function(ten)
{
	#Slab Centring and Scaling	
	#Slab Centring should not work #Ref. Harshman Chapter 6 - Data preprocessing and the extended PARAFAC model, 1984, http://www.psychology.uwo.ca/faculty/harshman/lawch6.pdf Page 238
	pre <- list()

	o = 2	
	tmp <- norm.slabcentring(ten,o); pten <- tmp$data; pre$slabC1 <- tmp$pre; pre$slabC1id <- o;

	o = 2
	tmp <- norm.slabscaling(pten,o); pten <- tmp$data; pre$slabS1 <- tmp$pre; pre$slabS1id <- o;
		
	return(list(data = pten, pre=pre))
}
undo.norm.DC.SC <- function(CP,pre)
{
	CP <- undo.norm.slabscaling(CP,pre$slabS1,pre$slabS1id);
	CP <- undo.norm.slabcentring(CP,pre$slabC1,pre$slabC1id);	
	return(CP)
}

norm.singlecentring <- function(ten,o=1)
{
	pre <- list()
	tmp <- norm.fibrecentering(ten,o); pten <- tmp$data; pre$fibreC1 <- tmp$pre; pre$fibreC1id <- o;
	return(list(data=pten,pre=pre))
}

norm.triplecentring <- function(ten)
{
	# Triple Fibre Centring as described by #Ref. Harshman Chapter 6 - Data preprocessing and the extended PARAFAC model, 1984, http://www.psychology.uwo.ca/faculty/harshman/lawch6.pdf Page 234/235
	
	pre <- list()

	o = 1	
	tmp <- norm.fibrecentering(ten,o); pten <- tmp$data; pre$fibreC1 <- tmp$pre; pre$fibreC1id <- o;

	o = 2	
	tmp <- norm.fibrecentering(pten,o); pten <- tmp$data; pre$fibreC2 <- tmp$pre; pre$fibreC2id <- o;
			
	o = 3	
	tmp <- norm.fibrecentering(pten,o); pten <- tmp$data; pre$fibreC3 <- tmp$pre; pre$fibreC3id <- o;

	return(list(data=pten,pre=pre))
}

norm.doublecentring <- function(ten,o1=1,o2=3)
{
	# Double Fibre Centring as described by #Ref. Harshman Chapter 6 - Data preprocessing and the extended PARAFAC model, 1984, http://www.psychology.uwo.ca/faculty/harshman/lawch6.pdf Page 235/236
	
	pre <- list()

	o = o1
	tmp <- norm.fibrecentering(ten,o); pten <- tmp$data; pre$fibreC1 <- tmp$pre; pre$fibreC1id <- o;

	o = o2
	tmp <- norm.fibrecentering(pten,o); pten <- tmp$data; pre$fibreC2 <- tmp$pre; pre$fibreC2id <- o;

	return(list(data=pten,pre=pre))
}
undo.norm.doublecentring <- function(CP,pre)
{
	CP <- undo.norm.fibrecentering(CP,pre$fibreC2,pre$fibreC2id);	
	CP <- undo.norm.fibrecentering(CP,pre$fibreC1,pre$fibreC1id);
	return(CP)
}

norm.DC.SC <- function(ten)
{
	# Double Fibre Centring as described by #Ref. Harshman Chapter 6 - Data preprocessing and the extended PARAFAC model, 1984, http://www.psychology.uwo.ca/faculty/harshman/lawch6.pdf Page 235/236
	
	pre <- list()
	
	o = 1	
	tmp <- norm.fibrecentering(ten,o); pten <- tmp$data; pre$fibreC1 <- tmp$pre; pre$fibreC1id <- o;
		
	o = 3	
	tmp <- norm.fibrecentering(pten,o); pten <- tmp$data; pre$fibreC2 <- tmp$pre; pre$fibreC2id <- o;
	
	o = 2
	tmp <- norm.slabscaling(pten,o); pten <- tmp$data; pre$slabS1 <- tmp$pre; pre$slabS1id <- o;
	
	return(list(data=pten,pre=pre))
}
undo.norm.DC.SC <- function(CP,pre)
{
	CP <- undo.norm.slabscaling(CP,pre$slabS1,pre$slabS1id);
	CP <- undo.norm.fibrecentering(CP,pre$fibreC2,pre$fibreC2id);	
	CP <- undo.norm.fibrecentering(CP,pre$fibreC1,pre$fibreC1id);
	return(CP)
}


norm.tensor.K2009 <- function(ten)
{
	#Normalization as described in Preprocessing of:
	#http://www.phaenex.uwindsor.ca/ojs/leddy/index.php/AMR/article/viewFile/2833/2271
	#Single Fibre Centering followed by Slab scaling
	
	pre <- list()
	
	o = 1	
	tmp <- norm.fibrecentering(ten,o); pten <- tmp$data; pre$fibreC1 <- tmp$pre; pre$fibreC1id <- o;
	
	o = 2
	tmp <- norm.slabscaling(pten,o); pten <- tmp$data; pre$slabS1 <- tmp$pre; pre$slabS1id <- o;
	
	return(list(data=pten,pre=pre))
}
undo.norm.tensor.K2009 <- function(CP,pre)
{
	CP <- undo.norm.slabscaling(CP,pre$slabS1,pre$slabS1id);
	CP <- undo.norm.fibrecentering(CP,pre$fibreC1,pre$fibreC1id);
	return(CP)
}

#
# Adjusts only Z,U,W
# TODO: Does not works for VB Case
undo.Preprocessing <- function(model,Preprocessing)
{
	CP <- list()
	CP$X <- model$X
	CP$U <- model$U
	
	M <- length(model$W)
	for(m in 1:M)
	{
	CP$W <- model$W[[m]]

	if(Preprocessing$method == "norm.tensor.K2009")
	 CP <- undo.norm.tensor.K2009(CP,Preprocessing[[m]])
	if(Preprocessing$method == "norm.DC.SC")
	 CP <- undo.norm.DC.SC(CP,Preprocessing[[m]])
	if(Preprocessing$method == "norm.doublecentring")
	 CP <- undo.norm.doublecentring(CP,Preprocessing[[m]])	 	 
	if(Preprocessing$method == "norm.total.rms")
	 CP <- undo.norm.total.rms(CP,Preprocessing[[m]])	
	if(Preprocessing$method == "norm.slab")
	 CP <- undo.norm.slab(CP,Preprocessing[[m]])
	 
	model$W[[m]] <- CP$W
	}
	model$X <- CP$X
	model$U	<- CP$U
	return(model)
}

################################################################################
## functions for examining the data
#
# Get Total Variance in the Data
#
get.total.var <- function(mat)
{
  cmat = ncol(mat)
  rmat = nrow(mat)
  cm = apply(mat,2,mean.fun)
  mat = mat - matrix(cm,rmat,cmat,byrow=TRUE)
  cv = sum(mat^2,na.rm=TRUE)/(rmat*cmat)
  return(cv)
}

#
# Get Total Variance in the Data Tensor
#
get.total.var.tensor <- function(ten)
{
  cv <- var(as.vector(ten),na.rm=TRUE) #slow for large tensor!
  return(cv)
}
get.total.rss.tensor <- function(ten)
{
  rss <- sum(sqrt(ten^2),na.rm=TRUE)/(dim(ten)[1]*dim(ten)[3]) #divide by the shared dimensions to keep the sclae reasonable
  return(rss)
} 
rms <- function(ten)
{
sqrt(mean(ten^2,na.rm=TRUE))
}
get.total.rms.tensor <- function(ten)
{
  return(rms(ten))
} 
get.total.ss.tensor <- function(ten)
{
  ss <- sum(ten^2,na.rm=TRUE)
  return(ss)
} 
get.slab.var <- function(ten,d)
{
 v <- apply(ten,d,function(x) {var(as.vector(x),na.rm=TRUE)})
 return(v)
}

get.slab.rms <- function(ten,d)
{
 v <- apply(ten,d,function(x) {sqrt(mean(x^2,na.rm=TRUE))})
 return(v)
}

get.slab.mean <- function(ten,d)
{
 v <- apply(ten,d,mean.fun)
 return(v)
}

get.slab.range <- function(ten,d)
{
 v <- apply(ten,d,function(x) {max(x,na.rm=TRUE)-min(x,na.rm=TRUE)})
 return(v)
}

get.slab.max <- function(ten,d)
{
 v <- apply(ten,d,function(x) {max(x,na.rm=TRUE)})
 return(v)
}

getWrescales <- function(ten)
{
	sc <- rep(1,dim=dim(ten)[2])
	for(j in 1:dim(ten)[2])
		sc[j] <- sqrt(mean((ten[,j,])^2,na.rm=TRUE))
	return(sc)
}

#
# Estimate Projection Variable W for GFA - when run on Tensor Data
#
estimWforGFA <- function(model,data)
{
M <- data$M
L <- data$L
Westim <- list()
for(m in 1:M)
{
	Westim[[m]] <- 0
	for(l in 1:L)
	{
	ind = (m-1)*L + l
	Westim[[m]] = Westim[[m]] + model$W[[ind]]
	}
	Westim[[m]] <- Westim[[m]]/L
}
model$Westim <- Westim
return(model)
}

#
# Estimate Latent Variables U for GFA - when run on Tensor Data
#
estimUWforGFA <- function(model,data)
{
if(length(model$Westim) == 0)
	model <- estimWforGFA(model,data)

M <- data$M
L <- data$L
K <- model$K #ncol(model$Westim[[1]])
Uestim <- matrix(0,L,K)
for(l in 1:L)
{
	for(m in 1:M)
	{
	ind = (m-1)*L + l
	Uestim[l,] = Uestim[l,] + colSums(model$W[[ind]] - model$Westim[[m]])
	}
}
model$Uestim <- Uestim
return(model)
}

#
# Estimate Latent Variables for Test Data - Tensor GFA
#
estimZtensor <- function(Y,model)
{
	if(sum(is.na(Y[[1]])))
		print("Function estimZtensor does not handles missing values. see TensorGFA2.R for handling them")
	
	th = 1e-5#0
	N <- nrow(Y[[1]])
	K <- ncol(model$X)
	M <- length(Y)
	UU <- crossprod(model$U)
	L <- nrow(model$U)
	D <- unlist(lapply(model$W,nrow))
	id <- rep(1,K)
	
	covX <- diag(K)
	for(m in 1:M){
		covX <- covX + crossprod(model$W[[m]]*sqrt(model$tau[[m]][1,]))*UU
	}
	eSX <- eigen(covX)
	covX <- tcrossprod( eSX$vectors*outer(id,1/sqrt(eSX$values)) )
	X <- 0

	for(m in 1:M){
	for(l in 1:L){
		X <- X + ((Y[[m]][,,l]*matrix(rep(model$tau[[m]][1,],N),N,D[m],byrow=TRUE)) %*% (model$W[[m]]*matrix(rep(model$U[l,],D[m]),D[m],ncol(model$U),byrow=TRUE)))
	}
	}

	X <- X%*%covX
	X <-  X + th*matrix(rnorm(N*K),N,K)%*%t( eSX$vectors*outer(id,1/sqrt(eSX$values)) )

	return(list(X=X,covX=covX))
}#EndFunction


#
# Estimate Latent Variables for Test Data - GFA
#
estimZgfa <- function(Y,model) {
th = 0#0.01
N <- nrow(Y[[1]])
K <- ncol(model$X)
M <- length(Y)

# Estimate the covariance of the latent variables
    covZ <- diag(K)
    for(m in 1:M){
      covZ <- covZ + crossprod(model$W[[m]]*sqrt(model$tau[[m]]))
    }
    eS <- eigen(covZ) #symmetric=TRUE
    covZ <- tcrossprod( eS$vectors*outer(rep(1,K),1/sqrt(eS$values)) )
#Estimate the latent variables    
    X <- 0
    for(m in 1:M){
      X <- X + Y[[m]]%*%(model$W[[m]]*model$tau[[m]])
    }
    X <- X%*%covZ
    
  # Add a tiny amount of noise on top of the latent variables,
  # to supress possible artificial structure in components that
  # have effectively been turned off
  X <- X + th*matrix(rnorm(N*K),N,K)%*%t( eS$vectors*outer(rep(1,K),1/sqrt(eS$values)) )

  return(list(X=X,covX=covZ))
}

#
#trim and align columns of Rec with those of Orig
#
trimAndAlignColumns <- function(Rec, Orig)
{
	if(ncol(Orig) > ncol(Rec))
	{
	 print("Mat 2 is larger than mat 1. ERROR, RETURNING")
	 return()
	}
	cc = cor(Orig,Rec)
	od = apply(abs(cc),1,which.max)
	Rec = Rec[,od]
	
	cc = cor(Orig,Rec)
	for(i in 1:ncol(Orig))
	{
		if(cc[i,i] < 0) #switch sign if the correlations are opposite
			Rec[,i] = -Rec[,i]
	}
	return(Rec)
}
################################################################################
############################ Printing Functions ################################
################################################################################

#
# Print Activity of a Single Model Run
#
print.activity <- function(model,text=NULL)
{
	print(rowSums(model$Z))
	print(model$Z)
	print(tail(model$cost,1))
}#EndFunction


#
# Print Activity of Multiple Model Runs
#
print.activity.multirun <- function(res,mse)
{
	for(i in 1:length(res)) {
	  print(get.print.string(i,res,mse))
  	}
}#EndFunction

get.print.string <- function(i,res,mse=NULL)
{
  if(length(mse)==0)
  	mse = rep(NA,i)
  z2 = sum(colSums(res[[i]]$Z) == nrow(res[[i]]$Z))
  z1 = sum((colSums(res[[i]]$Z) < nrow(res[[i]]$Z)) & (colSums(res[[i]]$Z) > 0))
  cost = round(tail(res[[i]]$cost,1))
  main = paste(format(i,width=2),"   SS:",z2,",",z1,"   C:",format(cost,width=5),"   MSE:",round(mse[i],3),sep="")
  #print(paste("Tau:",1/res[[i]]$tau[[1]]))
  #print(paste("U:",range(res[[i]]$U)))
  #print(paste("X:",range(res[[i]]$X)))
  #print(paste("W:",range(res[[i]]$W[[1]])))  
  return(main)
}

# Print Activity for VB
# 
print.activity.multirun.VB <- function(res,mse=NULL)
{
	for(i in 1:length(res)) {
	  print(get.print.string.VB(i,res,mse))
  	}
}#EndFunction

#
# get String of Activity for VB
# 
get.print.string.VB <- function(i,res,mse=NULL)
{
  if(length(mse)==0)
  	mse = rep(NA,i)
  active <- round(1/res[[i]]$alpha > 1e-3)
  z2 = sum(colSums(active) == nrow(active))
  z1 = sum((colSums(active) < nrow(active)) & (colSums(active) > 0))
  cost = round(tail(res[[i]]$cost,1))
  main = paste(format(i,width=2),"   SS:",z2,",",z1,"   C:",format(cost,width=5),"   MSE:",round(mse[i],3),sep="")
  return(main)
}#EndFunction

################################################################################
################ Count Component Activities - Use these functions or derivatives of these allways
getActive <- function(model,type)
{
	if(type=="GB")
		active = model$Z		
	else
		active = round(1/model$alpha > 1e-3)
	return(active)
}
countTotalComponents.Type <- function(model,type="GB")
{
	active <- getActive(model,type)
	return(sum(active))
}
countSharedComponents.Type <- function(model,type="GB")
{
	active <- getActive(model,type)
	if(nrow(active)==1) return(0)
	return(sum(colSums(active) == nrow(active)))
}
countSpecificComponents.Type <- function(model,type="GB")
{
	active <- getActive(model,type)
	return(sum(colSums(active) ==1))
}
countCombinationComponents.Type <- function(model,type="GB")
{
	active <- getActive(model,type)
	return(sum(colSums(active) < nrow(active) & colSums(active) > 1))
}
countTotalComponents.list.Type <- function(res,type="GB")
{
	vec <- unlist(lapply(res,function(x) { countTotalComponents.Type(x,type) }))
	return(vec)
}
countSharedComponents.list.Type <- function(res,type="GB")
{
	vec <- unlist(lapply(res,function(x) { countSharedComponents.Type(x,type) }))
	return(vec)
}
countSpecificComponents.list.Type <- function(res,type="GB")
{
	vec <- unlist(lapply(res,function(x) { countSpecificComponents.Type(x,type) }))
	return(vec)
}
countCombinationComponents.list.Type <- function(res,type="GB")
{
	vec <- unlist(lapply(res,function(x) { countCombinationComponents.Type(x,type) }))
	return(vec)
}
################################################################################
make.matrix <- function(model,kk="all")
{
	K = ncol(model$X)
	if(kk=="all")
		kk = 1:K

	M <- length(model$W)
	Yestim <- list()

	for(m in 1:M)
	{
	Yestim[[m]] <- 0
	for(k in kk)
		Yestim[[m]] <- Yestim[[m]] + outer(model$X[,k],model$W[[m]][,k])
	}
	
	return(Yestim)	
}
make.tensor <- function(model,kk="all")
{
	K = ncol(model$U)
	if(kk=="all")
		kk = 1:K

	M <- length(model$W)
	Yestim <- list()
	## Can be speeded up - see proof tensor
	for(m in 1:M)
	{
	Yestim[[m]] <- 0
	for(k in kk)
		Yestim[[m]] <- Yestim[[m]] + outer(outer(model$X[,k],model$W[[m]][,k]),model$U[,k])
	}
	
	return(Yestim)	
}
remake.tensor <- function(model,kk="all")
{
	return(make.tensor(model,kk))
}

#
# Remake the tensor for each posterior sample and then take the expected value. Uses all posterior samples.
#
remake.tensor.samples.EV <- function(model)
{
	samples = length(model$posterior$X)
	Rec <- list()
	for(s in 1:samples)
	{
		md = getPosteriorSample(model$posterior,s)
		if(s == 1)
			Rec = remake.tensor(md)
		else
		{
			RecSamp = remake.tensor(md)
			for(m in 1:length(Rec))
				Rec[[m]] = Rec[[m]] + RecSamp[[m]]
		}		
		if((s %% 100) == 0)
			cat(".",append=TRUE)
	}
	for(m in 1:length(Rec))
		Rec[[m]] = Rec[[m]]/samples
	return(Rec)
}

remake.matrix.samples.EV <- function(model)
{
	samples = length(model$posterior$X)
	Rec <- list()
	for(s in 1:samples)
	{
		md = getPosteriorSampleMatrix(model$posterior,s)
		if(s == 1)
			Rec = make.matrix(md)
		else
		{
			RecSamp = make.matrix(md)
			for(m in 1:length(Rec))
				Rec[[m]] = Rec[[m]] + RecSamp[[m]]
		}		
		if((s %% 100) == 0)
			cat(".",append=TRUE)
	}
	for(m in 1:length(Rec))
		Rec[[m]] = Rec[[m]]/samples
	return(Rec)
}

make.tensor.withoutK <- function(model,kk)
{
	M <- length(model$W)
	K = ncol(model$U)
	Yestim <- list()
	## Can be speeded up - see proof tensor
	for(m in 1:M)
	{
	Yestim[[m]] <- 0
	for(k in 1:K)
		if(k != kk)
		Yestim[[m]] <- Yestim[[m]] + outer(outer(model$X[,k],model$W[[m]][,k]),model$U[,k])
	}

	return(Yestim)
}
remake.tensor.withoutK <- function(model,kk)
{
return(make.tensor.withoutK(model,kk))
}

##################################################################################
############################### Posterior Functions ##############################
listmean <- function(ll)
{
##TODO: should be based on component matching in permutations and sign!!
	mat <- ll[[1]]
	samples <- length(ll)
	if(samples>1)
	for(i in 2:samples)
	{		
		mat <- mat + ll[[i]]
	}
	mat <- mat/samples
	return(mat)
}

listmeanW <- function(ll)
{
##TODO: should be based on component matching in permutations and sign!!
	samples <- length(ll)
	M <- length(ll[[1]])
	matL <- list()
	
	for(m in 1:M)
		matL[[m]] <- ll[[1]][[m]]
		
	if(samples>1)
	for(i in 2:samples)
		for(m in 1:M)
			matL[[m]] <- matL[[m]] + ll[[i]][[m]]
	for(m in 1:M)			
		matL[[m]] <- matL[[m]]/samples
	return(matL)
}

getPosteriorMean <- function(model)
{
	p <- model$posterior
	if(length(p)<1)
		return(model)
	rs <- list()
	rs$Z <- round(listmean(p$Z),0)
	rs$U <- listmean(p$U)
	rs$X <- listmean(p$X)
	rs$W <- listmeanW(p$W)
	rs$tau <- listmeanW(p$tau)
	rs$cost <- mean(p$cost)
	model$rs.mean <- rs	
	model$rs.mean <- trimModel(rs)	
	return(model)
}

getPosteriorStd <- function(model,rs.mean=NULL)
{
	if(length(rs.mean)==0)
	{
		print("Getting Mean")
		model = getPosteriorMean(model)
		rs.mean = model$rs.mean
	}
	
	p <- model$posterior
	if(length(p)<1)
		return(model)
	rs <- list()
	rs$Z <- round(sqrt(listmean(lapply(p$Z,function(x) { (x - rs.mean$Z)^2 } ))),0)	
	rs$U <- sqrt(listmean( lapply(p$U,function(x) { (x - rs.mean$U)^2 } ) ))
	rs$X <- sqrt(listmean( lapply(p$X,function(x) { (x - rs.mean$X)^2 } ) ))	
	# UA = array(0,dim=c(20,6,10000))
	# UA[,,1] = p$U[[1]]
	# for(i in 2:10000)
		# UA[,,i] = p$U[[i]]
		
	#rs$W <- listmeanW(p$W)
	#rs$tau <- listmeanW(p$tau)
	rs$tau <- list(); rs$W <- list()
	for(m in 1:M)
	{
		rs$W[[m]] <- sqrt(listmean( lapply(p$W,function(x) { (x[[m]] - rs.mean$W[[m]])^2 } ) ))
		rs$tau[[m]] <- sqrt(listmean( lapply(p$tau,function(x) { (x[[m]] - rs.mean$tau[[m]])^2 } ) ))
	}
	rs$cost <- sd(p$cost)
	model$rs.std <- rs
	#model$rs.std <- trimModel(rs)
	return(model)
}
###############################################################################
## Consensus posterior methods NOT TO BE USED BEFORE CONFIRMATION OF CORRECTNESS
listmeanConsensus <- function(ll,smp)
{
	mat <- ll[[smp[1]]]	
	ct = 1
	if(length(smp)>1)	
	for(i in 2:length(smp))
	{	
		tmp = ll[[smp[i]]]
		if(sum(apply(cor(mat,tmp),1,max) > 0.7) == ncol(mat))
		{
			mat <- mat + ll[[smp[i]]][,apply(cor(mat,tmp),1,which.max)]
			ct = ct + 1
		}
	}
	mat <- mat/ct
	print(ct)
	return(mat)
}

listmeanWConsensus <- function(ll,smp)
{
	M <- length(ll[[ smp[1] ]])
	matL <- list()
	
	for(m in 1:M)
		matL[[m]] <- ll[[ smp[1] ]][[m]]
		
	ct = 1
	if(length(smp)>1)
	for(i in 2:length(smp))
	{
		for(m in 1:M)
		{
			tmp = ll[[ smp[i] ]][[m]]
			#if(sum(apply(cor(matL[[m]],tmp),1,max) > 0.9) == ncol(mat))
			matL[[m]] <- matL[[m]] + tmp
		}
		ct = ct + 1
	}	
	for(m in 1:M)			
		matL[[m]] <- matL[[m]]/ct
		
	print(ct)
	return(matL)
}

getPosteriorMeanConsensus <- function(model)
{
	p <- model$posterior
	if(length(p)<1)
		return(model)
	rs <- list()
	rs$Z <- round(listmean(p$Z),0)

	Zsize = prod(dim(rs$Z))
	smp = which(unlist(lapply(p$Z,function(x) { return(sum(x == rs$Z) == Zsize) } ))) #consensus samples	
	rs$U <- listmeanConsensus(p$U,smp)
	rs$X <- listmeanConsensus(p$X,smp)
	rs$W <- listmeanWConsensus(p$W,smp)
	rs$tau <- listmeanWConsensus(p$tau,smp)
	rs$cost <- mean(p$cost)
	model$rs.mean <- rs
	model$rs.mean <- trimModel(rs)	
	return(model)
}

getPosteriorStdConsensus <- function(model,rs.mean=NULL)
{
	if(length(rs.mean)==0)
	{
		print("Getting Mean")
		model = getPosteriorMean(model)
		rs.mean = model$rs.mean
	}
	
	p <- model$posterior
	if(length(p)<1)
		return(model)
		
	Zsize = prod(dim(rs.mean$Z))
	smp = which(unlist(lapply(p$Z,function(x) { return(sum(x == rs.mean$Z) == Zsize) } ))) #consensus samples	
	
	rs <- list()
	rs$Z <- round(sqrt(listmeanConsensus(lapply(p$Z,function(x) { (x - rs.mean$Z)^2 } ),smp)),0)	
	rs$U <- sqrt(listmeanConsensus( lapply(p$U,function(x) { (x - rs.mean$U)^2 } ),smp ))
	rs$X <- sqrt(listmeanConsensus( lapply(p$X,function(x) { (x - rs.mean$X)^2 } ),smp ))	
	#rs$W <- listmeanW(p$W)
	#rs$tau <- listmeanW(p$tau)
	rs$tau <- list(); rs$W <- list()
	for(m in 1:M)
	{
		rs$W[[m]] <- sqrt(listmeanConsensus(lapply(p$W,function(x) { (x[[m]] - rs.mean$W[[m]])^2 } ),smp))
		rs$tau[[m]] <- sqrt(listmeanConsensus(lapply(p$tau,function(x) { (x[[m]] - rs.mean$tau[[m]])^2 } ),smp))
	}
	rs$cost <- sd(p$cost)
	model$rs.std <- rs
	#model$rs.std <- trimModel(rs)
	return(model)
}
###############################################################################
trimModel <- function(model)
{
	M = length(model$W)
	deadK = which(colSums(model$Z) == 0)
	model$U[,deadK] = 0
	model$X[,deadK] = 0
	for(m in 1:M)
		model$W[[m]][,deadK] = 0
	return(model)
}

trimModelGFA <- function(model)
{
	M = length(model$W)
	deadK = which(colSums(model$Z) == 0)
	model$X[,deadK] = 0
	for(m in 1:M)
		model$W[[m]][,deadK] = 0
	return(model)
}

getPosteriorMeanGFA <- function(model)
{
	p <- model$posterior
	if(length(p)<1)
		return(model)
	rs <- list()
	rs$Z <- round(listmean(p$Z),0)
	rs$X <- listmean(p$X)
	rs$W <- listmeanW(p$W)
	rs$tau <- listmeanW(p$tau)
	rs$cost <- mean(p$cost)	
	model$rs.mean <- rs
	#model$rs.mean <- trimModelGFA(rs)
	return(model)
}

alignTwoRuns <- function(model1, model2)
{
	od <- permuteTwoRuns(model1, model2)
	#work on scale as well
}

permuteTwoRuns <- function(model1, model2)
{
	M <- length(model1$W)
	K <- ncol(model1$W[[1]])
	corWl <- list()
	corX <- cor(model1$X,model2$X)
	corU <- cor(model1$U,model2$U)
	for(m in 1:M)
		corWl[[m]] <- cor(model1$W[[m]],model2$W[[m]])
	
	corW <- matrix(0,nrows(corWl[[1]]),ncols(corWl[[1]]))
	for(ki in 1:K)
	 for(kj in 1:K)
	 {
		ct = 0
		for(m in 1:M)
			if(!is.na(corWl[[m]][ki,kj]))
			{
				corW[ki,kj] = corW[ki,kj] + corWl[[m]][ki,kj]
				ct = ct + 1
			}
		corW[ki,kj] = corW[ki,kj]/ct
	 }
	 
	corT <- abs(corX) + abs(corU) + abs(corW)
	corT_temp <- cbind(corT,rep(-Inf,nrow(corT)))
	od = 0
	for(i in 1:nrow(corT))
	{
		od[i] = which.max(corT_temp[i,])
		corT_temp <- corT_temp[,-od[i]]
	}
	return(od) # return order of second with respect to first. 
}

getPosteriorSample <- function(post,s)
{
	samp <- list()
	samp$X <- post$X[[s]]
	samp$U <- post$U[[s]]
	samp$W <- post$W[[s]]
	samp$tau <- post$tau[[s]]
	samp$alpha <- post$alpha[[s]]
	samp$alphaU <- post$alphaU[[s]]
	samp$Z <- post$Z[[s]]
	return(samp)
}

getPosteriorSampleMatrix <- function(post,s)
{
	samp <- list()
	samp$X <- post$X[[s]]
	samp$W <- post$W[[s]]
	samp$tau <- post$tau[[s]]
	samp$alpha <- post$alpha[[s]]	
	samp$Z <- post$Z[[s]]
	return(samp)
}