/*
 * Decompiled with CFR 0.152.
 */
package control_injector;

import control_injector.ByteCodeWriter;
import control_injector.JiveControlLogger;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.lang.instrument.ClassFileTransformer;
import java.lang.instrument.IllegalClassFormatException;
import java.lang.instrument.Instrumentation;
import java.lang.reflect.Method;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.ProtectionDomain;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.logging.Logger;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;

public class ControlInjectorAgent
implements Opcodes {
    private static String FQinjectorClass = "control_injector.ControlInjector_InMemory";
    private static String FQinjectorClassSl = "control_injector/ControlInjector_InMemory";
    private static String FQClassName;
    private static String FQClassNameSl;
    private static String ExecutionTraceFilePath;
    private static HashSet<String> packagesToInstrument;
    private static HashSet<String> ClassesToInstrument;
    private static HashSet<String> methodsToInstrument;
    private static HashSet<String> classes_methodsToInstrument;
    private static String instrumentationLevel;
    private static boolean debug;
    private static int write_count;
    private static String debug_string;
    private static boolean sysEnd;
    private static final Logger LOGGER;

    public static void premain(String agentArgs, Instrumentation inst) {
        try {
            JiveControlLogger.setup();
            if (debug) {
                LOGGER.info("Control Agent Args : " + agentArgs);
            }
            if (agentArgs != null) {
                int i;
                String[] args = agentArgs.split(";");
                for (i = 0; i < args.length; ++i) {
                    args[i] = args[i].trim();
                }
                if (debug) {
                    LOGGER.info("args length = " + args.length);
                    if (args.length > 0) {
                        i = 0;
                        for (String arg : args) {
                            LOGGER.info("arg " + i + " :" + arg);
                            ++i;
                        }
                    } else {
                        LOGGER.warning("Using default execution trace file : " + ExecutionTraceFilePath);
                    }
                }
                if (args.length >= 1 && args[0].length() > 0) {
                    ArrayList lines = Collections.emptyList();
                    try {
                        lines = (ArrayList)Files.readAllLines(Paths.get(args[0], new String[0]));
                        for (String line : lines) {
                            int i2;
                            String[] instrument = line.split(":");
                            for (i2 = 0; i2 < instrument.length; ++i2) {
                                instrument[i2] = instrument[i2].trim();
                            }
                            if (instrument.length == 2) {
                                String toInstrument = instrument[1].replaceAll("[.]", "/");
                                if (instrument[0].equalsIgnoreCase("package")) {
                                    packagesToInstrument.add(toInstrument);
                                    continue;
                                }
                                if (instrument[0].equalsIgnoreCase("class")) {
                                    ClassesToInstrument.add(toInstrument);
                                    continue;
                                }
                                if (instrument[0].equalsIgnoreCase("method")) {
                                    i2 = toInstrument.lastIndexOf("/");
                                    String clazz = toInstrument.substring(0, i2);
                                    classes_methodsToInstrument.add(clazz);
                                    methodsToInstrument.add(toInstrument);
                                    continue;
                                }
                                LOGGER.warning("Incorrect format found for instrumentation config - " + line);
                                continue;
                            }
                            LOGGER.warning("Incorrect format found for instrumentation config - " + line);
                        }
                    }
                    catch (IOException e) {
                        System.err.println("Exception in getting package names : " + e.getMessage());
                        e.printStackTrace();
                    }
                } else {
                    System.err.println("Provide the package or class name(s) to instrument");
                    System.exit(0);
                }
                if (args.length >= 2 && args[1].length() > 0) {
                    ExecutionTraceFilePath = args[1];
                }
                if (args.length >= 3) {
                    if (args[2].length() > 0) {
                        debug = Boolean.parseBoolean(args[2]);
                    }
                    debug_string = args[2];
                }
                if (args.length >= 4 && args[3].length() > 0) {
                    write_count = Integer.parseInt(args[3]);
                }
                if (args.length == 5 && args[4].length() > 0) {
                    sysEnd = Boolean.parseBoolean(args[4]);
                }
            } else {
                System.err.println("Provide the package or class name(s) and other required arguments to instrument");
                System.exit(0);
            }
            inst.addTransformer(new ClassFileTransformer(){

                @Override
                public byte[] transform(ClassLoader classLoader, String className, Class<?> classBeingRedefined, ProtectionDomain protectionDomain, byte[] classfileBuffer) throws IllegalClassFormatException {
                    try {
                        Class<?> injectorClass = null;
                        if (className == null) {
                            return classfileBuffer;
                        }
                        FQClassNameSl = className;
                        FQClassName = FQClassNameSl.replace("/", ".");
                        if (debug) {
                            LOGGER.info("Got class - " + className);
                            if (classLoader != null) {
                                LOGGER.info("\tLoaded by - " + classLoader.toString());
                            }
                        }
                        boolean instrument = false;
                        String classToInstrument = null;
                        String pkgToInstrument = null;
                        if (!classes_methodsToInstrument.isEmpty() && classes_methodsToInstrument.contains(className)) {
                            instrument = true;
                            classToInstrument = className;
                            instrumentationLevel = "method";
                        } else if (!ClassesToInstrument.isEmpty() && ClassesToInstrument.contains(className)) {
                            instrument = true;
                            classToInstrument = className;
                            instrumentationLevel = "class";
                        } else if (!packagesToInstrument.isEmpty()) {
                            for (String pack : packagesToInstrument) {
                                if (!className.startsWith(pack)) continue;
                                instrument = true;
                                pkgToInstrument = pack;
                                instrumentationLevel = "package";
                                break;
                            }
                        }
                        if (!instrument) {
                            return classfileBuffer;
                        }
                        if (debug) {
                            if (pkgToInstrument != null) {
                                LOGGER.info("Got matched with package :" + pkgToInstrument);
                                System.out.println("Instrumentation level : " + instrumentationLevel + "   Got class - " + className + " matched with package :" + pkgToInstrument);
                            } else {
                                LOGGER.info("Got matched with class :" + classToInstrument);
                                System.out.println("Instrumentation level : " + instrumentationLevel + "   Got class - " + className + " matched with classEntry :" + classToInstrument);
                            }
                        }
                        try {
                            injectorClass = Class.forName(FQinjectorClass);
                            if (debug) {
                                LOGGER.info("Injector class loaded");
                            }
                            if (ExecutionTraceFilePath.length() > 0) {
                                Method method = injectorClass.getMethod("jiveBciSetup", String.class, String.class, Integer.class);
                                if (debug) {
                                    LOGGER.info("Invoking JiveBCISetup for initializing trace file");
                                }
                                method.invoke(null, ExecutionTraceFilePath, debug_string, new Integer(write_count));
                            }
                        }
                        catch (Exception e1) {
                            LOGGER.info("Error in loading Injector class required for instrumentation : " + e1.getMessage());
                            e1.printStackTrace();
                        }
                        ClassReader reader = new ClassReader(classfileBuffer);
                        ModifiedClassVisitor mcvisitor = null;
                        ByteCodeWriter writer = null;
                        if (instrumentationLevel == "package") {
                            writer = new ByteCodeWriter(reader, classLoader, packagesToInstrument, true, 2);
                            mcvisitor = new ModifiedClassVisitor(393216, className, writer, packagesToInstrument, methodsToInstrument, instrumentationLevel);
                        } else if (instrumentationLevel == "class") {
                            writer = new ByteCodeWriter(reader, classLoader, ClassesToInstrument, false, 2);
                            mcvisitor = new ModifiedClassVisitor(393216, className, writer, ClassesToInstrument, methodsToInstrument, instrumentationLevel);
                        } else if (instrumentationLevel == "method") {
                            writer = new ByteCodeWriter(reader, classLoader, classes_methodsToInstrument, false, 2);
                            mcvisitor = new ModifiedClassVisitor(393216, className, writer, classes_methodsToInstrument, methodsToInstrument, instrumentationLevel);
                        }
                        reader.accept(mcvisitor, 0);
                        byte[] newClass = writer.toByteArray();
                        File newClassFile = null;
                        try {
                            newClassFile = new File("bci_cfi/" + className + ".class");
                            newClassFile.delete();
                            newClassFile.getParentFile().mkdirs();
                        }
                        catch (Exception e) {
                            System.out.println("Unable to create file : " + e.getMessage());
                        }
                        try (FileOutputStream fileOuputStream = new FileOutputStream(newClassFile);){
                            fileOuputStream.write(newClass);
                        }
                        catch (IOException e) {
                            System.out.println("Error in fos :" + e.getMessage());
                            e.printStackTrace();
                        }
                        return writer.toByteArray();
                    }
                    catch (Exception e) {
                        System.out.println(e.getMessage());
                        e.printStackTrace();
                        throw e;
                    }
                }
            });
        }
        catch (Exception e) {
            System.out.println(e.getMessage());
            e.printStackTrace();
        }
    }

    static {
        ExecutionTraceFilePath = "../JIVE_Execution_Trace1.csv";
        packagesToInstrument = new HashSet();
        ClassesToInstrument = new HashSet();
        methodsToInstrument = new HashSet();
        classes_methodsToInstrument = new HashSet();
        instrumentationLevel = "package";
        debug = false;
        write_count = 10;
        debug_string = "false";
        sysEnd = true;
        LOGGER = Logger.getLogger("CONTROL_LOGGER");
    }

    public static class ModifiedMethodVisitor
    extends MethodVisitor {
        int lineno;
        String sourceFile;
        String methodDesc;
        String className;
        String methodName;
        String superClassName;
        boolean isStaticMethod = false;
        boolean enteredMain = false;
        boolean isObjectInitialised = false;
        HashSet<String> packagesToInstrument;
        HashSet<String> classesToInstrument;
        boolean isPackage;

        public ModifiedMethodVisitor(int api, String owner, String superClassName, int access, String name, String desc, MethodVisitor mv, String sourceFile, HashSet<String> instrumentList, boolean isPkg) {
            super(api, mv);
            this.className = owner;
            this.sourceFile = sourceFile;
            this.methodName = name;
            this.methodDesc = desc;
            this.superClassName = superClassName;
            this.isPackage = isPkg;
            if (isPkg) {
                this.packagesToInstrument = instrumentList;
            } else {
                this.classesToInstrument = instrumentList;
            }
            int bitmask = 15;
            int static_specifier = access & bitmask;
            if (static_specifier >= 8) {
                this.isStaticMethod = true;
                if (debug) {
                    LOGGER.info("Modifying static method " + this.methodName);
                }
            }
        }

        @Override
        public void visitCode() {
            this.mv.visitMethodInsn(184, FQinjectorClassSl, "getCurrentMethodName", "()V", false);
            this.mv.visitCode();
        }

        @Override
        public void visitInsn(int opcode) {
            if (opcode >= 172 && opcode <= 177 && this.methodName.equalsIgnoreCase("main") && sysEnd) {
                if (debug) {
                    LOGGER.info("calling finish...");
                }
                this.mv.visitMethodInsn(184, FQinjectorClassSl, "jiveBciFinish", "()V", false);
            }
            super.visitInsn(opcode);
        }
    }

    public static class ModifiedClassVisitor
    extends ClassVisitor {
        private int api;
        boolean visitedStaticBlock = false;
        boolean isInterface = false;
        boolean isAbstract = false;
        String sourceFile = null;
        String className;
        HashSet<String> packagesToInstrument;
        HashSet<String> classesToInstrument;
        HashSet<String> methodsToInstrument;
        String instrumentationLevel;
        String superClassName = null;

        public ModifiedClassVisitor(int api, String className, ClassWriter cw, HashSet<String> instrumentList, HashSet<String> instrumentMethods, String insLevel) {
            super(api, cw);
            this.api = api;
            this.className = className;
            this.instrumentationLevel = insLevel;
            if (this.instrumentationLevel.equalsIgnoreCase("package")) {
                this.packagesToInstrument = instrumentList;
            } else {
                this.classesToInstrument = instrumentList;
            }
            this.methodsToInstrument = instrumentMethods;
        }

        @Override
        public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
            this.superClassName = superName;
            if (debug) {
                LOGGER.info("Super Classname : " + superName + " \n Interfaces :" + interfaces.length);
                for (int i = 0; i < interfaces.length; ++i) {
                    LOGGER.info("     " + interfaces[i]);
                }
            }
            int bitmask = 3840;
            int abstract_interface_specifier = access & bitmask;
            switch (abstract_interface_specifier) {
                case 1024: {
                    this.isAbstract = true;
                    break;
                }
                case 512: {
                    this.isInterface = true;
                    break;
                }
                case 1536: {
                    this.isAbstract = true;
                    this.isInterface = true;
                }
            }
            if (this.isInterface) {
                super.visit(version, access, name, signature, superName, interfaces);
                return;
            }
            super.visit(version, access, name, signature, superName, interfaces);
        }

        @Override
        public MethodVisitor visitMethod(int access, String name, String desc, String signature, String[] exceptions) {
            int bitmask;
            int abstract_specifier;
            if (debug) {
                LOGGER.info("Modifying method - access =" + access + "   name =" + name + "  desc=" + desc + "  sign=" + signature + "   exceptions =");
            }
            MethodVisitor mv = super.visitMethod(access, name, desc, signature, exceptions);
            if ("<clinit>".equals(name) && !this.visitedStaticBlock) {
                this.visitedStaticBlock = true;
            }
            if ((abstract_specifier = access & (bitmask = 3840)) == 1024) {
                return mv;
            }
            if (debug) {
                LOGGER.info("Calling method modifier for method - " + this.className + ":" + name);
            }
            ModifiedMethodVisitor mmw = null;
            if (this.instrumentationLevel.equalsIgnoreCase("package")) {
                mmw = new ModifiedMethodVisitor(this.api, this.className, this.superClassName, access, name, desc, mv, this.sourceFile, this.packagesToInstrument, true);
            } else if (this.instrumentationLevel.equalsIgnoreCase("class")) {
                mmw = new ModifiedMethodVisitor(this.api, this.className, this.superClassName, access, name, desc, mv, this.sourceFile, this.classesToInstrument, false);
            } else {
                String method = this.className + "/" + name;
                if (this.methodsToInstrument.contains(method)) {
                    mmw = new ModifiedMethodVisitor(this.api, this.className, this.superClassName, access, name, desc, mv, this.sourceFile, this.classesToInstrument, false);
                } else {
                    return mv;
                }
            }
            return mmw;
        }

        @Override
        public void visitSource(String source, String debug1) {
            if (debug) {
                LOGGER.info("Visiting Source: " + source + "  debug = " + debug1);
            }
            this.sourceFile = source;
            super.visitSource(source, debug1);
        }

        @Override
        public void visitEnd() {
            super.visitEnd();
            if (debug) {
                LOGGER.info("\n\nVisit completed !!! \n\n\n");
            }
        }
    }
}

