aboutsummaryrefslogblamecommitdiff
path: root/src/main/java/org/glavo/javah/JNIGenerator.java
blob: c1a9d40b3cf2942b82b8e51aed4ed50598349f30 (plain) (tree)






























































                                                                                                   
                                                          
                                                  

                                                        











































































                                                                                                                         
package org.glavo.javah;

import org.objectweb.asm.*;

import java.io.IOException;
import java.io.PrintWriter;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class JNIGenerator extends ClassVisitor {
    public static final String FILE_HEADER =
            "/* DO NOT EDIT THIS FILE - it is machine generated */\n" +
                    "#include <jni.h>\n";

    public static final String FILE_END = "\n";

    private final PrintWriter output;
    private final Path classFile;

    private String className;

    private Map<String, Integer> methodNameCount = new HashMap<>();
    private LinkedList<NativeMethod> methods = new LinkedList<>();
    private LinkedList<ConstantField> constants = new LinkedList<>();

    public JNIGenerator(PrintWriter output, Path classFile) {
        super(Opcodes.ASM7);
        Objects.requireNonNull(output);
        Objects.requireNonNull(classFile);
        this.output = output;
        this.classFile = classFile;
    }

    public void generate() {
        ClassReader cls = null;
        try {
            cls = new ClassReader(Files.readAllBytes(classFile));
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
        cls.accept(this, ClassReader.SKIP_CODE | ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);

        String name = Utils.encode(className).replace('/', '_');
        String guard = "_Included_" + className;

        output.println("/* Header for class " + className + " */\n\n");
        output.println("#ifndef " + guard);
        output.println("#define " + guard);
        output.println("#ifdef __cplusplus");
        output.println("extern \"C\" {");
        output.println("#endif");

        for (ConstantField constant : constants) {
            String cm = name + "_" + Utils.encode(constant.name);
            String value;
            if (constant.value instanceof Float) {
                value = constant.value.toString() + "f";
            } else if (constant.value instanceof Long) {
                value = constant.value.toString() + "i64";
            } else if (constant.value instanceof Double) {
                value = constant.value.toString();
            } else {
                value = constant.value.toString() + "L";
            }
            output.println("#undef " + cm);
            output.println("#define " + cm + " " + value);
        }
        for (NativeMethod method : methods) {
            String mm = "Java_" + name + "_" + Utils.encode(method.name);
            if (methodNameCount.getOrDefault(method.name, 1) > 1) {
                mm = mm + "__" + Utils.mangledArgSignature(method.type);
            }
            String r = Utils.mapToNativeType(method.type.getReturnType());
            String args = Stream.concat(Stream.of("JNIEnv *", method.isStatic ? "jclass" : "jobject"),
                    Arrays.stream(method.type.getArgumentTypes()).map(Utils::mapToNativeType))
                    .collect(Collectors.joining(","));

            output.println("/*");
            output.println(" * Class:     " + name);
            output.println(" * Method:    " + mm);
            output.println(" * Signature: " + method.type);
            output.println(" */");
            output.println("JNIEXPORT " + r + " JNICALL " + mm);
            output.println("  (" + args + ");");
            output.println();
        }

        output.println("#ifdef __cplusplus");
        output.println("}");
        output.println("#endif");
        output.println("#endif");
    }

    @Override
    public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
        className = name;
    }

    @Override
    public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
        if ((access & Opcodes.ACC_NATIVE) != 0) {
            methodNameCount.put(name, methodNameCount.getOrDefault(name, 0) + 1);
            methods.add(new NativeMethod(name, Type.getType(descriptor), (access & Opcodes.ACC_STATIC) != 0));
        }
        return null;
    }

    @Override
    public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) {
        if (value != null && !(value instanceof String)) {
            constants.add(new ConstantField(name, Type.getType(descriptor), value));
        }
        return null;
    }

    private static class NativeMethod {
        String name;
        Type type;
        boolean isStatic;

        public NativeMethod(String name, Type type, boolean isStatic) {
            this.name = name;
            this.type = type;
            this.isStatic = isStatic;
        }
    }

    private static class ConstantField {
        String name;
        Type type;
        Object value;

        public ConstantField(String name, Type type, Object value) {
            this.name = name;
            this.type = type;
            this.value = value;
        }
    }
}