#
# Elementwise Sparse Group Factor Analysis
# Suleiman Ali Khan, suleiman.khan@aalto.fi
# Seppo Virtanen, seppo.j.virtanen@aalto.fi
# v 0.1, December 2013
#

#
# Copyright 2013 Seppo Virtanen and Suleiman Ali Khan. All rights reserved.
# The software is licensed under the FreeBSD license;
#

#
# Method:
# Y[[m]] \approx XW[[m]]^T, m=1,..,M (M views in total)
# X, latent variables
# W, the projections (list)
# Z, M by K binary matrix indicating group activity
# alpha, list of same size of W, element-wise factored prior for projections
#


gfa <- function(Y,K,opts){
  #
  # Main function for Bayesian group factor analysis
  #
  # Inputs:
  #   Y    : List of M data matrices. Y[[m]] is a matrix with
  #          N rows (samples) and D_m columns (features). The
  #          samples need to be co-occurring.
  #          NOTE: All of these should be centered, so that the mean
  #                of each feature is zero
  #   K    : The number of components
  #   opts : List of options (see function getDefaultOpts())
  #
  # Output:
  # The trained model, which is a list that contains the following elements:
  #
  #   posterior: Gibbs samples from the posterior taken post burnin.
  #
  #   SS   : A single sample from the posterior containing the following:
  #    Z    : The binary activation matrix; K times M matrix
  #    X    : The latent variables; N times K matrix
  #    W    : List of the mean projections; D_i times K matrices
  #    tau  : The mean precisions (inverse variance, so 1/tau gives the
  #          variances denoted by sigma in the paper); M-element vector
  #    alpha: The mean precisions of the projection weights, the
  #          variances of the ARD prior; M times K matrix
  #
  #   cost : Vector collecting the log likelihood for each
  #          iteration


  #
  # Store dimensionalities of data sets 
  #
  M <- length(Y)
  D <- unlist(lapply(Y,ncol))
  N <- nrow(Y[[1]])
  
  tau_tmp <- lapply(Y,function(x){colSums(x^2)})
  const <- - N*sum(D)*log(2*pi)/2
  id <- rep(1,K)  
  X <- matrix(rnorm(N*K),N,K)
  covX <- diag(K)
  XX <- crossprod(X)
  XXdiag <- rep(1,K)
  
  # Some constants for speeding up the computation and initialization
  alpha <- vector("list")
  for(m in 1:M)
    alpha[[m]] <- matrix(K,D[m],K)
  alpha_0 <- opts$prior.alpha_0 # ARD hyperparameters
  beta_0 <- opts$prior.beta_0 
  Z <- matrix(1,M,K) # binary mask
  alpha_0t <- opts$prior.alpha_0t
  beta_0t <- opts$prior.beta_0t
    
  tau <- vector("list")
  b_tau <- vector("list")
  for(m in 1:M)
  {
    tau[[m]] <- rep(opts$init.tau,D[m])
    b_tau[[m]] <- rep(1,D[m])
  }
  a_tau <- N/2
  SS <- list()
  
  W <- vector("list",length=M)
  for(m in 1:M)
    W[[m]] <- matrix(0,D[m],K)
  
  maxiter <- opts$iter.max
  samples <- opts$iter.sampling
  thining <- opts$iter.thining
  
  # the component numbers and loglikelihood
  cost <- Zactive <- rep(0,maxiter+samples)  
  
  # Prior prabiliry for Z_{m,k}=1.
  # Preference to choose spike in case of tie
  r <- rep(0.5,K) #non-informative initialization
  
  posterior <- list()
  if(samples>0)
  {
	  posterior$W <- list(); length(posterior$W) <- ceiling(samples/thining)
	  posterior$X <- list(); length(posterior$X) <- ceiling(samples/thining)
	  posterior$Z <- list(); length(posterior$Z) <- ceiling(samples/thining)
	  posterior$cost <- rep(0,ceiling(samples/thining))
	  posterior$alpha <- list(); length(posterior$alpha) <- ceiling(samples/thining) 
 	  posterior$tau <- list(); length(posterior$tau) <- ceiling(samples/thining)	  
  }
  
  #
  # The main loop
  #
  for(iter in 1:(maxiter+samples)){
    print(iter)
    #
    # sample Z and W
    #
    if(iter > 1){
      XXdiag <- diag(XX)
      diag(XX) <- 0
      for(m in 1:M){
        for(k in 1:K){
          lambda <- tau[[m]]*XXdiag[k] + alpha[[m]][,k]
          mu <- tau[[m]]/lambda*(crossprod(Y[[m]],X[,k]) - W[[m]]%*%XX[k,])
          logpr <- 0.5*( sum(log(alpha[[m]][,k]) - log(lambda)) + crossprod(mu*sqrt(lambda)) ) + log(r[k]) - log(1-r[k])
					
          zone <- 1/(1+exp(-logpr))
	  if(iter > 500){
          Z[m,k] <- as.double((runif(1) < zone))
	  }
          if(Z[m,k]==1){
            W[[m]][,k] <- mu + 1/sqrt(lambda)*rnorm(D[m])
          }else{
            W[[m]][,k] <- 0
          }
        }
      }
	zm <- colSums(Z)
	for(k in 1:K)
		r[k] <- rbeta(1,shape1=zm[k]+1,shape2=M-zm[k]+1)
    }else{
      for(m in 1:M){
        covW <- diag(1,K) + opts$init.tau*XX
        eS <- eigen(covW)
        covW <- tcrossprod(eS$vectors*outer(id,1/eS$values),eS$vectors)
        W[[m]] <- crossprod(Y[[m]],X)%*%covW*opts$init.tau + matrix(rnorm(D[m]*K),D[m],K)%*%t( eS$vectors*outer(id,1/sqrt(eS$values)) )
      }
    }
    #
    # sample X (latent variables)
    #
    covZ <- diag(K) 
    for(m in 1:M){
      covZ <- covZ + crossprod(W[[m]]*sqrt(tau[[m]]))
    }
    eS <- eigen(covZ)
    covZ <- tcrossprod( eS$vectors*outer(id,1/sqrt(eS$values)) )
    X[] <- 0
    for(m in 1:M){
      X <- X + Y[[m]]%*%(W[[m]]*tau[[m]])
    }
    X <- X%*%covZ + matrix(rnorm(N*K),N,K)%*%t( eS$vectors*outer(id,1/sqrt(eS$values)) )
	R <- colMeans(X^2)
	X <- X*outer(rep(1,N),1/R)
	for(m in 1:M)
		W[[m]] <- W[[m]]*outer(rep(1,D[m]),R)

	XX <- crossprod(X)

	# sample alpha
	for(m in 1:M){
	  alpha[[m]] <- W[[m]]^2/2
	  keep <- which(alpha[[m]]!=0)
	  for(n in keep) alpha[[m]][n] <- rgamma(1,shape=alpha_0[m]+0.5,rate=beta_0[m]+alpha[[m]][n]) + 1e-7
	  alpha[[m]][-keep] <- rgamma( length(alpha[[m]][-keep]),shape=alpha_0[m],rate=beta_0[m]) + 1e-7
	}

    # sample noise
    for(m in 1:M){
      b_tau[[m]] <- ( tau_tmp[[m]] - 2*rowSums( W[[m]]*crossprod(Y[[m]],X) ) + 
                      rowSums((W[[m]]%*%XX)*W[[m]]) )/2
      for(d in 1:D[m])
        	tau[[m]][d] <- rgamma(1,shape= alpha_0t[m]+ a_tau,rate= beta_0t[m] + b_tau[[m]][d])
    }
    
    # calculate log likelihood
    cost[iter] <- const
    for(m in 1:M)
    cost[iter] <- cost[iter] + N*0.5*sum(log(tau[[m]])) - sum(b_tau[[m]]*tau[[m]])
    
    print(rowSums(Z))
    print(cost[iter])
    
    if(iter > maxiter)
    {
    	siter <- iter - maxiter
    	if((siter %% thining) == 0)
    	{
    		ind = siter/thining
		posterior$W[[ind]] <- W
		posterior$X[[ind]] <- X
		posterior$Z[[ind]] <- Z
		posterior$cost[ind] <- cost[iter]
		posterior$alpha[[ind]] <- alpha
		posterior$tau[[ind]] <- tau
	}
    }
    
    if(iter == maxiter)
    {
      SS$X = X
      SS$W = W
      SS$Z = Z  
      SS$alpha = alpha
      SS$tau = tau
      SS$cost = cost[1:maxiter]
      SS$Zactive = Zactive[1:maxiter]
    }
    Zactive[iter] <- sum(Z)
  }
  
  return(list(cost=cost,SS=SS,posterior=posterior))
}
gfa <- cmpfun(gfa)

getDefaultOpts <- function(Y){
  #
  # A function for generating a default set of parameters.
  #
  # To run the algorithm with other values:
  #   opts <- getDefaultOpts()
  #   opts$opt.method <- "BFGS"
  #   model <- gsCCA(Y,K,opts)
  
  #
  # Whether to use the rotation explained in the ICML'11 paper.
  # Using the rotation is strongly recommended, only turn this
  # off if it causes problems. Only in VB case.
  #  - TRUE|FALSE
  #
  rotate <- TRUE
  
  #
  # Parameters for controlling how the rotation is solved
  #  - opt.method chooses the optimization method and
  #    takes values "BFGS" or "L-BFGS". The former
  #    is typically faster but takes more memory, so the latter
  #    is the default choice. For small K may use BFGS instead.
  #  - opt.iter is the maximum number of iterations
  #  - lbfgs.factr is convergence criterion for L-BFGS; smaller
  #    values increase the accuracy (10^7 or 10^10 could be tried
  #    to speed things up)
  #  - bfgs.crit is convergence criterion for BFGS; smaller
  #    values increase the accuracy (10^-7 or 10^-3 could also be used)
  #
  opt.method <- "L-BFGS"
  opt.iter <- 10^5
  lbfgs.factr <- 10^3
  bfgs.crit <- 10^-5
  
  #
  # Initial value for the noise precisions. Should be large enough
  # so that the real structure is modeled with components
  # instead of the noise parameters (see Luttinen&Ilin, 2010)
  #  Values: Positive numbers, but generally should use values well
  #          above 1
  #
  init.tau <- 10^3
  
  #
  # Parameters for controlling when the algorithm stops.
  # It stops when the relative difference in the lower bound
  # falls below iter.crit or iter.max iterations have been performed.
  #
  iter.crit <- 10^-6
  iter.max <- 5e3
  
  #
  # Additive noise level for latent variables. The latent variables
  # of inactive components (those with very large alpha) occasionally
  # show some structure in the mean values, even though the distribution
  # matches very accurately the prior N(0,I). This structure disappears
  # is a tiny amount of random noise is added on top of the
  # mean estimates. Setting the value to 0 will make the predictions
  # deterministic. Only in VB case.
  #
  addednoise <- 1e-5
  
  #
  # Hyperparameters
  # - alpha_0, beta_0 for the ARD precisions
  # - alpha_0t, beta_0t for the residual noise predicions
  #
  prior.alpha_0 <- prior.beta_0 <- rep(1e-3,length(Y))
  prior.alpha_0t <- prior.beta_0t <- rep(1,length(Y))
  
  #
  # Verbosity level
  #  0: Nothing
  #  1: Final cost function value for each run of gsCCAexperiment()
  #  2: Cost function values for each iteration
  #
  verbose <- 2
  
  iter.sampling <- 1000
  iter.thining <- 5 
    
  return(list(rotate=rotate, init.tau=init.tau, iter.crit=iter.crit,
              iter.max=iter.max, iter.sampling=iter.sampling, iter.thining=iter.thining, opt.method=opt.method,
              lbfgs.factr=lbfgs.factr, bfgs.crit=bfgs.crit, opt.iter=opt.iter,
              addednoise=1e-6,
              prior.alpha_0=prior.alpha_0,prior.beta_0=prior.beta_0,
              prior.alpha_0t=prior.alpha_0t,prior.beta_0t=prior.beta_0t,
              verbose=verbose,iter.sampling=iter.sampling))
}

#
# Computes Expected Value of the Posterior
#
getPosteriorEV <- function(model)
{
	res = list()
	aC = 1:ncol(model$posterior$W[[1]][[1]])
	M = length(model$posterior$W[[1]])
	## Check for Label Switching
	posteriorShared <- list(); posteriorShared$Z <- list(); posteriorShared$W <- list(); posteriorShared$cost <- list(); posteriorShared$X <- list(); posteriorShared$labels <- list();
	posteriorShared$label.Inds <- list()
	ZFull = 0
	for(i in 1:length(model$posterior$Z))
	{
		Z = model$posterior$Z[[i]]
		ZFull = ZFull + Z
		Z = Z[,aC]
		W = list()

		for(m in 1:length(model$posterior$W[[i]]))
		{
			W[[m]] = model$posterior$W[[i]][[m]][,aC]
		}
		X = model$posterior$X[[i]][,aC]

		if(i ==1)
		{
			posteriorShared$label.Inds[[i]] = 1:length(aC)
		}
		else{
			cr = cor(posteriorShared$X[[i-1]],X)
			if(sum(diag(cr) < 0))
			{
				print(paste("Negative Self Correlations",i,sep=""))
				return(0)
			}
			mcr = apply(cr,1,max)
			w.mcr = apply(cr,1,function(x) {which.max(x)})
			if(length(w.mcr)!= length(unique(w.mcr)))
			{
				print(paste("Error Check Correlations",i,sep=""))
				return(0)
			}
			if(sum(w.mcr != 1:length(w.mcr)))
			{
				print(paste("Label Switching Occured - Check Manually",i,sep=""))
				return(0)
			}
			posteriorShared$label.Inds[[i]] = w.mcr #all ok
		}
		posteriorShared$labels[[i]] = aC[posteriorShared$label.Inds[[i]]]
		posteriorShared$Z[[i]] = Z[,posteriorShared$label.Inds[[i]]]
		posteriorShared$W[[i]] = list()
		for(m in 1:length(W))
		{
			posteriorShared$W[[i]][[m]] = W[[m]][,posteriorShared$label.Inds[[i]]]
		}
		posteriorShared$X[[i]] = X[,posteriorShared$label.Inds[[i]]]
		posteriorShared$cost[[i]] = model$posterior$cost[i]
	}
	ZFull = ZFull/length(model$posterior$Z)
	Zt = ZFull[,aC]
	colnames(Zt) = aC
	gc()

	samples = 1:length(model$posterior$W)
	Z = 0
	X = 0
	W = list()
	for(m in 1:M) W[[m]] <- 0
	for(s in samples)
	{
		Z = Z + posteriorShared$Z[[s]]
		X = X + posteriorShared$X[[s]]
		for(m in 1:M)
		W[[m]] = W[[m]] + posteriorShared$W[[s]][[m]]
	}
	Z = Z/length(samples)
	X = X/length(samples)
	for(m in 1:M)
		W[[m]] = W[[m]]/length(samples)
	colnames(Z) = aC

	res = list()
	res$Z = round(Z,0)
	res$X = X
	res$W = W

	return(res)
}
