#
# Relaxed multi-tensor factorization (rMTF).
# Author: Eemeli Leppäaho, eemeli.leppaaho@aalto.fi
# Based on the GFA scripts by Seppo Virtanen and Tommi Suvitaival
#

rMTF <- function(Y, K, opts, sharing, W=NULL, V=NULL, U=NULL, rz=NULL, tau=NULL, beta=NULL, filename="",init=NULL) {
  
  #
  # The main function for Bayesian group factor analysis
  #
  # Inputs:
  #   Y    : List of M data matrices. Y[[m]] is a matrix with
  #          N rows (samples) and D_m columns (features). The
  #          samples need to be co-occurring.
  #          NOTE: All of these should be centered, so that the mean
  #                of each feature is zero
  #          NOTE: The algorithm is roughly invariant to the scale
  #                of the data, but extreme values should be avoided.
  #                Data with roughly unit variance or similar scale
  #                is recommended.
  #          NOTE: Tensor slabs should be separate elements of Y
  #   K    : The number of components
  #   opts : List of options (see function getDefaultOpts())
  # sharing: Defines the data structure of the matrices and tensors
  #          Should be a vector of length(Y), 0s denote matrices and
  #          other numbers matrices.
  #          Example: 1 matrix and 2 tensors with depth 3
  #                   sharing <- c(0, 1, 1, 1, 2, 2, 2)
  #
  # Output:
  # The trained model, which is a list that contains the final Gibbs samples of:
  #   W    : Projections for each data view
  #   V    : The mean projections of a tensor
  #   U    : The third factorization matrix
  #   X    : The first factorization matrix (latent variables)
  #   Z    : The spike and slab structure of W
  #   alpha: The scale (ARD) variable for W
  #   beta : The scale (ARD) variable for X
  #alpha_sh: Precision lambda controlling the level or bilinearity/trilinearity
  #   K    : The inferred number of latent components (<= starting K)
  #posterior: posterior samples of saved parameters (see function getDefaultOpts())
  
  if(file.exists(filename)) { #Part of the sampling done already
    load(filename)
    
  } else { #Else initialization
    # Store dimensionalities of data sets.
    M <- length(Y)
    
    #Sharing of W's
    if(length(min(sharing):max(sharing))!=length(unique(sharing)))
      stop("Sharing ill-defined")
    independent <- which(sharing==0)
    foo <- table(sharing)
    MM <- length(foo) #Different tensors
    if(any(sharing==0)) {
      MM <- MM-1
      foo <- foo[-which(names(foo)==0)]
    }
    shared <- vector("list",length=MM)
    Dsh <- rep(NA,MM)
    for(mm in 1:MM) {
      shared[[mm]] <- which(sharing==names(foo)[mm])
      Dsh[mm] <- ncol(Y[[shared[[mm]][1]]])
    }
    
    D <- sapply(Y,ncol)
    Ds <- c(1,cumsum(D)+1) ; De <- cumsum(D)
    gr <- vector("list")
    for(m in 1:M) {
      gr[[m]] <- Ds[m]:De[m]
      if (!is.null(colnames(Y[[m]]))) {
        names(gr[[m]]) <- colnames(Y[[m]])
      }
    }
    Dshared <- vector("list",length=MM)
    for(mm in 1:MM) {
      Dshared[[mm]] <- vector()
      for(m in shared[[mm]])
        Dshared[[mm]] <- c(Dshared[[mm]],gr[[m]])
    }
    
    Y <- do.call(cbind,Y)
    D <- ncol(Y)
    N <- nrow(Y)
    
    
    
    alpha_0 <- opts$prior.alpha_0
    beta_0 <- opts$prior.beta_0
    alpha_0t <- opts$prior.alpha_0t
    beta_0t <- opts$prior.beta_0t
    alpha_0g <- opts$prior.alpha_0g
    beta_0g <- opts$prior.beta_0g
    
    if (!is.null(W) & !is.null(rz) & !is.null(tau) & !is.null(beta)) { # Projection given as an argument.
      projection.fixed = TRUE
      K = ncol(W)
    } else {
      projection.fixed = FALSE
    }
    
    # Some constants for speeding up the computation
    #const <- - N*Ds/2*log(2*pi) # Constant factors for the lower bound
    Yconst <- colSums(Y^2) #unlist(lapply(Y,function(x){sum(x^2)}))
    id <- rep(1,K)              # Vector of ones for fast matrix calculations
    
    ##
    ## Initialize the model randomly
    ##
    # Other initializations could be done, but overdispersed random initialization is quite good.
    
    # Latent variables
    X <- matrix(rnorm(N*K,0,1),N,K)
    X <- scale(X)
    if(!is.null(init))
      X[,1:ncol(init$X)] <- init$X
    covZ <- diag(1,K) # Covariance
    XX <- crossprod(X) # Second moments
    
    if(!projection.fixed) {
      if(opts$normalLatents & opts$ARDLatent=="shared" & projection.fixed){
        beta <- outer(rep(1,N),beta)
      }else{
        beta <- matrix(1,N,K)
      }
    }
    
    tmp <- matrix(0,K,K)
    
    i.pred <- 0
    if (opts$iter.saved>0 & opts$iter.burnin<opts$iter.max) {
      mod.saved <- ceiling((opts$iter.max-opts$iter.burnin)/opts$iter.saved) # Save every 'mod.saved'th Gibbs sample.
      S <- floor((opts$iter.max-opts$iter.burnin)/mod.saved) # Number of saved Gibbs samples.
    }
    
    if (projection.fixed) {
      obs = NULL # variable indices of the source views
      for (mi in 1:length(opts$prediction)) { # Go through all views.
        if (!opts$prediction[mi]) {
          obs = c(obs, gr[[mi]])
        }
      }
      if(!opts$normalLatents){
        WtW <- crossprod(W[obs,]*sqrt(tau[obs]))
        WtWdiag <- diag(WtW)
        diag(WtW) <- 0
        if (is.null(dim(rz))) {
          if (length(rz)==K) {
            rz = matrix(data=rz, nrow=N, ncol=K, byrow=TRUE)
          } else {
            print("ERROR: rz not of required length")
          }
        }
      }else{
        covZ <- diag(1,K) + crossprod(W[obs,]*sqrt(tau[obs]))
        eS <- eigen(covZ)
        tmp <- eS$vectors*outer(id,1/sqrt(eS$values))
        covZ <- tcrossprod(tmp)
      }
      if (any(opts$prediction)) {
        prediction = vector(mode="list", length=length(gr))
        Y.true = Y
        for (m in which(opts$prediction)) {
          prediction[[m]] = matrix(0,N,length(gr[[m]]))
          Y[,gr[[m]]] <- 0
        }
      }
      
      cost <- rep(NA,opts$iter.max)
      const <- 0
      for(m in which(opts$prediction==F))
        const <- const - N*length(gr[[m]])*log(2*pi)/2
      a_tau <- alpha_0t + N/2 # The parameters of the Gamma distribution
      b_tau <- rep(beta_0t,D) # for the noise precisions
      
    } else {
      if (any(opts$prediction)) {
        print("Error: Prediction without projections given")
      }
      WtW <- matrix(0,K,K)
      WtWdiag <- rep(0,K)
      
      const <- - N*D*log(2*pi)/2
      tau <- rep(opts$init.tau,D) # The mean noise precisions
      a_tau <- alpha_0t + N/2 # The parameters of the Gamma distribution
      b_tau <- rep(beta_0t,D) # for the noise precisions
      
      W <- matrix(0,D,K)
      V <- gamma <- vector("list",length=MM)
      for(mm in 1:MM) {
        V[[mm]] <- matrix(0,Dsh[mm],K)
        gamma[[mm]] <- matrix(1,Dsh[mm],K)
      }
      Z <- matrix(1,D,K)
      U <- matrix(NA,M,K)
      covW <- diag(1,K)
      
      alpha <- matrix(1,D,K)
      
      cost <- vector()  # For storing the lower bounds
      lambda <- mu <- zone <- rep(1,D)
      r <- matrix(0.5,D,K)
      rz <- matrix(0.5,N,K)
    }
    
    ##Missing Values
    missingValues <- FALSE
    na.inds <- which(is.na(Y))
    if(length(na.inds)>0 & !projection.fixed) {
      missingValues <- TRUE
      if(opts$verbose>0)
        print("Missing Values Detected, Prediction using EM type approximation")
    }
    #missingValues.InitIter 
    #0 - start prediction from mean of all values, and use updates of the model allways
    #between 2 and maxiter - use updates of the model in a single imputation setting for  iterations past the set value.
    #maxiter+1 - ignore missing values. never update with the model's single imputation. 
    if(!is.null(opts$iter.missing))
      missingValues.InitIter <- opts$iter.missing
    else
      missingValues.InitIter <- round(opts$iter.burnin/2)
    
    if(missingValues & opts$verbose>0 & !projection.fixed)
      print(paste("missingValues.InitIter=",missingValues.InitIter))
    ## Missing Values end
    
    posterior <- NULL
    iter <- 1
  }
  
  ##
  ## The main loop
  ##
  for(iter in iter:opts$iter.max) {    
    if (!projection.fixed) {
      if(iter == 1){
        covW <- diag(1,K) + opts$init.tau*XX
        eS <- eigen(covW)
        tmp <- eS$vectors*outer(id,1/sqrt(eS$values))
        covW <- tcrossprod(tmp)
        
        if(missingValues && (missingValues.InitIter < iter)) {
          ##Could give bad init when high missing values. therefore skip from here.
          Y[na.inds] <- mean(Y[-na.inds])
        }
        estimW = matrix(0,D,K) #equivalent of crossprod(Y,X)
        for(k in 1:K){
          if(missingValues && (missingValues.InitIter >= iter)) {
            tmpY <- Y
            tmpY[is.na(tmpY)] <- 0
            estimW[,k] <- crossprod(tmpY,X[,k])
          } else {
            estimW[,k] <- crossprod(Y,X[,k])
          }
        }
        W <- estimW%*%covW*opts$init.tau + matrix(rnorm(D*K),D,K)%*%tmp
        
        #Initialize V
        for(mm in 1:MM) {
          for(m in shared[[mm]])
            V[[mm]] <- V[[mm]] + W[gr[[m]],]
          V[[mm]] <- V[[mm]]/length(shared[[mm]])
        }
        for(mm in 1:MM) {
          for(m in shared[[mm]])
            W[gr[[m]],] <- V[[mm]] + matrix(rnorm(Dsh[mm]*K,0,1),Dsh[mm],K) #Init with fixed 0.1 lambda
        }
        
        for(mm in 1:MM)
          U[shared[[mm]],] <- 1 #Random initialization could be tried as well
        
      }else{
        #Update U
        for(i in shared[[1]]) {
          if(!all(is.na(U[i,]))) {
            mu <- 0
            lambda <- 1
            for(mm in 1:MM) {
              m <- shared[[mm]][which(shared[[1]]==i)]
              mu <- mu + alpha[m,]*colSums(Z[gr[[m]],]*V[[mm]]*W[gr[[m]],])
              lambda <- lambda + alpha[m,]*colSums(Z[gr[[m]],]*V[[mm]]^2)
            }
            U[i,] <- mu/lambda + rnorm(K)*lambda^(-0.5)
          }
        }
        if(MM>1) {
          for(mm in 2:MM) {
            for(m in shared[[mm]])
              U[m,] <- U[shared[[1]][which(shared[[mm]]==m)],]
          }
        }
        
        #Sampling V
        for(mm in 1:MM) {
          V[[mm]] <- V[[mm]]*0
          sigma <- gamma[[mm]]
          for(m in shared[[mm]]) {
            V[[mm]] <- V[[mm]] + sweep(alpha[gr[[m]],]*Z[gr[[m]],]*W[gr[[m]],],MARGIN=2,U[m,],'*')
            sigma <- sigma + sweep(Z[gr[[m]],]*alpha[gr[[m]],],MARGIN=2,U[m,]^2,'*')
          }
          V[[mm]] <- V[[mm]]/sigma
          V[[mm]] <- V[[mm]] + matrix(rnorm(Dsh[mm]*K),Dsh[mm],K)*sigma^(-0.5)
        }
        
        XXdiag <- diag(XX)
        diag(XX) <- 0
        for(k in 1:K){
          lambda <- tau*XXdiag[k] + alpha[,k]
          if(missingValues && (missingValues.InitIter >= iter)) {
            ss <- tcrossprod(X[,-k],W[,-k])
            tmp <- Y-ss
            tmp[na.inds] <- 0
            mu_sub <- crossprod(tmp,X[,k])
            mu <- mu_sub*tau
          } else {
            mu <- tau*as.vector( crossprod(Y,X[,k]) - W%*%XX[k,])
          }
          for(mm in 1:MM) {
            mu[Dshared[[mm]]] <- mu[Dshared[[mm]]] + rep(U[shared[[mm]],k],rep(Dsh[mm],length(shared[[mm]])))*
              rep(V[[mm]][,k],length(shared[[mm]]))*alpha[Dshared[[mm]],k]
          }
          
          if(iter > opts$sampleZ){
            if(opts$spikeW!="group"){
              zone <- 0.5*( log(alpha[,k]) - log(lambda) + mu^2/lambda) + log(r[,k]) - log(1-r[,k])
              for(mm in 1:MM) {
                zone[Dshared[[mm]]] <- zone[Dshared[[mm]]] - 0.5*rep(U[shared[[mm]],k],rep(Dsh[mm],length(shared[[mm]])))^2*
                  rep(V[[mm]][,k]^2,length(shared[[mm]]))*alpha[Dshared[[mm]],k]
              }
              zone <- 1/(1+exp(-zone))
              zone <- as.double(runif(D) < zone)
              Z[,k] <- zone
            }else{
              zone <- 0.5*(log(alpha[,k])-log(lambda) + mu^2/lambda)
              for(m in 1:M){
                logpr <- sum(zone[gr[[m]]]) + log(r[gr[[m]][1],k]) - log( 1-r[gr[[m]][1],k] )
                for(mm in 1:MM) {
                  if(m%in%shared[[mm]])
                    logpr <- logpr  - 0.5*sum(U[m,k]^2*V[[mm]][,k]^2*alpha[gr[[m]],k])
                }
                logpr <- 1/(1+exp(-logpr))
                Z[gr[[m]],k] <- as.double(runif(1)<logpr)
              }
            }
          }
          W[,k] <- mu/lambda + rnorm(D)*lambda^(-0.5)
        }
      }
      W <- W*Z
    } # Rest for fixed projections (new sample prediction) too
    
    ## Check if components can be removed.
    # If Gibbs samples are going to be saved, components are dropped only during the burn-in.
    # If the projection has been given as an argument (e.g. for prediction), components will not be dropped.
    if ((iter<=opts$iter.burnin | opts$iter.saved==0) & !projection.fixed) {
      zm <- colSums(Z)
      keep <- which(zm!=0)
      if(length(keep)==0) {
        print("All components shut down, returning a NULL model.")
        return(list())
      }
      if(length(keep)!=K){
        K <- length(keep)
        id <- rep(1,K)
        alpha <- alpha[,keep,drop=F]
        X <- X[,keep,drop=F]
        W <- W[,keep,drop=F]
        for(mm in 1:MM) {
          V[[mm]] <- V[[mm]][,keep,drop=F]
          gamma[[mm]] <- gamma[[mm]][,keep,drop=F]
        }
        U <- U[,keep,drop=F]
        Z <- Z[,keep,drop=F]
        r <- r[,keep,drop=F]
        beta <- beta[,keep,drop=F]
        rz <- rz[,keep,drop=F]
      }
    }
    
    ## sample r
    
    if(iter>opts$sampleZ & !projection.fixed){
      if(opts$spikeW=="shared"){
        zm <- colSums(Z)
        for(k in 1:K)
          r[,k] <- (zm[k]+opts$prior.beta[1])/(D+sum(opts$prior.beta))
      }
      if(opts$spikeW=="grouped"){
        for(m in 1:M){
          zm <- colSums(Z[gr[[m]],,drop=FALSE])
          for(k in 1:K)
            r[gr[[m]],k] <- (zm[k]+opts$prior.beta[1])/(length(gr[[m]])+sum(opts$prior.beta)) #rbeta(1,shape1=zm[k]+1,shape2=length(gr[[m]])-zm[k]+1)
        }
      }
      if(opts$spikeW=="group"){
        zm <- rep(0,K)
        for(m in 1:M){
          zm <- zm + Z[gr[[m]][1],]
        }
        for(m in 1:M){
          for(k in 1:K){
            r[gr[[m]],k] <- (zm[k]+opts$prior.beta[1])/(M+sum(opts$prior.beta))
          }
        }
      }
    }
    
    ## 
    ## Update the latent variables
    ##
    if(iter>opts$sampleZ & !opts$normalLatents){
      if (!projection.fixed) {
        WtW <- crossprod(W*sqrt(tau))
        WtWdiag <- diag(WtW)
        diag(WtW) <- 0
      }
      for(k in 1:K){
        lambda <- WtWdiag[k] + beta[,k]
        if(missingValues && (missingValues.InitIter >= iter)) {
          tmpY <- Y
          tmpY[is.na(tmpY)] <- 0
          mu <- (tmpY%*%(W[,k]*tau) - X%*%WtW[k,])/lambda
        } else {
          mu <- (Y%*%(W[,k]*tau) - X%*%WtW[k,])/lambda
        }
        zone <- 0.5*( log(beta[,k]) - log(lambda) + lambda*mu^2) + log(rz[,k]) - log(1-rz[,k])
        zone <- 1/(1+exp(-zone))
        zone <- as.double(runif(N) < zone)
        X[,k] <- mu + rnorm(N)*lambda^(-0.5)
        X[,k] <- X[,k]*zone
        if (!projection.fixed) {
          zm <- sum(zone)
          rz[,k] <- (zm+opts$prior.betaX[1])/(N+sum(opts$prior.betaX))
        }
      }
    } else {
      if (!projection.fixed) {
        covZ <- diag(1,K) + crossprod(W*sqrt(tau))
        foo <- try(eS <- eigen(covZ))
        if(inherits(foo,"try-error"))
          browser()
        tmp <- eS$vectors*outer(id,1/sqrt(eS$values))
        covZ <- tcrossprod(tmp)
      }
      if(missingValues && (missingValues.InitIter >= iter)) {
        tmpY <- Y
        tmpY[is.na(tmpY)] <- 0
        X <- tmpY%*%(W*tau)
      } else {
        X <- Y%*%(W*tau)
      }
      X <- X%*%covZ + matrix(rnorm(N*K),N,K)%*%tmp
    }
    XX <- crossprod(X)
    
    ##
    ## Update alpha, beta and gamma, the ARD parameters
    ##
    
    if(!opts$normalLatents){
      if(opts$ARDLatent=="element"){
        beta <- X^2/2 + opts$prior.beta_0X
        keep <- which(beta!=0)
        for(n in keep) beta[n] <- rgamma(1,shape=opts$prior.alpha_0X+0.5,rate=beta[n]) +1e-7
        beta[-keep] <- rgamma( length(beta[-keep]),shape=opts$prior.alpha_0X,rate=opts$prior.beta_0X) + 1e-7
      }
      if(opts$ARDLatent == "shared" & !projection.fixed){
        xx <- colSums(X^2)/2 + opts$prior.beta_0X
        tmpxx <- colSums(X!=0)/2 + opts$prior.alpha_0X
        for(k in 1:K)
          beta[,k] <- rgamma(1,shape=tmpxx[k],rate=xx[k]) + 1e-7
      }
    }
    if (!projection.fixed) {
      if(opts$ARDW=="shared"){
        ww <- colSums(W^2)/2 + beta_0
        tmpz <- colSums(Z)/2 + alpha_0
        for(k in 1:K)
          alpha[,k] <- rgamma(1,shape=tmpz[k],rate=ww[k]) + 1e-7
      } else if(opts$ARDW=="grouped"){
        for(m in 1:M){
          ww <- colSums(W[gr[[m]],,drop=FALSE]^2)/2 + beta_0
          tmpz <- colSums(Z[gr[[m]],,drop=FALSE])/2 + alpha_0
          for(k in 1:K)
            alpha[gr[[m]],k] <- rgamma(1,shape=tmpz[k],rate=ww[k]) + 1e-7
        }
      } else if(opts$ARDW=="element"){
        alpha <- W^2/2 + beta_0
        keep <- which(alpha!=0)
        for(n in keep) alpha[n] <- rgamma(1,shape=alpha_0+0.5,rate=alpha[n]) +1e-7
        alpha[-keep] <- rgamma( length(alpha[-keep]),shape=alpha_0,rate=beta_0) + 1e-7
      }
      
      # Shared views
      beta_sh <- opts$prior.beta_sh
      if(opts$ARDshared=="shared") {
        for(mm in 1:MM) {
          beta_sh <- opts$prior.beta_sh
          for(m in shared[[mm]])
            beta_sh <- beta_sh + sum(Z[gr[[m]],]*(W[gr[[m]],]-sweep(V[[mm]],MARGIN=2,U[m,],'*'))^2)/2
          alpha[Dshared[[mm]],] <- rgamma(1, shape=opts$prior.alpha_sh+sum(Z[Dshared[[mm]],])/2, rate=beta_sh)
        }
      } else if(opts$ARDshared=="grouped") {
        for(mm in 1:MM) {
          for(m in shared[[mm]]) {
            alpha[gr[[m]],] <- rgamma(1, shape=opts$prior.alpha_sh+sum(Z[gr[[m]],])/2, rate=beta_sh +
                                        sum(Z[gr[[m]],]*(W[gr[[m]],]-sweep(V[[mm]],MARGIN=2,U[m,],'*'))^2)/2)
          }
        }
      }
      
      #Gamma
      for(mm in 1:MM) {
        if(opts$ARDV=="grouped") {
          gamma[[mm]] <- t(t(gamma[[mm]]*0) + rgamma(K, shape=alpha_0g+0.5*Dsh[mm], rate=beta_0g+0.5*colSums(V[[mm]]^2)))
        } else if(opts$ARDV=="element") {
          for(k in 1:K)
            gamma[[mm]][,k] <- rgamma(Dsh[mm], shape=alpha_0g+0.5, rate=beta_0g+0.5*V[[mm]][,k]^2)
        }
      }
      
      ##
      ## Update tau, the noise precisions
      ##
      
      XW <- tcrossprod(X,W)
      if(missingValues && (missingValues.InitIter > iter)) {
        #in the last iter of ignoring missing values, update Y[[m]] with the model updates
        XW <- Y - XW
        XW[is.na(XW)] <- 0
        b_tau <- colSums(XW^2)/2
      } else{
        if(missingValues) {
          Y[na.inds] <- XW[na.inds] #TODO: adding noise should be tested in here
        }
        b_tau <- colSums((Y-XW)^2)/2
      }
      if(opts$tauGrouped){
        for(m in 1:M){
          tau[gr[[m]]] <- rgamma(1,shape=alpha_0t+N*length(gr[[m]])/2, rate=beta_0t+sum(b_tau[gr[[m]]]))
        }
      }else{
        for(d in 1:D)
          tau[d] <- rgamma(1,shape=a_tau, rate= beta_0t+b_tau[d])
      }
      
      # calculate likelihood
      cost[iter] <- const + N*sum(log(tau))/2 - crossprod(b_tau,tau)
    }
    
    if(filename != "" & iter%%100==0) { #Every 100 iterations, save the sampling
      save(list=ls(),file=filename)
      if(opts$verbose>0)
        print(paste0("Iter ", iter, ", saved chain to '",filename,"'"))
    }
    
    # Calculate likelihood of observed views
    if(projection.fixed) {
      tautmp <- ( Yconst + rowSums((W%*%XX)*W) - 2*rowSums(crossprod(Y,X)*W) )/2
      tautmp <- a_tau/(beta_0t+tautmp)
      cost[iter] <- const
      for(m in which(opts$prediction==F))
        cost[iter] <- cost[iter] + N*sum(log(tautmp[gr[[m]]]))/2 - crossprod(b_tau[gr[[m]]],tautmp[gr[[m]]])
    }
    
    ##
    ## Prediction and collection of Gibbs samples
    ##
    
    if (any(opts$prediction)) { ## Prediction
      if (iter%%10==0 & opts$verbose>1) {
        print(paste0("Predicting: ",iter,"/",opts$iter.max))
      }
      for (m in which(opts$prediction)) { # Go through the views that will be predicted.
        tmpPred <- tcrossprod(x=X, y=W[gr[[m]],]) # Update the prediction.
        if (iter>opts$iter.burnin & ((iter-opts$iter.burnin)%%mod.saved)==0) {
          i.pred <- i.pred+1/sum(opts$prediction)
          prediction[[m]] = prediction[[m]] + tmpPred
        }
      }
      
    } else {
      if (iter%%10==0 & opts$verbose>0) {
        print(paste0("Learning: ",iter,"/",opts$iter.max," - K=",sum(colSums(Z)!=0)," - ",Sys.time()))
      }
      if (opts$iter.saved>0) { ## Collection of Gibbs samples
        
        if (iter>opts$iter.burnin & ((iter-opts$iter.burnin)%%mod.saved)==0) { ## Save the Gibbs sample.
          if (is.null(posterior)) { # Initialize the list containing saved Gibbs samples.
            posterior <- list() # List containing saved Gibbs samples
            if (opts$save.posterior$W & !any(opts$prediction)) {
              posterior$W <- array(dim=c(S, dim(W))) # SxDxK
              if(length(V)==1)
                posterior$V <- array(dim=c(S, dim(V[[1]]))) # SxDxK
              posterior$U <- array(dim=c(S, dim(U))) # SxMxK
              colnames(posterior$W) = colnames(Y)
            }
            posterior$rz <- array(dim=c(S, ncol(rz))) # SxK - Same 'rz' for all samples. Thus, save only a vector per iteration.
            posterior$r <- array(dim=c(S, M,K)) # SxKxM - Same 'r' for all variables of a view. Thus, save only a components x views matrix per iteration.
            if (opts$save.posterior$tau & !any(opts$prediction)) {
              posterior$tau <- array(dim=c(S, length(tau))) # SxD
              colnames(posterior$tau) = colnames(Y)
            }
            if (opts$save.posterior$X & !any(opts$prediction)) {
              posterior$X = array(dim=c(S, dim(X))) # SxNxK
            }
            posterior$beta = array(dim=c(S,dim(beta)))
            gr.start = vector(mode="integer", length=length(gr))
            for (m in 1:length(gr)) { # Find the first index of each view. (Used for saving 'r'.)
              gr.start[m] = gr[[m]][1]
            }
          }
          s <- (iter-opts$iter.burnin)/mod.saved
          if (!projection.fixed) {
            if (opts$save.posterior$W) {
              posterior$W[s,,] <- W
              if(length(V)==1)
                posterior$V[s,,] <- V[[1]]
              posterior$U[s,,] <- U
            }
            posterior$tau[s,] <- tau
            posterior$rz[s,] <- rz[1,] # In the current model, 'rz' is identical for all samples (rows are identical). Thus, save only a vector.
            posterior$r[s,,] <- r[gr.start,] # In the current model, 'r' is identical for all variables within a view. Thus, save only a views x components vector.
            posterior$beta[s,,] <- beta
          }
          if (opts$save.posterior$X) {
            posterior$X[s,,] <- X
          }
        }
      }
    }
    
  } ## The main loop of the algorithm ends.
  
  if(filename!="" & opts$iter.max >=10)
    file.remove(filename) #No need for temporary storage any more
  ## Return the output of the model as a list
  if (any(opts$prediction)) {
    for (m in which(opts$prediction)) { # Go through the views that will be predicted.
      prediction[[m]] <- prediction[[m]]/i.pred
    }
    prediction$cost <- cost
    return(prediction)
  } else {
    d1 <- unlist(lapply(gr,function(x){x[1]}))
    if(MM==1 & opts$ARDV=="grouped")
      gamma <- gamma[[1]][1,]
    if(MM==1 & opts$ARDshared=="shared")
      alpha_sh <- alpha[Dshared[[1]][1],1]
    else if(MM==1 & opts$ARDshared=="grouped")
      alpha_sh <- alpha[unlist(lapply(gr,function(x){x[1]}))[shared[[1]]],1]
    else
      alpha_sh <- NA
    if(opts$ARDW=="grouped")
      alpha <- alpha[d1,]
    if(opts$spikeW=="group")
      Z <- Z[d1,]
    
    return(list(W=W, V=V, U=U, X=X, Z=Z, r=r, rz=rz, tau=tau, alpha=alpha, beta=beta, gamma=gamma,
                alpha_sh=alpha_sh, groups=gr, D=D, K=K,cost=cost, posterior=posterior, opts=opts))
  }
  
}


getDefaultOpts <- function() {
  #
  # A function for generating a default set of parameters.
  #
  # To run the algorithm with the default values:
  #   opts <- getDefaultOpts()
  #   model <- gibbsgfa(Y,K,opts,sharing)
  
  #
  # Initial value for the noise precisions. Should be large enough
  # so that the real structure is modeled with components
  # instead of the noise parameters (see Luttinen&Ilin, 2010)
  #  Values: Positive numbers, but generally should use values well
  #          above 1
  #
  init.tau <- 10^3
  
  #
  # Parameters for controlling when the algorithm stops.
  # It stops when the number of samples is iter.max.
  #
  iter.max <- 5000
  iter.saved <- 0
  iter.burnin <- floor(iter.max/2)
  
  #
  # Hyperparameters
  # - alpha_0, beta_0 for the ARD precisions
  # - alpha_0t, beta_0t for the residual noise predicions
  #
  prior.alpha_0 <- prior.beta_0 <- 1
  prior.alpha_0t <- prior.beta_0t <- 1
  prior.alpha_0g <- prior.beta_0g <- 1
  
  prior.alpha_sh <- prior.beta_sh <- 1
  
  # Different spike-and-slab sparsity priors for the loadings W.
  # - group sparsity ("group"): binary variable indicating activity for each factor and data set 
  # - sructured ("grouped"): for each element a separate indicator, but the prior is structured
  #   such that each data set and factor share the same hyperparameter. This sharing promotes
  #   group sparsity in an indirect way.
  # - unstructured ("shared"): This prior does not use group structure at all.
  
  spikeW <- "group"
  
  # A parameter telling when to start sampling spike and slab parameters.
  sampleZ <- 1
  
  # Hyperparameter for indicator variables.
  # The first element is prior count for ones and the second for zeros.
  prior.beta <- c(1,1)
  
  # Different ARD priors for the loadings W.
  # - "grouped": Shared precision for each factor and group.
  # - "element": Separate precision for each element in W.
  # - "shared": No group structure.
  ARDW <- "grouped"
  ARDV <- "grouped"
  
  # ARD for the shared projections with mean V
  # - "shared": same precision for all the (m,d,k) combinations
  # - "grouped": precision for each view inferred independently
  ARDshared <- "shared"
  
  # Gaussian prior for the latent variables?
  # - if TRUE then corresponds to standard Gaussian
  # - if FALSE then spike and slab prior is used
  normalLatents <- TRUE
  prior.betaX <- c(1,1) # the hyperparameters of the spike and slab
  # Set prior for the precisions of non-Gaussian latent variables.
  # options: ("shared", "element") (for help see prior for W)
  ARDLatent <- "shared"
  # The hyperparameter values for precisions corresponding to the latent variables.
  # Should use quite conservative values due sensible scaling.
  prior.alpha_0X <- 1
  prior.beta_0X <- 1
  
  # Prior for the noise precision.
  # TRUE/FALSE:
  # - TRUE: a group shares the same noise residual
  # - FALSE: each data feature has separate noise precision.
  tauGrouped <- TRUE
  
  return(list(init.tau=init.tau,
              iter.max=iter.max, iter.saved=iter.saved, iter.burnin=iter.burnin,
              prior.alpha_0=prior.alpha_0,prior.beta_0=prior.beta_0,
              prior.alpha_sh=prior.alpha_sh, prior.beta_sh=prior.beta_sh,
              prior.alpha_0t=prior.alpha_0t,prior.beta_0t=prior.beta_0t,
              prior.alpha_0g=prior.alpha_0g,prior.beta_0g=prior.beta_0g,
              prior.beta=prior.beta,prior.betaX=prior.betaX,
              spikeW=spikeW,sampleZ=sampleZ,normalLatents=normalLatents,
              ARDLatent=ARDLatent,prior.alpha_0X=prior.alpha_0X,ARDshared=ARDshared,
              prior.beta_0X=prior.beta_0X,ARDW=ARDW,ARDV=ARDV,tauGrouped=tauGrouped,verbose=1))
}


predictrMTF <- function(Y, model, opts, sharing=NULL, filename="") {
  ## Predict a set of data views given the rMTF model (Gibbs samples) and another set of data views.
  ##
  ##
  ## Arguments
  ##
  ## Y: Data views in a list
  ## model: Object from the function 'rMTF'
  ## opts$predict: Logical vector with the length matching the length of Y,
  ## describing which data views will be predicted (TRUE) from the others (FALSE).
  ## sharing: defines which views belong to a tensor
  
  
  ##
  ## Initialization
  ##
  
  N <- nrow(Y[[which(!opts$prediction)[1]]]) # Number of samples in the prediction set
  
  X <- vector("list",length=length(Y)) #Store the data to be predicted
  
  Y.prediction <- vector(mode="list", length=length(model$groups))
  names(Y.prediction) <- names(Y) # names(De)
  for (mi in which(opts$prediction)) { # Go through all views that will be predicted.
    Y.prediction[[mi]] <- array(data=0, dim=c(N, length(model$groups[[mi]]))) # Initialize the prediction matrix for view 'mi'.
    rownames(Y.prediction[[mi]]) <- rownames(Y[[which(!opts$prediction)[1]]]) # Names of the prediction samples
    colnames(Y.prediction[[mi]]) <- names(model$groups[[mi]]) # Names of the variables in view 'mi', which will be predicted.
    X[[mi]] <- Y[[mi]]
    Y[[mi]] <- array(dim=dim(Y.prediction[[mi]]), dimnames=dimnames(Y.prediction[[mi]])) # Initialize view 'mi' as missing data. These will be predicted.
  }
  
  ##
  ## Prediction
  ##
  mse <- matrix(NA,length(which(opts$prediction)),nrow(model$posterior$W))
  cost <- matrix(NA,opts$iter.max,nrow(model$posterior$W))
  start <- 1
  if(filename!="") {
    if(file.exists(filename)) {
      load(filename)
      start <- which(is.na(mse[1,]))[1]
      print(paste0("Loaded predictions from '",filename,"', continuing with posterior sample ",start))
    }
  }
  
  for (ni in start:nrow(model$posterior$W)) { # Go through all saved Gibbs samples.
    if(opts$verbose>0)
      print(paste0("Predicting, Gibbs sample ",ni))
    W <- matrix(model$posterior$W[ni,,],nrow(model$W),ncol(model$W))
    beta <- matrix(model$posterior$beta[ni,,],nrow(model$beta),ncol(model$beta))
    
    U <- matrix(model$posterior$U[ni,,],nrow(model$U),ncol(model$U))
    prediction.ni <- rMTF(Y=Y, K=NULL, opts=opts, sharing=sharing, W=W, V=NA,
                          U=U,rz=model$posterior$rz[ni,], tau=model$posterior$tau[ni,],
                          beta=model$posterior$beta[ni,,])
    
    for (mi in which(opts$prediction)) { # Go through the target views.
      Y.prediction[[mi]] = Y.prediction[[mi]] + prediction.ni[[mi]]/nrow(model$posterior$W)
      const <- nrow(model$posterior$W)/ni
      mse[which(mi==which(opts$prediction)),ni] <- mean((X[[mi]]-Y.prediction[[mi]]*const)^2)
      cost[,ni] <- prediction.ni$cost
      if(opts$verbose>1)
        print(paste0("MSE at iteration ",ni,": ",mse[which(mi==which(opts$prediction)),ni]))
    }
    if(filename!="" & ni%%10==0) {
      save(Y.prediction,mse,cost,file=filename)
      print(paste0("Saved tmp predictions to: '",filename,"'"))
    }
  }
  Y.prediction$mse <- mse
  Y.prediction$cost <- cost
  
  return(Y.prediction)
  
}
