/*
 * each atom in different times share a subsampling rate
 */

#include <stdlib.h>
#include <math.h>
#include <time.h>
#include <assert.h>
#include "DEIntegrator.h"
#include "mex.h"

#include "arms.c"
#include "slice-sampler.h"

//#define RECORD

#ifdef _MSC_VER
#define finite _finite
#define isnan _isnan
#endif

#ifdef	 __USE_ISOC99
/* INFINITY and NAN are defined by the ISO C99 standard */
#else
double my_infinity(void) {
  double zero = 0;
  return 1.0/zero;
}
double my_nan(void) {
  double zero = 0;
  return zero/zero;
}
#define INFINITY my_infinity()
#define NAN my_nan()
#endif

/********************
 * M:		concentration parameters for sources DP, 1 * nS
 * q:		subsampling rates from source s to group t, nDP * nS
 * n:		counts on sources, 1 * nS
 * n_t:		counts on groups, 1 * nDP
 * n_ts:	counts from group t to source s, nDP * nS
 * nj:		counts on sources for topics k, 1 * K
 * nj_t:	counts on groups for topics k, nDP * K
 * K_c:		#topics in each group, last one is #topics in total, 1 * (nDP + 1)
 * K_ct:	#topics in each source, 1 * nS
 * r:		indicators of inheritance from source s to group t, nDP * K
 * rr:		indicators of random jumps inherited from source s to group t, nDP * nS * K_e
 * K_t:		max #topics considered so far, 1 * 1
 * nDP:		#groups
 * nS:		#sources
 * JJ:		jumps in sources s, nS * K
 * K_e:		#random jumps in each sources
 * K_id:	indicator of which sources this topics belongs to, 1 * K
 */

int maxExJ, sumExJ;
double a, *M, *q, aq, bq, sumQU, *L, maxU;
int *n, *n_t, *n_t_all, **n_ts, *nj, **nj_t, *K_c, *K_ct, **r, *K_id, *K_e, K_t, nS, nDP, *tag_tr, *count1, *count2;
double *JJ, *u, gamma_a, a_ratio, *integral, *inte_tmp;
int iidx, q_i, q_j;
double qlv = -30, qrv = 0, inf = 1.0 / 0.0;
FILE* fid;

int countQ, countU, output;

double **pool;
int len_pool, cur_len_pool;

/*
 * this defines the integrants
*/
class NGG
{
public:
	double operator()(double x) const
	{
		return exp(-x) / pow(x, 1 + a);
	}
};

class NGG1
{
public:
	double operator()(double x) const
	{
		int j;
		double q, ss = 0;
		q = aq / (aq + bq);
		for(j = 0; j < nDP; j++){
			ss += log(1 - q + q * exp(-u[j] * x));
		}
		ss = 1 - exp(ss);
		ss *= exp(-x) * pow(x, -1 - a);
		return ss;
	}
};

/*
 * data loading functions, adapted from Teh's HDP code
*/
#define max1(x1,x2) ( (x1) < (x2) ? (x2) : (x1) )

#define mxReadCellVectorDef(funcname,type) \
  type **funcname(const mxArray *mcell, type shift) { \
    mxArray *mvector; \
    double *mdouble; \
    int ii, jj; \
    type **result; \
    result = (type**)malloc(sizeof(type*)*mxGetNumberOfElements(mcell));  \
    for ( jj = 0 ; jj < mxGetNumberOfElements(mcell) ; jj++ ) {  \
      mvector = mxGetCell(mcell,jj);  \
      mdouble = mxGetPr(mvector);  \
      result[jj] = (type*)malloc(sizeof(type)*mxGetNumberOfElements(mvector));  \
      for ( ii = 0 ; ii < mxGetNumberOfElements(mvector) ; ii++ )  \
        result[jj][ii] = (type)mdouble[ii] + shift;  \
    } \
    return result; \
  }
mxReadCellVectorDef(mxReadIntCellVector,int);
mxReadCellVectorDef(mxReadDoubleCellVector,double);

#define mxWriteCellVectorDef(funcname, type) \
  mxArray *funcname(int numcell,int *numentry,type **var,type shift, int del) { \
    mxArray *result, *mvector; \
    double *mdouble; \
    int ii, jj; \
    result = mxCreateCellMatrix(1,numcell); \
    for ( jj = 0 ; jj < numcell ; jj++) { \
      mvector = mxCreateDoubleMatrix(1,numentry[jj],mxREAL); \
      mxSetCell(result,jj,mvector); \
      mdouble = mxGetPr(mvector); \
      for ( ii = 0 ; ii < numentry[jj] ; ii++ ) \
        mdouble[ii] = var[jj][ii] + shift; \
		if(del) \
      		free(var[jj]); \
    } \
	if(del) \
    	free(var); \
    return result; \
  }
mxWriteCellVectorDef(mxWriteIntCellVector, int);
mxWriteCellVectorDef(mxWriteDoubleCellVector, double);

#define reallo(funcname, type) \
	type** funcname(type** var, int old_len, int add_len, int dim) { \
		int i, j; \
		type** result = (type**)malloc((old_len + add_len) * sizeof(type*)); \
		for(i = 0; i < old_len + add_len; i++){ \
			if(i < old_len){ \
				result[i] = var[i]; \
				var[i] = NULL; \
			}else{ \
				result[i] = (type*)malloc(dim * sizeof(type)); \
				for(j = 0; j < dim; j++){ \
					result[i][j] = (type)0; \
				} \
			} \
		} \
		free(var); \
		return result; \
	}
reallo(realloInt, int);
reallo(realloDouble, double);

#define mxReadVectorDef(funcname,type,str) \
  type *funcname(const mxArray *mvector,int number,type shift,type init) { \
    double *mdouble; \
    type *result;  \
    int ii; \
    number = max1(number,mxGetNumberOfElements(mvector)); \
    result = (type*) malloc(sizeof(type)*number); \
    mdouble = mxGetPr(mvector); \
    for ( ii = 0 ; ii < mxGetNumberOfElements(mvector) ; ii++ ) \
      result[ii] = (type)mdouble[ii] + shift; \
    for ( ii = mxGetNumberOfElements(mvector) ; ii < number ; ii++ ) \
      result[ii] = init; \
    return result; \
  } 
mxReadVectorDef(mxReadIntVector,int,"%d ");
mxReadVectorDef(mxReadDoubleVector,double,"%g ");

#define mxWriteVectorDef(funcname, type, str) \
  mxArray *funcname(int mm,int nn,type *var,type shift, int del) { \
    mxArray *result; \
    double *mdouble; \
    int ii; \
    result = mxCreateDoubleMatrix(mm,nn,mxREAL); \
    mdouble = mxGetPr(result); \
    for ( ii = 0 ; ii < mm*nn ; ii++ ) \
      mdouble[ii] = var[ii] + shift; \
	if(del) \
    	free(var); \
    return result; \
  } 
mxWriteVectorDef(mxWriteIntVector, int, "%d ");
mxWriteVectorDef(mxWriteDoubleVector, double, "%g ");

#define mxReadMatrixDef(funcname,type) \
  type **funcname(const mxArray *marray,int mm,int nn,type shift,type init) { \
    double *mdouble; \
    int ii, jj, m1, n1; \
    type **result; \
    mdouble = mxGetPr(marray); \
    mm = max1(mm, m1 = mxGetM(marray)); \
    nn = max1(nn, n1 = mxGetN(marray)); \
    result = (type**) malloc(sizeof(type*)*mm); \
    for ( jj = 0 ; jj < mm ; jj++ ) { \
      result[jj] = (type*) malloc(sizeof(type)*nn); \
    } \
    for ( jj = 0 ; jj < m1 ; jj++ ) {\
      for ( ii = 0 ; ii < n1 ; ii++ ) \
        result[jj][ii] = (type)mdouble[ii*m1+jj] + shift; \
      for ( ii = n1 ; ii < nn ; ii++ ) \
        result[jj][ii] = init; \
    } \
    for ( jj = m1 ; jj < mm ; jj++ ) \
      for ( ii = 0 ; ii < nn ; ii++ ) \
        result[jj][ii] = init; \
    return result; \
  }
mxReadMatrixDef(mxReadIntMatrix,int);
mxReadMatrixDef(mxReadDoubleMatrix,double);

#define mxWriteMatrixDef(funcname, type) \
  mxArray *funcname(int mm,int nn,int maxm,type **var,type shift, int del) { \
    mxArray *result; \
    double *mdouble; \
    int ii, jj; \
    result  = mxCreateDoubleMatrix(mm,nn,mxREAL); \
    mdouble = mxGetPr(result); \
    for ( jj = 0 ; jj < mm ; jj++) { \
      for ( ii = 0 ; ii < nn ; ii++ ) \
        mdouble[jj+mm*ii] = var[jj][ii] + shift; \
		if(del) \
      		free(var[jj]); \
    } \
	if(del){ \
		for ( jj = mm ; jj < maxm ; jj ++ ) \
			if(var[jj]) \
				free(var[jj]); \
		free(var); \
	}\
	return result; \
  }
mxWriteMatrixDef(mxWriteIntMatrix, int);
mxWriteMatrixDef(mxWriteDoubleMatrix, double);

double mxReadScalar(const mxArray *mscalar) 
{
	return (*mxGetPr(mscalar));
}

mxArray *mxWriteScalar(double var) 
{ 
	mxArray *result;
	result = mxCreateDoubleMatrix(1,1,mxREAL);
	*mxGetPr(result) = var; 
	return result;
}

/***** for windows ******/
/*double drand48()
{
	return (double)(rand() + 1) / (double)(RAND_MAX + 2);
}*/

double randgamma(double rr, double theta) 
{
	double aa, bb, cc, dd;
  	double uu, vv, ww, xx, yy, zz;

  	if ( rr <= 0.0 ) {
    	/* Not well defined, set to zero and skip. */
    	return 0.0;
  	} else if ( rr == 1.0 ) {
    	/* Exponential */
    	return - log(drand48()) / theta;
  	} else if ( rr < 1.0 ) {
    	/* Use Johnks generator */
    	cc = 1.0 / rr;
    	dd = 1.0 / (1.0-rr);
    	while (1) {
      		xx = pow(drand48(), cc);
      		yy = xx + pow(drand48(), dd);
      		if ( yy <= 1.0 ) {
        		return -log(drand48()) * xx / yy / theta;
      		}
    	}
  	} else { /* rr > 1.0 */
    	/* Use bests algorithm */
    	bb = rr - 1.0;
    	cc = 3.0 * rr - 0.75;
    	while (1) {
      		uu = drand48();
      		vv = drand48();
      		ww = uu * (1.0 - uu);
      		yy = sqrt(cc / ww) * (uu - 0.5);
      		xx = bb + yy;
      		if (xx >= 0) {
        		zz = 64.0 * ww * ww * ww * vv * vv;
        		if ( ( zz <= (1.0 - 2.0 * yy * yy / xx) ) ||
             		( log(zz) <= 2.0 * (bb * log(xx / bb) - yy) ) ) {
          			return xx / theta;
        		}
      		}
    	}
  	}
}

double randbeta(double aa, double bb) 
{
	aa = randgamma(aa, 1);
  	bb = randgamma(bb, 1);
  	return aa/(aa+bb);
}

void randdir(double *pi, double *alpha, int veclength, int skip)
{
	double *pi2, *piend;
	double sum;

	sum = 0.0;
	piend = pi + veclength*skip;
	for (pi2 = pi ; pi2 < piend ; pi2 += skip){
		sum += *pi2 = randgamma(*alpha, 1);
		alpha += skip;
	}
	for ( pi2 = pi ; pi2 < piend ; pi2 += skip) {
		*pi2 /= sum;
	}
}

int randmult(double* pi, int veclength)
{
	int i;
  	double sum = 0.0, mass;
	
	for ( i = 0 ; i < veclength ; i++ ){
    	sum += *(pi + i);
		mxAssert(*(pi + i) >= 0, "entries less than zero!!!");
	}
	mxAssert(sum > 0, "sum equals to zero!!!");
	mass = drand48();
	if(mass <= 0 || mass >= 1){
		fprintf(fid, "rand = %f\n", mass);
		mxAssert(mass > 0 && mass < 1, "mass exceeds bounds!!!");
	}
  	mass *= sum;
	i = 0;
  	while (1) {
    	mass -= *(pi + i);
    	if ( mass <= 0.0 ) break;
    	i++;
  	}
	return i;
}


double Normal(double m, double s)
/* ========================================================================
* Returns a normal (Gaussian) distributed real number.
* NOTE: use s > 0.0
*
* Uses a very accurate approximation of the normal idf due to Odeh & Evans, 
* J. Applied Statistics, 1974, vol 23, pp 96-97.
* ========================================================================
*/
{ 
	const double p0 = 0.322232431088;     const double q0 = 0.099348462606;
	const double p1 = 1.0;                const double q1 = 0.588581570495;
	const double p2 = 0.342242088547;     const double q2 = 0.531103462366;
	const double p3 = 0.204231210245e-1;  const double q3 = 0.103537752850;
	const double p4 = 0.453642210148e-4;  const double q4 = 0.385607006340e-2;
	double u, t, p, q, z;
	
	u   = drand48();
	if (u < 0.5)
		t = sqrt(-2.0 * log(u));
	else
		t = sqrt(-2.0 * log(1.0 - u));
	p   = p0 + t * (p1 + t * (p2 + t * (p3 + t * p4)));
	q   = q0 + t * (q1 + t * (q2 + t * (q3 + t * q4)));
	if (u < 0.5)
		z = (p / q) - t;
	else
		z = t - (p / q);
	return (m + s * z);
}

double gammaln(double x)
{
	#define M_lnSqrt2PI 0.91893853320467274178
	static double gamma_series[] = {
		76.18009172947146,
		-86.50532032941677,
		24.01409824083091,
		-1.231739572450155,
		0.1208650973866179e-2,
		-0.5395239384953e-5
	};
	int i;
	double denom, x1, series;
	mxAssert(x > 0, "argument less than zero!!!");
	/*if(x < 0)
		return NAN;
	if(x == 0)
		return INFINITY;*/
	/* Lanczos method */
	denom = x + 1;
	x1 = x + 5.5;
	series = 1.000000000190015;
	for(i = 0; i < 6; i++) {
		series += gamma_series[i] / denom;
		denom += 1.0;
	}
	return( M_lnSqrt2PI + (x + 0.5) * log(x1) - x1 + log(series / x));
}

double gamma(double x)
{
	double rr;
	mxAssert(finite(x) && x > 0, "x less than zero!!!");
	rr = exp(gammaln(x));
	mxAssert(finite(rr), "result infinite!!!");
	return rr;
}

/*
 * initial pool for the components in slice sampling
*/
void initPool(int K, int V)
{
	int i;
	pool = (double**)malloc(sizeof(double*) * K);
	for(i = 0; i < K; i++){
		pool[i] = (double*)calloc(V, sizeof(double));
	}
	len_pool = K;
	cur_len_pool = K - 1;
}
void freePool()
{
	int i;
	for(i = 0; i < len_pool; i++){
		if(pool[i]){
			free(pool[i]);
		}
	}
	free(pool);
}

void CalIntegral(double* val)
{
	int i, j;
	double q = aq / (aq + bq);
	//NGG1 f;
	for(i = 0; i < nS; i++){
		//iidx = i;
		//val[i] = DEIntegrator<NGG1>::Integrate(f, 0, L[i], 1e-7);
		
		val[i] = 0;
		for(j = 0; j < nDP; j++){
			val[i] += u[j];
		}
		val[i] *= q;
		val[i] *= pow(L[i], 1 - a) / (1 - a);
	}
}

double likelihood(double **mu, double *sum_mu, int **s, int **data, int *n, int nDP, double h, int V)
{
	int i, j, k, v;
	double smu, lik = 0;
	smu = h * V;
	for(i = 0; i < nDP; i++){
		for(j = 0; j < n[i]; j++){
			k = s[i][j] - 1;
			v = data[i][j] - 1;
			mu[k][v]--;
			sum_mu[k]--;
		}
	}
	for(i = 0; i < nDP; i++){
		for(j = 0; j < n[i]; j++){
			k = s[i][j] - 1;
			v = data[i][j] - 1;
			lik  += log((h + mu[k][v]) / (smu + sum_mu[k]));
			mu[k][v]++;
			sum_mu[k]++;
		}
	}
	return lik;
}

void CalTrTeLik(double* theta, int **data, double **mu, double *sum_mu, double**us, double h, int V, double *lik)
{
	int i, k, l, ll, v;
	double summ, smu, s1, s2;
	lik[0] = lik[1] = 0;
	s1 = 0; s2 = 0;
	smu = h * V;
	for(i = 0; i < nDP; i++){
		if(tag_tr[i] == 1){
			s1 += n_t[i];
			summ = 0;
			
			for(l = 0; l < n_t[i]; l++){
				summ = 0;
				v = data[i][l] - 1;
				for(k = 0; k < K_t; k++){
					if((nj_t[i][k] > 0 || r[i][k] == 1) && JJ[k] > us[i][l]){
						theta[k] = 1; //JJ[k];
					}else{
						theta[k] = 0;
					}
					summ += theta[k];
				}
				for(k = 0; k < K_t; k++){
					theta[k] /= summ;
				}
				
				summ = 0;
				for(k = 0; k < K_t; k++){
					if(mu[k]){
						summ += theta[k] * (h + mu[k][v]) / (smu + sum_mu[k]);
					}else if(theta[k] > 0){
						summ += theta[k] / V;
					}
				}
				mxAssert(summ > 0, "sum less than zero!!!");
				lik[0] += log(summ);
			}
		}else{
		
			s2 += n_t_all[i] - n_t[i];
			summ = 0;
			
			for(k = 0; k < K_t; k++){
				if(r[i][k] == 1 && JJ[k] > us[i][n_t[i]]){
					theta[k] = JJ[k];
				}else{
					theta[k] = 0;
				}
				summ += theta[k];
			}
			for(k = 0; k < K_t; k++){
				theta[k] /= summ;
			}
			for(l = n_t[i]; l < n_t_all[i]; l++){
				v = data[i][l] - 1;
				summ = 0;
				for(k = 0; k < K_t; k++){
					if(mu[k]){
						summ += theta[k] * (h + mu[k][v]) / (smu + sum_mu[k]);
					}else if(theta[k] > 0){
						summ += theta[k] / V;
					}
				}
				lik[1] += log(summ);
			}
		}
	}
	lik[0] = exp(-lik[0] / max1(s1, 1));
	lik[1] = exp(-lik[1] / max1(s2, 1));
}

void calPerp(double* theta, int **data, double **mu, double *sum_mu, double**us, double h, int V, double *perp)
{
	int i, k, l, v, t, tag, s1, s2;
	double summ, smu, sumu;
	smu = h * V;
	perp[0] = perp[1] = 0;
	s1 = 0; s2 = 0;
	for(i = 0; i < nDP; i++){
		/*tag = 0;
		for(k = 0; k < K_t; k++){
			if(nj[k] == 0){
				if(tag == 0){
					theta[k] = 0;
					for(l = 0; l < nS; l++){
						sumu = u[i];
						for(t = 0; t < nDP; t++){
							if(t == i){
								continue;
							}
							if(drand48() < (aq / (aq + bq))){
								sumu += u[t];
							}
						}
						theta[k] += a * M[l] / pow(1 + sumu, 1 - a);
					}
					tag = 1;
				}else{
					theta[k] = 0;
				}
			}else if(r[i][k] == 1){
				sumu = 1;
				for(t = 0; t < nDP; t++){
					if(r[t][k] == 1){
						sumu += u[t];
					}
				}
				theta[k] = (nj[k] - a) / sumu;
			}else{
				theta[k] = 0;
			}
		}		
		summ = 0;
		for(k = 0; k < K_t; k++){
			summ += theta[k];
		}
		for(k = 0; k < K_t; k++){
			theta[k] /= summ;
		}*/
		
		if(tag_tr[i] == 1){
			s1 += n_t[i];
			for(l = 0; l < n_t[i]; l++){
				
				summ = 0;
				for(k = 0; k < K_t; k++){
					if((nj_t[i][k] > 0 || r[i][k] == 1) && JJ[k] > us[i][l]){
						theta[k] = 1; //JJ[k];
					}else{
						theta[k] = 0;
					}
					summ += theta[k];
				}
				for(k = 0; k < K_t; k++){
					theta[k] /= summ;
				}
				
				v = data[i][l] - 1;
				summ = 0;
				for(k = 0; k < K_t; k++){
					if(mu[k]){
						summ += theta[k] * (h + mu[k][v]) / (smu + sum_mu[k]);
					}else{
						summ += theta[k] / V;
					}
				}
				perp[0] += log(summ);
			}
		}else{
			
			tag = 0;
			for(k = 0; k < K_t; k++){
				if(nj[k] == 0){
					if(tag == 0){
						theta[k] = 0;
						for(l = 0; l < nS; l++){
							//theta[k] += a * M[l] * pow(JJ[k], -a) * (1 + (1 - q[i][K_id[k]]) * exp(-u[i] * JJ[k])) / gamma_a;
							sumu = u[i];
							for(t = 0; t < nDP; t++){
								if(t == i){
									continue;
								}
								if(drand48() < (aq / (aq + bq))){
									sumu += u[t];
								}
							}
							theta[k] += a * M[l] / pow(1 + sumu, 1 - a);
						}
						tag = 1;
					}else{
						theta[k] = 0;
					}
				}else{
					//theta[k] = JJ[k];
					sumu = 1;
					for(t = 0; t < nDP; t++){
						if(r[t][k] == 1){
							sumu += u[t];
						}
					}
					theta[k] = (nj[k] - a) / sumu;
				}
			}
			
			summ = 0;
			for(k = 0; k < K_t; k++){
				summ += theta[k];
			}
			for(k = 0; k < K_t; k++){
				theta[k] /= summ;
			}
			
			s2 += n_t_all[i] - n_t[i];
			for(l = n_t[i]; l < n_t_all[i]; l++){
				v = data[i][l] - 1;
				summ = 0;
				for(k = 0; k < K_t; k++){
					if(mu[k]){
						summ += theta[k] * (h + mu[k][v]) / (smu + sum_mu[k]);
					}else{
						summ += theta[k] / V;
					}
				}
				perp[1] += log(summ);
			}
		}
	}
	perp[0] = exp(-perp[0] / s1);
	perp[1] = exp(-perp[1] / max1(s2, 1));
}

/*
 * sample the indicators r
*/
void sampleR(int dp, double *JJ, int **r, double *u)
{
	int i, k, j, jj, b, e;
	double pp;

	double ss;
	if(dp >= 0){
		b = dp;
		e = dp + 1;
	}else{
		b = 0;
		e = nDP;
	}
	for(j = b; j < e; j++){		
		for(k = 0; k < K_t; k++){
			mxAssert(K_id[k] >= 0, "empty jumps!!!");
			if(nj_t[j][k] > 0 || q[k] == 1){
				r[j][k] = 1;
			}else if(q[k] > 0){
				pp = q[k] * exp(-u[j] * JJ[k]);
				pp = pp / (1 - q[k] + pp);
				if(drand48() < pp){
					r[j][k] = 1;
				}else{
					r[j][k] = 0;
				}
			}else{
				r[j][k] = 0;
			}
		}
	}
}

/*
 * another way to sample the indicators r, seems works better
*/
void sampleR_1(int dp, double *JJ, int **r, double *u)
{
	int i, k, j, jj, b, e;
	double pp, ss;

	if(dp >= 0){
		b = dp;
		e = dp + 1;
	}else{
		b = 0;
		e = nDP;
	}
	for(j = b; j < e; j++){
		ss = 0;
		for(k = 0; k < K_t; k++){
			if(nj_t[j][k] > 0 || q[k] == 1.0){
				r[j][k] = 1;
				ss += JJ[k];
			}else{
				r[j][k] = 0;
			}
		}
		for(i = 0; i < 2; i++){
			for(k = 0; k < K_t; k++){
				if(q[k] == 0.0){
					r[j][k] = 0;
				}else if(nj_t[j][k] == 0 && q[k] < 1.0){
					if(r[j][k] == 1){
						ss -= JJ[k];
					}
					pp = 1.0 / (1 + (1 - q[k]) / q[k] * pow((ss + JJ[k]) / ss, n_t[j]));
					if(drand48() < pp){
						r[j][k] = 1;
						ss += JJ[k];
					}else{
						r[j][k] = 0;
					}
				}
			}
		}
	}
}

/*
 * sample jumps with observations
*/
void sampleJ(double *J_ex, double a)
{
	int j, k, i;
	double tmp;
	for(i = 0; i < 2; i++){
		sampleR(-1, J_ex, r, u);
		//sampleR_1(-1, J_ex, r, u);
		for(k = 0; k < K_t; k++){
			if(nj[k] > 0){
				tmp = 1;
				for(j = 0; j < nDP; j++){		
					if(r[j][k] == 1){
						tmp += u[j];
					}
				}
				J_ex[k] = randgamma(nj[k] - a, tmp);
				mxAssert(J_ex[k] > 0, "Jumps less than zero!!!");
			}
		}
	}
}

/*
 * intialize memory for storing extra jumps
*/
void initRM(int nS, int nDP)
{
	int i;
	K_e = (int*)malloc(sizeof(int) * nS);
	for(i = 0; i < nS; i++){
		K_e[i] = 0;
	}
}
void freeRM(int ns, int nDP)
{
	free(K_e);
}

/*
 * simulate jumps without observations, using different bounded mean measures
*/
/* w_t(s) = ... e^(-s) t^(-1-a) */
void SimuPoiP(int i, double L, double a, int* K_ex, double* JJ_ex)
{
	int jj, k = 0;
	double q, e, p, pp, p1, sumu, t1, t = L;

	q = (aq / (aq + bq));
	pp = 0;
	sumu = 0;
	for(jj = 0; jj < nDP; jj++){
		sumu += u[jj];
	}
	pp = max(0, 1 - q * sumu * t);
	while(1){
		e = randgamma(1, 1);
		p1 = pp;
		p = p1 * a * M[i] * exp(-t) / gamma_a / pow(t, 1 + a);
		mxAssert(!isnan(p) && finite(p), "p equals to NAN or INFINITE!!!");
		if(e >= p){
			break;
		}else{
			t1 = t - log(1 - e / p);
			if(!finite(t1)){
				break;
			}
			mxAssert(t1 >= t, "illegal jumps!!!");
		}

		pp = max(0, 1 - q * sumu * t1);
		p = pp / p1 * pow(t / t1, 1 + a);
		mxAssert(finite(p) && p <= 1, "accept rate larger than 1!!!");
		if(p > drand48()){
			JJ_ex[k % maxExJ] = t1;
			k++;
		}
		t = t1;
		/******* too many jumps *******/
		/*if(k == maxExJ - 1){
			break;
		}*/
	}
	if(k > maxExJ){
		k = maxExJ;
	}
	K_ex[i] = k;
	/*
	 if(output == 1){
	 	 fprintf(fid, "#extra jumps in the %d-th source: %d\n", i, K_e[i]);
	 }else{
	 	 printf("#extra jumps in the %d-th source: %d\n", i, K_e[i]);
	 }
	 */
}

/* w_t(s) = ... e^(-t) s^(-1-a) */
void SimuPoiP_1(int i, double L, double a, int* K_ex, double* JJ_ex)
{
	int jj, k = 0;
	double q, e, p, pp, p1, t1, t = L;

	q = (aq / (aq + bq));
	pp = 0;
	for(jj = 0; jj < nDP; jj++){
		if(q == 1.0){
			pp -= u[jj] * t;
		}else if(q > 0){
			pp += log(1 - q + q * exp(-u[jj] * t));
		}
	}
	while(1){
		e = randgamma(1, 1);
		p1 = pp;
		p = exp(p1) * M[i] * exp(-t) / gamma_a / pow(t, a);
		mxAssert(!isnan(p) && finite(p), "p equals to NAN or INFINITE!!!");
		if(e >= p){
			break;
		}else{
			t1 = pow((1 - e / p), -1.0 / a) * t;
			if(!finite(t1)){
				break;
			}
			mxAssert(t1 >= t, "illegal jumps!!!");
		}

		pp = 0;
		for(jj = 0; jj < nDP; jj++){
			if(q == 1){
				pp -= u[jj] * t1;
			}else if(q > 0){
				pp += log(1 - q + q * exp(-u[jj] * t1));
			}
		}
		p = exp(t + pp - p1 - t1);
		mxAssert(finite(p) && p <= 1, "accept rate larger than 1!!!");
		if(p > drand48()){
			JJ_ex[k % maxExJ] = t1;
			k++;
		}
		t = t1;
		/******* too many jumps *******/
		/*if(k == maxExJ - 1){
			break;
		}*/
	}
	if(k > maxExJ){
		k = maxExJ;
	}
	K_ex[i] = k;
	/*
	if(output == 1){
		fprintf(fid, "#extra jumps in the %d-th source: %d\n", i, K_e[i]);
	}else{
		printf("#extra jumps in the %d-th source: %d\n", i, K_e[i]);
	}
	*/
}

/* w_t(s) = ... e^(-s) t^(-1-a) */
void SimuPoiP_sub(int i, double L, double a, int* K_ex, double* JJ_ex)
{
	int jj, k = 0;
	double q, e, p, usum, t1, t = L;

	q = (aq / (aq + bq));
	while(1){
		e = randgamma(1, 1);
		usum = 0;
		for(jj = 0; jj < nDP; jj++){
			if(q > drand48()){
				usum += u[jj];
			}
		}
		p = a * M[i] * exp(-(1 + usum) * t) / gamma_a / pow(t, 1 + a) / (1 + usum);
		mxAssert(!isnan(p) && finite(p), "p equals to NAN or INFINITE!!!");
		if(e >= p){
			break;
		}else{
			t1 = t - log(1 - e / p) / (1 + usum);
			if(!finite(t1)){
				break;
			}
			mxAssert(t1 >= t, "illegal jumps!!!");
		}

		p = pow(t / t1, 1 + a);
		mxAssert(finite(p) && p <= 1, "accept rate larger than 1!!!");
		if(p > drand48()){
			JJ_ex[k % maxExJ] = t1;
			k++;
		}
		t = t1;
	}
	if(k > maxExJ){
		k = maxExJ;
	}
	K_ex[i] = k;
}

/* w_t(s) = ... e^(-t) s^(-1-a) */
void SimuPoiP_sub_1(int i, double L, double a, int* K_ex, double* JJ_ex)
{
	int jj, k = 0;
	double q, e, p, usum, t1, t = L;

	q = (aq / (aq + bq));
	while(1){
		e = randgamma(1, 1);
		usum = 0;
		for(jj = 0; jj < nDP; jj++){
			if(q > drand48()){
				usum += u[jj];
			}
		}
		p = M[i] * exp(-(1 + usum) * t) / gamma_a / pow(t, a);
		mxAssert(!isnan(p) && finite(p), "p equals to NAN or INFINITE!!!");
		if(e >= p){
			break;
		}else{
			t1 = pow((1 - e / p), -1.0 / a) * t;
			if(!finite(t1)){
				break;
			}
			mxAssert(t1 >= t, "illegal jumps!!!");
		}

		p = exp((1 + usum) * (t - t1));
		mxAssert(finite(p) && p <= 1, "accept rate larger than 1!!!");
		if(p > drand48()){
			JJ_ex[k % maxExJ] = t1;
			k++;
		}
		t = t1;
	}
	if(k > maxExJ){
		k = maxExJ;
	}
	K_ex[i] = k;
}

/*
 * sample the slice auxiliary varialbes
*/
void SampleUS(int ndp, double **us, int **s)
{
	int i, j, b, e;
	if(ndp < 0){
		b = 0;
		e = nDP;
	}else{
		b = ndp;
		e = ndp + 1;
	}
	
	for(i = b; i < e; i++){
		for(j = 0; j < n_t[i]; j++){
			us[i][j] = drand48() * JJ[s[i][j] - 1];
			mxAssert(finite(us[i][j]), "us infinite!!!");
		}
		if(tag_tr[i] == 0){
			us[i][n_t[i]] = drand48() * JJ[s[i][n_t[i]] - 1];
		}
	}
}
/*
 * determine the bound L
*/
void SampleL(double *L, double **us, int **s)
{
	int i, j, k;
	
	L[0] = 0.0001 / maxU;
	for(i = 1; i < nS; i++){
		L[i] = L[i - 1];
	}
	for(i = 0; i < nDP; i++){
		for(j = 0; j < n_t[i]; j++){
			k = s[i][j] - 1;
			mxAssert(k < K_t, "k larger than maxium K!!!");
			if(us[i][j] < L[K_id[k]]){
			//if(us[i][j] < L[0]){
				L[K_id[k]] = us[i][j];
				//L[0] = us[i][j];
			}
		}
		if(tag_tr[i] == 0){
			k = s[i][n_t[i]] - 1;
			mxAssert(k < K_t, "k larger than maxium K!!!");
			if(us[i][n_t[i]] < L[K_id[k]]){
			//if(us[i][n_t[i]] < L[0]){
				L[K_id[k]] = us[i][n_t[i]];
				//L[0] = us[i][n_t[i]];
			}
		}
	}
	//for(i = 1; i < nS; i++){
	//	L[i] = L[0];
	//}
}

void sampleM_Gibbs(int i, double L, double a, double a0, double b0)
{	
	int k, nn = 0;
	NGG f;
	double inte = DEIntegrator<NGG>::Integrate(f, L, 100, 1e-7);
	mxAssert(finite(inte), "integral infinite!!!");
	for(k = 0; k < K_t; k++){
		if(K_id[k] == i){
			nn++;
		}
	}
	/*M[i] = randgamma(nn + a0, pow(L, 1 - a) * a_ratio * sumQU[i] + inte * a / gamma_a + b0);*/
	M[i] = randgamma(nn + a0, (inte + integral[i]) * a / gamma_a + b0);
}
void sampleU_Gibbs(int j, double* L, double au, double bu, double a)
{
	int i, k;
	double uu, tmp, sumrj, acc, inte;
	NGG1 f1;
	uu = u[j];
	
	sumrj = 0;
	tmp = 0;
	for(k = 0; k < K_t; k++){
		mxAssert(K_id[k] >= 0, "no source jumps!!!");
		if(nj_t[j][k] > 0 || r[j][k] == 1){
			sumrj += JJ[k];
		}
	}
	
	/*inte = 0;
	for(i = 0; i < nS; i++){
		inte += (aq / (aq + bq)) * M[i] * pow(L[i], 1 - a);
	}
	inte *= a / gamma_a / (1 - a);*/
	if(tag_tr[j] == 1){
		tmp = randgamma(n_t[j] + au - 1, sumrj+bu);
	}else{
		tmp = randgamma(n_t[j] + au, sumrj+bu);
	}
	
	inte = 0;
	acc = 0;
	u[j] = tmp;
	CalIntegral(inte_tmp);
	for(i = 0; i < nS; i++){
		acc -= a * M[i] / gamma_a * (inte_tmp[i] - integral[i]);
	}
	if(tag_tr[j] == 1){
		acc += (n_t[j] + au - 2) * log(u[j] / uu);
	}else{
		acc += (n_t[j] - 1 + au) * log(u[j] / uu);
	}
	acc -= (u[j] - uu) * (sumrj + bu);
	if(acc < log(drand48())){
		u[j] = uu;
	}else{
		countU++;
		for(i = 0; i < nS; i++){
			integral[i] = inte_tmp[i];
		}
		sumQU += (u[j] - uu);
	}
}

struct aqPost{
	aqPost()
	{}
	double operator()(double x) const
	{
		int k;
		double val = 0;
		for(k = 0; k < K_t; k++){
			if(nj[k] > 0){
				mxAssert(q[k] > 0, "q less than zero!!!");
				val += x * log(q[k]);
				val += gammaln(count1[k] + count2[k] + bq + x) - gammaln(count1[k] + x);
			}
		}
		val += log(x) - x;
		return val;
	}
};
struct bqPost{
	bqPost()
	{}
	double operator()(double x) const
	{
		int k;
		double val = 0;
		for(k = 0; k < K_t; k++){
			if(nj[k] > 0){
				mxAssert(q[k] > 0, "q less than zero!!!");
				val += x * log(1 - q[k]);
				val += gammaln(count1[k] + count2[k] + aq + x) - gammaln(count2[k] + x);
			}
		}
		val += log(x) - x;
		return val;
	}
};
struct aPosterior{
	int sumk;
	double sumJ;
	aPosterior(int sumk, double sumJ)
	:sumk(sumk), sumJ(sumJ)
	{}
	double operator()(double x) const
	{
		int i, k;
		double acc, a_tmp, inte1, inte2, q = aq / (aq + bq);
		NGG f;
		NGG1 f1;
		
		a_tmp = a;
		a = x;
		acc = sumk * (log(x) - gammaln(1 - x));
		acc -= x * log(sumJ);
		for(i = 0; i < nS; i++){
			inte1 = DEIntegrator<NGG>::Integrate(f, L[i], 100, 1e-7);
			inte2 = q * sumQU * pow(L[i], 1 - x) / (1 - x);
			acc -= M[i] * (inte1 + inte2) * x / gamma(1 - x);
		}
		a = a_tmp;
		return acc;
	}
};

/*
 * sample Q
*/
double sampleaq(double aq)
{
	aqPost post;
	return slice_sampler1d(post, aq, drand48, 1.0e-10, 1e10, 0.0, 10, 1000);
}

double samplebq(double bq)
{
	bqPost post;
	return slice_sampler1d(post, bq, drand48, 1.0e-10, 1e10, 0.0, 10, 1000);
}
// without rejection
void sampleQ_beta()
{
	int j, n1, n2, t, k;
	double qq, acc, tmp;
	for(k = 0; k < K_t; k++){
		count1[k] = count2[k] = 0;
		if(nj[k] > 0){
			for(j = 0; j < nDP; j++){
				if(r[j][k] == 1){
					count1[k]++;
				}else{
					count2[k]++;
				}
			}
		}
	}
//	for(t = 0; t < 2; t++){
//		for(k = 0; k < 2; k++){
//			aq = sampleaq(aq);
//			bq = samplebq(bq);
//		}
		for(k = 0; k < K_t; k++){
			q[k] = randbeta(count1[k] + aq, count2[k] + bq);
		}
//	}
}

/*
 * sample \sigma
*/
void samplea_slice(int K)
{
	int i;
	double sumJ = 0;
	for(i = 0; i < K; i++){
		sumJ += JJ[i];
	}
	mxAssert(sumJ > 0, "sum of jumps less than zero!!!");
	aPosterior apos(K_t, sumJ);
	a = slice_sampler1d(apos, a, drand48, 1.0e-10, 0.9999, 0.0, 10, 1000);
	gamma_a = gamma(1 - a);
}

/******************
input: 	0: mu,	K * V						output: 0: mu
		1: sum_mu,	K								1: sum_mu
		2: s,	cell								2: nj
		3: M										3: nj_t
		4: q										4: n_ts
		5: n										5: Kid
		6: n_t										6: s
		7: n_ts										7: n
		8: nj										8: M
		9: nj_t										9: u
		10: K_c										10: Kct	
		11: K_ct									11: us
		12: K_id									12: L
		13: gamma									13: r
		14: burnin,	1 * 1							14: Kc
		15:	nsample, 1 * 1
		16: lag, 1 * 1
		17: calLik, 1 * 1
		18: data
		19:	a
		20:	gamma_a
		
******************/

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
	const int npara = 15;
	char* log_file;
	int i, j, l, num_DP, v, k, K_tmp, V, K, kk, iter, maxIter, burnin, nsample, lag, calLik, nLik;
	int **s, **data, *holds, *counts, ns, counter;
	double **mu, *sum_mu, *prob, **us, mu0, v_sum, lik;
	double *lik_pr, *lik_te_pr, *comp_tmp, *JJ_ex;
	double llik[2];
	int lik_iter = 0;
	mxArray *struc;
	const char* fieldnames[npara] = {"a", "mu", "sum_mu", "nj", "nj_t", "n_ts", "Kid", "s", "n", "M", "u", "us", "L", "r", "q"};
	// read data and initialize
	V = mxGetN(prhs[0]);
	K = mxGetM(prhs[0]); /* current maximal topics */
	nDP = mxGetNumberOfElements(prhs[6]);
	tag_tr = mxReadIntVector(prhs[24], nDP, 0, 0);
	nS = mxGetNumberOfElements(prhs[5]);
	mu0 = mxGetScalar(prhs[13]);
	burnin = mxGetScalar(prhs[14]);
	nsample = mxGetScalar(prhs[15]);
	lag = mxGetScalar(prhs[16]);
	calLik = (int)mxGetScalar(prhs[17]);
	maxIter = burnin + nsample * lag;

	maxExJ = mxGetScalar(prhs[21]);
	sumExJ = maxExJ * nS;
	JJ_ex = (double*)malloc(sizeof(double) * maxExJ);
	mu = mxReadDoubleMatrix(prhs[0], K, V, 0, 0);
	sum_mu = mxReadDoubleVector(prhs[1], K, 0, 0);
	s = mxReadIntCellVector(prhs[2], 0);
	M = mxReadDoubleVector(prhs[3], nS, 0, 0);
	q = mxReadDoubleVector(prhs[4], K, 0, 0);
	n = mxReadIntVector(prhs[5], nS, 0, 0);
	n_t = mxReadIntVector(prhs[6], nDP, 0, 0);
	n_ts = mxReadIntMatrix(prhs[7], nDP, nS, 0, 0);
	nj = mxReadIntVector(prhs[8], K, 0, 0);
	nj_t = mxReadIntMatrix(prhs[9], nDP, K, 0, 0);
	K_c = mxReadIntVector(prhs[10], nDP + 1, 0, 0);
	K_ct = mxReadIntVector(prhs[11], nS, 0, 0);
	K_id = mxReadIntVector(prhs[12], K, 0, 0);
	data = mxReadIntCellVector(prhs[18], 0);
	n_t_all = mxReadIntVector(prhs[23], nDP, 0, 0);
	a = mxGetScalar(prhs[19]);
	gamma_a = mxGetScalar(prhs[20]);
	
	output = (int)mxGetScalar(prhs[25]);
		
	//srand ( time(NULL) );
	if(output == 1){
		i = (mxGetM(prhs[26]) * mxGetN(prhs[26])) + 1;
		log_file = (char*)mxCalloc(i, sizeof(char)); 
		int status = mxGetString(prhs[26], log_file, i);
		mxAssert(status == 0, "read log file fail!!!");
		fid = fopen(log_file, "w");
	}
	aq = randgamma(2, 1);
	bq = randgamma(2, 1);
	count1 = (int*)malloc(sizeof(int) * K);
	count2 = (int*)malloc(sizeof(int) * K);
	/*
	 * file to store statistics
	 */
#ifdef RECORD
	FILE *fidstat = fopen("stat_TNGG_ftm.txt", "w");
	/******************/
#endif
	
	initPool(maxExJ, V);
	a_ratio = a / (1 - a) / gamma_a;
	K_t = K;
	v_sum = mu0 * V;
	prob = (double*)malloc(sizeof(double) * (K + sumExJ));
	holds = (int*)malloc(sizeof(int) * (K));
	counts = (int*)malloc(sizeof(int) * (K));
	//u = (double*)malloc(sizeof(double) * nDP);
	u = mxReadDoubleVector(prhs[22], nDP, 0, 0);
	L = (double*)malloc(sizeof(double) * nS);
	JJ = (double*)malloc(sizeof(double) * K);
	us = (double**)malloc(sizeof(double*) * nDP);
	integral = (double*)malloc(sizeof(double) * nS);
	inte_tmp = (double*)malloc(sizeof(double) * nS);
	lik = 0;
	for(i = 0; i < nDP; i++){
		if(tag_tr[i] == 1){
			us[i] = (double*)malloc(sizeof(double) * (n_t[i]));
		}else{
			us[i] = (double*)malloc(sizeof(double) * (n_t[i] + 1));
		}
		//u[i] = 10;
		lik += u[i];
	}
	maxU = u[0];
	for(i = 1; i < nDP; i++){
		if(u[i] > maxU){
			maxU = u[i];
		}
	}

	for(k = 0; k < K_t; k++){
		if(nj[k] > 0){
			JJ[k] = randgamma(nj[k] - a, 1 + lik);
		}else{
			JJ[k] = 0;
		}
	}
	SampleUS(-1, us, s);
	SampleL(L, us, s);
	
	initRM(nS, nDP);

	r = (int**)malloc(sizeof(int*) * nDP);
	for(i = 0; i < nDP; i++){
		r[i] = (int*)malloc(sizeof(int) * K);
		for(j = 0; j < K; j++){
			r[i][j] = 0;
		}
	}
	
	sampleR(-1, JJ, r, u);
	//sampleR_1(-1, JJ, r, u);

	sumQU = 0;
	for(j = 0; j < nDP; j++){
		sumQU += u[j];
	}
	
	ns = 0;
	struc = mxCreateStructMatrix(1, nsample, npara, fieldnames);
	plhs[0] = struc;
	if(calLik == 0){
		nLik = 0;
	}else{
		nLik = (maxIter - 1) / calLik + 1;
	}
	plhs[1] = mxCreateDoubleMatrix(1, nLik, mxREAL);
	lik_pr = mxGetPr(plhs[1]);
	plhs[2] = mxCreateDoubleMatrix(1, nLik, mxREAL);
	lik_te_pr = mxGetPr(plhs[2]);

	K_tmp = K_t;
	CalIntegral(integral);
	for(iter = 0; iter < maxIter; iter++){
		
		// correct errors in sumQU
		if((iter + 1) % 300 == 0){
			sumQU = 0;
			for(j = 0; j < nDP; j++){
				sumQU += u[j];
			}
		}

		//if(iter > 1000 && iter < 2000){
		countQ = 0;
		sampleQ_beta();
		if(output == 1){
			fprintf(fid, "acceptance rate for Q: %f\n", ((double)countQ) / (nS * nDP));
		}else{
			printf("acceptance rate for Q: %f\n", ((double)countQ) / (nS * nDP));
		}
		//}
		
		sampleR(-1, JJ, r, u);
		//sampleR_1(-1, JJ, r, u);

		for(num_DP = 0; num_DP < nDP; num_DP++){
			int tag;
			
		    /* Update allocation variables, s */
			for(j = 0; j < n_t[num_DP]; j++){
				v = data[num_DP][j] - 1;
				k = s[num_DP][j] - 1;
				mxAssert(k >= 0, "topic assignment less than zero!!!");
				mu[k][v]--;
				mxAssert(mu[k][v] >= 0, "counts on topic less than zero!!!");
				sum_mu[k]--;
				mxAssert(sum_mu[k] >= 0, "total counts on topic less than zero!!!");
				nj_t[num_DP][k]--;
				mxAssert(nj_t[num_DP][k] >= 0, "counts on one doc less than zero!!!");
				nj[k]--;
				mxAssert(nj[k] >= 0, "counts on topic less than zero!!!");
				n[K_id[k]]--;
				n_ts[num_DP][K_id[k]]--;
				
				for(kk = 0; kk < K_t; kk++){
					mxAssert(K_id[kk] >= 0, "topic without sources!!!");
					if((nj_t[num_DP][kk] > 0 || r[num_DP][kk] == 1) && JJ[kk] >= us[num_DP][j]){
						if(mu[kk]){
							prob[kk] = (mu[kk][v] + mu0) / (sum_mu[kk] + v_sum);
						}else{
							mxAssert(sum_mu[kk] == 0, "#data larget than zero!!!");
							prob[kk] = 1.0 / V;
						}
					}else{
						prob[kk] = 0;
					}
					mxAssert(prob[kk] >= 0, "prob less than zero!!!");
				}
				
		        kk = randmult(prob, K_t);
				s[num_DP][j] = kk + 1;
				
				if(mu[kk] == NULL){
					int ii;
					if(cur_len_pool < 0){
						for(ii = 0; ii < len_pool; ii++){
							pool[ii] = (double*)calloc(V, sizeof(double));
						}
						cur_len_pool = len_pool - 1;
					}
					mu[kk] = pool[cur_len_pool];
					pool[cur_len_pool] = NULL;
					cur_len_pool--;
				}
				mu[kk][v]++;
				mxAssert(mu[kk][v] > 0, "counts on topic less than zero!!!");
				sum_mu[kk]++;
				mxAssert(sum_mu[kk] > 0, "total counts on topic less than zero!!!");
				nj_t[num_DP][kk]++;
				mxAssert(nj_t[num_DP][kk] > 0, "counts on one doc less than zero!!!");
				nj[kk]++;
				n[K_id[kk]]++;
				n_ts[num_DP][K_id[kk]]++;
				//mxAssert(q[num_DP][KK_id[kk]] > 0, "rate less than zero!!!");
			}
			if(tag_tr[num_DP] == 0){ //testing data
				k = s[num_DP][n_t[num_DP]] - 1;
				mxAssert(k >= 0, "topic assignment less than zero!!!");
				nj_t[num_DP][k]--;
				mxAssert(nj_t[num_DP][k] >= 0, "counts on one doc less than zero!!!");
				nj[k]--;
				mxAssert(nj[k] >= 0, "counts on topic less than zero!!!");
				n[K_id[k]]--;
				n_ts[num_DP][K_id[k]]--;
				
				for(kk = 0; kk < K_t; kk++){
					mxAssert(K_id[kk] >= 0, "topic without sources!!!");
					if((nj_t[num_DP][kk] > 0 || r[num_DP][kk] == 1) && JJ[kk] >= us[num_DP][n_t[num_DP]]){
						prob[kk] = JJ[kk];
					}else{
						prob[kk] = 0;
					}
					mxAssert(prob[kk] >= 0, "prob less than zero!!!");
				}
				
		        kk = randmult(prob, K_t);
				s[num_DP][n_t[num_DP]] = kk + 1;
				
				nj_t[num_DP][kk]++;
				mxAssert(nj_t[num_DP][kk] > 0, "counts on one doc less than zero!!!");
				nj[kk]++;
				n[K_id[kk]]++;
				n_ts[num_DP][K_id[kk]]++;
			}
		}
		
		if(calLik && iter % calLik == 0){
			//llik[0] = likelihood(mu, sum_mu, s, data, n_t, nDP, mu0, V); llik[1] = 0;
			//CalTrTeLik(prob, data, mu, sum_mu, us, mu0, V, llik);
			calPerp(prob, data, mu, sum_mu, us, mu0, V, llik);
			lik_pr[lik_iter] = llik[0];
			lik_te_pr[lik_iter] = llik[1];
			if(output == 1){
				fprintf(fid, "In iteration %d, K = %d, a = %f, lik_tr = %f, lik_te = %f...\n", iter, K_tmp, a, lik_pr[lik_iter], lik_te_pr[lik_iter]);
			}else{
				printf("In iteration %d, K = %d, a = %f, lik_tr = %f, lik_te = %f...\n", iter, K_tmp, a, lik_pr[lik_iter], lik_te_pr[lik_iter]);
			}
			lik_iter++;
		}else{
			if(output == 1){
				fprintf(fid, "In iteration %d, K = %d, a = %f...\n", iter, K_tmp, a);
			}else{
				printf("In iteration %d, K = %d, a = %f...\n", iter, K_tmp, a);
			}
		}

		countU = 0;
		for(i = 0; i < 2; i++){
			for(j = 0; j < nDP; j++){
				sampleU_Gibbs(j, L, 1, 0, a);
			}
			sampleJ(JJ, a);
		}
		if(output == 1){
			fprintf(fid, "acceptance rate for U: %f\n", ((double)countU) / (2 * nDP));
		}else{
			printf("acceptance rate for U: %f\n", ((double)countU) / (2 * nDP));
		}

		i = 0;
		counter = 1;
		for(k = 0; k< K_t; k++){
			if(nj[k] > 0){
				holds[i] = k;
				counts[k] = counter;
				counter++;
				i++;
			}else{
				counts[k] = 0;
			}
		}
		for(k = 0; k < i; k++){
			JJ[k] = JJ[holds[k]];
			nj[k] = nj[holds[k]];
			K_id[k] = K_id[holds[k]];
			sum_mu[k] = sum_mu[holds[k]];
			q[k] = q[holds[k]];
			if(mu[k] && k != holds[k]){
				int ii;
				if(cur_len_pool == len_pool - 1){
					len_pool *= 2;
					pool = (double**)realloc(pool, len_pool * sizeof(double*));
					for(ii = cur_len_pool + 1; ii < len_pool; ii++){
						pool[ii] = NULL;
					}
				}
				cur_len_pool++;
				pool[cur_len_pool] = mu[k];
				mu[k] = NULL;
			}
			mu[k] = mu[holds[k]];
			mxAssert(mu[k] != NULL, "empty component!!!");
			if(k != holds[k]){
				mu[holds[k]] = NULL;
			}
			for(j = 0; j < nDP; j++){
				nj_t[j][k] = nj_t[j][holds[k]];
			}
		}
		for(k = i; k < K_t; k++){
			JJ[k] = 0;
			nj[k] = 0;
			K_id[k] = -1;
			sum_mu[k] = 0;
			q[k] = (aq / (aq + bq));
			if(mu[k]){
				int ii;
				if(cur_len_pool == len_pool - 1){
					len_pool *= 2;
					pool = (double**)realloc(pool, len_pool * sizeof(double*));
					for(ii = cur_len_pool + 1; ii < len_pool; ii++){
						pool[ii] = NULL;
					}
				}
				cur_len_pool++;
				pool[cur_len_pool] = mu[k];
				mu[k] = NULL;
			}
			for(j = 0; j < nDP; j++){
				nj_t[j][k] = 0;
			}
		}
		K_t = i;
		for(j = 0; j < nDP; j++){
			for(l = 0; l < n_t[j]; l++){
				s[j][l] = counts[s[j][l] - 1];
				mxAssert(s[j][l] > 0, "assignment topic not exist!!!");
			}
			if(tag_tr[j] == 0){
				s[j][n_t[j]] = counts[s[j][n_t[j]] - 1];
				mxAssert(s[j][n_t[j]] > 0, "assignment topic not exist!!!");
			}
		}

		maxU = u[0];
		for(i = 1; i < nDP; i++){
			if(u[i] > maxU){
				maxU = u[i];
			}
		}
		SampleUS(-1, us, s);
		SampleL(L, us, s);
		
		CalIntegral(integral);
		
		K_tmp = K_t;
		for(i = 0; i < nS; i++){
			SimuPoiP(i, L[i], a, K_e, JJ_ex);
			//SimuPoiP_1(i, L[i], a, K_e, JJ_ex); //preferable
			//SimuPoiP_sub(i, L[i], a, K_e, JJ_ex);
			//SimuPoiP_sub_1(i, L[i], a, K_e, JJ_ex);
			if(output == 1){
				fprintf(fid, "#extra jumps: %d\n", K_e[i]);
			}else{
				printf("#extra jumps: %d\n", K_e[i]);
			}
			if(K_t + K_e[i] > K){
				while(K_t + K_e[i] > K){
					K *= 2;
				}
				mu = (double**)realloc(mu, K * sizeof(double*));
				sum_mu = (double*)realloc(sum_mu, K * sizeof(double));
				q = (double*)realloc(q, K * sizeof(double));
				count1 = (int*)realloc(count1, K * sizeof(int));
				count2 = (int*)realloc(count2, K * sizeof(int));
				holds = (int*)realloc(holds, K * sizeof(int));
				counts = (int*)realloc(counts, K * sizeof(int));
				nj = (int*)realloc(nj, K * sizeof(int));
				K_id = (int*)realloc(K_id, K * sizeof(int));
				for(k = K_t; k < K; k++){
					mu[k] = NULL;
					sum_mu[k] = 0;
					nj[k] = 0;
					K_id[k] = -1;
				}
				JJ = (double*)realloc(JJ, K * sizeof(double));
				for(k = K_t; k < K; k++){
					JJ[k] = 0;
				}
				prob = (double*)realloc(prob, K * sizeof(double));
				for(j = 0; j < nDP; j++){
					nj_t[j] = (int*)realloc(nj_t[j], K * sizeof(int));
					r[j] = (int*)realloc(r[j], K * sizeof(int));
					for(k = K_t; k < K; k++){
						nj_t[j][k] = 0;
						r[j][k] = 0;
					}
				}
			}
			for(k = K_t; k < K_t + K_e[i]; k++){
				JJ[k] = JJ_ex[k - K_t];
				K_id[k] = i;
			}
			K_t += K_e[i];
		}
		
		for(i = 0; i < nS; i++){
			sampleM_Gibbs(i, L[i], a, 0.1, 0.1);
		}

		//samplea_slice(K_t);
		
		/*
		 * store statistics to calculate ESS
		 */
#ifdef RECORD
		if(iter >= maxIter - 1000){
			llik[0] = likelihood(mu, sum_mu, s, data, n_t, nDP, mu0, V);
			fprintf(fidstat, "%f ", llik[0]);
			for(i = 0; i < nS; i++){
				fprintf(fidstat, "%f ", M[i]);
			}
			fprintf(fidstat, "%d\n", K_tmp);
		}
#endif
		
		/**** record output **********/
		if((iter + 1) > burnin && (iter + 1 - burnin)%lag == 0){
			int del;
			mxArray *MM;
			if(iter != maxIter - 1){
				del = 0;
			}else{
				del = 1;
			}

			mxSetField(struc, ns, "a", mxWriteScalar(a));
			mxSetField(struc, ns, "mu", mxWriteDoubleMatrix(K_tmp, V, K, mu, 0, del));
			mxSetField(struc, ns, "sum_mu", mxWriteDoubleVector(1, K_tmp, sum_mu, 0, del));
			mxSetField(struc, ns, "nj", mxWriteIntVector(1, K_tmp, nj, 0, del));
			mxSetField(struc, ns, "nj_t", mxWriteIntMatrix(nDP, K_tmp, nDP, nj_t, 0, del));
			mxSetField(struc, ns, "n_ts", mxWriteIntMatrix(nDP, nS, nDP, n_ts, 0, del));
			mxSetField(struc, ns, "Kid", mxWriteIntVector(1, K_tmp, K_id, 0, del));
			mxSetField(struc, ns, "s", mxWriteIntCellVector(nDP, n_t, s, 0, del));
			mxSetField(struc, ns, "n", mxWriteIntVector(1, nS, n, 0, del));
			mxSetField(struc, ns, "M", mxWriteDoubleVector(1, nS, M, 0, del));
			mxSetField(struc, ns, "u", mxWriteDoubleVector(1, nDP, u, 0, del));
			mxSetField(struc, ns, "us", mxWriteDoubleCellVector(nDP, n_t, us, 0, del));
			mxSetField(struc, ns, "L", mxWriteDoubleVector(1, nS, L, 0, del));
			mxSetField(struc, ns, "r", mxWriteIntMatrix(nDP, K_tmp, nDP, r, 0, del));
			mxSetField(struc, ns, "q", mxWriteDoubleVector(1, K_tmp, q, 0, del));
			ns++;
		}

	}

	free(prob);
	for(j = 0; j < nDP; j++){
		free(data[j]);
	}
	free(count1);
	free(count2);
	free(data);
	free(JJ);
	free(JJ_ex);
	free(n_t);
	free(n_t_all);
	freeRM(nS, nDP);
	free(holds);
	free(counts);
	free(K_c);
	free(K_ct);
	free(tag_tr);
	freePool();
	if(output == 1){
		mxFree(log_file);
		fclose(fid);
	}
	free(integral);
	free(inte_tmp);
#ifdef RECORD
	fclose(fidstat);
#endif
}

