classdef sbn
    properties
        gibbsSamples=10;
        burnin=5;
        AIS_Samples=1000;
        gpu=false;
    end
    methods
        function obj=sbn()
            % parameter 1 is W
            % parameter 2 is vbias
            % parameter 3 is hbias
        end
        function initialParameters=initializeParameters(obj,M,J)
            initialParameters{1}=.1*randn(M,J); %% connection weights
            initialParameters{2}=zeros(M,1); %% visible bias term
            initialParameters{3}=zeros(J,1); %% hidden bias term
        end
        
        function [v,vtest,parameters]=makeData(obj,J,M,N,Ntest,sig2)
            if nargin<6
                sig2=2;
            end
            btrue=zeros(J,1);
            htrue=double(bsxfun(@lt,rand(J,N),sigmoid(btrue)));
            Wtrue=sig2*randn(M,J);
            ctrue=zeros(M,1);
            v=double(rand(M,N)<sigmoid(bsxfun(@plus,ctrue,Wtrue*htrue)));
            htest=double(bsxfun(@lt,rand(J,Ntest),sigmoid(btrue)));
            vtest=double(rand(M,Ntest)<sigmoid(bsxfun(@plus,ctrue,Wtrue*htest)));
            parameters{1}=Wtrue;
            parameters{2}=ctrue;
            parameters{3}=btrue;
        end
        
        
        
        function [gradient,gradVarEst,hVB]=variationalGradient(obj,v,parameters,hVB)
            gradVarEst=[];
            gpu=obj.gpu;
            N=size(v,2);
            if nargin<4
                hVB=variationalHiddenUnits(v,parameters,gpu);
            end
            MCI=20;
            W=parameters{1};
            c=parameters{2};
            b=parameters{3};
            %             if gpu
            %                 W=gpuArray(parameters{1});
            %                 c=gpuArray(parameters{2});
            %                 b=gpuArray(parameters{3});
            %                 v=gpuArray(v);
            %             else
            %                 W=parameters{1};
            %                 c=parameters{2};
            %                 b=parameters{3};
            %             end
            [M,J]=size(parameters{1});
            term2=zeros(M,J);
            vb=zeros(M,1);
            
            for iter=1:MCI
                h=double(rand(size(hVB))<hVB);
                [lpv,pv]=calcPy(W,h,c);
                vb=vb+mean(pv,2);
                term2=term2+pv*h';
            end
            %%
            gradient{1}=(v*hVB'-term2./MCI)/N;
            gradient{2}=mean(v,2)-vb/MCI;
            gradient{3}=mean(hVB,2)-sigmoid(b);
        end
        
        
        function [logpv,logweights]=functionEvaluation(obj,v,parameters)
            gpu=obj.gpu;
            N=size(v,2);
            if gpu
                W=gpuArray(parameters{1});
                c=gpuArray(parameters{2});
                b=gpuArray(parameters{3});
                v=gpuArray(v);
                logweights=gpuArray.zeros(1,N);
            else
                W=parameters{1};
                c=parameters{2};
                b=parameters{3};
                logweights=zeros(1,N);
            end
            Samples=obj.AIS_Samples;
            
            
            %% collect
            
            if gpu
                Temps=gpuArray(linspace(0,1,Samples+1));
                [h,pv]=HgibbsSamplerGPU(W*0,b*0,c*0,v*0,1);
            else
                Temps=linspace(0,1,Samples+1);
                [h,pv]=HgibbsSampler(W*0,b*0,c*0,v*0,1);
            end
            
            Wo=W*0;
            bo=b*0;
            co=c*0;
            for iter=2:Samples+1
                Wt=W*Temps(iter);
                bt=b*Temps(iter);
                ct=c*Temps(iter);
                if gpu
                    [h,pv]=HgibbsSamplerGPU(Wt,bt,ct,v,1,h);
                else
                    [h,pv]=HgibbsSampler(Wt,bt,ct,v,1,h);
                end
                logweights=logweights+calcNegEnergy(v,h,Wt,bt,ct)-calcNegEnergy(v,h,Wo,bo,co);
                Wo=Wt;
                bo=bt;
                co=ct;
            end
            if gpu
                logweights=gather(logweights);
            end
            
            
            logpv=-size(v,1)*log(2)+mean(logweights);
        end
        
        function lowerBound=variationalFunctionEvaluation(obj,v,parameters,hVB)
            %%
            gpu=obj.gpu;
            if nargin<4
                hVB=variationalHiddenUnits(v,parameters,gpu);
            end
            W=parameters{1};
            c=parameters{2};
            b=parameters{3};
            MCI=40;
            N=size(v,2);
            mE=zeros(1,N);
            for iter=1:MCI
                h=double(rand(size(hVB))<hVB);
                mE=mE+calcNegEnergy(v,h,W,b,c);
            end
            offset=1e-8;
            entropy=dot(hVB,log(hVB+offset))+dot((1-hVB),log(1-hVB+offset));
            lowerBoundAll=mE./MCI-entropy;
            lowerBound=mean(lowerBoundAll);
            if isnan(lowerBound)
                1;
            end
            
        end
        
        function h=getPosterior(obj,v,parameters,samples)
            if nargin<4
                samples=obj.burnin;
            end
            gpu=obj.gpu;
            %gpu=false;
            if gpu
                W=gpuArray(parameters{1});
                c=gpuArray(parameters{2});
                b=gpuArray(parameters{3});
                v=gpuArray(v);
            else
                W=parameters{1};
                c=parameters{2};
                b=parameters{3};
            end
            N=size(v,2);
            %% burnin
            if gpu
                h=HgibbsSamplerGPU(W,b,c,v,samples);
            else
                h=HgibbsSampler(W,b,c,v,samples);
            end
            
        end
    end
end

function hbar=variationalHiddenUnits(v,parameters,gpu)
if nargin<3
    gpu=false;
end
if gpu
    W=gpuArray(parameters{1});
    c=gpuArray(parameters{2});
    b=gpuArray(parameters{3});
    v=gpuArray(v);
else
    W=parameters{1};
    c=parameters{2};
    b=parameters{3};
end
% W=parameters{1};
% c=parameters{2};
% b=parameters{3};
[M,N]=size(v);
mci=5;
J=numel(b);
h=repmat(sigmoid(b),1,N);
variationalpasses=5;
for g=1:variationalpasses
    jset=randperm(J);
    for j=jset
        ho=h(j,:);
        term1=W(:,j)'*v+b(j);
        term2=zeros(size(term1));
        for m=1:mci
            if gpu
                hs=double(bsxfun(@lt,gpuArray.rand(size(h)),h));
            else
                hs=double(bsxfun(@lt,rand(size(h)),h));
            end
            term2=term2+calcTerm2(W,c,hs,j);
        end
        r=term1+term2/mci;
        h(j,:)=.5*(ho+sigmoid(r));
    end
end
if gpu
    hbar=gather(h);
else
    hbar=h;
end
end

function [nE]=calcNegEnergy(v,h,W,b,c)
nE=c'*v+dot(v,W*h)+b'*h-sum(log(1+exp(bsxfun(@plus,W*h,c))))-sum(log(1+exp(b)));

end

function [h,pv]=HgibbsSampler(W,b,c,v,gibbsSamples,h)
[M,N]=size(v);
J=numel(b);
if nargin<6;
    h=double(bsxfun(@lt,rand(J,N),sigmoid(b)));
end
for g=1:gibbsSamples
    jset=randperm(J);
    for j=jset
        term1=W(:,j)'*v+b(j);
        term2=calcTerm2(W,c,h,j);
        r=term1+term2;
        h(j,:)=rand(1,N)<sigmoid(r);
    end
end
if nargout>1
    [~,pv]=calcPy(W,h,c);
end
end

function [h,pv]=HgibbsSamplerGPU(W,b,c,v,gibbsSamples,h)
[M,N]=size(v);
J=numel(b);
if nargin<6;
    h=double(bsxfun(@lt,gpuArray.rand(J,N),sigmoid(b)));
end
for g=1:gibbsSamples
    jset=randperm(J);
    for j=jset
        term1=W(:,j)'*v+b(j);
        term2=calcTerm2(W,c,h,j);
        r=term1+term2;
        h(j,:)=gpuArray.rand(1,N)<sigmoid(r);
    end
end

if nargout>1
    [~,pv]=calcPy(W,h,c);
end


end

function t2=calcTerm2(W,c,h,j)
%%
% tic
% j=gpuArray(j);
h(j,:)=0;
T=bsxfun(@plus,W*h,c);

EmT=exp(-T);
t2=sum(log(EmT+1))-sum(log(bsxfun(@plus,EmT,exp(W(:,j)))));
% EmT=exp(-T);

% t2down=-sum(log(1+exp(T)));
% t2down=-sum(log(1+EmT));
% h(j,:)=1;
% t2up=-sum(log(1+exp(bsxfun(@plus,T,W(:,j)))));
% t2up=-sum(log(bsxfun(@plus,EmT,exp(W(:,j)))));
% t2=t2up-t2down;
% q=t2;
% toc
%%
% tic
% j=gpuArray(j);
% h(j,:)=0;
% T=bsxfun(@plus,W*h,c);
% q2=t2;
% % toc
% %%
% % tic
% j=gpuArray(j);
% h(j,:)=0;
% T=bsxfun(@plus,W*h,c);
% EmT=exp(-T);
% t2=sum(log(EmT+1)./bsxfun(@plus,EmT,exp(W(:,j))));
% q3=t2;
% toc
end

function [lpv,pv]=calcPy(W,h,c)
if nargin<3
    c=zeros(size(W,1),1);
end
pv=(sigmoid(bsxfun(@plus,W*h,c)));
lpv=log(pv);
end
