#
# Group Factor Analysis for two data modalities, bound together by a data view
# Authors: Eemeli Leppaaho, eemeli.leppaaho@aalto.fi
#          Seppo Virtanen, s.virtanen@warwick.ac.uk

#
# Description:
# Model is Y[[v]] = X[[v]]W[[v]]^T + E, v=1,2 and model supports various kinds of sparsity
# (see getDefaultOpts()).

gibbsgfa <- function(Y, K, opts, W=NULL, rz=NULL, tau=NULL, beta=NULL, filename="") {
  #
  # The main function for 2way group factor analysis
  #
  # Inputs:
  #   Y    : List of two data sets, paired in the first set. Y[[1]][[1]] is the first data
  #          set of the first modality. Y[[v]][[m]] is a N[v]xD[[v]][m] matrix. There are
  #          M[v] data views in modality v. Y[[1]][[1]]==t(Y[[2]][[1]]), i.e. N[v]==D[[3-v]][1]
  #          NOTE: All of these should be centered, so that the mean
  #                of each feature is zero
  #          NOTE: The algorithm is roughly invariant to the scale
  #                of the data, but extreme values should be avoided.
  #                Data with roughly unit variance or similar scale
  #                is recommended.
  #   K    : The number of components for each modality, a vector of length K.
  #   opts : List of options (see function getDefaultOpts())
  #
  
  if(file.exists(filename)) {
    print("Part of the sampling done already.")
    load(filename)
    print(paste("Continuing from iteration",iter))
    ptm <- proc.time()[1] - time #Starting time - time used so far
    
  } else { #Else initialization
    ptm <- proc.time()[1]
    if(any(Y[[1]][[1]]!=t(Y[[2]][[1]]),na.rm=T))
      stop("The two sets of views must have the first view shared (transposed).")
    
    # Store dimensionalities of data sets.
    M <- D <- N <- rep(NA,2)
    ov <- c(2,1) #The other view
    gr <- Yconst <- id <- list()
    for(v in 1:2) {
      M[v] <- length(Y[[v]])
      d <- sapply(Y[[v]],ncol)
      Ds <- c(1,cumsum(d)+1) ; De <- cumsum(d)
      gr[[v]] <- vector("list")
      for(m in 1:M[v]) {
        gr[[v]][[m]] <- Ds[m]:De[m]
        if (!is.null(colnames(Y[[v]][[m]]))) {
          names(gr[[v]][[m]]) <- colnames(Y[[v]][[m]])
        }
      }
      Y[[v]] <- do.call(cbind,Y[[v]]) #abind(Y[[v]],along=2)
      D[v] <- ncol(Y[[v]])
      N[v] <- nrow(Y[[v]])
    }
    
    alpha_0 <- opts$prior.alpha_0
    beta_0 <- opts$prior.beta_0
    alpha_0t <- opts$prior.alpha_0t
    beta_0t <- opts$prior.beta_0t
    #Make the tau prior vector-valued, if it is not already
    if(length(alpha_0t)==1)
      alpha_0t <- list(rep(alpha_0t, M[1]),rep(alpha_0t, M[2]))
    if(length(beta_0t)==1)
      beta_0t <- list(rep(beta_0t, M[1]),rep(beta_0t, M[2]))
    
    const <- rep(NA,2)
    Yconst <- id <- b_tau <- Z <- alpha <- covZ <- covW <- r <- na.inds <- list()
    
    if (!is.null(W) & !is.null(rz) & !is.null(tau) & !is.null(beta)) { # Projection given as an argument.
      projection.fixed = TRUE
      for(v in 1:2)
        K[v] = ncol(W[[v]])
    } else {
      projection.fixed = FALSE
      W <- beta <- rz <- tau <- prediction <- list()
    }
    X <- XX <- prediction <- list()
    
    for(v in 1:2) {
      # Some constants for speeding up the computation
      #const <- - N[v]*Ds/2*log(2*pi) # Constant factors for the lower bound
      Yconst[[v]] <- colSums(Y[[v]]^2) #unlist(lapply(Y[[v]],function(x){sum(x^2)}))
      id[[v]] <- rep(1,K[v])              # Vector of ones for fast matrix calculations
      
      ##
      ## Initialize the model randomly
      ##
      # Other initializations could be done, but overdispersed random initialization is quite good.
      
      # Latent variables
      X[[v]] <- matrix(rnorm(N[v]*K[v],0,1),N[v],K[v])
      X[[v]] <- scale(X[[v]])
      covZ[[v]] <- diag(1,K[v]) # Covariance
      XX[[v]] <- crossprod(X[[v]]) # Second moments
      
      if(opts$normalLatents & opts$ARDLatent=="shared" & projection.fixed){
        beta[[v]] <- outer(rep(1,N[v]),beta[[v]])
      }else{
        beta[[v]] <- matrix(1,N[v],K[v])
      }
    }
    
    i.pred <- 0
    if (opts$iter.saved>0 & opts$iter.burnin<opts$iter.max) {
      mod.saved <- ceiling((opts$iter.max-opts$iter.burnin)/opts$iter.saved) # Save every 'mod.saved'th Gibbs sample.
      S <- floor((opts$iter.max-opts$iter.burnin)/mod.saved) # Number of saved Gibbs samples.
      #     print(paste(mod.saved,S)
    }
    
    if (projection.fixed) { # Projection given as an argument to the sampler (fixed)
      # if normalLatents == FALSE, but if true?
      tmp <- WtW <- WtWdiag <- list()
      for(v in 1:2) {
        obs = NULL # variable indices of the source views
        for (mi in 1:length(opts$prediction[[v]])) { # Go through all views.
          if (!opts$prediction[[v]][mi]) {
            obs = c(obs, gr[[v]][[mi]])
          }
        }
        covZ[[v]] <- diag(1,K[v]) + crossprod(W[[v]][obs,]*sqrt(tau[[v]][obs]))
        eS <- eigen(covZ[[v]])
        tmp[[v]] <- eS$vectors*outer(id[[v]],1/sqrt(eS$values))
        covZ[[v]] <- tcrossprod(tmp[[v]])
        if(!opts$normalLatents){
          WtW[[v]] <- crossprod(W[[v]][obs,]*sqrt(tau[[v]][obs]))
          WtWdiag[[v]] <- diag(WtW[[v]])
          diag(WtW[[v]]) <- 0
          if (is.null(dim(rz[[v]]))) {
            if (length(rz[[v]])==K[v]) {
              rz[[v]] = matrix(data=rz[[v]], nrow=N[v], ncol=K[v], byrow=TRUE)
            } else {
              stop(paste0("rz[[",v,"]] not of required length"))
            }
          }
        }
        if (any(opts$prediction[[v]])) {
          prediction[[v]] <- vector(mode="list", length=length(gr[[v]]))
          Y.true <- Y[[v]]
          for (m in which(opts$prediction[[v]])) {
            prediction[[v]][[m]] = matrix(0,N[v],length(gr[[v]][[m]]))
            Y[[v]][,gr[[v]][[m]]] <- 0
          }
        }
        #Added things
        const[v] <- 0
        for(m in which(opts$prediction[[v]]==F))
          const[v] <- const[v] - N[v]*length(gr[[v]][[m]])*log(2*pi)/2
        cost <- rep(sum(const),opts$iter.max)
        b_tau[[v]] <- rep(beta_0t[[v]],unlist(lapply(gr[[v]],length)))
        #
      }
      
    } else { #Non-fixed projections
      if (any(unlist(opts$prediction))) {
        print("Error: Prediction without projections given")
      }
      for(v in 1:2) {
        WtW <- matrix(0,K[v],K[v])
        WtWdiag <- rep(0,K[v])
        
        const[v] <- - N[v]*D[v]*log(2*pi)/2
        tau[[v]] <- rep(opts$init.tau,D[v]) # The mean noise precisions
        b_tau[[v]] <- rep(beta_0t[[v]],unlist(lapply(gr[[v]],length)))
        
        W[[v]] <- matrix(0,D[v],K[v])
        Z[[v]] <- matrix(1,D[v],K[v])
        covW[[v]] <- diag(1,K[v])
        
        alpha[[v]] <- matrix(1,D[v],K[v])
        
        cost <- rep(sum(const) + N[1]*N[2]*log(2*pi)/2,opts$iter.max)  # For storing the lower bounds
        r[[v]] <- matrix(0.5,D[v],K[v])
        rz[[v]] <- matrix(0.5,N[v],K[v])
      }
    }
    
    ##Missing Values
    missingValues <- FALSE
    nd.inds <- list()
    for(v in 1:2) {
      na.inds[[v]] <- which(is.na(Y[[v]]))
      if(length(na.inds[[v]])>0 & !projection.fixed) {
        missingValues <- TRUE
        if(opts$verbose>0)
          print("Missing Values Detected, Prediction using EM type approximation")
        for(m in 1:M[v])
          alpha_0t[[v]][m] <- alpha_0t[[v]][m] + sum(!is.na(Y[[v]][,gr[[v]][[m]]]))/2
      } else {
        for(m in 1:M[v])
          alpha_0t[[v]][m] <- alpha_0t[[v]][m] + N[v]*length(gr[[v]][[m]])/2 #View-wise noise
      }
    }
    #missingValues.InitIter 
    #0 - start prediction from mean of all values, and use updates of the model allways
    #between 2 and maxiter - use updates of the model in a single imputation setting for  iterations past the set value.
    #maxiter+1 - ignore missing values. never update with the model's single imputation. 
    if(!is.null(opts$iter.missing))
      missingValues.InitIter <- opts$iter.missing
    else
      missingValues.InitIter <- round(opts$iter.burnin/2)
    
    if(missingValues & opts$verbose>0 & !projection.fixed)
      print(paste("missingValues.InitIter=",missingValues.InitIter))
    ## Missing Values end
    
    #Which X and W posterior samples to save in case of convergenceCheck?
    if(opts$convergenceCheck & opts$iter.saved>=8 & !projection.fixed) {
      ps <- floor(opts$iter.saved/4)
      start <- 1:ps
      end <- (-ps+1):0+opts$iter.saved
    } else {
      start <- end <- c()
    }
    
    posterior <- NULL
    iter <- 1
  }
  
  ##
  ## The main loop
  ##
  
  for(iter in iter:opts$iter.max) {
    ## Sample the projections W.
    if (!projection.fixed) {
      if(iter == 1){
        for(v in 1:2) {
          covW[[v]] <- diag(1,K[v]) + opts$init.tau*XX[[v]]
          eS <- eigen(covW[[v]])
          tmp <- eS$vectors*outer(id[[v]],1/sqrt(eS$values))
          covW[[v]] <- tcrossprod(tmp)
          
          if(missingValues && (missingValues.InitIter < iter)) {
            ##Could give bad init when high missing values. therefore skip from here.
            Y[[v]][na.inds[[v]]] <- mean(Y[[v]][-na.inds[[v]]])
          }
          
          estimW = matrix(0,D[v],K[v]) #equivalent of crossprod(Y[[v]],X[[v]])
          tmpY <- Y[[v]]
          if(v==2) #Some variance explained by the first mode already
            tmpY[,gr[[v]][[1]]] <- tmpY[,gr[[v]][[1]]] - t(tcrossprod(X[[ov[v]]],W[[ov[v]]][gr[[ov[v]]][[1]],]))
          for(k in 1:K[v]){
            if(missingValues && (missingValues.InitIter >= iter)) {
              tmpY[is.na(tmpY)] <- 0
              estimW[,k] <- crossprod(tmpY,X[[v]][,k])
            } else {
              estimW[,k] <- crossprod(tmpY,X[[v]][,k])
            }
          }
          W[[v]] <- estimW%*%covW[[v]]*opts$init.tau + matrix(rnorm(D[v]*K[v]),D[v],K[v])%*%tmp
        }
        
      }else{
        for(v in 1:2) {
          XXdiag <- diag(XX[[v]])
          diag(XX[[v]]) <- 0
          tmpY <- Y[[v]]
          tmpY[,gr[[v]][[1]]] <- tmpY[,gr[[v]][[1]]] - t(tcrossprod(X[[ov[v]]],W[[ov[v]]][gr[[ov[v]]][[1]],]))
          for(k in 1:K[v]){
            lambda <- tau[[v]]*XXdiag[k] + alpha[[v]][,k]
            if(missingValues && (missingValues.InitIter >= iter)) {
              ss <- tcrossprod(X[[v]][,-k],W[[v]][,-k])
              tmp <- tmpY-ss
              tmp[na.inds[[v]]] <- 0
              mu_sub <- crossprod(tmp,X[[v]][,k])
              mu <- mu_sub*tau[[v]]/lambda
            } else {
              mu <- tau[[v]]/lambda*as.vector( crossprod(tmpY,X[[v]][,k]) - W[[v]]%*%XX[[v]][k,])
            }
            if(iter > opts$sampleZ){
              if(opts$spikeW!="group"){
                zone <- 0.5*( log(alpha[[v]][,k]) - log(lambda) + lambda*mu^2) + log(r[[v]][,k]) - log(1-r[[v]][,k])
                zone <- 1/(1+exp(-zone))
                zone <- as.double(runif(D[v]) < zone)
                Z[[v]][,k] <- zone
              } else {
                zone <- 0.5*(log(alpha[[v]][,k])-log(lambda) + lambda*mu^2)
                for(m in 1:M[v]){
                  logpr <- sum(zone[gr[[v]][[m]]]) + log(r[[v]][gr[[v]][[m]][1],k]) - log( 1-r[[v]][gr[[v]][[m]][1],k] )
                  logpr <- 1/(1+exp(-logpr))
                  Z[[v]][gr[[v]][[m]],k] <- as.double(runif(1)<logpr)
                }
              }
            }
            W[[v]][,k] <- mu + rnorm(D[v])*lambda^(-0.5)
          }
        }
      }
      for(v in 1:2)
        W[[v]] <- W[[v]]*Z[[v]]
    }
    
    ## Check if components can be removed.
    # If Gibbs samples are going to be saved, components are dropped only during the burn-in.
    # If the projection has been given as an argument (e.g. for prediction), components will not be dropped.
    if ((iter<=opts$iter.burnin | opts$iter.saved==0) & !projection.fixed) {
      for(v in 1:2) {
        keep <- which(colSums(Z[[v]])!=0 & colSums(abs(X[[v]]))>0)
        if(length(keep)==0) {
          print(paste0("All components shut down (mode ",v,"), returning a NULL model."))
          return(list())
        }
        if(length(keep)!=K[v]){
          K[v] <- length(keep)
          id[[v]] <- rep(1,K[v])
          alpha[[v]] <- alpha[[v]][,keep,drop=F]
          X[[v]] <- X[[v]][,keep,drop=F]
          W[[v]] <- W[[v]][,keep,drop=F]
          Z[[v]] <- Z[[v]][,keep,drop=F]
          r[[v]] <- r[[v]][,keep,drop=F]
          beta[[v]] <- beta[[v]][,keep,drop=F]
          rz[[v]] <- rz[[v]][,keep,drop=F]
        }
      }
    }
    
    ## sample r
    for(v in 1:2) {
      if(iter>opts$sampleZ & !projection.fixed){
        if(opts$spikeW=="shared"){
          zm <- colSums(Z[[v]])
          for(k in 1:K[v])
            r[[v]][,k] <- (zm[k]+opts$prior.beta[1])/(D[v]+sum(opts$prior.beta))
        }
        if(opts$spikeW=="grouped"){
          for(m in 1:M[v]){
            zm <- colSums(Z[[v]][gr[[v]][[m]],,drop=FALSE])
            for(k in 1:K[v])
              r[[v]][gr[[v]][[m]],k] <- (zm[k]+opts$prior.beta[1])/(length(gr[[v]][[m]])+sum(opts$prior.beta))
          }
        }
        if(opts$spikeW=="group"){
          zm <- rep(0,K[v])
          for(m in 1:M[v]){
            zm <- zm + Z[[v]][gr[[v]][[m]][1],]
          }
          for(m in 1:M[v]){
            for(k in 1:K[v]){
              r[[v]][gr[[v]][[m]],k] <- (zm[k]+opts$prior.beta[1])/(M[v]+sum(opts$prior.beta))
            }
          }
        }
      }
      
      ## 
      ## Update the latent variables
      ##
      
      if(iter>opts$sampleZ & !opts$normalLatents) {
        if (!projection.fixed) {
          WtW <- crossprod(W[[v]]*sqrt(tau[[v]]))
          WtWdiag <- diag(WtW)
          diag(WtW) <- 0
        }
        for(k in 1:K[v]){
          if(projection.fixed)
            lambda <- WtWdiag[[v]][k] + beta[[v]][,k]
          else
            lambda <- WtWdiag[k] + beta[[v]][,k]
          tmpY <- Y[[v]]
          tmpY[,gr[[v]][[1]]] <- tmpY[,gr[[v]][[1]]] - t(tcrossprod(X[[ov[v]]],W[[ov[v]]][gr[[ov[v]]][[1]],]))
          if(missingValues && (missingValues.InitIter >= iter)) {
            tmpY[is.na(tmpY)] <- 0
          }
          if(projection.fixed)
            mu <- (tmpY%*%(W[[v]][,k]*tau[[v]]) - X[[v]]%*%WtW[[v]][k,])/lambda
          else
            mu <- (tmpY%*%(W[[v]][,k]*tau[[v]]) - X[[v]]%*%WtW[k,])/lambda
          zone <- 0.5*( log(beta[[v]][,k]) - log(lambda) + lambda*mu^2) + log(rz[[v]][,k]) - log(1-rz[[v]][,k])
          zone <- 1/(1+exp(-zone))
          zone <- as.double(runif(N[v]) < zone)
          X[[v]][,k] <- mu + rnorm(N[v])*lambda^(-0.5)
          X[[v]][,k] <- X[[v]][,k]*zone
          if (!projection.fixed) {
            zm <- sum(zone)
            rz[[v]][,k] <- (zm+opts$prior.betaX[1])/(N[v]+sum(opts$prior.betaX)) #rbeta(1,shape1=zm+prior.betaX1,shape2=N[v]-zm+prior.betaX2)
          }
        }
      } else {
        if (!projection.fixed) {
          covZ[[v]] <- diag(1,K[v]) + crossprod(W[[v]]*sqrt(tau[[v]]))
          eS <- eigen(covZ[[v]])
          tmp <- eS$vectors*outer(id[[v]],1/sqrt(eS$values))
          covZ[[v]] <- tcrossprod(tmp)
        }
        tmpY <- Y[[v]]
        tmpY[,gr[[v]][[1]]] <- tmpY[,gr[[v]][[1]]] - t(tcrossprod(X[[ov[v]]],W[[ov[v]]][gr[[ov[v]]][[1]],]))
        if(missingValues && (missingValues.InitIter >= iter)) {
          tmpY[is.na(tmpY)] <- 0
          X[[v]] <- tmpY%*%(W[[v]]*tau[[v]])
        } else {
          X[[v]] <- tmpY%*%(W[[v]]*tau[[v]])
        }
        X[[v]] <- X[[v]]%*%covZ[[v]]
        if(projection.fixed)
          X[[v]] <- X[[v]] + matrix(rnorm(N[v]*K[v]),N[v],K[v])%*%tmp[[v]]
        else
          X[[v]] <- X[[v]] + matrix(rnorm(N[v]*K[v]),N[v],K[v])%*%tmp
      }
      XX[[v]] <- crossprod(X[[v]])
      
      ##
      ## Update alpha and beta, the ARD parameters
      ##
      
      if(!opts$normalLatents){
        if(opts$ARDLatent=="element"){
          beta[[v]] <- X[[v]]^2/2 + opts$prior.beta_0X
          keep <- which(beta[[v]]!=0)
          for(n in keep) beta[[v]][n] <- rgamma(1,shape=opts$prior.alpha_0X+0.5,rate=beta[[v]][n]) +1e-7
          beta[[v]][-keep] <- rgamma( length(beta[[v]][-keep]),shape=opts$prior.alpha_0X,rate=opts$prior.beta_0X) + 1e-7
        }
        if(opts$ARDLatent == "shared" & !projection.fixed){
          xx <- colSums(X[[v]]^2)/2 + opts$prior.beta_0X
          tmpxx <- colSums(X[[v]]!=0)/2 + opts$prior.alpha_0X
          for(k in 1:K[v])
            beta[[v]][,k] <- rgamma(1,shape=tmpxx[k],rate=xx[k]) + 1e-7
        }
      }
      if (!projection.fixed) {
        if(opts$ARDW=="shared"){
          ww <- colSums(W[[v]]^2)/2 + beta_0
          tmpz <- colSums(Z[[v]])/2 + alpha_0
          for(k in 1:K[v])
            alpha[[v]][,k] <- rgamma(1,shape=tmpz[k],rate=ww[k]) + 1e-7
        }
        if(opts$ARDW=="grouped"){
          for(m in 1:M[v]){
            ww <- colSums(W[[v]][gr[[v]][[m]],,drop=FALSE]^2)/2 + beta_0
            tmpz <- colSums(Z[[v]][gr[[v]][[m]],,drop=FALSE])/2 + alpha_0
            for(k in 1:K[v])
              alpha[[v]][gr[[v]][[m]],k] <- rgamma(1,shape=tmpz[k],rate=ww[k]) + 1e-7
          }
        }
        if(opts$ARDW=="element"){
          alpha[[v]] <- W[[v]]^2/2 + beta_0
          keep <- which(alpha[[v]]!=0)
          for(n in keep) alpha[[v]][n] <- rgamma(1,shape=alpha_0+0.5,rate=alpha[[v]][n]) +1e-7
          alpha[[v]][-keep] <- rgamma( length(alpha[[v]][-keep]),shape=alpha_0,rate=beta_0) + 1e-7
        }   
        
        ##
        ## Update tau, the noise precisions
        ##
        
        XW <- tcrossprod(X[[v]],W[[v]])
        if(missingValues && (missingValues.InitIter > iter)) { #in the last iter of ignoring missing values, update Y[[v]][[m]] with the model updates
          XW <- Y[[v]] - XW
          XW[is.na(XW)] <- 0
          b_tau[[v]][-gr[[v]][[1]]] <- colSums(XW[,-gr[[v]][[1]]]^2)/2
        } else{
          if(missingValues) {
            Y[[v]][na.inds[[v]]] <- XW[na.inds[[v]]]
          }
          b_tau[[v]][-gr[[v]][[1]]] <- colSums((Y[[v]][,-gr[[v]][[1]]]-XW[,-gr[[v]][[1]]])^2)/2
        }
        
        if(opts$tauGrouped){
          for(m in 2:M[v]){
            tau[[v]][gr[[v]][[m]]] <- rgamma(1, shape = alpha_0t[[v]][m],
                                             rate = beta_0t[[v]][m] + sum(b_tau[[v]][gr[[v]][[m]]]))
          }
        }else{
          stop("Tau needs to be grouped: otherwise the noise prior for the shared view is ill-defined")
        }
        # calculate likelihood.
        cost[iter] <- cost[iter] + N[v]*sum(log(tau[[v]][-gr[[v]][[1]]]))/2 - crossprod(b_tau[[v]][-gr[[v]][[1]]],tau[[v]][-gr[[v]][[1]]])
      }
      
    } #Loop over v ends
    #Update tau for the paired view
    if(!projection.fixed) {
      XW <- tcrossprod(X[[1]],W[[1]][gr[[1]][[1]],]) - t(tcrossprod(X[[2]],W[[2]][gr[[2]][[1]],]))
      if(missingValues && (missingValues.InitIter > iter)) {
        XW <- Y[[1]][,gr[[1]][[1]]] - XW
        XW[is.na(XW)] <- 0
        b_tau[[1]][gr[[1]][[1]]] <- colSums(XW[,gr[[1]][[1]]]^2)/2
      } else {
        if(missingValues) {
          Y[[1]][na.inds[[3]]] <- XW[na.inds[[3]]] #Missing values in the shared view
          Y[[2]][na.inds[[4]]] <- t(Y[[1]])[na.inds[[4]]] #Copied
        }
        b_tau[[1]][gr[[v]][[1]]] <- colSums((Y[[1]][,gr[[1]][[1]]]-XW[,gr[[1]][[1]]])^2)/2
      }
      
      if(opts$tauGrouped){
        tau[[1]][gr[[1]][[1]]] <- rgamma(1, shape = alpha_0t[[1]][1],
                                         rate = beta_0t[[1]][1] + sum(b_tau[[1]][gr[[1]][[1]]]))
        tau[[2]][gr[[2]][[1]]] <- tau[[1]][gr[[1]][[1]][1]]
      }else{
        stop("Tau needs to be grouped: otherwise the noise prior for the shared view is ill-defined")
      }
      cost[iter] <- cost[iter] + N[1]*sum(log(tau[[1]][gr[[1]][[1]]]))/2 - crossprod(b_tau[[1]][gr[[1]][[1]]],tau[[1]][gr[[1]][[1]]])
    }
    
    
    if(filename != "" & iter%%100==0) { #Every 100 iterations, save the sampling
      time <- proc.time()[1]-ptm
      save(list=ls(),file=filename)
      if(opts$verbose>0)
        print(paste0("Iter ", iter, ", saved chain to '",filename,"'"))
    }
    
    # Calculate likelihood of observed views
    if(projection.fixed) {
      for(v in 1:2) {
        tautmp <- ( Yconst[[v]] + rowSums((W[[v]]%*%XX[[v]])*W[[v]]) - 2*rowSums(crossprod(Y[[v]],X[[v]])*W[[v]]) )/2
        tautmp <- rep(alpha_0t[[v]] -N[v]*unlist(lapply(gr[[v]],length)) + N[v]/2, unlist(lapply(gr[[v]],length)))/(beta_0t+tautmp)
        for(m in which(opts$prediction[[v]]==F))
          cost[iter] <- cost[iter] + N[v]*sum(log(tautmp[gr[[v]][[m]]]))/2 - crossprod(b_tau[[v]][gr[[v]][[m]]],tautmp[gr[[v]][[m]]])
      }
    }
    
    ##
    ## Prediction and collection of Gibbs samples
    ##
    
    for(v in 1:2) {
      if (any(opts$prediction[[v]])) { ## Prediction
        if (iter%%10==0 & opts$verbose>1) {
          print(paste0("Predicting: ",iter,"/",opts$iter.max))
        }
        if (iter>opts$iter.burnin & ((iter-opts$iter.burnin)%%mod.saved)==0) {
          i.pred <- i.pred+1/2
          for (m in which(opts$prediction[[v]])) { # Go through the views that will be predicted.
            prediction[[v]][[m]] <- prediction[[v]][[m]] + tcrossprod(X[[v]], W[[v]][gr[[v]][[m]],])
          }
        }
        
      } else if(!any(unlist(opts$prediction))) {
        if (iter%%10==0 & opts$verbose>0 & v==1) {
          print(paste0("Learning: ",iter,"/",opts$iter.max," - K=",paste0(K,collapse=",")," - ",Sys.time()))
        }
        if (opts$iter.saved>0) { ## Collection of Gibbs samples
          if (iter>opts$iter.burnin & ((iter-opts$iter.burnin)%%mod.saved)==0) { ## Save the Gibbs sample.
            if (iter-opts$iter.burnin==mod.saved) { # Initialize the list containing saved Gibbs samples.
              if(v==1)
                posterior <- list(r=list(),rz=list(),beta=list()) # List containing saved Gibbs samples
              if ((opts$save.posterior$W | opts$convergenceCheck) & !any(opts$prediction[[v]])) {
                if(v==1) {posterior$W <- list()}
                posterior$W[[v]] <- array(dim=c(S, dim(W[[v]]))) # SxDxK[v]
                colnames(posterior$W[[v]]) = colnames(Y[[v]])
              }
              posterior$rz[[v]] <- array(dim=c(S, ncol(rz[[v]]))) # SxK[v] - Same 'rz' for all samples. Thus, save only a vector per iteration.
              posterior$r[[v]] <- array(dim=c(S, M[v],K[v])) # SxK[v]xM[v] - Same 'r' for all variables of a view. Thus, save only a components x views matrix per iteration.
              posterior$tau[[v]] <- array(dim=c(S, length(tau[[v]]))) # SxD
              colnames(posterior$tau[[v]]) = colnames(Y[[v]])
              if ((opts$save.posterior$X | opts$convergenceCheck) & !any(opts$prediction[[v]])) {
                if(v==1) {posterior$X <- list()}
                posterior$X[[v]] = array(dim=c(S, dim(X[[v]]))) # SxN[v]xK[v]
              }
              posterior$beta[[v]] = array(dim=c(S,dim(beta[[v]])))
              if(v==1) {gr.start <- list()}
              gr.start[[v]] = vector(mode="integer", length=length(gr[[v]]))
              for (m in 1:length(gr[[v]])) { # Find the first index of each view. (Used for saving 'r'.)
                gr.start[[v]][m] = gr[[v]][[m]][1]
              }
            }
            
            s <- (iter-opts$iter.burnin)/mod.saved
            if (!projection.fixed) {
              if (opts$save.posterior$W | (opts$convergenceCheck & s%in%c(start,end)))
                posterior$W[[v]][s,,] <- W[[v]]
              posterior$tau[[v]][s,] <- tau[[v]]
              posterior$rz[[v]][s,] <- rz[[v]][1,] # In the current model, 'rz' is identical for all samples (rows are identical). Thus, save only a vector.
              posterior$r[[v]][s,,] <- r[[v]][gr.start[[v]],] # In the current model, 'r' is identical for all variables within a view. Thus, save only a views x components vector.
              posterior$beta[[v]][s,,] <- beta[[v]]
            }
            if (opts$save.posterior$X | (opts$convergenceCheck & s%in%c(start,end)))
              posterior$X[[v]][s,,] <- X[[v]]
          }
        }
      }
    }
  } ## The main loop of the algorithm ends.
  
  if(opts$convergenceCheck & opts$iter.saved>=8 & !projection.fixed) {
    #Estimate the convergence of the data reconstruction, based on the Geweke diagnostic
    if(opts$verbose>0)
      print("Starting convergence check")
    conv <- 0
    for(v in 1:2) {
      for(i in 1:N[v]) {
        if(opts$verbose>1) {
          if(i%%10==0)
            cat(".")
          if(i%%100==0)
            print(paste0(i,"/",N[v]))
        }
        y <- matrix(NA,0,ncol(Y[[v]]))
        for(ps in c(start,end)) {
          y <- rbind(y, tcrossprod(posterior$X[[v]][ps,i,],posterior$W[[v]][ps,,]))
        }
        foo <- rep(NA,ncol(Y[[v]]))
        for(j in 1:ncol(Y[[v]])) {
          if(sd(y[start,j])>1e-10 & sd(y[-start,j])>1e-10) {
            foo[j] <- t.test(y[start,j],y[-start,j])$p.value
          } else { #Essentially constant reconstruction
            if(abs(mean(y[start,j])-mean(y[-start,j]))<1e-10)
              foo[j] <- 1 #Same constant
            else
              foo[j] <- 0 #Different constant
          }
        }
        conv <- conv + sum(foo<0.05)/length(foo)/N[v]/2 #check how many values are below 0.05
      }
      if(!opts$save.posterior$X)
        posterior$X[[v]] <- NULL
      if(!opts$save.posterior$W)
        posterior$W[[v]] <- NULL
      gc()
    }
    
  } else {
    conv <- NA
  }
  
  if(filename!="" & opts$iter.max>=10)
    file.remove(filename) #No need for temporary storage any more
  
  ## Return the output of the model as a list
  if (any(unlist(opts$prediction))) {
    for(v in 1:2) {
      if(any(opts$prediction[[v]])) {
        for (m in which(opts$prediction[[v]])) # Go through the views that will be predicted.
          prediction[[v]][[m]] <- prediction[[v]][[m]]/i.pred
        prediction[[v]]$cost <- cost
      }
    }
    return(prediction)
  } else {
    for(v in 1:2) {
      d1 <- unlist(lapply(gr[[v]],function(x){x[1]}))
      if(opts$ARDW=="grouped")
        alpha[[v]] <- alpha[[v]][d1,]
      if(opts$spikeW=="group")
        Z[[v]] <- Z[[v]][d1,]
    }
    time <- proc.time()[1] - ptm
    
    return(list(W=W, X=X, Z=Z, r=r, rz=rz, tau=tau, alpha=alpha, beta=beta, groups=gr, D=D, K=K,
                cost=cost, posterior=posterior, opts=opts, conv=conv, time=time))
  }
  
}


predictGibbsGFA <- function(Y, model, opts) {
  
  ## Predict a set of data views given the GFA model (Gibbs samples) and another set of data views.
  ##
  ## Tommi Suvitaival
  ## 27.9.2013
  ##
  ## Arguments
  ##
  ## Y: Data views in a list
  ## model: Object from the function 'gibbsGFA'
  ## opts$predict: Logical vector with the length matching the length of Y, describing which data views will be predicted (TRUE).
  
  
  ##
  ## Initialization
  ##
  N <- rep(NA,2)
  X <- Y.prediction <- mse <- cost <- W <- beta <- tau <- rz <- list()
  for(v in 1:2) {
    N[v] <- nrow(Y[[v]][[which(!opts$prediction[[v]])[1]]]) # Number of samples in the prediction set
    
    X[[v]] <- vector("list",length=length(Y[[v]])) #Store the data to be predicted
    
    #   N.shared.components <- rep(x=NA, times=nrow(model$posterior$W[[v]])) # Views x components # array(dim=c(length(De),nrow(model$posterior$W[[v]])))
    Y.prediction[[v]] <- vector(mode="list", length=length(model$groups))
    names(Y.prediction[[v]]) <- names(Y[[v]]) # names(De)
    for (mi in which(opts$prediction[[v]])) { # Go through all views that will be predicted.
      Y.prediction[[v]][[mi]] <- array(data=0, dim=c(N[v], length(model$groups[[v]][[mi]]))) # Initialize the prediction matrix for view 'mi'.
      rownames(Y.prediction[[v]][[mi]]) <- rownames(Y[[v]][[which(!opts$prediction[[v]])[1]]]) # Names of the prediction samples
      colnames(Y.prediction[[v]][[mi]]) <- names(model$groups[[v]][[mi]]) # Names of the variables in view 'mi', which will be predicted.
      X[[v]][[mi]] <- Y[[v]][[mi]]
      Y[[v]][[mi]] <- array(dim=dim(Y.prediction[[v]][[mi]]), dimnames=dimnames(Y.prediction[[v]][[mi]])) # Initialize view 'mi' as missing data. These will be predicted.
    }
    
    ##
    ## Prediction
    ##
    mse[[v]] <- matrix(NA,length(which(opts$prediction[[v]])),nrow(model$posterior$W[[v]]))
    cost[[v]] <- matrix(NA,opts$iter.max,nrow(model$posterior$W[[v]]))
  }
  
  for (ni in 1:nrow(model$posterior$W[[1]])) { # Go through all saved Gibbs samples.
    if(opts$verbose>0)
      print(paste0("Predicting, Gibbs sample ",ni))
    for(v in 1:2) {
      W[[v]] <- matrix(model$posterior$W[[v]][ni,,],nrow(model$W[[v]]),ncol(model$W[[v]]))
      beta[[v]] <- matrix(model$posterior$beta[[v]][ni,,],nrow(model$beta[[v]]),ncol(model$beta[[v]]))
      tau[[v]] <- model$posterior$tau[[v]][ni,]
      rz[[v]] <- model$posterior$rz[[v]][ni,]
    }
    prediction.ni <- gibbsgfa(Y=Y, K=NULL, opts=opts, W=W, rz=rz, tau=tau, beta=beta)
    for(v in 1:2) {
      if(any(opts$prediction[[v]])) {
        for (mi in which(opts$prediction[[v]])) { # Go through the target views.
          Y.prediction[[v]][[mi]] = Y.prediction[[v]][[mi]] + prediction.ni[[v]][[mi]]/nrow(model$posterior$W[[v]])
          const <- nrow(model$posterior$W[[v]])/ni
          mse[[v]][which(mi==which(opts$prediction[[v]])),ni] <- mean((X[[v]][[mi]]-Y.prediction[[v]][[mi]]*const)^2)
          cost[[v]][,ni] <- prediction.ni[[v]]$cost
          if(opts$verbose>1)
            print(paste0("MSE at iteration ",ni,": ",mse[[v]][which(mi==which(opts$prediction[[v]])),ni]))
        }
      }
    }
  }
  for(v in 1:2) {
    Y.prediction[[v]]$mse <- mse
    Y.prediction[[v]]$cost <- cost
  }
  
  return(Y.prediction)
  
}


informativeNoisePrior <- function(Y,prop.to.be.explained.by.noise,conf,opts) {
  # This function sets view wise paramteres for an informative prior such that the mean of   
  # noise prior is equal to the proportion of variance we want to be explained
  # by noise. The confidence paramter sets width of the prior distribution.
  
  # Inputs:
  #   Y:       list of two dataset lists. data normalization does not matter.
  #   prop.to.be.explained.by.noise: proportion of total variance of each view to be explained by noise.
  #                                  Valid Values: 0.01 -> 0.99. Suggested -> 0.5
  #   conf:    width of the distribution. Valid values: 1e-2 -> 100. Suggested -> 1
  
  # Outputs:
  #   a list containing 2 vectors: prior.alpha_0t and prior.beta_0t
  #   prior.alpha_0t is alpha_0t paramter of HyperPrior for each view
  #   prior.beta_0t is beta_0t paramter of HyperPrior for each view
  
  # Derivation
  # alpha = conf*alpha_GFA (alpha_GFA in GFA derivations)
  # beta = conf*amount_of_var 
  # For gamma mu = a/b
  # In model a and b have opposite role so, amount_of_var = b/a, i.e. b = a*amount_of_var, b = conf*total_amount_of_var
  
  prior.alpha_0t <- prior.beta_0t <- list()
  for(v in 1:2) {
    M <- length(Y[[v]])
    D <- unlist(lapply(Y[[v]],ncol))
    Ds <- c(1,cumsum(D)+1) ; De <- cumsum(D)
    gr <- vector("list")
    for(m in 1:M) {
      gr[[m]] <- Ds[m]:De[m]
    }
    N <- nrow(Y[[v]][[1]])
    prior.alpha_0t[[v]] <- rep(NA,sum(M))
    prior.beta_0t[[v]] <- rep(NA,sum(M))
    
    if(opts$tauGrouped) {
      for (m in 1:M) {
        prior.alpha_0t[[v]][m] <- conf*N*D[m]/2
        total.variance.in.view <- sum(Y[[v]][[m]]^2,na.rm=T)/2
        sigma.hat <- prop.to.be.explained.by.noise * total.variance.in.view
        prior.beta_0t[[v]][m] <- conf*sigma.hat
      }
    } else {
      stop("Tau needs to be grouped.")
    }
  }
  return(list(prior.alpha_0t=prior.alpha_0t,prior.beta_0t=prior.beta_0t))
}

getDefaultOpts <- function() {
  #
  # A function for generating a default set of parameters.
  #
  # To run the algorithm with other values:
  #   opts <- getDefaultOpts()
  #   model <- gibbsgfa(Y,K[v],opts)
  
  #
  # Initial value for the noise precisions. Should be large enough
  # so that the real structure is modeled with components
  # instead of the noise parameters (see Luttinen&Ilin, 2010)
  #  Values: Positive numbers, but generally should use values well
  #          above 1
  #
  init.tau <- 10^3
  
  #
  # Parameters for controlling when the algorithm stops.
  # It stops when the number of samples is iter.max.
  #
  iter.max <- 5000
  iter.saved <- 0
  iter.burnin <- floor(iter.max/2)
  
  #
  # Hyperparameters
  # - alpha_0, beta_0 for the ARD precisions
  # - alpha_0t, beta_0t for the residual noise predicions
  #
  prior.alpha_0 <- prior.beta_0 <- 1
  prior.alpha_0t <- prior.beta_0t <- 1
  
  # Different spike-and-slab sparsity priors for the loadings W.
  # - group sparsity ("group"): binary variable indicating activity for each factor and data set 
  # - sructured ("grouped"): for each element a separate indicator, but the prior is structured
  #   such that each data set and factor share the same hyperparameter. This sharing promotes
  #   group sparsity in an indirect way.
  # - unstructured ("shared"): This prior does not use group structure at all.
  
  spikeW <- "group"
  
  # A parameter telling when to start sampling spike and slab parameters.
  sampleZ <- 1
  
  # Hyperparameter for indicator variables.
  # The first element is prior count for ones and the second for zeros.
  prior.beta <- c(1,1)
  
  # Different ARD priors for the loadings W.
  # - "grouped": Shared precision for each factor and group.
  # - "element": Separate precision for each element in W.
  # - "shared": No group structure.
  ARDW <- "grouped"
  
  # Gaussian prior for the latent variables?
  # - if TRUE then corresponds to standard Gaussian
  # - if FALSE then spike and slab prior is used
  normalLatents <- TRUE
  prior.betaX <- c(1,1) # the hyperparameters of the spike and slab
  # Set prior for the precisions of non-Gaussian latent variables.
  # options: ("shared", "element") (for help see prior for W)
  ARDLatent <- "shared"
  # The hyperparameter values for precisions corresponding to the latent variables.
  # Should use quite conservative values due sensible scaling.
  prior.alpha_0X <- 1
  prior.beta_0X <- 1
  
  # Prior for the noise precision.
  # TRUE/FALSE:
  # - TRUE: a group shares the same noise residual
  # - FALSE: each data feature has separate noise precision.
  tauGrouped <- TRUE
  
  #convergenceCheck check for the convergence of the data reconstruction, based on the Geweke diagnostic
  #if TRUE, it will in any case save the posterior samples of X and W, causing higher memory requirements
  
  return(list(init.tau=init.tau,
              iter.max=iter.max, iter.saved=iter.saved, iter.burnin=iter.burnin,
              prior.alpha_0=prior.alpha_0,prior.beta_0=prior.beta_0,
              prior.alpha_0t=prior.alpha_0t,prior.beta_0t=prior.beta_0t,
              prior.beta=prior.beta,prior.betaX=prior.betaX,
              spikeW=spikeW,sampleZ=sampleZ,normalLatents=normalLatents,
              ARDLatent=ARDLatent,prior.alpha_0X=prior.alpha_0X,
              prior.beta_0X=prior.beta_0X,ARDW=ARDW,tauGrouped=tauGrouped,verbose=1,
              convergenceCheck=FALSE))
}
