/*
 * NBMM_sampling.cpp
 *
 *  Created on: Aug 23, 2013
 *      Author: cchen
 */

#include <string.h>
#include <math.h>
#include <gsl/gsl_rng.h>
#include <gsl/gsl_randist.h>
#include <gsl/gsl_permutation.h>

#include <scythestat/rng.h>
#include <scythestat/stat.h>
#include <scythestat/smath.h>

#include "gibbs.h"

extern double* vec;
extern gsl_rng* glob_r;

extern gsl_matrix *work;
extern gsl_vector *res1;
extern Matrix<> *A;
extern Matrix<> *B;
extern Matrix<> *A_tot;
extern Matrix<> *B_tot;
extern double *y_probs;

static void sample_topic_lda(int d, int l, int w, Assignment* ass, Cts* cts, Model* model,
		Corpus* c)
{
	int old_topic, new_topic;
	int k;
	double sum_v, r, Ak, Bk;

	old_topic = ass->topic_ass[d][l];
	(*cts->n)(old_topic, d) -= 1;
	assert((*cts->n)(old_topic, d) >= 0);
	cts->m[old_topic][w] -= 1;
	assert(cts->m[old_topic][w] >= 0);
	cts->M[old_topic] -= 1;
	assert(cts->M[old_topic] >= 0);

	for(k = 0; k < model->K; ++k){
		r = (model->alpha + (*cts->n)(k, d));
		r *= (cts->m[k][w] + model->gamma);
		r /= (cts->M[k] + model->gamma * model->v);
		r = log(r);

		Ak = model->L * model->eta[cts->yd[d]][k];
		for(int ll = 0; ll < model->L; ll++){
			Ak -= model->eta[ll][k];
		}
		Ak *= (model->c + model->c * model->c * model->ell / cts->lambda[d]) / cts->N[d];
		Bk = (*B)(k, k);
		for(int k1 = 0; k1 < model->K; k1++){
			Bk += 2 * (*B)(k, k1) * (*cts->n)(k1, d);
		}
		Bk /= (cts->N[d] * cts->N[d]);

		vec[k] = r + Ak + Bk;
	}

	double max1 = vec[0];
	for(int i = 1; i < model->K; i++){
		if(vec[i] > max1){
			max1 = vec[i];
		}
	}
	for(int i = 0; i < model->K; i++){
		vec[i] -= max1;
		vec[i] = exp(vec[i]);
	}

	sum_v = 0;
	for(k = 0; k < model->K; ++k)
		sum_v += vec[k];
	for(k = 0; k < model->K; ++k)
		vec[k] = vec[k] / sum_v;
	new_topic = next_discrete_normalised(vec, model->K);

	ass->topic_ass[d][l] = new_topic;
	(*cts->n)(new_topic, d) += 1;
	cts->m[new_topic][w] += 1;
	cts->M[new_topic] += 1;
	assert(cts->M[new_topic] >= 1);
}

static void sample_lambda(Model *model, Cts *cts, Corpus *cc)
{
	for(int d = 0; d < cc->ndocs; d++){
		double tmp = 0;
		for(int l = 0; l < model->L; l++){
			if(l == cts->yd[d]){
				continue;
			}
			double tmp1 = model->ell - model->eta_z[cts->yd[d]][d] + model->eta_z[l][d];
			tmp += tmp1 * tmp1;
		}
		tmp *= model->c * model->c;
		cts->lambda[d] = GIGRND((3.0 - model->L) / 2.0, model->L - 1.0, tmp);
	}
}

static void sample_eta(Model *model, Cts *cts, Corpus *cc, int recompute)
{
	//rmvnorm(glob_r, model->K, mean1, var1, sample_mu);
	if(recompute){
		for(int k = 0; k< model->K + 1; k++){
			(*A_tot)(k) = 0;
			for(int k1 = 0; k1 < model->K + 1; k1++){
				if(k == k1){
					(*B_tot)(k, k1) = 1 / model->nu / model->nu;
				}else{
					(*B_tot)(k, k1) = 0;
				}
			}
		}
		for(int d = 0; d < cc->ndocs; d++){
			(*A_tot) += (model->eta_z[cts->yd[d]][d] - (model->c * model->ell + cts->lambda[d]) / model->c)
					/ cts->lambda[d] / cts->N[d] * (*cts->n)(_, d);
			(*B_tot) += model->c * model->c / cts->lambda[d] / cts->N[d] / cts->N[d] * ((*cts->n)(_, d) * t((*cts->n)(_, d)));
		}
	}
	//for(int iter = 0; iter < 2; iter++){
	for(int l = 0; l < model->L; l++){
		if(cts->Cy[l] > 0){
			(*B) = (*B_tot);
			for(int k = 0; k< model->K + 1; k++){
				(*A)(k) = 0;
			}
			for(int d = 0; d < cc->ndocs; d++){
				if(cts->yd[d] != l){
					(*A) -= (model->c * model->ell + cts->lambda[d]) / model->c / cts->lambda[d] * (*cts->n)(_, d) / cts->N[d];
					(*A) += model->eta_z[cts->yd[d]][d] / cts->lambda[d] * (*cts->n)(_, d) / cts->N[d];
				}else{
					double tmp = 0;
					for(int ll = 0; ll < model->L; ll++){
						if(ll == l){
							continue;
						}
						tmp += model->eta_z[ll][d];
					}
					(*A) += ((model->L - 1) * (model->c * model->ell + cts->lambda[d]) / model->c + tmp) / cts->lambda[d] * (*cts->n)(_, d) / cts->N[d];
					if(model->L > 2){
						(*B) += model->c * model->c * (model->L - 2) * (*cts->n)(_, d) * t((*cts->n)(_, d)) / cts->lambda[d] / cts->N[d] / cts->N[d];
					}
				}
			}
			(*A) *= (model->c * model->c);
			(*B) = invpd(*B);
			Matrix<> B1 = (*B) * (*A);
			rmvnorm_1(glob_r, model->K + 1, B1, *B, res1, work, model->eta[l]);
		}else{
			for(int k = 0; k < model->K + 1; k++){
				model->eta[l][k] = gsl_ran_gaussian(glob_r, model->nu);
			}
		}

		// update A_tot, eta_z
		for(int d = 0; d < cc->ndocs; d++){
			model->eta_z[l][d] = 0;
			for(int k = 0; k < model->K + 1; k++){
				model->eta_z[l][d] += model->eta[l][k] * (*cts->n)(k, d) / cts->N[d];
			}
		}
	}//}
}

static void sample_y(Model *model, Cts *cts, Corpus *cc)
{
	int L;

	for(int d = 0; d < cc->ndocs; d++){

		//if(d < 100 || d > cc->ndocs - 100){
		//	continue;
		//}

		//remove y_d
		int y = cts->yd[d];
		cts->Cy[y]--;
		/*if(cts->Cy[y] == 0){
			int idx = (int)(next_uniform() * model->C);

			double *tmp = model->eta[y];
			for(int l = y; l < model->L-1; l++){
				model->eta[l] = model->eta[l+1];
			}
			model->eta[model->L-1] = model->eta_aux[idx];
			model->eta_aux[idx] = tmp;

			tmp = model->eta_z[y];
			for(int l = y; l < model->L-1; l++){
				model->eta_z[l] = model->eta_z[l+1];
			}
			model->eta_z[model->L-1] = model->eta_z_aux[idx];
			model->eta_z_aux[idx] = tmp;

			for(int dd = 0; dd < cc->ndocs; dd++){
				if(cts->yd[dd] > y){
					cts->yd[dd]--;
				}
			}

			for(int l = y; l < model->L-1; l++){
				cts->Cy[l] = cts->Cy[l+1];
			}
			cts->Cy[model->L-1] = 0;

			model->L--;

		}*/

		// sampling
		for(int l = 0; l < model->L; l++){
			y_probs[l] = log(cts->Cy[l] + model->omega / model->L);
			for(int ll = 0; ll < model->L; ll++){
				if(ll == l){
					continue;
				}
				//double tmp = cts->lambda[d] + model->c * (model->ell - model->eta_z[l][d] + model->eta_z[ll][d]);
				//y_probs[l] -= tmp * tmp / cts->lambda[d] / 2;
				y_probs[l] -= 2 * model->c * maxx(0, model->ell - model->eta_z[l][d] + model->eta_z[ll][d]);
			}
		}
		/*for(int l = 0; l < model->C; l++){
			y_probs[model->L + l] = log(model->omega / model->C);
			for(int ll = 0; ll < model->L; ll++){
				double tmp = cts->lambda[d] + model->c * (model->ell - model->eta_z_aux[l][d] + model->eta_z[ll][d]);
				y_probs[model->L + l] -= tmp * tmp / cts->lambda[d] / 2;
			}
		}*/

		double max1 = y_probs[0];
		//for(int i = 1; i < model->L + model->C; i++){
		for(int i = 1; i < model->L; i++){
			if(y_probs[i] > max1){
				max1 = y_probs[i];
			}
		}
		//for(int i = 0; i < model->L + model->C; i++){
		for(int i = 0; i < model->L; i++){
			y_probs[i] -= max1;
			y_probs[i] = exp(y_probs[i]);
		}
		double sum_v = 0;
		//for(int k = 0; k < model->L + model->C; ++k){
		for(int k = 0; k < model->L; ++k){
			sum_v += y_probs[k];
		}
		//for(int k = 0; k < model->L + model->C; ++k){
		for(int k = 0; k < model->L; ++k){
			y_probs[k] /= sum_v;
		}
		//y = next_discrete_normalised(y_probs, model->L + model->C);
		y = next_discrete_normalised(y_probs, model->L);
		if(y < model->L){
			cts->Cy[y]++;
		}else{
			//augment
			L = model->maxL;
			while(model->L >= L){
				L += 10;
			}
			if(L > model->maxL){
				model->maxL = L;
				model->eta = (double**)realloc(model->eta, L * sizeof(double*));
				model->eta_z = (double**)realloc(model->eta_z, L * sizeof(double*));
				cts->Cy = (int*)realloc(cts->Cy, L * sizeof(int));
				y_probs = (double*)realloc(y_probs, (L + model->C) * sizeof(double));
				for(int l = model->L; l < model->maxL; l++){
					model->eta[l] = NULL;
					model->eta_z[l] = NULL;
					cts->Cy[l] = 0;
				}
			}

			y = (y - model->L) % model->C;

			double *tmp = model->eta[model->L];
			model->eta[model->L] = model->eta_aux[y];
			model->eta_aux[y] = tmp;
			tmp = model->eta_z[model->L];
			model->eta_z[model->L] = model->eta_z_aux[y];
			model->eta_z_aux[y] = tmp;
			if(!model->eta_aux[y]){
				model->eta_aux[y] = (double*)calloc(model->K + 1, sizeof(double));
				model->eta_z_aux[y] = (double*)calloc(cc->ndocs, sizeof(double));
			}
			for(int k = 0; k < model->K + 1; k++){
				model->eta_aux[y][k] = gsl_ran_gaussian(glob_r, model->nu);
			}
			for(int dd = 0; dd < cc->ndocs; dd++){
				model->eta_z_aux[y][dd] = 0;
				for(int k = 0; k < model->K + 1; k++){
					model->eta_z_aux[y][dd] += model->eta_aux[y][k] * (*cts->n)(k, dd) / cts->N[dd];
				}
			}

			y = model->L;
			model->L++;
			cts->Cy[y] = 1;
		}
		cts->yd[d] = y;
		//printf("L = %d\n", model->L);
	}
}

void stm_e_gibbs(Model* model, Cts* cts, Assignment* ass, Corpus* c, int doy)
{
	int l, d, w;
	gsl_permutation * p;

	std::cout << "finished!!" << std::endl << "sampling eta...\n";
	sample_eta(model, cts, c, 1);
	for(int k = 0; k < model->K + 1; k++){
		printf("%lf ", model->eta[0][k]);
	}
	printf("\n");

	sample_lambda(model, cts, c);

	for(d = 0; d < c->ndocs; ++d){

		// init B
		for(int k1 = 0; k1 < model->K; k1++){
			for(int k2 = 0; k2 < model->K; k2++){
				(*B)(k1, k2) = 0;
				for(int l = 0; l < model->L; l++){
					if(l == cts->yd[d]){
						continue;
					}
					double tmp1 = model->eta[cts->yd[d]][k1] - model->eta[l][k1];
					double tmp2 = model->eta[cts->yd[d]][k2] - model->eta[l][k2];
					(*B)(k1, k2) += tmp1 * tmp2;
				}
			}
		}
		for(int k1 = 0; k1 < model->K; k1++){
			for(int k2 = 0; k2 < model->K; k2++){
				(*B)(k1, k2) *= -model->c * model->c / cts->lambda[d] / 2;
			}
		}

		//std::cout << "Document " << d << "..." << std::endl;
		p = gsl_permutation_alloc (c->docs[d].total);
		gsl_permutation_init (p);
		gsl_ran_shuffle (glob_r, p->data, c->docs[d].total, sizeof(size_t));
		for(l = 0; l < c->docs[d].total; ++l){
			w = gsl_permutation_get(p, l);
			sample_topic_lda(d, w, c->docs[d].words[w], ass, cts, model, c);
			//sample_topic_lda(d, l, c->docs[d].words[l], ass, cts, model, c);
		}
		gsl_permutation_free(p);
	}
	// should update some statistics here
	for(int l = 0; l < model->L; l++){
		for(int d = 0; d < c->ndocs; d++){
			model->eta_z[l][d] = 0;
			for(int k = 0; k < model->K+1; k++){
				model->eta_z[l][d] += model->eta[l][k] * (*cts->n)(k, d);
			}
			model->eta_z[l][d] /= cts->N[d];
		}
	}
	for(int l = 0; l < model->C; l++){
		for(int d = 0; d < c->ndocs; d++){
			model->eta_z_aux[l][d] = 0;
			for(int k = 0; k < model->K+1; k++){
				model->eta_z_aux[l][d] += model->eta_z_aux[l][k] * (*cts->n)(k, d);
			}
			model->eta_z_aux[l][d] /= cts->N[d];
		}
	}

	if(doy){
		sample_y(model, cts, c);
	}
	printf("L = %d:", model->L);
	for(int l = 0; l < model->L; l++){
		printf("%d ", cts->Cy[l]);
	}
	printf("\n");
}
