%%%%%% C implementation for z, t %%%%%%%%%%%%%

function [samples, lik_tr, lik_te, elapse, elapse_tot] = sampler_SNGG_slice(data, M, q, burnin, every, ...
    numbofits, initK, v, h, dolik, a, maxExJ, tag_tr, pred, model, output, init, samp_init, log_file, q0)

elapse = 1;

if nargin < 17
    init = 0;
end

mu0 = h;
[nDP, nS] = size(q);

if nargin < 20
    q0 = rand(nDP, nS);
end

if init == 1
    M = samp_init.M;
elseif length(M) ~= nS
	M = M(1) * ones(1, nS);
end

n_t = zeros(1, nDP);
n_t_all = zeros(1, nDP);
for i = 1:nDP
	if tag_tr(i) == 1
    	n_t(i) = length(data{i});
    	n_t_all(i) = length(data{i});
	else
		n_t_all(i) = length(data{i});
		n_t(i) = ceil(n_t_all(i) / 2);
	end
end
%     a = samp_init.a;
%     q = samp_init.q;
%     for i = 1:size(q, 1)
%         q(i, :) = q(i, :) / max(q(i, :));
%     end
if init == 1
    a = samp_init.a;
    q = samp_init.q;
    for i = 1:size(q, 1)
        q(i, :) = q(i, :) / max(q(i, :));
    end
    empidx = find(samp_init.sum_mu > 0);
    s = samp_init.s;
    if isfield(samp_init, 'Kid')
        K_id = samp_init.Kid(empidx) + 1;
    else
        K_id = ceil(rand(1, length(empidx)) * nS);
    end
    for i = 1:nDP
        if tag_tr(i) == 0
            s{i} = [s{i} ceil(rand * length(empidx))];
        end
    end
else
    if length(initK) ~= nS
        initK = initK(1) * ones(1, nS);
    end
    % Set up for calculating predictive or not
    s = cell(1, nDP);
    cumK = zeros(1, nS + 1);
    cumK(2:nS+1) = cumsum(initK);
    K_id = zeros(1, sum(initK));
    for i = 1:nS
        K_id(cumK(i)+1:cumK(i+1)) = i;
    end

    if isempty(find(q == 0))
        for i = 1:nDP
            if tag_tr(i) == 0
                s{i} = ceil(rand(1, n_t(i) + 1) * sum(initK));        
            else
                s{i} = ceil(rand(1, n_t(i)) * sum(initK));
            end
        end
    else
        for i = 1:nDP
            kk = initK .* double(q(i, :) > 0);
            f = 1;
            k1 = [];
            for j = 1:length(kk)
                if kk(j) > 0
                    k1(f:f+kk(j)-1) = cumK(j)+1:cumK(j+1);
                    f = f + kk(j);
                end
            end
            if tag_tr(i) == 0
                s{i} = k1(ceil(rand(1, n_t(i) + 1) * (f - 1)));
            else
                s{i} = k1(ceil(rand(1, n_t(i)) * (f - 1)));
            end
            assert(min(s{i}) > 0 && max(s{i}) <= sum(initK));
        end
    end
end

if init == 1
    mu = samp_init.mu(empidx, :);
    sum_mu = samp_init.sum_mu(empidx);
else
    kstar = initK;
    % Set jump sizes and parameters of components
    mu = zeros(sum(kstar), v);
    for j = 1:nDP
        for i = 1:n_t(j)
            mu(s{j}(i), data{j}(i)) = mu(s{j}(i), data{j}(i)) + 1;
        end
    end
    sum_mu = sum(mu, 2)';
end

nj = sum_mu;
if pred == 1
    for i = 1:nDP
        if tag_tr(i) == 0
            nj(s{i}(n_t(i) + 1)) = nj(s{i}(n_t(i) + 1)) + 1;
        end
    end
end
n = zeros(1, nS);
n_ts = zeros(nDP, nS);
for k = 1:length(nj)
	if K_id(k) >= 1
		n(K_id(k)) = n(K_id(k)) + nj(k);
	end
end
nj_t = zeros(nDP, length(nj));
K_c = zeros(1, nDP+1);
K_ct = zeros(1, nS);
for i = 1:nDP
	for j = 1:length(nj)
		if K_id(j) > 0
			nj_t(i, j) = length(find(s{i} == j));
			n_ts(i, K_id(j)) = n_ts(i, K_id(j)) + nj_t(i, j);
			if nj_t(i, j) > 0
				K_c(i) = K_c(i) + 1;
			end
		end
	end
end
K_c(nDP + 1) = length(nj);
for k = 1:length(nj)
	if K_id(k) > 0 && nj(k) > 0
		K_ct(K_id(k)) = K_ct(K_id(k)) + 1;
	end
end

K_id = K_id - 1;

%load u.mat;
u = 10 * ones(1, nDP);

tic

if strcmp(model, 'MNGG_slice') == 1
    %%% slice sampling MNGG
    [samples, lik_tr, lik_te] = sample_zt_MNGG_slice(mu, sum_mu, s, M, q, n, n_t, ...
        n_ts, nj, nj_t, K_c, K_ct, K_id, mu0, burnin, numbofits, every, dolik, ...
        data, a, gamma(1 - a), maxExJ, u, n_t_all, tag_tr, output, log_file);
elseif strcmp(model, 'TNGG_slice') == 1
    %%% slice sampling TNGG
    [samples, lik_tr, lik_te, elapse] = sample_zt_TNGG_slice(mu, sum_mu, s, M, q, n, n_t, ...
        n_ts, nj, nj_t, K_c, K_ct, K_id, mu0, burnin, numbofits, every, dolik, ...
        data, a, gamma(1 - a), maxExJ, u, n_t_all, tag_tr, output, log_file, q0);
elseif strcmp(model, 'TNGG_q') == 1
    %%% slice sampling TNGG with prior on q
    [samples, lik_tr, lik_te] = sample_zt_TNGG_q(mu, sum_mu, s, M, q, n, n_t, ...
        n_ts, nj, nj_t, K_c, K_ct, K_id, mu0, burnin, numbofits, every, dolik, ...
        data, a, gamma(1 - a), maxExJ, u, n_t_all, tag_tr, output, log_file);
elseif strcmp(model, 'TNGG_ftm') == 1
    %%% each atom in different times share a subsampling rate
    q = 0.9 * ones(1, length(K_id));
    [samples, lik_tr, lik_te] = sample_zt_TNGG_ftm(mu, sum_mu, s, M, q, n, n_t, ...
        n_ts, nj, nj_t, K_c, K_ct, K_id, mu0, burnin, numbofits, every, dolik, ...
        data, a, gamma(1 - a), maxExJ, u, n_t_all, tag_tr, output, log_file);
elseif strcmp(model, 'MNGG_tilt') == 1
    %%% q_rt ~ Gamma(aq, bq),   r_{trk} ~ Gamma(q_rk, 1),   mu_t \prop
    %%% r_{trk}w_{rk}\delta_rk
    q = 0.9 * ones(1, length(K_id));
    [samples, lik_tr, lik_te] = sample_zt_MNGG_tilt(mu, sum_mu, s, M, q, n, n_t, ...
        n_ts, nj, nj_t, K_c, K_ct, K_id, mu0, burnin, numbofits, every, dolik, ...
        data, a, gamma(1 - a), maxExJ, u, n_t_all, tag_tr, output, log_file);
end

elapse_tot = toc;

for i = 1:length(samples)
	samples(i).gamma = h;
	samples(i).V = v;
    samples(i).model = model;
end
