# Mehmet Gonen (mehmet.gonen@gmail.com)

kbmf1k1k_semisupervised_regression_variational_train <- function(Kx, Kz, Y, parameters) {
  set.seed(parameters$seed)

  if (any(colSums(is.na(Y))==0) || any(rowSums(is.na(Y))==0)) print('missing value handling assumes missingness everywhere, problems can arise with indices!')
  
  Dx <- dim(Kx)[1]
  Nx <- dim(Kx)[2]
  Dz <- dim(Kz)[1]
  Nz <- dim(Kz)[2]
  R <- parameters$R
  sigma_g <- parameters$sigma_g
  sigma_y <- parameters$sigma_y
  
  ##############################################
  # indexes of 1-way effects and 2-way effects #
  ##############################################
  
  
  # effective ranks after taking into account fixed rows/cols
  if (is.null(parameters$fixed.cols)) {
    
    Rx.fixed <- 0
  } else {
    Rx.fixed <- ncol(parameters$fixed.cols)
  }
  
  if (is.null(parameters$fixed.rows)) {
    
    Rz.fixed <- 0
  } else {
    
    Rz.fixed <- nrow(parameters$fixed.rows)
    
  }
  
  print('TODO: if only 1 direction has fixed rows/cols, make sure it is the first one. Write a test for that.')
  
  Rx.eff <- R - Rx.fixed
  
  # the indexes that will be updated
  if (Rx.eff > 0) {
    R.eff.inds.X <- (Rx.fixed+1):R
    R.fixed.inds.X <- setdiff(1:R, R.eff.inds.X)  
  } else {
    R.eff.inds.X <- NULL
    R.fixed.inds.X <- 1:R
  }
  
  
  Rz.eff <- R - Rz.fixed
  
  # the indexes that will be updated
  if (Rx.fixed > 0) R.eff.inds.Z <- c(1:Rx.fixed) else R.eff.inds.Z <- c()
  
  R.fixed.total <- Rx.fixed+Rz.fixed
  
  if (R < R.fixed.total) stop('not enough components for all 1-way effects')
  
  if (R > R.fixed.total) {
    R.eff.inds.Z <- c(R.eff.inds.Z, (R.fixed.total+1):R)
  } 
  R.fixed.inds.Z <- setdiff(1:R, R.eff.inds.Z)
  
  # 2way: 
  R.inds.2way <- intersect(R.eff.inds.X, R.eff.inds.Z)
  R.inds.2way.X <- tail(1:Rx.eff, length(R.inds.2way))
  R.inds.2way.Z <- tail(1:Rz.eff, length(R.inds.2way))
  
  R.2way <- length(R.inds.2way)
  
  
  # fix the parameters for 1way-X components
  if (parameters$f1w.X) {
    
    if (nrow(parameters$fixed.rows)>1) stop('do you want to fix all 1way (X) components?')
    
    print('fixing Hx and Ax for 1way_g components')
    R.eff.inds.X <- setdiff(R.eff.inds.X, R.fixed.inds.Z)
    Ax.inds.to.update <- R.inds.2way.X
    
  } else {
    
    Ax.inds.to.update <- 1:Rx.eff
  }
  
  
  if (parameters$f1w.Z) {
    if (nrow(parameters$fixed.rows)>1) stop('do you want to fix all 1way (X) components?')
    
    print('fixing Hx and Ax for 1way_g components')
  
    R.eff.inds.Z <- setdiff(R.eff.inds.Z, R.fixed.inds.X)
    Az.eff.inds <- setdiff(1:Rz.eff, 1:length(R.fixed.inds.X))
    
    
  } else {
    
    
    Az.eff.inds <- 1:Rz.eff
  }
  
  if (is.null(parameters$sigma.H.fe)) {
    
    Az.inds.to.update <- Az.eff.inds
  } else {
   
    # in the case of fixed effects: do not update Az for the fixed effects
    # -components (keep mean fixed to 0)
    Az.inds.to.update <- setdiff(Az.eff.inds, 1)
    
  }
  
 
   
  
  
  
  
  parameters$R.eff.inds.X <- R.eff.inds.X
  parameters$R.eff.inds.Z <- R.eff.inds.Z
  
  print('R.eff.inds.X')
  print(R.eff.inds.X)
  print('R.eff.inds.Z')
  print(R.eff.inds.Z)
  
  init.coef <- 0

  
  ####################################################################
  # now sigma_gx and sigma_gz are different for different components #
  ####################################################################
  # Parameters given assuming multiple kernel implementation. 
  # Hence, replace _h... with _g... #
  
  parameters$sigma_gx <- parameters$sigma_hx
  parameters$sigma_gz <- parameters$sigma_hz
  
  
  sigma_gxs <- rep(parameters$sigma_gx, R)
  if (length(R.fixed.inds.X)>0) {
    sigma_gxs[R.fixed.inds.X] <- parameters$sigma.H.1way
  }
  
  sigma_gzs <- rep(parameters$sigma_gz, R)
  
  if (length(R.fixed.inds.Z)>0) {
    if (!is.null(parameters$sigma.H.fe)) {
      sigma_gzs[R.fixed.inds.Z] <- parameters$sigma.H.fe 
      
    } else {
      
      sigma_gzs[R.fixed.inds.Z] <- parameters$sigma.H.1way 
    }
  }
  
  Lambdax <- list(alpha = matrix(NA, Dx, Rx.eff), beta = matrix(NA, Dx, Rx.eff))
  Lambdaz <- list(alpha = matrix(NA, Dz, Rz.eff), beta = matrix(NA, Dz, Rz.eff))
  
  
  if (length(R.fixed.inds.Z)>0) {
    
    Lambdax$alpha[, 1:length(R.fixed.inds.Z)] <- parameters$alpha_lambdax_1w + 0.5
    Lambdax$beta[, 1:length(R.fixed.inds.Z)] <- parameters$beta_lambdax_1w
  }
  
  if (length(R.fixed.inds.X)>0) {
    
    Lambdaz$alpha[, 1:length(R.fixed.inds.X)] <- parameters$alpha_lambdaz_1w + 0.5
    Lambdaz$beta[, 1:length(R.fixed.inds.X)] <- parameters$beta_lambdaz_1w
  }
  
  # 2way: 
  if (length(R.inds.2way) > 0) {
    
    Lambdax$alpha[, R.inds.2way.X] <- parameters$alpha_lambdax_2w + 0.5
    Lambdax$beta[, R.inds.2way.X] <- parameters$beta_lambdax_2w
    
    Lambdaz$alpha[, R.inds.2way.Z] <- parameters$alpha_lambdaz_2w + 0.5
    Lambdaz$beta[, R.inds.2way.Z] <- parameters$beta_lambdaz_2w
  }
  
  # X -direction
  Ax <- list(mu = matrix(rnorm(Dx * Rx.eff, sd=init.coef*sqrt(parameters$var.Ax.tmp)), Dx, Rx.eff), 
             sigma = array(diag(1, Dx, Dx), c(Dx, Dx, Rx.eff)))
  
  diag.vals.tmp  <- c(rep(parameters$sigma.H.1way, Rx.fixed), rep(1, Rx.eff))
  Gx <- list(mu = matrix(rnorm(R * Nx, sd= init.coef*parameters$sigma_gx), R, Nx), 
             sigma = array(diag(diag.vals.tmp, R, R), c(R, R, Nx)))

  if (length(R.fixed.inds.X)>0) Gx$mu[R.fixed.inds.X,] <- parameters$fixed.cols
  
  
  # Z -direction
  Az <- list(mu = matrix(rnorm(Dz * Rz.eff, sd=init.coef*sqrt(parameters$var.Az.tmp)), Dz, Rz.eff), sigma = array(diag(1, Dz, Dz), c(Dz, Dz, Rz.eff)))
  
  # in case of fixed effects, the regression weights are initialized
  # and fixed to 0
  if (!is.null(parameters$sigma.H.fe)) Az$mu[1,] <- 0
  
  
  diag.vals.tmp <- rep(1, R)
  diag.vals.tmp[R.fixed.inds.Z] <- parameters$sigma.H.1way
  
  Gz <- list(mu = matrix(rnorm(R * Nz, sd= init.coef*parameters$sigma_gz), R, Nz), 
             sigma = array(diag(diag.vals.tmp, R, R), c(R, R, Nz)))
  if (!is.null(parameters$fixed.rows)) Gz$mu[R.fixed.inds.Z,] <- parameters$fixed.rows

  print('learning column-wise variance parameters')
  sigma_y <- list(mu = rep(1, ncol(Y)))
  tau <- list(alpha = rep(1, ncol(Y)), beta = rep(1, ncol(Y)))
  
  KxKx <- tcrossprod(Kx, Kx)
  KzKz <- tcrossprod(Kz, Kz)
  
  # collect traces
  traces <- list()
  
  traces$Gx <- list()
  traces$Gx$mu <- list()
  traces$Gx$sigma <- list()
  
  
  traces$Gz <- list()
  traces$Gz$mu <- list()
  traces$Gz$sigma <- list()
  
  
  traces$Ax <- list()
  traces$Az <- list()
  
  traces$Ax$mu <- list()
  #traces$Ax$sigma <- list()
  
  traces$Az$mu <- list()
  #traces$Az$sigma <- list()
  
  
  traces$Lambdax <- list()
  traces$Lambdax$beta <- list()
  
  traces$Lambdaz <- list()
  traces$Lambdaz$beta <- list()
  
  traces$sigma_y <- list()
  traces$sigma_y$mu <- list()
  
  traces$ll <- list()
  traces$ss <- list()
  
  
  dont.update <- c('Lambdax', 'Lambdaz', 'etax', 'etaz') #, 'Ax', 'Az', 'sigma_y'
  
  conv.monitor <- FALSE
  
  if (length(parameters$pars.to.fix)>0) dont.update <- c(dont.update, parameters$pars.to.fix)
  
  # for sparsity, update LambdaZ
  if (parameters$update.Lambda.z) dont.update <- setdiff(dont.update, 'Lambdaz')
  
  if (any(dont.update=='Hx')) {
    ind.tmp <- which(dont.update=='Hx')
    dont.update[ind.tmp] <- 'Gx'
  }
  
  if (any(dont.update=='Hz')) {
    ind.tmp <- which(dont.update=='Hz')
    dont.update[ind.tmp] <- 'Gz'
  }
  
  if (length(Az.inds.to.update)==0) dont.update <- c(dont.update, 'Az')
  
  if (length(Ax.inds.to.update)==0) dont.update <- c(dont.update, 'Ax')
  
  if (any(parameters$init.mode==c(1,2,3))) {
    
    if (length(R.inds.2way)>0) {
      Gx$mu[R.inds.2way, ] <- parameters$Hx.2way
      Gz$mu[R.inds.2way, ] <- parameters$Hz.2way  
    }
    
    if (!is.null(parameters$fixed.rows)) {
      
      Gx$mu[R.fixed.inds.Z, ] <- parameters$Hx.1way  
      print('init')
      
      Ax.1way.ind <- setdiff(1:Rx.eff, R.inds.2way.X)
      if (!is.null(parameters$Ax.1way.GBLUP)) Ax$mu[, Ax.1way.ind] <- parameters$Ax.1way.GBLUP
      
      
    }
    
    if (!is.null(parameters$fixed.cols)) {
    
      Az.1way.ind <- setdiff(1:Rz.eff, R.inds.2way.Z)
      
    }
    
    
    if (!is.null(parameters$fixed.cols)) Gz$mu[R.fixed.inds.X, ] <- parameters$Hz.1way
    
    Hx.mu.init <- Gx$mu
    
  } else {
    stop('specify initialisation mode!')
  }
  
  # thinning
  if (parameters$iteration>1500) thin <- 100 else if (parameters$iteration>199) thin <- 10 else thin <- 1

  for (iter in 1:parameters$iteration) {
    
    if (iter%%10== 0 || iter == 1) print(iter)
    
    
    
    # X-side: only updated when there are non-fixed components
    if (Rx.eff > 0) {
      
      if (any(dont.update=='Gx') && conv.monitor) {
        
        ##########################################
        #monitoring: stop after convergence
        #
        # relevant in the case when the latent
        # components are not updated
        ##########################################
        
       
        if (all(Ax$mu==0)) {
          Ax.mu.old <- Ax$mu
          Ax.mu.old[,] <- rnorm(prod(dim(Ax$mu)))
        } else Ax.mu.old <- Ax$mu 
      }
        
      # update Lambdax
      if (!any(dont.update == 'Lambdax' )) {
        for (s in 1:Rx.eff) {
          Lambdax$beta[,s] <- 1 / (1 / parameters$beta_lambdax + 0.5 * (Ax$mu[,s]^2 + diag(Ax$sigma[,,s])))
        }
      }
      
      
      
      if (!any(dont.update=='Ax')) {
        
        
        
        for (s in Ax.inds.to.update) {
        
          r.tmp <- R.eff.inds.X[s]
          
          
          Ax$sigma[,,s] <- chol2inv(chol(diag(as.vector(Lambdax$alpha[,s] * Lambdax$beta[,s]), Dx, Dx) + KxKx / sigma_gxs[r.tmp]^2))
          Ax$mu[,s] <- Ax$sigma[,,s] %*% (Kx %*% matrix(Gx$mu[r.tmp,], Nx, 1) / sigma_gxs[r.tmp]^2)
        }
      }
      
      if (any(dont.update=='Gx') && conv.monitor) {
        
        
        norm.ok <- (mean(abs(Ax.mu.old)/abs(Ax$mu)) -1 ) < 1e-6
        cor.ok <- (1-(cor(c(Ax$mu), c(Ax.mu.old)))) < 1e-6
        
        print(cor.ok)
        print(norm.ok)
        
        if (cor.ok && norm.ok) {
          
          dont.update <- c(dont.update, 'Ax')  
          print(paste0('Ax converged, iteration ', iter))
        }  
      }
      
      # update Gx
      if (!any(dont.update=='Gx')) {
        
        
        inv.sigma2.Gx.tmp <- diag(1/ rep(parameters$sigma_gx^2, R.2way), R.2way, R.2way)
        
        
        for (i in 1:Nx) {
          
          indices <- which(is.na(Y[i,]) == FALSE)
    
          ###################################################
          # 1-way-X components correspond to R.fixed.inds.Z #
          ###################################################
          
          tmp.precs <- 1/sigma_y$mu[indices]^2
          
          if (length(R.fixed.inds.Z)>0 && !parameters$f1w.X) {
            
            prec.X.1way <- sum(tmp.precs)+ 1/parameters$sigma_gx^2
           
            
            if (length(R.inds.2way)>0) {
              
              # when learning 1-way effects, take into account the current
              # distributions of the 2-way effects
              
              mu.tmp <- Y[i, indices, drop = FALSE] - crossprod(Gx$mu[R.inds.2way, i, drop=F], Gz$mu[R.inds.2way, indices, drop=F])
              
            } else mu.tmp <- Y[i, indices, drop = FALSE]
            
            # 1 way Z effects
            if (length(R.fixed.inds.X)>0) {
              mu.tmp <- mu.tmp - Gz$mu[R.fixed.inds.X, indices]
            }
            # terms from the likelihood + prior
            mu.tmp <- sum(mu.tmp*tmp.precs) + crossprod(Ax$mu[, Ax.1way.ind, drop=F], Kx[,i]) / parameters$sigma_gx^2
            
            # covariance
            Gx$sigma[R.fixed.inds.Z, R.fixed.inds.Z, i] <- 1/prec.X.1way
            
            # Cov * (prior + likelihood)
            Gx$mu[R.fixed.inds.Z,i] <- Gx$sigma[R.fixed.inds.Z, R.fixed.inds.Z, i] * mu.tmp  
          }
          
          
          ########################
          # the 2-way components #
          ########################
          
          if (R.2way>0) {
            
            
            Gx$sigma[R.inds.2way, R.inds.2way, i] <- chol2inv(chol(
              inv.sigma2.Gx.tmp  
              + (tcrossprod(Gz$mu[R.inds.2way, indices, drop = FALSE], Gz$mu[R.inds.2way, indices, drop = FALSE]) 
                 + apply(Gz$sigma[R.inds.2way, R.inds.2way, indices, drop = FALSE] 
                         * rep(tmp.precs, each=prod(dim(Gz$sigma[R.inds.2way, R.inds.2way,])[1:2])), 1:2, sum))
              ))
            
            
            if (length(R.fixed.inds.X)>0) {
              
              # fixed effects: that is the fixed columns in Gx
              Y.tmp <- (Y[i, indices, drop = FALSE] - Gz$mu[R.fixed.inds.X, indices])
              
              
            } else {
       
              Y.tmp <- Y[i, indices, drop = FALSE]
              #Gx$mu[R.inds.2way,i] <- (Gx$sigma[R.inds.2way, R.inds.2way, i] %*% (crossprod(Ax$mu[, R.inds.2way.X, drop=F], Kx[,i]) / parameters$sigma_gx^2 + tcrossprod(Gz$mu[R.inds.2way, indices, drop = FALSE], Y[i, indices, drop = FALSE] * tmp.precs)))
            } 
            
            if (length(R.fixed.inds.Z)>0) {
              
              Y.tmp <- Y.tmp - Gx$mu[R.fixed.inds.Z,i]
              
            }
            

            #if (iter >10) browser()            
            # print((Gx$sigma[R.inds.2way, R.inds.2way, i] %*% (
            #   crossprod(Ax$mu[, R.inds.2way.X, drop=F], Kx[,i]) / parameters$sigma_gx^2 
            #   + tcrossprod(Gz$mu[R.inds.2way, indices, drop = FALSE], Y.tmp*tmp.precs)
            # )))
            
            Gx$mu[R.inds.2way,i] <- (Gx$sigma[R.inds.2way, R.inds.2way, i] %*% (
              crossprod(Ax$mu[, R.inds.2way.X, drop=F], Kx[,i]) / parameters$sigma_gx^2 
              + tcrossprod(Gz$mu[R.inds.2way, indices, drop = FALSE], Y.tmp*tmp.precs)
              ))
            
            #Gx$mu[R.inds.2way,i] <- (Gx$sigma[R.inds.2way, R.inds.2way, i] %*% (crossprod(Ax$mu[, R.inds.2way.X, drop=F], Kx[,i]) / parameters$sigma_gx^2 + tcrossprod(Gz$mu[R.inds.2way, indices, drop = FALSE], (Y[i, indices, drop = FALSE] - Gz$mu[R.fixed.inds.X, indices])*tmp.precs)))
            
            
            
          }
          
          
        }
        
        
        
        if (any(is.na(Gx$mu))) stop('NAs in Hx')
      }
      
      
    }
    
    if (Rz.eff > 0) {
      # update Lambdaz
      if (!any(dont.update == 'Lambdaz' )) {
        for (s in 1:Rz.eff) {
          Lambdaz$beta[,s] <- 1 / (1 / parameters$beta_lambdaz + 0.5 * (Az$mu[,s]^2 + diag(Az$sigma[,,s])))
        }
      }
      
      
      if (!any(dont.update == 'Az' )) {
        
        
        
        for (s in Az.inds.to.update) {
          
    
          r.tmp <- R.eff.inds.Z[s]
          Az$sigma[,,s] <- chol2inv(chol(diag(as.vector(Lambdaz$alpha[,s] * Lambdaz$beta[,s]), Dz, Dz) + KzKz / sigma_gzs[r.tmp]^2))
          Az$mu[,s] <- Az$sigma[,,s] %*% (tcrossprod(Kz, Gz$mu[r.tmp,,drop = FALSE]) / sigma_gzs[r.tmp]^2)
          
          # r.tmp <- R.eff.inds.X[s]
          # 
          # Ax$sigma[,,s] <- chol2inv(chol(diag(as.vector(Lambdax$alpha[,s] * Lambdax$beta[,s]), Dx, Dx) + KxKx / sigma_gxs[r.tmp]^2))
          # Ax$mu[,s] <- Ax$sigma[,,s] %*% (Kx %*% matrix(Gx$mu[r.tmp,], Nx, 1) / sigma_gxs[r.tmp]^2)
          
        }
        
        if (any(Az$mu>1e3)) browser()
        
        
      }
      
      # update Gz
      if (!any(dont.update == 'Gz' )) {
        
        # 
        # if (!is.null(parameters$sigma.H.fe)) {
        #   diag.vals.tmp <- c(parameters$sigma.H.fe^2, rep(parameters$sigma_gz^2, (R.2way-1)))
        # } else {
        #   
        # }
        # 
        # 
        
        
        #inv.sigma2.Gz.tmp <- diag(1/ c( rep(parameters$sigma_gz^2, Rx.fixed), rep(parameters$sigma.H.1way^2, Rz.fixed), rep(parameters$sigma_gz^2, Rz.eff-Rx.fixed) ), R, R)
        
        
        
        for (j in 1:Nz) {
          
          indices <- which(is.na(Y[,j]) == FALSE)
          tmp.precs <- rep(1/sigma_y$mu[j]^2, length(indices))
          # print('jaska')
          # browser()
          if (length(R.fixed.inds.X)>0 && !parameters$f1w.Z) {
            
            if (is.null(parameters$sigma.H.fe)) sigma.to.use <- parameters$sigma_gz else sigma.to.use <- parameters$sigma.H.fe
            
            prec.Z.1way <- sum(tmp.precs)+ 1/sigma.to.use^2
            Gz$sigma[R.fixed.inds.X, R.fixed.inds.X, j] <- 1/prec.Z.1way
            
            mu.tmp <- Y[indices, j, drop = FALSE]
            
            if (length(R.inds.2way)>0) {
              
              mu.tmp <- mu.tmp  - crossprod(Gx$mu[R.inds.2way, indices, drop=F], Gz$mu[R.inds.2way, j, drop=F])
              
              #mu.tmp <- sum((Y[indices, j, drop = FALSE] - crossprod(Gx$mu[R.inds.2way, indices, drop=F], Gz$mu[R.inds.2way, j, drop=F]))*tmp.precs) + crossprod(Az$mu[, Az.1way.ind, drop=F], Kz[,j]) 
              
            } #else {
              
              #mu.tmp <- sum(Y[indices, j, drop = FALSE]*tmp.precs) + crossprod(Az$mu[, Az.1way.ind, drop=F], Kz[,j])
            #}
            
            if (length(R.fixed.inds.Z)>0) {
              mu.tmp <- mu.tmp - Gx$mu[R.fixed.inds.Z, indices]
            }
            
            # "likelihood + prior"
            mu.tmp <- sum(mu.tmp*tmp.precs) + crossprod(Az$mu[, Az.1way.ind, drop=F], Kz[,j])/(sigma.to.use)^2
            
            #mu.tmp <- mu.tmp / sigma.to.use^2
            
            Gz$mu[R.fixed.inds.X, j] <- Gz$sigma[R.fixed.inds.X, R.fixed.inds.X, j] * mu.tmp  
          }
          
          if (length(R.inds.2way)>0) {
            
            # print('jaska')
            # browser()
            
            diag.vals.tmp <- rep(parameters$sigma_gz^2, R.2way)
            inv.sigma2.Gz.tmp <- diag(1/ diag.vals.tmp, R.2way, R.2way)
            
            Gz$sigma[R.inds.2way, R.inds.2way,j] <- chol2inv(chol(inv.sigma2.Gz.tmp + (tcrossprod(Gx$mu[R.inds.2way, indices, drop = FALSE], Gx$mu[R.inds.2way, indices, drop = FALSE]) + apply(Gx$sigma[R.inds.2way, R.inds.2way, indices, drop = FALSE], 1:2, sum)) / sigma_y$mu[j]^2))
            
            
            if (length(R.fixed.inds.Z)>0) {
              
              Y.tmp <- (Y[indices, j, drop = FALSE] - Gx$mu[R.fixed.inds.Z, indices])

            } else Y.tmp <- Y[indices, j, drop = FALSE]
            
            if (length(R.fixed.inds.X)>0) {
              
              Y.tmp <- (Y.tmp - Gz$mu[R.fixed.inds.X, j])
            }
            
            # Cov * ("prior + likelihood")
            Gz$mu[R.inds.2way, j] <- (Gz$sigma[R.inds.2way, R.inds.2way, j] %*% (crossprod(Az$mu[, R.inds.2way.Z], Kz[,j]) / parameters$sigma_gz^2 + Gx$mu[R.inds.2way, indices, drop = FALSE] %*% Y.tmp / sigma_y$mu[j]^2))
              
            # } else {
            #   
            #   Gz$mu[R.inds.2way, j] <- (Gz$sigma[R.inds.2way, R.inds.2way, j] %*% (crossprod(Az$mu[, R.inds.2way.Z], Kz[,j]) / parameters$sigma_gz^2 + Gx$mu[R.inds.2way, indices, drop = FALSE] %*% Y[indices, j, drop = FALSE] / sigma_y$mu[j]^2))  
            #   
            # }
            
            #Gz$mu[R.inds.2way, j] <- (Gz$sigma[R.inds.2way, R.inds.2way, j] %*% (crossprod(Az$mu[, R.inds.2way.Z], Kz[,j]) / parameters$sigma_gz^2 + Gx$mu[R.inds.2way, indices, drop = FALSE] %*% Y.tmp / sigma_y$mu[j]^2) )
            
            
          }
          
          ########################
          # the 2-way components #
          ########################
          # 
          # if (R.2way>0) {
          #   
          #   
          #   Gx$sigma[R.inds.2way, R.inds.2way, i] <- chol2inv(chol(inv.sigma2.Gx.tmp  + (tcrossprod(Gz$mu[R.inds.2way, indices, drop = FALSE], Gz$mu[R.inds.2way, indices, drop = FALSE]) + apply(Gz$sigma[R.inds.2way, R.inds.2way, indices, drop = FALSE] * rep(tmp.precs, each=prod(dim(Gz$sigma[R.inds.2way, R.inds.2way,])[1:2])), 1:2, sum))))
          #   
          #   
          #   if (length(R.fixed.inds.X)>0) {
          #     
          #     # fixed effects: that is the fixed columns in Gx
          #     
          #     Gx$mu[R.inds.2way,i] <- (Gx$sigma[R.inds.2way, R.inds.2way, i] %*% (crossprod(Ax$mu[, R.inds.2way.X, drop=F], Kx[,i]) / parameters$sigma_gx^2 + tcrossprod(Gz$mu[R.inds.2way, indices, drop = FALSE], (Y[i, indices, drop = FALSE] - Gz$mu[R.fixed.inds.X, indices])*tmp.precs)))
          #     
          #   } else {
          #     
          #     # BUG? R.inds.2way.X?       
          #     Gx$mu[R.inds.2way,i] <- (Gx$sigma[R.inds.2way, R.inds.2way, i] %*% (crossprod(Ax$mu[, R.inds.2way.X, drop=F], Kx[,i]) / parameters$sigma_gx^2 + tcrossprod(Gz$mu[R.inds.2way, indices, drop = FALSE], Y[i, indices, drop = FALSE] * tmp.precs)))
          #   }  
          # }
        }
        
        if (any(Gz$mu>1e5)) browser()
        
        # print('Gz')
        # browser()
          
      }
        
    
      }
    
    if (!any(dont.update == 'sigma_y' ) && iter > 1) {
      
      # update tau and sigma_y_list    
      
      if (varargin$shared.sigma.y) {
        
        stop('not implemented yet')
        
        
        # shared residual variance parameter
        tau$alpha <- rep(parameters$tau_alpha + sum(!is.na(Y))/2, ncol(Y))
        sum.term <- rep(NA, ncol(Y))
        to.sum <- Y
        to.sum[!is.na(to.sum)] <- 0
        for (j in 1:ncol(Y)) {
          for (i in 1:nrow(Y)) {
            if (!is.na(Y[i,j])) {
              to.sum[i,j] <- Y[i,j]^2 -2*Y[i,j]*sum(Gx$mu[,i]*Gz$mu[,j]) 
              + sum( diag( (tcrossprod(Gx$mu[,i]) + Gx$sigma[,,i]) %*% (tcrossprod(Gz$mu[,j]) + Gz$sigma[,,j]) ) )  
            }
          }
        }
        
        # takes 0.37 s/iteration, maybe worth optimization
        #Y^2-2*Y*t(Hx$mu)%*%Hz$mu+
        tau$beta <- rep(parameters$tau_beta + 0.5*sum(to.sum, na.rm=T), ncol(Y))
        sigma_y$mu <- sqrt(1/(tau$alpha/tau$beta))
        
        if (any(is.na(sigma_y$mu))) stop('NAs in sigma_y')  
        
      } else {
        
        # column-specific variance parameters
        tau$alpha <- parameters$tau_alpha + colSums(!is.na(Y))/2
        sum.term <- rep(NA, ncol(Y))
        to.sum <- Y
        to.sum[!is.na(to.sum)] <- 0
        
        s2 <- a22abb2 <- a <- b <- to.sum
        for (j in 1:ncol(Y)) {
          for (i in 1:nrow(Y)) {
            if (!is.na(Y[i,j])) {
              
              if (length(R.inds.2way)>0 && length(R.fixed.inds.X)>0 && length(R.fixed.inds.Z)>0) {
                to.sum[i,j] <- (      
                  Y[i,j]^2 
                  -2*Y[i,j]*sum(Gx$mu[R.inds.2way,i]*Gz$mu[R.inds.2way,j]) 
                  + sum( diag( (tcrossprod(Gx$mu[R.inds.2way,i]) 
                                + Gx$sigma[R.inds.2way,R.inds.2way,i]) %*% (tcrossprod(Gz$mu[R.inds.2way,j])
                                                                            + Gz$sigma[R.inds.2way,R.inds.2way,j] ) ) )
                  + Gx$mu[R.fixed.inds.Z ,i]^2 
                  + Gx$sigma[R.fixed.inds.Z , R.fixed.inds.Z, i]
                  + Gz$mu[R.fixed.inds.X ,j]^2 
                  + Gz$sigma[R.fixed.inds.X , R.fixed.inds.X, j]
                  + 2*Gx$mu[R.fixed.inds.Z ,i] * Gz$mu[R.fixed.inds.X ,j]
                  + 2* (Gx$mu[R.fixed.inds.Z ,i] + Gz$mu[R.fixed.inds.X ,j]) * sum(Gx$mu[R.inds.2way,i]*Gz$mu[R.inds.2way,j])
                  - 2* Y[i,j] * (Gx$mu[R.fixed.inds.Z ,i] + Gz$mu[R.fixed.inds.X ,j])) 
                if (to.sum[i,j]<0 && iter>1) browser()
                
              } else if (length(R.inds.2way)>0 && length(R.fixed.inds.X)>0 && length(R.fixed.inds.Z)==0) {
                
                to.sum[i,j] <- (      
                  Y[i,j]^2 
                  -2*Y[i,j]*sum(Gx$mu[R.inds.2way,i]*Gz$mu[R.inds.2way,j]) 
                  + sum( diag( (tcrossprod(Gx$mu[R.inds.2way,i]) 
                                + Gx$sigma[R.inds.2way,R.inds.2way,i]) %*% (tcrossprod(Gz$mu[R.inds.2way,j])
                                                                            + Gz$sigma[R.inds.2way,R.inds.2way,j] ) ) )
                  + Gx$mu[R.fixed.inds.Z ,i]^2 
                  + Gx$sigma[R.fixed.inds.Z , R.fixed.inds.Z, i]
                  + Gz$mu[R.fixed.inds.X ,j]^2 
                  + Gz$sigma[R.fixed.inds.X , R.fixed.inds.X, j]
                  + 2*Gx$mu[R.fixed.inds.Z ,i] * Gz$mu[R.fixed.inds.X ,j]
                  + 2* (Gx$mu[R.fixed.inds.Z ,i] + Gz$mu[R.fixed.inds.X ,j]) * sum(Gx$mu[R.inds.2way,i]*Gz$mu[R.inds.2way,j])
                  - 2* Y[i,j] * (Gx$mu[R.fixed.inds.Z ,i] + Gz$mu[R.fixed.inds.X ,j])) 
                if (to.sum[i,j]<0 && iter>1) browser()
                
                
              } else if (length(R.inds.2way)>0 && length(R.fixed.inds.X)==0 && length(R.fixed.inds.Z)>0) {
                
              } else if (length(R.inds.2way)>0 && length(R.fixed.inds.X) == 0 && length(R.fixed.inds.Z)==0) {
                
                to.sum[i,j] <- (      
                  Y[i,j]^2 
                  -2*Y[i,j]*sum(Gx$mu[R.inds.2way,i]*Gz$mu[R.inds.2way,j]) 
                  + sum( diag( (tcrossprod(Gx$mu[R.inds.2way,i]) 
                                + Gx$sigma[R.inds.2way,R.inds.2way,i]) %*% (tcrossprod(Gz$mu[R.inds.2way,j])
                                                                            + Gz$sigma[R.inds.2way,R.inds.2way,j] ) ) )
                  ) 
                if (to.sum[i,j]<0 && iter>1) browser()
                
              } else if (length(R.inds.2way)>0 && length(R.fixed.inds.X)>0 && length(R.fixed.inds.Z)==0) {
                
                stop('not done yet')
                
              } else if (length(R.inds.2way)==0 && length(R.fixed.inds.X)>0 && length(R.fixed.inds.Z)==0) {
                
                stop('not done yet 2')
                
              } else if (length(R.inds.2way)==0 && length(R.fixed.inds.X)==0 && length(R.fixed.inds.Z)>0) {
                
                to.sum[i,j] <- (      
                  Y[i,j]^2 
                  -2*Y[i,j]*sum(Gx$mu[R.inds.2way,i]*Gz$mu[R.inds.2way,j]) 
                  + sum( diag( (tcrossprod(Gx$mu[R.inds.2way,i]) 
                                + Gx$sigma[R.inds.2way,R.inds.2way,i]) %*% (tcrossprod(Gz$mu[R.inds.2way,j])
                                                                            + Gz$sigma[R.inds.2way,R.inds.2way,j] ) ) )
                  + Gx$mu[R.fixed.inds.Z ,i]^2 
                  + Gx$sigma[R.fixed.inds.Z , R.fixed.inds.Z, i]
                  + 2* Gx$mu[R.fixed.inds.Z ,i] * sum(Gx$mu[R.inds.2way,i]*Gz$mu[R.inds.2way,j])
                  - 2* Y[i,j] * Gx$mu[R.fixed.inds.Z ,i]) 
                if (to.sum[i,j]<0 && iter>1) browser()
                
              } else if (length(R.inds.2way)==0 && length(R.fixed.inds.X)>0 && length(R.fixed.inds.Z)>0) {
                
                
                
                to.sum[i,j] <- (      
                  Y[i,j]^2 
                  + Gx$mu[R.fixed.inds.Z ,i]^2 
                  + Gx$sigma[R.fixed.inds.Z , R.fixed.inds.Z, i]
                  + Gz$mu[R.fixed.inds.X ,j]^2 
                  + Gz$sigma[R.fixed.inds.X , R.fixed.inds.X, j]
                  + 2*Gx$mu[R.fixed.inds.Z ,i] * Gz$mu[R.fixed.inds.X ,j]
                  - 2* Y[i,j] * (Gx$mu[R.fixed.inds.Z ,i] + Gz$mu[R.fixed.inds.X ,j])) 
                if (to.sum[i,j]<0 && iter>1) browser()
                
              } else if (length(R.inds.2way)>0 && length(R.fixed.inds.X)>0 && length(R.fixed.inds.Z)==0) {
                
                to.sum[i,j] <- (      
                  Y[i,j]^2 
                  -2*Y[i,j]*sum(Gx$mu[R.inds.2way,i]*Gz$mu[R.inds.2way,j]) 
                  + sum( diag( (tcrossprod(Gx$mu[R.inds.2way,i]) 
                                + Gx$sigma[R.inds.2way,R.inds.2way,i]) %*% (tcrossprod(Gz$mu[R.inds.2way,j])
                                                                            + Gz$sigma[R.inds.2way,R.inds.2way,j] ) ) )
                  + Gx$mu[R.fixed.inds.Z ,i]^2 
                  + Gx$sigma[R.fixed.inds.Z , R.fixed.inds.Z, i]
                  + Gz$mu[R.fixed.inds.X ,j]^2 
                  + Gz$sigma[R.fixed.inds.X , R.fixed.inds.X, j]
                  + 2*Gx$mu[R.fixed.inds.Z ,i] * Gz$mu[R.fixed.inds.X ,j]
                  + 2* (Gx$mu[R.fixed.inds.Z ,i] + Gz$mu[R.fixed.inds.X ,j]) * sum(Gx$mu[R.inds.2way,i]*Gz$mu[R.inds.2way,j])
                  - 2* Y[i,j] * (Gx$mu[R.fixed.inds.Z ,i] + Gz$mu[R.fixed.inds.X ,j])) 
                if (to.sum[i,j]<0 && iter>1) browser()
                
                
              } else stop('unknown combination')
              
              # to.sum[i,j]
              # a22abb2[i,j]+s2[i,j]
            }
          }
        }
        
        # takes 0.37 s/iteration, maybe worth optimization
        #Y^2-2*Y*t(Hx$mu)%*%Hz$mu+
        tau$beta <- parameters$tau_beta + 0.5*colSums(to.sum, na.rm=T)
        sigma_y$mu <- sqrt(1/(tau$alpha/tau$beta))
        
        
        if (any(is.na(sigma_y$mu))) stop('NAs in sigma_y')  
      }
      

      
      
    }
    
    # if (iter == 1) {
    #   print('remove me')
    #   
    #   if (Rx.eff > 0) {
    #     if (all(dont.update!='Gx')) Gx$mu[R.eff.inds.X,] <- rnorm(length(Gx$mu[R.eff.inds.X,]), sd=0.00001)
    #     Ax$mu[, Ax.inds.to.update] <- rnorm(length(Ax$mu[, Ax.inds.to.update]), sd=0.00001)  
    #   }
    #   
    #   if (Rz.eff > 0) {
    #     Gz$mu[R.eff.inds.Z,] <- rnorm(length(Gz$mu[R.eff.inds.Z,]), sd=0.00001)
    #     Az$mu[,Az.inds.to.update] <- rnorm(length(Az$mu[,Az.inds.to.update]), sd=0.00001)
    #   }
    # }
    
    # add to trace
    if (any(iter == seq(from=1,to=10, by=1)) || iter%%thin==0) {
      #for (i in setdiff(names(traces), 'sigma_h')) {
      for (i in setdiff(names(traces), c('sigma_hx', 'sigma_hz'))) {
        
        #for (i in names(traces)) {
        if (i =='ll') {
          
          current.length <- length(traces[[i]])
          tmp.ll <- dnorm((Y-t(Gx$mu)%*%Gz$mu), sd=matrix(sigma_y$mu, nrow=nrow(Y), ncol=ncol(Y), byrow=T), log=T)
          tmp.ll[is.na(tmp.ll)] <- 0
          traces[[i]][[current.length+1]] <- sum(tmp.ll)
          
        } else if (i == 'ss') {
          
          current.length <- length(traces[[i]])
          tmp.ss <- colMeans( (Y-t(Gx$mu)%*%Gz$mu)^2, na.rm=T)
          tmp.ss[is.na(tmp.ss)] <- 0
          
          traces[[i]][[current.length+1]] <- tmp.ss
          
        } else {
          for (j in names(traces[[i]])) {
            
            current.length <- length(traces[[i]][[j]])
            
            if ((i =='Ax' || i =='Az') && j == 'sigma') {
              
              eval(parse(text=paste0('tmp.tr.obj <- ', i, '$', j)))
              tmp.tr.obj <- apply(tmp.tr.obj, 3, diag)
              traces[[i]][[j]][[current.length+1]] <- tmp.tr.obj
              
            } else {
              
              traces[[i]][[j]][[current.length+1]] <- eval(parse(text=paste0(i, '$', j)))
            }
          }  
        }  
      } 
      #traces[['sigma_hz']][['val']][[current.length+1]] <- parameters$sigma_hz
    }
    
    
    
    if (iter %%100==0) print(gc())
    
  }
  
  # for prediction, the R.eff.inds.X is assumed to contain all non-fixed components (fixed.rows and fixed.cols)
  if (parameters$f1w.X) {
   
    parameters$R.eff.inds.X <- sort(c(R.eff.inds.X, R.fixed.inds.Z))
    
  } 
  
  
  if (parameters$f1w.Z) {
    parameters$R.eff.inds.Z <- sort(c(R.eff.inds.Z, R.fixed.inds.X))
  }
  
  
  # For plotting, the parameters Gx/Gz correspond to old Hx/Hz
  # rename for convenience
  names(traces)[which(names(traces)=='Gz')] <- 'Hz'
  names(traces)[which(names(traces)=='Gx')] <- 'Hx'
  
  # remove traces for parameters not updated
  if (length(dont.update)>0) traces[setdiff(dont.update, 'Az')] <- NULL
  
  
  
  state <- list(Lambdax = Lambdax, Ax = Ax, Lambdaz = Lambdaz, Az = Az, ez = varargin$inits$ez, Hx = Gx, Hz = Gz, parameters = parameters, traces=traces, Hx.mu.init = Hx.mu.init)
}
