
[p,ntrain] = size(Vtrain); [~,ntest] = size(Vtest);

nlayer = length(KK) - 1;

burnin = 6;
ns = 5;
H_s = cell(nlayer, 1);
for i = 1:nlayer
    H_s{i} = cell(ns, 1);
end

%% initialize W and H
if initW == false
    W = cell(1, nlayer);
    c = cell(1, nlayer);
    for i = 1:nlayer
        W{i} = 0.1 * randn(KK(i), KK(i+1));
        c{i} = 0.1 * randn(KK(i), 1);
    end
    b = 0.1 * randn(KK(nlayer+1), 1);
else
    W = WW;
    c = CC;
    b = B;
end

%%% Momentum
p_w = cell(1, nlayer);
p_c = cell(1, nlayer);
for i = 1:nlayer
    p_w{i} = randn(size(W{i}));
    p_c{i} = randn(size(c{i}));
end
p_b = randn(size(b));

TrainAcc = zeros(1, maxIter); TestAcc = zeros(1, maxIter);
TotalTime = zeros(1, maxIter);
TrainLogProb = zeros(1, maxIter); 
TestLogProb = zeros(1, maxIter);

Leapfrog = 50;
sigma = 1;

Htrain = cell(nlayer, 1);
Htest = cell(nlayer, 1);

gammaTrain = cell(nlayer, 1);
gammaTest = cell(nlayer, 1);

gradw = cell(nlayer, 1);

%%% For AIS
model_gibbs=sbn();
model_gibbs.AIS_Samples = 1000;
model_gibbs.burnin=3;
model_gibbs.gibbsSamples=4;
initialParameters = cell(1, 3);
initialParameters{1} = 0;
initialParameters{2} = 0;
initialParameters{3} = 0;
nais = 1;
AIS = [];
%%%%

disp(['SBN with SGHMC, J = ' num2str(KK(end)) ', h = ' num2str(h) ', C = ' num2str(C)]);

for iter = 1:maxIter

    dotest = (iter > test_burnin && mod(iter, 3) == 0);
    inittest = iter == test_burnin + 1;

    pp = randperm(ntrain);
    Vtrain_s = Vtrain(:, pp(1:batch));
    
    prob = 1./(1+exp(-b));
    Htrain{nlayer} = +(repmat(prob, 1, batch) > rand(KK(nlayer+1), batch))';
    for nl = nlayer-1:-1:1
        X = bsxfun(@plus, W{nl+1} * (Htrain{nl+1}'), c{nl+1}); 
        prob = 1 ./ (1+exp(-X)); 
        Htrain{nl} = +(prob >= rand(KK(nl+1), batch))';
    end
    
    if inittest == true
        prob = 1./(1+exp(-b));
        Htest{nlayer} = +(repmat(prob, 1, ntest) > rand(KK(nlayer+1), ntest))';
        for nl = nlayer-1:-1:1
            X = bsxfun(@plus, W{nl+1} * (Htest{nl+1}'), c{nl+1}); 
            prob = 1 ./ (1+exp(-X)); 
            Htest{nl} = +(prob >= rand(KK(nl+1), ntest))';
        end
    end

    % initialize gibbs sampler parameters
    maxit = burnin + ns;

    %% Gibbs sampling
    fprintf('Start Gibbs sampling for iteration %d...\n', iter);
    tic;
    
    for ii = 1:maxit
        if ii > burnin
            j = ii - burnin;
            for nl = 1:nlayer
                H_s{nl}{j} = Htrain{nl};
            end
        end

        % 1. update gamma0
        for nl = 1:nlayer
            Xmat = bsxfun(@plus, W{nl} * (Htrain{nl}'), c{nl});
            ndim = KK(nl);
        	Xvec = reshape(Xmat, ndim * batch, 1);
        	gamma0vec = PolyaGamRndTruncated(ones(ndim * batch, 1), Xvec, 20);
        	gammaTrain{nl} = reshape(gamma0vec, ndim, batch);
        end

        % 3. update H
        for nl = 1:nlayer
        	res = W{nl} * (Htrain{nl}');
            for k = 1:KK(nl+1)
                res = res - W{nl}(:, k) * (Htrain{nl}(:, k)');
            	mat1 = bsxfun(@plus, res, c{nl});
                if nl == 1
                    vec1 = sum(bsxfun(@times, Vtrain_s - 0.5 - gammaTrain{nl} .* mat1, W{nl}(:, k)));
                else
                    vec1 = sum(bsxfun(@times, Htrain{nl-1}' - 0.5 - gammaTrain{nl} .* mat1, W{nl}(:, k)));
                end
                vec2 = sum(bsxfun(@times, gammaTrain{nl}, W{nl}(:, k).^2)) / 2;
                if nl == nlayer
                    logz = vec1 - vec2 + b(k);
                else
                    logz = vec1 - vec2 + W{nl+1}(k, :) * (Htrain{nl+1}') + c{nl+1}(k);
                end
                probz = 1 ./ (1 + exp(-logz));
                Htrain{nl}(:, k) = (probz > rand(1, batch))';
                res = res + W{nl}(:, k) * (Htrain{nl}(:, k)');
            end
        end
        
        if ii < 3 && dotest == true
            for nl = 1:nlayer
                Xmat = bsxfun(@plus, W{nl} * (Htest{nl}'), c{nl});
                ndim = KK(nl);
                Xvec = reshape(Xmat, ndim * ntest, 1);
                gamma0vec = PolyaGamRndTruncated(ones(ndim * ntest, 1), Xvec, 20);
                gammaTest{nl} = reshape(gamma0vec, ndim, ntest);
            end

            for nl = 1:nlayer
                res = W{nl} * (Htest{nl}');
                for k = 1:KK(nl+1)
                    res = res - W{nl}(:, k) * (Htest{nl}(:, k)');
                    mat1 = bsxfun(@plus, res, c{nl});
                    if nl == 1
                        vec1 = sum(bsxfun(@times, Vtest - 0.5 - gammaTest{nl} .* mat1, W{nl}(:, k)));
                    else
                        vec1 = sum(bsxfun(@times, Htest{nl-1}' - 0.5 - gammaTest{nl} .* mat1, W{nl}(:, k)));
                    end
                    vec2 = sum(bsxfun(@times, gammaTest{nl}, W{nl}(:, k).^2)) / 2;
                    if nl == nlayer
                        logz = vec1 - vec2 + b(k);
                    else
                        logz = vec1 - vec2 + W{nl+1}(k, :) * (Htest{nl+1}') + c{nl+1}(k);
                    end
                    probz = 1 ./ (1 + exp(-logz));
                    Htest{nl}(:, k) = (probz > rand(1, ntest))';
                    res = res + W{nl}(:, k) * (Htest{nl}(:, k)');
                end
            end
        end
    end
    
    %%% SGHMC
    fprintf('Start SGHMC for iteration %d...\n', iter);
    eta = eta_s;
    
    mat1 = [];
    for nl = 1:nlayer
        mat = cell2mat(H_s{nl});
        gradw{nl} = zeros(KK(nl), KK(nl+1) + 1);
        tmp = [W{nl}, c{nl}] * ([mat, ones(batch * ns, 1)]');
        idx = tmp > 10;
        tmp(idx) = 1 ./ (1 + exp(-tmp(idx)));
        idx = ~idx;
        tmp(idx) = exp(tmp(idx));
        tmp(idx) = tmp(idx) ./ (1 + tmp(idx));
        xx = [mat, ones(batch*ns, 1)];
        Kn = KK(nl);
        for j = 1:Kn
            prodw = sum(repmat(tmp(j, :)', 1, size(mat, 2) + 1) .* [mat, ones(batch * ns, 1)], 1);
            gradw{nl}(j, :) = gradw{nl}(j, :) - prodw;
            if nl == 1
                idx = repmat(Vtrain_s(j, :)', ns, 1) == 1;
            else
                idx = mat1(:, j) == 1;
            end
            gradw{nl}(j, :) = gradw{nl}(j, :) + sum(xx(idx, :), 1);
        end
        mat1 = mat;

        gradw{nl} = gradw{nl} * ntrain / batch / ns - sigma * [W{nl}, c{nl}];
        assert(isfinite(sum(gradw{nl}(:))));
    end
    
    tmp = mean(cell2mat(H_s{nlayer}), 1);
	idxb = b > 10;
	tmpb = zeros(size(b));
	tmpb(idxb) = 1 ./ (1 + exp(-b(idxb)));
	idxb = ~idxb;
	tmpb(idxb) = exp(b(idxb));
	tmpb(idxb) = tmpb(idxb) ./ (1 + tmpb(idxb));
    gradb = ntrain * (tmp' - tmpb) - sigma * b;
 
    for ll = 1:Leapfrog
        for nl = 1:nlayer
            W{nl} = W{nl} + p_w{nl} * h;
            c{nl} = c{nl} + p_c{nl} * h;

            p_w{nl} = (1 - C * h) .* p_w{nl} + gradw{nl}(:, 1:size(p_w{nl}, 2)) * h + ...
                sqrt(2 * C * h) * randn(size(p_w{nl}));
            assert(isfinite(sum(p_w{nl}(:))));
            p_c{nl} = (1 - C * h) .* p_c{nl} + gradw{nl}(:, end) * h + ...
                sqrt(2 * C * h) * randn(size(p_c{nl}));
            %assert(isfinite(sum(p_c{nl}(:))));

        end
        b = b + p_b * h;
        p_b = (1 - C * h) .* p_b + gradb * h + sqrt(2 * C * h) * randn(size(p_b));

    end
	%fprintf('sum_c1 = %f\n', sum(c{1}));

    if iter > test_burnin
        initialParameters{1} = (initialParameters{1} * (nais-1) + W{1}) / nais;
        initialParameters{2} = (initialParameters{2} * (nais-1) + c{1}) / nais;
        initialParameters{3} = (initialParameters{3} * (nais-1) + b) / nais;
        nais = nais + 1;
    end
    if iter == maxIter
        logpv_AIS=model_gibbs.functionEvaluation(Vtest,initialParameters);
        disp(['Loglikehood with AIS ', num2str(logpv_AIS)]);
        AIS = [AIS, logpv_AIS];
    end
    
    
    mat = bsxfun(@plus,W{1}*Htrain{1}',c{1});
    TrainLogProb(iter) = sum(sum(mat.*Vtrain_s-log(1+exp(mat))))/batch;
    
    if dotest == true
        mat = bsxfun(@plus,W{1}*Htest{1}',c{1});
        TestLogProb(iter) = sum(sum(mat.*Vtest-log(1+exp(mat))))/ntest;
	else
		TestLogProb(iter) = 0;
    end
    
    TotalTime(iter) = toc;
    if mod(iter,1)==0
        disp(['Iteration: ' num2str(iter) ' Acc: ' num2str(TrainAcc(iter)) ' ' num2str(TestAcc(iter))...
            ' LogProb: ' num2str(TrainLogProb(iter))  ' ' num2str(TestLogProb(iter))...
             ' Totally spend ' num2str(TotalTime(iter))]);
%         index = randperm(ntrain);
%         figure(111);
%         dispims(VtrainRecons(:,index(1:100)),28,28);

%        figure(222);
%        imagesc(W{nlayer+1}); colorbar;

%        figure(222);
%        subplot(2,2,1); imagesc(W{1}); colorbar;
%        subplot(2,2,2); imagesc(Htrain{1}); colorbar;
%        subplot(2,2,3); bar(c{1});
%        drawnow; hold off
        
        if dispimg == 1
            figure(333);
            [~,index_d] = sort(sum(W{1}.^2, 1),'descend');
            dispims(W{1}(:, index_d(1:min(100, KK(2)))),28,28);
            drawnow;
        end
    end
end

result.TrainAcc = TrainAcc; 
result.TestAcc = TestAcc;
result.TotalTime = TotalTime;
result.TrainLogProb = TrainLogProb; 
result.TestLogProb = TestLogProb;


