# Mehmet Gonen (mehmet.gonen@gmail.com)
kbmf1k1k_semisupervised_regression_variational_train <- function(Kx, Kz, Y, parameters) {
  
  set.seed(parameters$seed)
  
  if (any(colSums(is.na(Y))==0) || any(rowSums(is.na(Y))==0)) print('missing value handling assumes missingness everywhere, problems can arise with indices!')
  
  if (is.null(parameters$fix.ez) && !is.null(parameters$inits$ez$mu)) parameters$fix.ez <- FALSE
  
  
  
  
  Dx <- dim(Kx)[1]
  Nx <- dim(Kx)[2]
  Px <- 1
  Dz <- dim(Kz)[1]
  Nz <- dim(Kz)[2]
  Pz <- 1
  R <- parameters$R
  
  ##############################################
  # indexes of 1-way effects and 2-way effects #
  ##############################################
  
  
  # effective ranks after taking into account fixed rows/cols
  if (is.null(parameters$fixed.cols)) {
    
    Rx.fixed <- 0
  } else {
    Rx.fixed <- ncol(parameters$fixed.cols)
  }
  
  if (is.null(parameters$fixed.rows)) {
    
    Rz.fixed <- 0
  } else {
    
    Rz.fixed <- nrow(parameters$fixed.rows)
  }
  
  print('TODO: if only 1 direction has fixed rows/cols, make sure it is the first one. Write a test for that.')
  
  Rx.eff <- R - Rx.fixed
  
  # the indexes that will be updated
  if (Rx.eff > 0) {
    R.eff.inds.X <- (Rx.fixed+1):R
    R.fixed.inds.X <- setdiff(1:R, R.eff.inds.X)  
  } else {
    R.eff.inds.X <- NULL
    R.fixed.inds.X <- 1:R
  }
  
  
  Rz.eff <- R - Rz.fixed
  
  # the indexes that will be updated
  if (Rx.fixed > 0) R.eff.inds.Z <- c(1:Rx.fixed) else R.eff.inds.Z <- c()
  
  R.fixed.total <- Rx.fixed+Rz.fixed
  
  if (R < R.fixed.total) stop('not enough components for all 1-way effects')
  
  if (R > R.fixed.total) {
    R.eff.inds.Z <- c(R.eff.inds.Z, (R.fixed.total+1):R)
  } 
  R.fixed.inds.Z <- setdiff(1:R, R.eff.inds.Z)
  
  # 2way: 
  R.inds.2way <- intersect(R.eff.inds.X, R.eff.inds.Z)
  R.inds.2way.X <- tail(1:Rx.eff, length(R.inds.2way))
  R.inds.2way.Z <- tail(1:Rz.eff, length(R.inds.2way))
  
  
  parameters$R.eff.inds.X <- R.eff.inds.X
  parameters$R.eff.inds.Z <- R.eff.inds.Z
  
  
  init.coef <- 0
  
  Lambdax <- list(alpha = matrix(NA, Dx, Rx.eff), beta = matrix(NA, Dx, Rx.eff))
  Lambdaz <- list(alpha = matrix(NA, Dz, Rz.eff), beta = matrix(NA, Dz, Rz.eff))
  
  
  if (length(R.fixed.inds.Z)>0) {
    
    Lambdax$alpha[, 1:length(R.fixed.inds.Z)] <- parameters$alpha_lambdax_1w + 0.5
    Lambdax$beta[, 1:length(R.fixed.inds.Z)] <- parameters$beta_lambdax_1w
  }
  
  if (length(R.fixed.inds.X)>0) {
    
    Lambdaz$alpha[, 1:length(R.fixed.inds.X)] <- parameters$alpha_lambdaz_1w + 0.5
    Lambdaz$beta[, 1:length(R.fixed.inds.X)] <- parameters$beta_lambdaz_1w
  }
  
  # 2way: 
  if (length(R.inds.2way) > 0) {
    
    #X:
    
    Lambdax$alpha[, R.inds.2way.X] <- parameters$alpha_lambdax_2w + 0.5
    Lambdax$beta[, R.inds.2way.X] <- parameters$beta_lambdax_2w
    
    #Z:
    
    Lambdaz$alpha[, R.inds.2way.Z] <- parameters$alpha_lambdaz_2w + 0.5
    Lambdaz$beta[, R.inds.2way.Z] <- parameters$beta_lambdaz_2w
  }
  
  
  
  
  
  Ax <- list(mu = matrix(rnorm(Dx * Rx.eff, sd=init.coef*sqrt(parameters$var.Ax.tmp)), Dx, Rx.eff), sigma = array(diag(1, Dx, Dx), c(Dx, Dx, Rx.eff)))
  
  
  Hx <- list(mu = matrix(rnorm(R * Nx, sd= init.coef*parameters$sigma_hx), R, Nx), sigma = array(diag(c(rep(parameters$sigma.H.1way, Rx.fixed), rep(1, Rx.eff)), R, R), c(R, R, Nx)))
  
  
  if (length(R.fixed.inds.X)>0) Hx$mu[R.fixed.inds.X,] <- parameters$fixed.cols
  
  print('TODO: check simulations, do they still work?')
  
  
  
  
  Az <- list(mu = matrix(rnorm(Dz * Rz.eff, sd=init.coef*sqrt(parameters$var.Az.tmp)), Dz, Rz.eff), sigma = array(diag(1, Dz, Dz), c(Dz, Dz, Rz.eff)))
  #Gz <- list(mu = array(rnorm(Rz.eff * Nz * Pz), c(Rz.eff, Nz, Pz)), sigma = array(diag(1, Rz.eff, Rz.eff), c(Rz.eff, Rz.eff, Pz)))
  
  # no changes here
  etaz <- list(alpha = matrix(parameters$alpha_eta + 0.5, Pz, 1), beta = matrix(parameters$beta_eta, Pz, 1))
  ez <- list(mu = matrix(1, Pz, 1), sigma = diag(1, Pz, Pz))
  
  # R remains R
  # fix group-indicators; for those the variance is different
  # mistake in the c(parameters$sigma.H.1way...), doesn't affect anything
  
  
  diag.vals.tmp <- rep(1, R)
  diag.vals.tmp[R.fixed.inds.Z] <- parameters$sigma.H.1way
  Hz <- list(mu = matrix(rnorm(R * Nz, sd= init.coef*parameters$sigma_hz), R, Nz), sigma = array(diag(diag.vals.tmp, nrow=length(diag.vals.tmp), ncol=length(diag.vals.tmp)), c(R, R, Nz)))
  Hz$mu[R.fixed.inds.Z,] <- parameters$fixed.rows
  
  print('learning column-wise variance parameters')
  sigma_y <- list(mu = rep(1, ncol(Y)))
  tau <- list(alpha = rep(1, ncol(Y)), beta = rep(1, ncol(Y)))

  # print('mean(apply(Hz$mu,1,var))')
  # print(sum(apply(Hz$mu,1,var)))
  # print('parameters$var.env')
  # print(parameters$var.env)
  # 
  # print('mean(apply(Hx$mu,1,var))')
  # print(sum(apply(Hx$mu,1,var)))
  # print('parameters$var.geno')
  # print(parameters$var.geno)
  
  
  KxKx <- tcrossprod(Kx, Kx)
  Kx <- matrix(Kx, Dx, Nx * Px)
  
  KzKz <- tcrossprod(Kz, Kz)
  Kz <- matrix(Kz, Dz, Nz * Pz)
  
  # initializations
  if (!is.null(parameters$inits$ez$mu)) {
    ez$mu <- parameters$inits$ez$mu
  }
  
  
  
  # collect traces
  traces <- list()
  
  traces$ez <- list()
  traces$ez$mu <- list()
  traces$ez$sigma <- list()
  
  
  # traces$Gz <- list()
  # traces$Gz$mu <- list()
  # traces$Gz$sigma <- list()
  
  traces$Hx <- list()
  traces$Hz <- list()
  
  traces$Hx$mu <- list()
  traces$Hz$mu <- list()
  
  traces$Hx$sigma <- list()
  traces$Hz$sigma <- list()
  
  traces$Ax <- list()
  traces$Az <- list()
  
  traces$Ax$mu <- list()
  traces$Ax$sigma <- list()
  
  traces$Az$mu <- list()
  traces$Az$sigma <- list()
  
  
  #traces$sigma_hx <- list()
  #traces$sigma_hx$val <- list()
  
  traces$sigma_hz <- list()
  traces$sigma_hz$val <- list()
  
  traces$etaz <- list()
  traces$etaz$alpha <- list()
  traces$etaz$beta <- list()
  
  traces$Lambdax <- list()
  traces$Lambdax$beta <- list()
  
  traces$Lambdaz <- list()
  traces$Lambdaz$beta <- list()
  
  traces$sigma_y <- list()
  traces$sigma_y$mu <- list()
  
  
  
  dont.update <- c('Lambdax', 'Lambdaz', 'etax', 'etaz') #, 'Ax', 'Az', 'sigma_y'
  
  
  traces$ll <- list()
  
  # initializations
  if (!is.null(parameters$inits$ez$mu)) {
    ez$mu <- parameters$inits$ez$mu
  }
  
  # informative initialization
  library('softImpute')
  
  
  if (any(parameters$init.mode==c(1,2))) {
    
    if (length(R.inds.2way)>0) {
      Hx$mu[R.inds.2way, ] <- parameters$Hx.2way
      Hz$mu[R.inds.2way, ] <- parameters$Hz.2way  
    }
    
    if (!is.null(parameters$fixed.rows)) Hx$mu[R.fixed.inds.Z, ] <- parameters$Hx.1way  
    
    if (!is.null(parameters$fixed.cols)) Hz$mu[R.fixed.inds.X, ] <- parameters$Hz.1way
    
    Hx.mu.init <- Hx$mu
    
  } else {
    stop('specify initialisation mode!')
  }
  
  if (!exists('Hx.mu.init')) Hx.mu.init <- NULL
  

  
  
  if (parameters$iteration>1500) thin <- 150 else if (parameters$iteration>51) thin <- 50 else thin <- 1
  #if (parameters$iteration>49) thin <- 50 else thin <- 1
  
  
  for (iter in 1:parameters$iteration) {
    
    if (iter > 4) {
      #dont.update <- unique(c(dont.update, c('Az', 'Gz', 'ez', 'Hz', 'Lambdax')))
      #dont.update <- unique(c(dont.update, c('ez')))
      #ez$mu <- ez$mu/max(ez$mu)
      
    }
    
    if (iter%%10== 0 || iter == 1) print(iter)
    

    
    # update Lambdax
    # R -> Rx.eff
    if (Rx.eff > 0) {
      
      if (!any(dont.update == 'Lambdax' )) {
        for (s in 1:Rx.eff) {
          Lambdax$beta[,s] <- 1 / (1 / parameters$beta_lambdax + 0.5 * (Ax$mu[,s]^2 + diag(Ax$sigma[,,s])))
        }
      }
      
      # update Ax
      if (!any(dont.update == 'Ax' )) {
        # R -> Rx.eff
        for (s in 1:Rx.eff) {
          r.tmp <- R.eff.inds.X[s]
          Ax$sigma[,,s] <- chol2inv(chol(diag(as.vector(Lambdax$alpha[,s] * Lambdax$beta[,s]), Dx, Dx) + KxKx / parameters$sigma_hx^2))
          Ax$mu[,s] <- Ax$sigma[,,s] %*% (Kx %*% matrix(Hx$mu[r.tmp,], Nx * Px, 1) / parameters$sigma_hx^2)
        }
      }
      
     if (any(is.na(Ax$mu))) stop('NAs in Ax')
      
      # update Hx
      inv.sigma2.Hx.tmp <- diag(1/ c( rep(parameters$sigma.H.1way^2, Rx.fixed), rep(parameters$sigma_hx^2, Rx.eff) ), R, R)
      
     
      
      for (i in 1:Nx) {
        
        indices <- which(is.na(Y[i,]) == FALSE)
        
        tmp.precs <- 1/sigma_y$mu[indices]^2
        
        Hx$sigma[,,i] <- chol2inv(chol(inv.sigma2.Hx.tmp + (tcrossprod(Hz$mu[,indices, drop = FALSE], t(t(Hz$mu[,indices, drop = FALSE])*tmp.precs) ) + apply(Hz$sigma[,,indices, drop = FALSE] * rep(tmp.precs, each=prod(dim(Hz$sigma)[1:2])), 1:2, sum) ) ))
        
        Hx$mu[R.eff.inds.X, i] <- (tcrossprod(Hz$mu[,indices, drop = FALSE], t(t(Y[i, indices, drop = FALSE])*tmp.precs) ) )[R.eff.inds.X, ,drop=F] + crossprod(Ax$mu, Kx[,i]) / parameters$sigma_hx^2 
        
        Hx$mu[R.eff.inds.X, i] <- (Hx$sigma[,,i] %*% Hx$mu[, i, drop=F])[R.eff.inds.X, ,drop=F]
      }
    }
    
    if (any(is.na(Hx$mu))) stop('NAs in Hx')
    
    if (FALSE) {
      
      # update Lambdaz
      if (!any(dont.update == 'Lambdaz' )) {
        for (s in 1:Rz.eff) {
          Lambdaz$beta[,s] <- 1 / (1 / parameters$beta_lambdaz + 0.5 * (Az$mu[,s]^2 + diag(Az$sigma[,,s])))
        }
      }
      # update Az
      if (!any(dont.update == 'Az' )) {
        for (s in 1:Rz.eff) {
          Az$sigma[,,s] <- chol2inv(chol(diag(as.vector(Lambdaz$alpha[,s] * Lambdaz$beta[,s]), Dz, Dz) + KzKz / parameters$sigma_gz^2))
          Az$mu[,s] <- Az$sigma[,,s] %*% (Kz %*% matrix(Gz$mu[s,,], Nz * Pz, 1) / parameters$sigma_gz^2)
        }
      }
      
      if (any(is.na(Az$mu))) stop('NAs in Az')
      
      # update Gz
      if (!any(dont.update == 'Gz' )) {
        for (n in 1:Pz) {
          
          Gz$sigma[,,n] <- chol2inv(chol(diag(1 / parameters$sigma_gz^2, Rz.eff, Rz.eff) + diag((ez$mu[n] * ez$mu[n] + ez$sigma[n, n]) / parameters$sigma_hz^2, Rz.eff, Rz.eff)))
          
          
          Gz$mu[,,n] <- crossprod(Az$mu, Kz[,((n - 1) * Nz + 1):(n * Nz)]) / parameters$sigma_gz^2 + ez$mu[n] * Hz$mu[R.eff.inds.Z, ,drop=F] / parameters$sigma_hz^2
          for (p in setdiff(1:Pz, n)) {
            Gz$mu[,,n] <- Gz$mu[,,n] - (ez$mu[n] * ez$mu[p] + ez$sigma[n, p]) * Gz$mu[,,p] / parameters$sigma_hz^2
          }
          
          if (Rz.eff == 1) {
            
            Gz$mu[,,n] <- matrix(Gz$sigma[,,n], Rz.eff, Rz.eff) %*% matrix(Gz$mu[,,n], Rz.eff, Nz)
          } else {
            
            Gz$mu[,,n] <- Gz$sigma[,,n] %*% Gz$mu[,,n]  
          }
          
        }
      }
      
      if (any(is.na(Gz$mu))) stop('NAs in Gz')
      
      
      # no changes 
      # update etaz
      
      # if (!any(dont.update == 'etaz' )) {
      #   etaz$beta <- 1 / (1 / parameters$beta_eta + 0.5 * (ez$mu^2 + diag(ez$sigma)))
      # }
      # if (!any(dont.update == 'ez' )) {
      #   # update ez
      #   ez$sigma <- diag(as.vector(etaz$alpha * etaz$beta))
      #   for (n in 1:Pz) {
      #     for (p in 1:Pz) {
      #       ez$sigma[n, p] <- ez$sigma[n, p] + (sum(Gz$mu[,,n] * Gz$mu[,,p]) + (n == p) * Nz * sum(diag(Gz$sigma[,,n]))) / parameters$sigma_hz^2
      #     }
      #   }
      #   ez$sigma <- chol2inv(chol(ez$sigma))
      #   # no changes ends
      #   for (n in sample(1:Pz)) {
      #     ez$mu[n] <- sum(Gz$mu[,,n] * Hz$mu[R.eff.inds.Z, ,drop=F]) / parameters$sigma_hz^2
      #   }
      #   
      #   if (!parameters$fix.ez) {
      #     
      #     # find a normalizer for the ez$sigma so that
      #     # the sum of kernel weights remains fixed
      #     tmp.ez <-  ez$sigma %*% ez$mu
      #     coef <- sqrt(sum(parameters$inits$ez$mu^2) / sum(tmp.ez^2))
      #     ez$sigma <- ez$sigma * coef
      #     
      #   }
      #   ez$mu <- ez$sigma %*% ez$mu
      # }
      
      # update Hz
      inv.sigma2.Hz.tmp <- diag(1/ c( rep(parameters$sigma_hz^2, Rx.fixed), rep(parameters$sigma.H.1way^2, Rz.fixed), rep(parameters$sigma_hz^2, Rz.eff-Rx.fixed) ), R, R)
      
      
      if (!any(dont.update == 'Hz' )) {
        for (j in 1:Nz) {
          
          
          indices <- which(is.na(Y[,j]) == FALSE)
          Hz$sigma[,,j] <- chol2inv(chol(inv.sigma2.Hz.tmp + (tcrossprod(Hx$mu[,indices, drop = FALSE], Hx$mu[,indices, drop = FALSE]) + apply(Hx$sigma[,,indices, drop = FALSE], 1:2, sum)) / sigma_y$mu[j]^2))
          
          Hz$mu[R.eff.inds.Z, j] <- (Hx$mu[, indices, drop = FALSE] %*% Y[indices, j, drop = FALSE] / sigma_y$mu[j]^2)[R.eff.inds.Z, ,drop=F]
          
          for (n in 1:Pz) {
            Hz$mu[R.eff.inds.Z, j] <- Hz$mu[R.eff.inds.Z, j, drop=F] + ez$mu[n] * Gz$mu[,j,n] / parameters$sigma_hz^2
          }
          Hz$mu[R.eff.inds.Z, j] <- (Hz$sigma[,,j] %*% Hz$mu[,j])[R.eff.inds.Z, ,drop=F]
          
        }
      }
      
      if (any(is.na(Hz$mu))) stop('NAs in Hz')
    }
    
    #########################
    
    if (Rz.eff > 0) {
      
      if (!any(dont.update == 'Lambdaz' )) {
        for (s in 1:Rz.eff) {
          Lambdaz$beta[,s] <- 1 / (1 / parameters$beta_lambdaz + 0.5 * (Az$mu[,s]^2 + diag(Az$sigma[,,s])))
        }
      }
      
      # update Az
      if (!any(dont.update == 'Az' )) {
        # R -> Rz.eff
        for (s in 1:Rz.eff) {
          r.tmp <- R.eff.inds.Z[s]
          Az$sigma[,,s] <- chol2inv(chol(diag(as.vector(Lambdaz$alpha[,s] * Lambdaz$beta[,s]), Dz, Dz) + KzKz / parameters$sigma_hz^2))
          Az$mu[,s] <- Az$sigma[,,s] %*% (Kz %*% matrix(Hz$mu[r.tmp,], Nz * Pz, 1) / parameters$sigma_hz^2)
        }
      }
      
      if (any(is.na(Az$mu))) stop('NAs in Az')
      
      # update Hz
      inv.sigma2.Hz.tmp <- diag(1/ c( rep(parameters$sigma.H.1way^2, Rz.fixed), rep(parameters$sigma_hz^2, Rz.eff) ), R, R)
      
      
      
      for (i in 1:Nz) {
        
        indices <- which(is.na(Y[i,]) == FALSE)
        
        tmp.precs <- 1/sigma_y$mu[indices]^2
        
        Hz$sigma[,,i] <- chol2inv(chol(inv.sigma2.Hz.tmp + (tcrossprod(Hz$mu[,indices, drop = FALSE], t(t(Hz$mu[,indices, drop = FALSE])*tmp.precs) ) + apply(Hz$sigma[,,indices, drop = FALSE] * rep(tmp.precs, each=prod(dim(Hz$sigma)[1:2])), 1:2, sum) ) ))
        
        Hz$mu[R.eff.inds.Z, i] <- (tcrossprod(Hz$mu[,indices, drop = FALSE], t(t(Y[i, indices, drop = FALSE])*tmp.precs) ) )[R.eff.inds.Z, ,drop=F] + crossprod(Az$mu, Kz[,i]) / parameters$sigma_hz^2 
        
        Hz$mu[R.eff.inds.Z, i] <- (Hz$sigma[,,i] %*% Hz$mu[, i, drop=F])[R.eff.inds.Z, ,drop=F]
      }
    }
    
    if (any(is.na(Hz$mu))) stop('NAs in Hz')
    
    #########################
    
    if (!any(dont.update == 'sigma_y' )) {
      # update tau and sigma_y_list    
      tau$alpha <- parameters$tau_alpha + colSums(!is.na(Y))/2
      sum.term <- rep(NA, ncol(Y))
      to.sum <- Y
      to.sum[!is.na(to.sum)] <- 0
      for (j in 1:ncol(Y)) {
        for (i in 1:nrow(Y)) {
          if (!is.na(Y[i,j])) {
            to.sum[i,j] <- Y[i,j]^2 -2*Y[i,j]*sum(Hx$mu[,i]*Hz$mu[,j]) + sum( diag( (tcrossprod(Hx$mu[,i]) + Hx$sigma[,,i]) %*% (tcrossprod(Hz$mu[,j]) + Hz$sigma[,,j]) ) )  
          }
        }
      }
      
      # takes 0.37 s/iteration, maybe worth optimization
      #Y^2-2*Y*t(Hx$mu)%*%Hz$mu+
      tau$beta <- parameters$tau_beta + 0.5*colSums(to.sum, na.rm=T)
      sigma_y$mu <- sqrt(1/(tau$alpha/tau$beta))
      
      if (any(is.na(sigma_y$mu))) stop('NAs in sigma_y')
    }
    
    
    
    # add to trace
    if (any(iter == seq(from=1,to=10, by=1)) || iter%%thin==0) {
      #for (i in setdiff(names(traces), 'sigma_h')) {
      for (i in setdiff(names(traces), c('sigma_hx', 'sigma_hz'))) {
        
        #for (i in names(traces)) {
        if (i =='ll') {
          
          current.length <- length(traces[[i]])
          tmp.ll <- dnorm((Y-t(Hx$mu)%*%Hz$mu), sd=matrix(sigma_y$mu, nrow=nrow(Y), ncol=ncol(Y), byrow=T), log=T)
          tmp.ll[is.na(tmp.ll)] <- 0
          traces[[i]][[current.length+1]] <- sum(tmp.ll)
          
        } else {
          for (j in names(traces[[i]])) {
            
            current.length <- length(traces[[i]][[j]])
            traces[[i]][[j]][[current.length+1]] <- eval(parse(text=paste0(i, '$', j)))
          }  
        }  
      } 
      
      traces[['sigma_hz']][['val']][[current.length+1]] <- parameters$sigma_hz
    }
    
    if (iter %%100==0) print(gc())
  }
  
  
  
  
  state <- list(Lambdax = Lambdax, Ax = Ax, Lambdaz = Lambdaz, Az = Az, etaz = etaz, ez = ez, Hx = Hx, Hz = Hz, parameters = parameters, traces=traces, Hx.mu.init = Hx.mu.init)
}
