function Samp = SGHMC_lda_e(params, x, data_heldout, data_test, trainfile)

doc_idx = [];
word_idx = [];
for i = 1:length(data_test)
	word_idx = [word_idx; data_test{i}];
	doc_idx = [doc_idx; i * ones(length(data_test{i}), 1)];
end
data_test = sparse(word_idx, doc_idx, ones(length(doc_idx), 1), params.V, length(data_heldout));

YflagTest = data_test > 0;
PhiThetaTest = zeros(params.V, length(data_heldout));

D = params.V;

%%% HMC setting
TrjLength = 1;
NLeap = 20;%5;
if isfield(params, 'Stepsz')
    Stepsz = params.Stepsz;
else
    Stepsz = TrjLength/NLeap/1000;%/500;
end

fid_train = fopen(trainfile, 'r');

%%% storage
NSamp = params.NSamp;
NBurnIn = params.NBurnIn;
Samp = zeros(10, D, params.K);

%%% Initialization
if nargin < 2 || isempty(x)
    x_ = rand(D, params.K);
	x = exp(x_);
    x = x ./ repmat(sum(x, 1), D, 1);
else
    x_ = 10 * exp(x);
end

p_w = randn(size(x_));
Leapfrog = 5;

if isfield(params, 'sto') && params.sto == 1
    zz = cell(1, params.batch);
end

if params.test == true
    do_test = 1;
    ntest = length(data_heldout);
    nk_t = zeros(params.K, ntest);
    zz_t = cell(1, ntest);
    for j = 1:ntest
        zz_t{j} = randi(params.K, 1, length(data_heldout{j}));
        for w = 1:length(zz_t{j})
            nk_t(zz_t{j}(w), j) = nk_t(zz_t{j}(w), j) + 1;
        end
    end
    nsample = 0;
    theta_t = zeros(params.K, ntest);
end

disp(['SGLD with n_run = ' num2str(params.n_run) ', burnin = ' num2str(params.burnin) ', batch = ' num2str(params.batch)]);

for Iter = 1:NSamp
	
    %%% display online acceptance rate per 100 iteration
    if mod(Iter, 100)==0
        disp(['Iteration ' num2str(Iter) ' completed!']);
    end
    
    if isfield(params, 'sto') && params.sto == 1

        params.X = GetBatch(fid_train, params.batch, trainfile);
        
        nk = zeros(params.K, params.batch);
        NN = zeros(params.V, params.K);
        tmp = NN;
        for j = 1:params.batch
            zz{j} = randi(params.K, 1, length(params.X{j}));
            for w = 1:length(zz{j})
                nk(zz{j}(w), j) = nk(zz{j}(w), j) + 1;
                NN(params.X{j}(w), zz{j}(w)) = NN(params.X{j}(w), zz{j}(w)) + 1;
            end
        end
        for ir = 1:params.n_run
            for i = 1:params.batch
                for w = 1:length(params.X{i})
                    wd = params.X{i}(w);
                    z = zz{i}(w);
                    nk(z, i) = nk(z, i) - 1;
                    NN(wd, z) = NN(wd, z) - 1;
                    prob = (nk(:, i) + params.alpha) .* x(wd, :)';
                    z = randix(prob);
                    zz{i}(w) = z;
                    nk(z, i) = nk(z, i) + 1;
                    NN(wd, z) = NN(wd, z) + 1;
                end
            end
            if ir > params.burnin
                tmp = tmp + NN;
            end
        end
        params.NN = tmp / (params.n_run - params.burnin) * params.N / params.batch;
    else
        for i = 1:params.N
            for w = 1:length(params.X{i})
                wd = params.X{i}(w);
                z = params.z{i}(w);
                params.n(z, i) = params.n(z, i) - 1;
                params.NN(wd, z) = params.NN(wd, z) - 1;
                prob = (params.n(:, i) + params.alpha) .* x(wd, :)';
                z = randix(prob);
                params.z{i}(w) = z;
                params.n(z, i) = params.n(z, i) + 1;
                params.NN(wd, z) = params.NN(wd, z) + 1;
            end
        end
    end
    
    if Iter==NBurnIn
        disp('Burn in completed!');
    end
   
	grad = (params.NN - repmat(sum(params.NN, 1), params.V, 1) .* x) + params.a - exp(x_);     
    for ll = 1:Leapfrog
	   x_ = x_ + p_w * Stepsz;
	   p_w = (1 - params.B * Stepsz) * p_w + grad * Stepsz + sqrt(2 * params.B * Stepsz) * randn(size(p_w));
    end
	x = x_ - repmat(max(x_), params.V, 1);
	x = exp(x);
    x = x ./ repmat(sum(x, 1), params.V, 1);
    
    %%% save sample beta
    if Iter > NSamp - 10
        for k = 1:params.K
            Samp(NSamp-Iter+1, :, k) = x(:, k)';
        end
    end
    
    if params.test == true && Iter > NBurnIn - 50
        for i = 1:ntest
            for w = 1:length(data_heldout{i})
                wd = data_heldout{i}(w);
                z = zz_t{i}(w);
                nk_t(z, i) = nk_t(z, i) - 1;
                prob = (nk_t(:, i) + params.alpha) .* x(wd, :)';
                z = randix(prob);
                zz_t{i}(w) = z;
                nk_t(z, i) = nk_t(z, i) + 1;
            end
        end
        
        %%% calculate perplexity
        if Iter > NBurnIn && mod(Iter, 2) == 0
            nsample = nsample + 1;
            theta_t = (nk_t + params.alpha) ./ repmat(sum(nk_t, 1) ...
                + params.alpha_s, params.K, 1);
            PhiThetaTest = PhiThetaTest + x * theta_t;
            theta_tn = PhiThetaTest / nsample;
            perp = sum(data_test(YflagTest).*log(theta_tn(YflagTest)));

            perp = exp(-perp / params.nw_test);
            disp(['Test perplexity: ' num2str(perp)]);
        end
    end
    
end
