#
# GFA with Gibbs sampling. Includes element-wise sparse version for biclustering (opts$spikeW="grouped")
# Authors: Seppo Virtanen, s.virtanen@warwick.ac.uk
#          Eemeli Leppaaho, eemeli.leppaaho@aalto.fi
#          Tommi Suvitaival, tommi.suvitaival@gmail.com

# (last modified by EL 1.10.2015)
#
# Description:
# Model is Y = XW^T + E, and model supports various kinds of sparsity
# (see getDefaultOpts()).

gibbsgfa <- function(Y, K, opts, W=NULL, rz=NULL, tau=NULL, beta=NULL, filename="") {
  
  # data: similarly to CCAGFA package, data sets as elements in list
  
  if(file.exists(filename)) {
    print("Part of the sampling done already.")
    load(filename)
    print(paste("Continuing from iteration",iter))
    ptm <- proc.time()[1] - time #Starting time - time used so far
    
  } else { #Else initialization
    ptm <- proc.time()[1]
    # Store dimensionalities of data sets.
    M <- length(Y)
    D <- sapply(Y,ncol)
    Ds <- c(1,cumsum(D)+1) ; De <- cumsum(D)
    gr <- vector("list")
    for(m in 1:M) {
      gr[[m]] <- Ds[m]:De[m]
      if (!is.null(colnames(Y[[m]]))) {
        names(gr[[m]]) <- colnames(Y[[m]])
      }
    }
    Y <- do.call(cbind,Y) #abind(Y,along=2)
    D <- ncol(Y)
    N <- nrow(Y)
    
    alpha_0 <- opts$prior.alpha_0
    beta_0 <- opts$prior.beta_0
    alpha_0t <- opts$prior.alpha_0t
    beta_0t <- opts$prior.beta_0t
    #Make the tau prior vector-valued, if it is not already
    if(length(alpha_0t)==1)
      alpha_0t <- rep(alpha_0t, D)
    if(length(beta_0t)==1)
      beta_0t <- rep(beta_0t, D)
    
    if (!is.null(W) & !is.null(rz) & !is.null(tau) & !is.null(beta)) { # Projection given as an argument.
      projection.fixed = TRUE
      K = ncol(W)
    } else {
      projection.fixed = FALSE
    }
    
    # Some constants for speeding up the computation
    id <- rep(1,K)              # Vector of ones for fast matrix calculations
    
    ##
    ## Initialize the model randomly
    ##
    # Other initializations could be done, but overdispersed random initialization is quite good.
    
    # Latent variables
    X <- matrix(rnorm(N*K,0,1),N,K)
    X <- scale(X)
    covZ <- diag(1,K) # Covariance
    XX <- crossprod(X) # Second moments
    
    if(!opts$normalLatents & projection.fixed) {
      if(opts$ARDLatent=="element")
        beta <- matrix(colMeans(beta),N,K,byrow=T) #Initialize as the mean over samples
    } else {
      beta <- matrix(1,N,K)
    }
    
    tmp <- matrix(0,K,K)
    
    i.pred <- 0
    if (opts$iter.saved>0 & opts$iter.burnin<opts$iter.max) {
      mod.saved <- ceiling((opts$iter.max-opts$iter.burnin)/opts$iter.saved) # Save every 'mod.saved'th Gibbs sample.
      S <- floor((opts$iter.max-opts$iter.burnin)/mod.saved) # Number of saved Gibbs samples.
      #     print(paste(mod.saved,S)
    }
    
    if (projection.fixed) { # Projection given as an argument to the sampler (fixed)
      # if normalLatents == FALSE, but if true?  
      obs = NULL # variable indices of the source views
      for (mi in 1:length(opts$prediction)) { # Go through all views.
        if (!opts$prediction[mi]) {
          obs = c(obs, gr[[mi]])
        }
      }
      covZ <- diag(1,K) + crossprod(W[obs,]*sqrt(tau[obs]))
      eS <- eigen(covZ)
      tmp <- eS$vectors*outer(id,1/sqrt(eS$values))
      covZ <- tcrossprod(tmp)
      if(!opts$normalLatents) {
        WtW <- crossprod(W[obs,]*sqrt(tau[obs]))
        WtWdiag <- diag(WtW)
        diag(WtW) <- 0
        if (is.null(dim(rz))) {
          if (length(rz)==K) {
            rz = matrix(data=rz, nrow=N, ncol=K, byrow=TRUE)
          } else {
            stop("rz not of required length")
          }
        }
      }
      if (any(opts$prediction)) {
        prediction = vector(mode="list", length=length(gr))
        Y.true = Y
        for (m in which(opts$prediction)) {
          prediction[[m]] = matrix(0,N,length(gr[[m]]))
          Y[,gr[[m]]] <- 0
        }
      }
      #Added things
      cost <- aic <- rep(NA,opts$iter.max)
      const <- 0
      for(m in which(opts$prediction==F))
        const <- const - N*length(gr[[m]])*log(2*pi)/2
      
    } else {
      if (any(opts$prediction)) {
        stop("Prediction without projections given")
      }
      WtW <- matrix(0,K,K)
      WtWdiag <- rep(0,K)
      
      const <- - N*D*log(2*pi)/2
      tau <- rep(opts$init.tau,D) # The mean noise precisions
      
      W <- matrix(0,D,K)
      Z <- matrix(1,D,K)
      covW <- diag(1,K)
      
      alpha <- matrix(1,D,K)
      
      cost <- aic <- rep(NA,opts$iter.max) # For storing the lower bounds
      lambda <- mu <- zone <- rep(1,D)
      r <- matrix(0.5,D,K)
      rz <- matrix(0.5,N,K)
    }
    
    ##Missing Values
    missingValues <- FALSE
    na.inds <- which(is.na(Y))
    if(length(na.inds)>0 & !projection.fixed) {
      missingValues <- TRUE
      if(opts$verbose>0)
        print("Missing Values Detected, Prediction using EM type approximation")
      #Update alpha_0t to take into account the observed data size
      a_tau <- alpha_0t + colSums(!is.na(Y))/2
      for(m in 1:M)
        alpha_0t[gr[[m]]] <- alpha_0t[gr[[m]]] + sum(!is.na(Y[,gr[[m]]]))/2
    } else {
      a_tau <- alpha_0t + N/2 #Elementwise noise
      for(m in 1:M)
        alpha_0t[gr[[m]]] <- alpha_0t[gr[[m]]] + N*length(gr[[m]])/2 #View-wise noise
    }
    for(m in 1:M)
      alpha_0t[m] <- alpha_0t[[gr[[m]][1]]]
    alpha_0t <- alpha_0t[1:M]
    #missingValues.InitIter 
    #0 - start prediction from mean of all values, and use updates of the model allways
    #between 2 and maxiter - use updates of the model in a single imputation setting for  iterations past the set value.
    #maxiter+1 - ignore missing values. never update with the model's single imputation. 
    if(!is.null(opts$iter.missing))
      missingValues.InitIter <- opts$iter.missing
    else
      missingValues.InitIter <- round(opts$iter.burnin/2)
    
    if(missingValues & opts$verbose>0 & !projection.fixed)
      print(paste("missingValues.InitIter=",missingValues.InitIter))
    ## Missing Values end
    
    #Which X and W posterior samples to save in case of convergenceCheck?
    if(opts$convergenceCheck & opts$iter.saved>=8 & !projection.fixed) {
      ps <- floor(opts$iter.saved/4)
      start <- 1:ps
      end <- (-ps+1):0+opts$iter.saved
    } else {
      start <- end <- c()
    }
    
    posterior <- NULL
    iter <- 1
  }
  
  ##
  ## The main loop
  ##
  
  for(iter in iter:opts$iter.max) {
    
    ## Sample the projections W.
    if (!projection.fixed) {
      if(iter == 1){
        covW <- diag(1,K) + opts$init.tau*XX
        eS <- eigen(covW)
        tmp <- eS$vectors*outer(id,1/sqrt(eS$values))
        covW <- tcrossprod(tmp)
        
        if(missingValues && (missingValues.InitIter < iter)) {
          ##Could give bad init when high missing values. therefore skip from here.
          Y[na.inds] <- mean(Y[-na.inds])
        }
        
        estimW = matrix(0,D,K) #equivalent of crossprod(Y,X)
        for(k in 1:K){
          if(missingValues && (missingValues.InitIter >= iter)) {
            tmpY <- Y
            tmpY[is.na(tmpY)] <- 0
            estimW[,k] <- crossprod(tmpY,X[,k])
          } else {
            estimW[,k] <- crossprod(Y,X[,k])
          }
        }
        W <- estimW%*%covW*opts$init.tau + matrix(rnorm(D*K),D,K)%*%tmp
        
      }else{
        XXdiag <- diag(XX)
        diag(XX) <- 0
        
        for(k in 1:K){
          lambda <- tau*XXdiag[k] + alpha[,k]
          if(missingValues && (missingValues.InitIter >= iter)) {
            ss <- tcrossprod(X[,-k],W[,-k])
            tmp <- Y-ss
            tmp[na.inds] <- 0
            mu_sub <- crossprod(tmp,X[,k])
            mu <- mu_sub*tau/lambda
          } else {
            mu <- tau/lambda*as.vector( crossprod(Y,X[,k]) - W%*%XX[k,])
          }
          if(iter > opts$sampleZ){
            if(opts$spikeW!="group"){
              zone <- 0.5*( log(alpha[,k]) - log(lambda) + lambda*mu^2) + log(r[,k]) - log(1-r[,k])
              zone <- 1/(1+exp(-zone))
              zone <- as.double(runif(D) < zone)
              Z[,k] <- zone
            } else {
              zone <- 0.5*(log(alpha[,k])-log(lambda) + lambda*mu^2)
              for(m in 1:M){
                logpr <- sum(zone[gr[[m]]]) + log(r[gr[[m]][1],k]) - log( 1-r[gr[[m]][1],k] )
                logpr <- 1/(1+exp(-logpr))
                Z[gr[[m]],k] <- as.double(runif(1)<logpr)
              }
            }
          }
          W[,k] <- mu + rnorm(D)*lambda^(-0.5)
        }
      }
      W <- W*Z
    }
    
    ## Check if components can be removed.
    # If Gibbs samples are going to be saved, components are dropped only during the burn-in.
    # If the projection has been given as an argument (e.g. for prediction), components will not be dropped.
    if ((iter<=opts$iter.burnin | opts$iter.saved==0) & !projection.fixed) {
      keep <- which(colSums(Z)!=0 & colSums(abs(X))>0)
      if(length(keep)==0) {
        print("All components shut down, stopping!")
        return(list())
      }
      if(length(keep)!=K){
        K <- length(keep)
        id <- rep(1,K)
        alpha <- alpha[,keep,drop=F]
        X <- X[,keep,drop=F]
        W <- W[,keep,drop=F]
        Z <- Z[,keep,drop=F]
        r <- r[,keep,drop=F]
        beta <- beta[,keep,drop=F]
        rz <- rz[,keep,drop=F]
      }
    }
    
    ## sample r
    
    if(iter>opts$sampleZ & !projection.fixed){
      if(opts$spikeW=="shared"){
        zm <- colSums(Z)
        for(k in 1:K)
          r[,k] <- (zm[k]+opts$prior.beta[1])/(D+sum(opts$prior.beta))
      }
      if(opts$spikeW=="grouped"){
        for(m in 1:M){
          zm <- colSums(Z[gr[[m]],,drop=FALSE])
          for(k in 1:K)
            r[gr[[m]],k] <- (zm[k]+opts$prior.beta[1])/(length(gr[[m]])+sum(opts$prior.beta)) #rbeta(1,shape1=zm[k]+1,shape2=length(gr[[m]])-zm[k]+1)
        }
      }
      if(opts$spikeW=="group"){
        zm <- rep(0,K)
        for(m in 1:M){
          zm <- zm + Z[gr[[m]][1],]
        }
        for(m in 1:M){
          for(k in 1:K){
            r[gr[[m]],k] <- (zm[k]+opts$prior.beta[1])/(M+sum(opts$prior.beta))
          }
        }
      }
    }
    
    ## 
    ## Update the latent variables
    ##
    
    # set Y to zero for predicted data. does not affect X.
    if(iter>opts$sampleZ & !opts$normalLatents) {
      if (!projection.fixed) {
        WtW <- crossprod(W*sqrt(tau))
        WtWdiag <- diag(WtW)
        diag(WtW) <- 0
      }
      for(k in 1:K){
        lambda <- WtWdiag[k] + beta[,k]
        if(missingValues && (missingValues.InitIter >= iter)) {
          tmpY <- Y
          tmpY[is.na(tmpY)] <- 0
          mu <- (tmpY%*%(W[,k]*tau) - X%*%WtW[k,])/lambda
        } else {
          mu <- (Y%*%(W[,k]*tau) - X%*%WtW[k,])/lambda
        }
        zone <- 0.5*( log(beta[,k]) - log(lambda) + lambda*mu^2) + log(rz[,k]) - log(1-rz[,k])
        zone <- 1/(1+exp(-zone))
        zone <- as.double(runif(N) < zone)
        X[,k] <- mu + rnorm(N)*lambda^(-0.5)
        X[,k] <- X[,k]*zone
        if (!projection.fixed) {
          zm <- sum(zone)
          rz[,k] <- (zm+opts$prior.betaX[1])/(N+sum(opts$prior.betaX)) #rbeta(1,shape1=zm+prior.betaX1,shape2=N-zm+prior.betaX2)
        }
      }
    } else {
      if (!projection.fixed) {
        covZ <- diag(1,K) + crossprod(W*sqrt(tau))
        eS <- eigen(covZ)
        tmp <- eS$vectors*outer(id,1/sqrt(eS$values))
        covZ <- tcrossprod(tmp)
      }
      if(missingValues && (missingValues.InitIter >= iter)) {
        tmpY <- Y
        tmpY[is.na(tmpY)] <- 0
        X <- tmpY%*%(W*tau)
      } else {
        X <- Y%*%(W*tau)
      }
      X <- X%*%covZ + matrix(rnorm(N*K),N,K)%*%tmp
    }
    XX <- crossprod(X)
    
    ##
    ## Update alpha and beta, the ARD parameters
    ##
    
    if(!opts$normalLatents){
      if(opts$ARDLatent=="element"){
        beta <- X^2/2 + opts$prior.beta_0X
        keep <- which(beta!=0)
        for(n in keep) beta[n] <- rgamma(1,shape=opts$prior.alpha_0X+0.5,rate=beta[n]) +1e-7
        beta[-keep] <- rgamma( length(beta[-keep]),shape=opts$prior.alpha_0X,rate=opts$prior.beta_0X) + 1e-7
      }
      if(opts$ARDLatent == "shared" & !projection.fixed){
        xx <- colSums(X^2)/2 + opts$prior.beta_0X
        tmpxx <- colSums(X!=0)/2 + opts$prior.alpha_0X
        for(k in 1:K)
          beta[,k] <- rgamma(1,shape=tmpxx[k],rate=xx[k]) + 1e-7
      }
    }
    if (!projection.fixed) {
      if(opts$ARDW=="shared"){
        ww <- colSums(W^2)/2 + beta_0
        tmpz <- colSums(Z)/2 + alpha_0
        for(k in 1:K)
          alpha[,k] <- rgamma(1,shape=tmpz[k],rate=ww[k]) + 1e-7
      }
      if(opts$ARDW=="grouped"){
        for(m in 1:M){
          ww <- colSums(W[gr[[m]],,drop=FALSE]^2)/2 + beta_0
          tmpz <- colSums(Z[gr[[m]],,drop=FALSE])/2 + alpha_0
          for(k in 1:K)
            alpha[gr[[m]],k] <- rgamma(1,shape=tmpz[k],rate=ww[k]) + 1e-7
        }
      }
      if(opts$ARDW=="element"){
        alpha <- W^2/2 + beta_0
        keep <- which(alpha!=0)
        for(n in keep) alpha[n] <- rgamma(1,shape=alpha_0+0.5,rate=alpha[n]) +1e-7
        alpha[-keep] <- rgamma( length(alpha[-keep]),shape=alpha_0,rate=beta_0) + 1e-7
      }   
      
      ##
      ## Update tau, the noise precisions
      ##
      
      XW <- tcrossprod(X,W)
      if(missingValues && (missingValues.InitIter > iter)) { #in the last iter of ignoring missing values, update Y[[m]] with the model updates
        XW <- Y - XW
        XW[is.na(XW)] <- 0
        b_tau <- colSums(XW^2)/2
      } else{
        if(missingValues) {
          Y[na.inds] <- XW[na.inds]
        }
        b_tau <- colSums((Y-XW)^2)/2
      }
      #Original (faster):
      #  tau <- ( Yconst + rowSums((W%*%XX)*W) - 2*rowSums(crossprod(Y,X)*W) )/2
      #  b_tau <- tau
      if(opts$tauGrouped){
        for(m in 1:M){
          tau[gr[[m]]] <- rgamma(1,shape=alpha_0t[m], rate=beta_0t[gr[[m]][1]]+sum(b_tau[gr[[m]]]))
        }
      }else{
        for(d in 1:D)
          tau[d] <- rgamma(1,shape=a_tau[d], rate= beta_0t[d]+b_tau[d])
      }
      
      # calculate likelihood.
      cost[iter] <- const + N*sum(log(tau))/2 - crossprod(b_tau,tau)
      aic[iter] <- 2*cost[iter] - (D*(K+1)-K*(K-1)/2)*2 #Akaike information criterion 
    }
    
    if(filename != "" & iter%%100==0) { #Every 100 iterations, save the sampling
      time <- proc.time()[1]-ptm
      save(list=ls(),file=filename)
      if(opts$verbose>0)
        print(paste0("Iter ", iter, ", saved chain to '",filename,"'"))
    }
    
    # Calculate likelihood of observed views
    if(projection.fixed) {
      XW <- tcrossprod(X,W)
      b_tau <- colSums((Y-XW)^2)/2
      cost[iter] <- const
      for(m in which(opts$prediction==F))
        cost[iter] <- cost[iter] + N*sum(log(tau[gr[[m]]]))/2 - crossprod(b_tau[gr[[m]]],tau[gr[[m]]])
    }
    
    ##
    ## Prediction and collection of Gibbs samples
    ##
    
    if (any(opts$prediction)) { ## Prediction
      if (iter%%10==0 & opts$verbose>1) {
        print(paste0("Predicting: ",iter,"/",opts$iter.max))
      }
      for (m in which(opts$prediction)) { # Go through the views that will be predicted.
        if (iter>opts$iter.burnin & ((iter-opts$iter.burnin)%%mod.saved)==0) {
          i.pred <- i.pred+1/sum(opts$prediction)
          prediction[[m]] <- prediction[[m]] + tcrossprod(X, W[gr[[m]],]) # Update the mean prediction.
        }
      }
      
    } else {
      if (iter%%10==0 & opts$verbose>0) {
        print(paste0("Learning: ",iter,"/",opts$iter.max," - K=",sum(colSums(Z)!=0)," - ",Sys.time()))
      }
      if (opts$iter.saved>0) { ## Collection of Gibbs samples
        
        if (iter>opts$iter.burnin & ((iter-opts$iter.burnin)%%mod.saved)==0) { ## Save the Gibbs sample.
          if (is.null(posterior)) { # Initialize the list containing saved Gibbs samples.
            posterior <- list() # List containing saved Gibbs samples
            if ((opts$save.posterior$W | opts$convergenceCheck) & !any(opts$prediction)) {
              posterior$W <- array(dim=c(S, dim(W))) # SxDxK
              colnames(posterior$W) = colnames(Y)
            }
            posterior$rz <- array(dim=c(S, ncol(rz))) # SxK - Same 'rz' for all samples. Thus, save only a vector per iteration.
            posterior$r <- array(dim=c(S, M,K)) # SxKxM - Same 'r' for all variables of a view. Thus, save only a components x views matrix per iteration.
            if (opts$save.posterior$tau & !any(opts$prediction)) {
              posterior$tau <- array(dim=c(S, length(tau))) # SxD
              colnames(posterior$tau) = colnames(Y)
            }
            if ((opts$save.posterior$X | opts$convergenceCheck) & !any(opts$prediction)) {
              posterior$X = array(dim=c(S, dim(X))) # SxNxK
            }
            if(!opts$normalLatents)
              posterior$beta = array(dim=c(S,dim(beta)))
            gr.start = vector(mode="integer", length=length(gr))
            for (m in 1:length(gr)) { # Find the first index of each view. (Used for saving 'r'.)
              gr.start[m] = gr[[m]][1]
            }
          }
          s <- (iter-opts$iter.burnin)/mod.saved
          if (!projection.fixed) {
            if (opts$save.posterior$W | (opts$convergenceCheck & s%in%c(start,end))) {
              posterior$W[s,,] <- W
            }
            posterior$tau[s,] <- tau
            posterior$rz[s,] <- rz[1,] # In the current model, 'rz' is identical for all samples (rows are identical). Thus, save only a vector.
            posterior$r[s,,] <- r[gr.start,] # In the current model, 'r' is identical for all variables within a view. Thus, save only a views x components vector.
            if(!opts$normalLatents)
              posterior$beta[s,,] <- beta
          }
          if (opts$save.posterior$X | (opts$convergenceCheck & s%in%c(start,end))) {
            posterior$X[s,,] <- X
          }
        }
      }
    }
  } ## The main loop of the algorithm ends.
  
  if(opts$convergenceCheck & opts$iter.saved>=8 & !projection.fixed) {
    #Estimate the convergence of the data reconstruction, based on the Geweke diagnostic
    if(opts$verbose>0)
      print("Starting convergence check")
    conv <- 0
    for(i in 1:N) {
      if(opts$verbose>1) {
        if(i%%10==0)
          cat(".")
        if(i%%100==0)
          print(paste0(i,"/",N))
      }
      y <- matrix(NA,0,ncol(Y))
      for(ps in c(start,end)) {
        y <- rbind(y, tcrossprod(posterior$X[ps,i,],posterior$W[ps,,]))
      }
      foo <- rep(NA,ncol(Y))
      for(j in 1:ncol(Y)) {
        if(sd(y[start,j])>1e-10 & sd(y[-start,j])>1e-10) {
          foo[j] <- t.test(y[start,j],y[-start,j])$p.value
        } else { #Essentially constant reconstruction
          if(abs(mean(y[start,j])-mean(y[-start,j]))<1e-10)
            foo[j] <- 1 #Same constant
          else
            foo[j] <- 0 #Different constant
        }
      }
      conv <- conv + sum(foo<0.05)/length(foo)/N #check how many values are below 0.05
    }
    if(!opts$save.posterior$X)
      posterior$X <- NULL
    if(!opts$save.posterior$W)
      posterior$W <- NULL
    gc()
    
  } else {
    conv <- NA
  }
  
  if(filename!="" & opts$iter.max>=10)
    file.remove(filename) #No need for temporary storage any more
  
  ## Return the output of the model as a list
  if (any(opts$prediction)) {
    for (m in which(opts$prediction)) { # Go through the views that will be predicted.
      prediction[[m]] <- prediction[[m]]/i.pred
    }
    prediction$cost <- cost
    return(prediction)
  } else {
    d1 <- unlist(lapply(gr,function(x){x[1]}))
    if(opts$ARDW=="grouped")
      alpha <- alpha[d1,]
    if(opts$spikeW=="group")
      Z <- Z[d1,]
    if(opts$normalLatents)
      beta <- beta[1,]
    time <- proc.time()[1] - ptm
    
    return(list(W=W, X=X, Z=Z, r=r, rz=rz, tau=tau, alpha=alpha, beta=beta, groups=gr, D=D, K=K,
                cost=cost, aic=aic, posterior=posterior, opts=opts, conv=conv, time=time))
  }
  
}


predictGibbsGFA <- function(Y, model, opts, filename="") {
  
  ## Predict a set of data views given the GFA model (Gibbs samples) and another set of data views.
  ##
  ## Tommi Suvitaival
  ## 27.9.2013
  ##
  ## Arguments
  ##
  ## Y: Data views in a list
  ## model: Object from the function 'gibbsGFA'
  ## opts$predict: Logical vector with the length matching the length of Y, describing which data views will be predicted (TRUE).
  
  
  ##
  ## Initialization
  ##
  
  N <- nrow(Y[[which(!opts$prediction)[1]]]) # Number of samples in the prediction set
  
  X <- vector("list",length=length(Y)) #Store the data to be predicted
  
  #   N.shared.components <- rep(x=NA, times=nrow(model$posterior$W)) # Views x components # array(dim=c(length(De),nrow(model$posterior$W)))
  Y.prediction <- vector(mode="list", length=length(model$groups))
  names(Y.prediction) <- names(Y) # names(De)
  for (mi in which(opts$prediction)) { # Go through all views that will be predicted.
    Y.prediction[[mi]] <- array(data=0, dim=c(N, length(model$groups[[mi]]))) # Initialize the prediction matrix for view 'mi'.
    rownames(Y.prediction[[mi]]) <- rownames(Y[[which(!opts$prediction)[1]]]) # Names of the prediction samples
    colnames(Y.prediction[[mi]]) <- names(model$groups[[mi]]) # Names of the variables in view 'mi', which will be predicted.
    X[[mi]] <- Y[[mi]]
    Y[[mi]] <- array(dim=dim(Y.prediction[[mi]]), dimnames=dimnames(Y.prediction[[mi]])) # Initialize view 'mi' as missing data. These will be predicted.
  }
  
  ##
  ## Prediction
  ##
  mse <- matrix(NA,length(which(opts$prediction)),nrow(model$posterior$W))
  cost <- matrix(NA,opts$iter.max,nrow(model$posterior$W))
  start <- 1
  ptm <- proc.time()[1]
  if(filename!="") {
    if(file.exists(filename)) {
      load(filename)
      ptm <- proc.time()[1] - time #Starting time - time used so far
      start <- which(is.na(mse[1,]))[1]
      print(paste0("Loaded predictions from '",filename,"', continuing with posterior sample ",start))
    }
  }
  
  beta <- NA
  for (ni in start:nrow(model$posterior$W)) { # Go through all saved Gibbs samples.
    if(opts$verbose>0)
      print(paste0("Predicting, Gibbs sample ",ni))
    W <- matrix(model$posterior$W[ni,,],nrow(model$W),ncol(model$W))
    if(!opts$normalLatents)
      beta <- matrix(model$posterior$beta[ni,,],nrow(model$beta),ncol(model$beta))
    prediction.ni <- gibbsgfa(Y=Y, K=NULL, opts=opts, W=W, rz=model$posterior$rz[ni,],
                                tau=model$posterior$tau[ni,],beta=beta)
    
    for (mi in which(opts$prediction)) { # Go through the target views.
      Y.prediction[[mi]] = Y.prediction[[mi]] + prediction.ni[[mi]]/nrow(model$posterior$W)
      const <- nrow(model$posterior$W)/ni
      mse[which(mi==which(opts$prediction)),ni] <- mean((X[[mi]]-Y.prediction[[mi]]*const)^2)
      cost[,ni] <- prediction.ni$cost
      if(opts$verbose>1)
        print(paste0("MSE at iteration ",ni,": ",mse[which(mi==which(opts$prediction)),ni]))
    }
    if(filename!="" & ni%%10==0) {
      time <- proc.time()[1]-ptm
      save(Y.prediction,mse,cost,time,file=filename)
      print(paste0("Saved tmp predictions to: '",filename,"'"))
    }
  }
  Y.prediction$mse <- mse
  Y.prediction$cost <- cost
  Y.prediction$time <- proc.time()[1]-ptm
  
  return(Y.prediction)
  
}


informativeNoisePrior <- function(Y,prop.to.be.explained.by.noise,conf,opts) {
  # This function sets view wise paramteres for an informative prior such that the mean of   
  # noise prior is equal to the proportion of variance we want to be explained
  # by noise. The confidence paramter sets width of the prior distribution.
  
  # Inputs:
  #   Y:       list of datasets. data normalization does not matter.
  #   prop.to.be.explained.by.noise: proportion of total variance of each view to be explained by noise.
  #                                  Valid Values: 0.01 -> 0.99. Suggested -> 0.5
  #   conf:    width of the distribution. Valid values: 1e-2 -> 100. Suggested -> 1
  
  # Outputs:
  #   a list containing 2 vectors: prior.alpha_0t and prior.beta_0t
  #   prior.alpha_0t is alpha_0t paramter of HyperPrior for each view
  #   prior.beta_0t is beta_0t paramter of HyperPrior for each view
  
  # Derivation
  # alpha = conf*alpha_GFA (alpha_GFA in GFA derivations)
  # beta = conf*amount_of_var 
  # For gamma mu = a/b
  # In model a and b have opposite role so, amount_of_var = b/a, i.e. b = a*amount_of_var, b = conf*total_amount_of_var
  
  M <- length(Y)
  D <- unlist(lapply(Y,ncol))
  Ds <- c(1,cumsum(D)+1) ; De <- cumsum(D)
  gr <- vector("list")
  for(m in 1:M) {
    gr[[m]] <- Ds[m]:De[m]
    if (!is.null(colnames(Y[[m]]))) {
      names(gr[[m]]) <- colnames(Y[[m]])
    }
  }
  N <- nrow(Y[[1]])
  prior.alpha_0t <- rep(NA,sum(D))
  prior.beta_0t <- rep(NA,sum(D))
  
  if(opts$tauGrouped) {
    for (m in 1:M) {
      prior.alpha_0t[gr[[m]]] <- conf*sum(!is.na(Y[[m]]))/2
      total.variance.in.view <- sum(Y[[m]]^2,na.rm=T)/2
      sigma.hat <- prop.to.be.explained.by.noise * total.variance.in.view
      prior.beta_0t[gr[[m]]] <- conf*sigma.hat
    }
  } else {
    for (m in 1:M) {
      prior.alpha_0t[gr[[m]]] <- conf*colSums(!is.na(Y[[m]]))/2
      total.variance.in.view <- colSums(Y[[m]]^2,na.rm=T)/2
      sigma.hat <- prop.to.be.explained.by.noise * total.variance.in.view
      prior.beta_0t[gr[[m]]] <- conf*sigma.hat
    } 
  }
  return(list(prior.alpha_0t=prior.alpha_0t,prior.beta_0t=prior.beta_0t))
}


getDefaultOpts <- function() {
  #
  # A function for generating a default set of parameters.
  #
  # To run the algorithm with other values:
  #   opts <- getDefaultOpts()
  #   model <- gibbsgfa(Y,K,opts)
  
  #
  # Initial value for the noise precisions. Should be large enough
  # so that the real structure is modeled with components
  # instead of the noise parameters (see Luttinen&Ilin, 2010)
  #  Values: Positive numbers, but generally should use values well
  #          above 1
  #
  init.tau <- 10^3
  
  #
  # Parameters for controlling when the algorithm stops.
  # It stops when the number of samples is iter.max.
  #
  iter.max <- 5000
  iter.saved <- 0
  iter.burnin <- floor(iter.max/2)
  
  #
  # Hyperparameters
  # - alpha_0, beta_0 for the ARD precisions
  # - alpha_0t, beta_0t for the residual noise predicions
  #
  prior.alpha_0 <- prior.beta_0 <- 1
  prior.alpha_0t <- prior.beta_0t <- 1
  
  # Different spike-and-slab sparsity priors for the loadings W.
  # - group sparsity ("group"): binary variable indicating activity for each factor and data set 
  # - sructured ("grouped"): for each element a separate indicator, but the prior is structured
  #   such that each data set and factor share the same hyperparameter. This sharing promotes
  #   group sparsity in an indirect way.
  # - unstructured ("shared"): This prior does not use group structure at all.
  
  spikeW <- "group"
  
  # A parameter telling when to start sampling spike and slab parameters.
  sampleZ <- 1
  
  # Hyperparameter for indicator variables.
  # The first element is prior count for ones and the second for zeros.
  prior.beta <- c(1,1)
  
  # Different ARD priors for the loadings W.
  # - "grouped": Shared precision for each factor and group.
  # - "element": Separate precision for each element in W.
  # - "shared": No group structure.
  ARDW <- "grouped"
  
  # Gaussian prior for the latent variables?
  # - if TRUE then corresponds to standard Gaussian
  # - if FALSE then spike and slab prior is used
  normalLatents <- TRUE
  prior.betaX <- c(1,1) # the hyperparameters of the spike and slab
  # Set prior for the precisions of non-Gaussian latent variables.
  # options: ("shared", "element") (for help see prior for W)
  ARDLatent <- "shared"
  # The hyperparameter values for precisions corresponding to the latent variables.
  # Should use quite conservative values due sensible scaling.
  prior.alpha_0X <- 1
  prior.beta_0X <- 1
  
  # Prior for the noise precision.
  # TRUE/FALSE:
  # - TRUE: a group shares the same noise residual
  # - FALSE: each data feature has separate noise precision.
  tauGrouped <- TRUE
  
  #convergenceCheck check for the convergence of the data reconstruction, based on the Geweke diagnostic
  #if TRUE, it will in any case save the posterior samples of X and W, causing higher memory requirements
  
  return(list(init.tau=init.tau,
              iter.max=iter.max, iter.saved=iter.saved, iter.burnin=iter.burnin,
              prior.alpha_0=prior.alpha_0,prior.beta_0=prior.beta_0,
              prior.alpha_0t=prior.alpha_0t,prior.beta_0t=prior.beta_0t,
              prior.beta=prior.beta,prior.betaX=prior.betaX,
              spikeW=spikeW,sampleZ=sampleZ,normalLatents=normalLatents,
              ARDLatent=ARDLatent,prior.alpha_0X=prior.alpha_0X,
              prior.beta_0X=prior.beta_0X,ARDW=ARDW,tauGrouped=tauGrouped,verbose=1,
              convergenceCheck=FALSE))
}

