## PeakANOVA: Stronger findings from mass spectral data through multi-peak modeling
## Sampler functions for multi-way dimensionality reduction

## Copyright 2013 Tommi Suvitaival
# Email: tommi.suvitaival@aalto.fi

# This file is part of PeakANOVA.

# PeakANOVA is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# PeakANOVA is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.

# You should have received a copy of the GNU Lesser General Public License
# along with PeakANOVA.  If not, see <http://www.gnu.org/licenses/>.



sampleFromPriorDR = function(data, priors, xlatDim, clustering=TRUE, sampleScale=FALSE, initialization=NULL) {

	model = list()
	
	if (clustering) {
		# Initialize dataset-specific latent variables 'xlat' and 'ylat'.
		model$xLat = array(data=rnorm(xlatDim[1]*ncol(data$X)), dim=c(xlatDim[1],ncol(data$X)))
		if (!is.null(initialization$V)) {
			model$V = initialization$V
		} else {
			model$V <- array(data=0, dim=c(nrow(data$X),xlatDim[1])) # first level clustering matrix, PxM
			if (length(xlatDim)>1) {
				init_clus = sample.int(n=ncol(model$V), size=nrow(model$V), prob=priors$theta$L, replace=TRUE)
			} else {
				init_clus = sample.int(n=ncol(model$V), size=nrow(model$V), prob=priors$theta$K, replace=TRUE)
			}
			for (k in 1:ncol(model$V)) {
				model$V[which(init_clus==k),k] = 1
			}
		}
		
		if (length(xlatDim)>1) {
			model$z = array(data=rnorm(xlatDim[2]*ncol(data$X)), dim=c(xlatDim[2],ncol(data$X)))
			model$W = array(data=0, dim=c(xlatDim[1],xlatDim[2])) # second level clustering matrix, MxK
			init_clus = sample.int(n=ncol(model$W), size=nrow(model$W), prob=priors$theta$K, replace=TRUE) # TODO: prior
			for (k in 1:ncol(model$W)) {
				model$W[which(init_clus==k),k] = 1
			}
			model$psi = rep(x=1, times=ncol(model$V)) ## TEST: psi fixed at generated value
		}
	} else { # If no clustering, use observed data 'x' as 'xlat'.
		model$xLat = data$X
	}

	model$sigma = rep(x=1, times=nrow(data$X))
	
	return(model)

}



sampleSigma = function(x, V, z, prior.N0=0, prior.sigma0=1) {

	K = ncol(V)
	N = ncol(x)
	M = nrow(x)

	sigma_hat = rowMeans((x-V%*%z)^2, na.rm=TRUE) # Do not include missing values -5.8.13.
	N = rowSums(!is.na(x)) # Do not include missing values -5.8.13.
	if (prior.N0>0) {
		Sigma = (N*sigma_hat+prior.N0*prior.sigma0)/rchisq(n=length(sigma_hat), df=N+prior.N0) # Do not include missing values -5.8.13.
	} else {
		Sigma = sigma_hat*N/rchisq(n=length(sigma_hat), df=N)
	}

	return(Sigma)

}



sampleV <- function(x=NULL, z=NULL, Sigma=NULL, theta=NULL, q=NULL, q.dbeta.log=NULL, V=NULL, shapes.beta.in=c(2,1), shapes.beta.out=c(1,1), return.likelihood=FALSE, infinite=FALSE, alpha.dp=1, object.order.randomization=TRUE) {
	
	if (!infinite) {
		K = length(theta)
	}
	if (!is.null(x)) {
		n = ncol(x) # number of samples
		m = nrow(x) # number of variables
		if (is.null(q)) {
			xRep = array(dim=c(K,dim(x)))
			for (k in 1:K) {
				xRep[k,,] = x
			}
		}
		if (is.null(mu)) {
			mu = rep(x=0, times=m)
		}
		if (is.null(vars)) {
			vars = rep(x=1, times=m)
		}
	} else if (!is.null(q.dbeta.log)) {
		m = nrow(q.dbeta.log$inside)
	} else {
		m = ncol(q)
	}
	if (infinite) {
		V.new = array(data=0, dim=dim(V))
	} else {
		V.new = array(data=0, dim=c(m,K))
		prior.log = log(theta)
		posterior = vector(mode="numeric", length=K)
		summ = vector(mode="numeric", length=K)
	}
	choice = NA
	
	if (!infinite & !is.null(q.dbeta.log)) {
		neighbors = vector(mode="integer")
		if (all(q.dbeta.log$outside==0)) {
			q.outside.flat = TRUE
		} else {
			q.outside.flat = FALSE
		}
		summ.q.inside = array(data=0, dim=c(m,K))
		summ.q.outside = array(data=0, dim=c(m,K))
		for (k in 1:K) { # Go through all clusters to pre-compute the sum for Q
			neighbors = which(V[,k]==1) # With this implementation, diag(q.dbeta.log$inside) and diag(q.dbeta.log$outside) need to be 0!
			if (length(neighbors)>0) {
				summ.q.inside[,k] = rowSums(q.dbeta.log$inside[,neighbors,drop=FALSE])
			}
			if (length(neighbors)<m & !q.outside.flat) {
				if (length(neighbors)==0) {
					summ.q.outside[,k] = rowSums(q.dbeta.log$outside)
				} else {
					summ.q.outside[,k] = rowSums(q.dbeta.log$outside[,-neighbors,drop=FALSE])
				}
			}
		}
	}
	
	if (object.order.randomization) {
		order.objects = permute(1:m)
	}

	for (i in order.objects) { # Go through all variables.
		if (!is.null(x)) {
			if (!is.null(q.dbeta.log)) { # Use pre-computed density values for the correlation.
				summ = rowSums(dnorm(x=xRep[,i,], mean=mu[i]+vars[i]*z, sd=sqrt(Sigma[i]), log=TRUE))+summ.q.inside[i,]+summ.q.outside[i,] # K-loop replaced for also q - 9.10.12
			} else if (!is.null(q)) {
				for (k in 1:K) { # Go through all clusters.
					summ[k] = sum(dnorm(x=x[i,], mean=mu[i]+vars[i]*z[k,], sd=sqrt(Sigma[i]), log=TRUE)) + sum(dbeta(x=q[,i,-i][,which(V[-i,k]==1)], shape1=shapes.beta.in[1], shape2=shapes.beta.in[2], log=TRUE)) + sum(dbeta(x=q[,i,-i][,which(V[-i,k]==0)], shape1=shapes.beta.out[1], shape2=shapes.beta.out[2], log=TRUE)) # Avoid the sample loop by computing the sum of probabilities of all samples at once.
				}
			} else {
				summ = rowSums(dnorm(x=xRep[,i,], mean=mu[i]+vars[i]*z, sd=sqrt(Sigma[i]), log=TRUE)) # K-loop replaced -10.11.10
			}
		} else {
			if (!is.null(q.dbeta.log)) {
				if (infinite) { # Compute the likelihood term for each variable separately, as the clustering is updated after assignment of each variable.
					summ = rep(x=NA, times=ncol(V)+1)
					for (ki in 1:ncol(V)) { # Go through all clusters.
						neighbors.ki = which(V[,ki]==1)
						summ[ki] = sum(q.dbeta.log$inside[i,neighbors.ki])+sum(q.dbeta.log$outside[i,-neighbors.ki])
					}
					summ[length(summ)] = sum(q.dbeta.log$outside[i,]) # likelihood of a new cluster with only one member
				} else {
					summ = summ.q.inside[i,]+summ.q.outside[i,]
				}
			} else if (!is.null(q)) {
				for (k in 1:K) {
					summ[k] = sum(dbeta(x=q[,i,-i][,which(V[-i,k]==1)], shape1=shapes.beta.in[1], shape2=shapes.beta.in[2], log=TRUE)) + sum(dbeta(x=q[,i,-i][,which(V[-i,k]==0)], shape1=shapes.beta.out[1], shape2=shapes.beta.out[2], log=TRUE))
				}
			}
		}
		if (infinite) { # Compute nonparametric prior.
			assignment.current.i = which(V[i,]==1)
			if (length(which(V[,assignment.current.i]==1))==1) {
				V = V[,-assignment.current.i]
				summ = summ[-assignment.current.i]
			} else {
				V[i,assignment.current.i] = 0
			}
			prior.log = log(c(colSums(V),alpha.dp)/(m-1+alpha.dp)) # c(existing, new)
		}
		evidence <- safesum(summ+prior.log)
		posterior <- exp(x=summ+prior.log-evidence)
		choice = sample.int(n=length(posterior), size=1, prob=posterior)
		if (infinite) {
			if (choice==length(posterior)) { # Create a new cluster.
				V = cbind(V, c(rep(x=0,times=i-1),1,rep(x=0,times=m-i)))
			} else { # Assign to an existing cluster.
				V[i,choice] = 1
			}
		} else {
			V.new[i,choice] = 1
		}
	}
	
	if (infinite) {
		V.new = V
		rm(V)
	}
	
	if (return.likelihood) {
		log.likelihood = 0
		if (is.null(x)) {
			if (!is.null(q.dbeta.log)) {
				if (infinite) { # Number of clusters might be changed.
					K = ncol(V.new)
				}
				for (k in 1:K) {
					neighbors = which(V.new[,k]==1)
					log.likelihood = log.likelihood + sum(q.dbeta.log$inside[neighbors,neighbors][upper.tri(x=q.dbeta.log$inside[neighbors,neighbors])]) + sum(q.dbeta.log$outside[neighbors,-neighbors])
				}
			}
		}
		return(list(V=V.new, log.likelihood=log.likelihood))
	} else {
		return(V.new)
	}

}



sampleXlats2mat = function(x, V, Sigma, design=NULL, effects=NULL, Psi=NULL, W=NULL, z=NULL, return.like=FALSE, varr.regularization=0) {

## Arguments:
# design: Matrix, N x (A+B+C+A*B+A*C+B*C+A*B*C)
# effects: Matrix, (A+B+C+A*B+A*C+B*C+A*B*C) x K

	N = ncol(x) # number of samples
	M = nrow(x) # number of variables
	K = ncol(V) # number of clusters

	x.is.na = is.na(x)

	if (!is.matrix(Sigma)) {
		tVinvSigma = t(V/Sigma)
		if (!is.null(Psi) & !is.null(W) & !is.null(z) & is.null(effects) & is.null(design)) { ## covariance matrix non-identity 
			if (!is.matrix(Psi)) {
				invPsi = 1/Psi
				varr = 1/(diag(tVinvSigma%*%V)+invPsi+varr.regularization)
				muu = diag(varr) %*% (tVinvSigma%*%x+diag(invPsi)%*%W%*%z) # diag(Psi) changed to diag(invPsi); the paper says otherwise but this is the way it used to be in earlier implementation -28.9.12
			} else {
				print("ERROR: matrix-form Psi")
			}
		} else if (is.null(W) & is.null(z) & !is.null(effects) & !is.null(design)) { ## population priors for the mean
			effectsPerSample = t(design%*%effects)
			if (!is.null(Psi)) {
				if (!is.matrix(Psi)) {
					invPsi = 1/Psi
				} else {
					print("ERROR: matrix-form Psi")
				}
				if (any(x.is.na)) { ## Handling for missing data
					print("Missing data found!")
					xlatnew <- array(dim=c(K,N))
					samples.with.missing.data = which(apply(X=x.is.na, MAR=2, FUN=any))
					if (length(samples.with.missing.data)<ncol(x)) { # samples with no missing data
						samples.without.missing.data = (1:N)[-samples.with.missing.data]
						varr = 1/(diag(tVinvSigma%*%V)+invPsi)
						muu = diag(varr) %*% (tVinvSigma%*%x[,samples.without.missing.data,drop=FALSE]+diag(invPsi)%*%effectsPerSample[,samples.without.missing.data,drop=FALSE])
						varr.sqrt = sqrt(varr)
						for (ni in 1:length(samples.without.missing.data)) {
							xlatnew[,samples.without.missing.data[ni]] = rnorm(n=K, mean=muu[,ni], sd=varr.sqrt)
						}
					} else { # Initialize for computation with missing values.
						varr = rep(x=NA, times=K)
						muu = rep(x=NA, times=K)
					}
					for (ni in 1:length(samples.with.missing.data)) { # samples with missing data
						variables.observed.ni = which(!x.is.na[,samples.with.missing.data[ni]])
						tVinvSigma.ni = t(V[variables.observed.ni,,drop=FALSE]/Sigma[variables.observed.ni])
						varr = 1/(diag(tVinvSigma.ni%*%V[variables.observed.ni,,drop=FALSE])+invPsi)
						muu = diag(varr) %*% (tVinvSigma.ni%*%x[variables.observed.ni,samples.with.missing.data[ni],drop=FALSE]+diag(invPsi)%*%effectsPerSample[,samples.with.missing.data[ni],drop=FALSE])
						xlatnew[,samples.with.missing.data[ni]] = rnorm(n=K, mean=muu, sd=sqrt(varr))
					}
				} else {
					varr = 1/(diag(tVinvSigma%*%V)+invPsi)
					muu = diag(varr) %*% (tVinvSigma%*%x+diag(invPsi)%*%effectsPerSample)
				}
			} else {
				varr = 1/(diag(tVinvSigma%*%V)+1)
				muu = diag(varr) %*% (tVinvSigma%*%x+effectsPerSample)
			}
		} else {
			print("ERROR: wrong arguments at sampleXlats2mat")
		}
		if (!any(x.is.na)) {
			varr.sqrt = sqrt(varr)
			xlatnew <- array(dim=c(K,N))
			for (ni in 1:N) {
				xlatnew[,ni] = rnorm(n=K, mean=muu[,ni], sd=varr.sqrt)
			}
		}

		if (return.like) {
			like = 0
			for (k in 1:K) { # Compute likelihood of the new variable values -24.3.10
				like = like + sum(dnorm(xlatnew[k,], mean=effectsPerSample[k,]), log=T)
			}
			return(list(z=xlatnew, like=like))
		} else {
			return(xlatnew)
		}
	} else {
		print("ERROR: matrix-form Sigma")
	}

}



safesum <- function(x) { 

	xmax <- max(x);
	safesum<-xmax+log(sum(exp(x-xmax)))

}
