'''
    ab_ucf50_groups.py

    Jason Corso

    Train and test an SVM on the UCF 50 data. 
    Does group-wise 5-fold cross-validation to avoid mixing groups between train-test sets.
    
    The processed UCF 50 data is available at 
    http://www.cse.buffalo.edu/~jcorso/r/actionbank

    MAKE sure that ../code is in your PYTHONPATH, i.e., export PYTHONPATH=../code
    before running this script
'''

import argparse
import glob
import gzip
import numpy as np
import os
import os.path
import random as rnd
import scipy.io as sio
import multiprocessing as multi


from actionbank import *
import ab_svm


def testgroups(groups,training,testing):
    # main group-wise testing routine

    print "Training", training
    print "Testing", testing

    fp = gzip.open(groups[0][0][0])
    vlen = len(np.load(fp))
    fp.close()
    print "vector length is %d"%vlen

    nt = 0
    for i in training:
        nt += len(groups[i])
    print "have %d training files"%nt
    Dtrain = np.zeros( (nt,vlen), np.uint8 )
    Ytrain = np.ones ( (nt) ) * -1000

    ti = 0
    for i in training:
        for j in groups[i]:
            # j[0] is path and j[1] is class
            fp = gzip.open(j[0])
            Dtrain[ti][:] = np.load(fp)
            fp.close()
            Ytrain[ti] = j[1]
            ti += 1

    ne = 0
    for i in testing:
        ne += len(groups[i])
    print "have %d testing files"%ne
    Dtest = np.zeros( (ne,vlen), np.uint8 )
    Ytest = np.ones ( (ne) ) * -1000

    ti = 0
    for i in testing:
        for j in groups[i]:
            # j[0] is path and j[1] is class
            fp = gzip.open(j[0])
            Dtest[ti][:] = np.load(fp)
            fp.close()
            Ytest[ti] = j[1]
            ti += 1

    res=ab_svm.SVMLinear(Dtrain,np.int32(Ytrain),Dtest,threads=multi.cpu_count()-1,useLibLinear=True,useL1R=False)
    tp=np.sum(res==Ytest)
    print 'Accuracy is %.1f%%' % ((np.float64(tp)/Dtest.shape[0])*100)
    return ((np.float64(tp)/Dtest.shape[0])*100)







if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="Script to perform 5-fold group-wise cross-validation on the UCF 50 data set using the included SVM code.", 
             formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("root", help="path to the directory containing the action bank processed ucf 50 files structured as in root/class/name_g00_c00_banked.npy.gz for each class")

    args = parser.parse_args()

    vlen = 0

    cdir = os.listdir(args.root)

    if (len(cdir) != 50):
        print "error: found %d classes, but there should be 50"%(len(cdir))

    groups = []
    for g in range(1,26):
        gset = []
        for ci,cl in enumerate(cdir):
            files = glob.glob(os.path.join(args.root,cl,'*g%02d*%s'%(g,banked_suffix)))
            for f in files:
                gset.append( [f,ci] )
        print "group %d has %d"%(g,len(gset))
        groups.append(gset)

    full = np.arange(25)

    sets = []
    sets.append(np.arange(0,5))
    sets.append(np.arange(5,10))
    sets.append(np.arange(10,15))
    sets.append(np.arange(15,20))
    sets.append(np.arange(20,25))

    accs = np.zeros((5))
    for i in range(5):
        accs[i] = testgroups(groups,np.setdiff1d(full,sets[i]),sets[i])

    print "mean accuracy is %f"%(accs.mean())

















