aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/org/glavo/javah/JNIGenerator.java
blob: c1a9d40b3cf2942b82b8e51aed4ed50598349f30 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package org.glavo.javah;

import org.objectweb.asm.*;

import java.io.IOException;
import java.io.PrintWriter;
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class JNIGenerator extends ClassVisitor {
    public static final String FILE_HEADER =
            "/* DO NOT EDIT THIS FILE - it is machine generated */\n" +
                    "#include <jni.h>\n";

    public static final String FILE_END = "\n";

    private final PrintWriter output;
    private final Path classFile;

    private String className;

    private Map<String, Integer> methodNameCount = new HashMap<>();
    private LinkedList<NativeMethod> methods = new LinkedList<>();
    private LinkedList<ConstantField> constants = new LinkedList<>();

    public JNIGenerator(PrintWriter output, Path classFile) {
        super(Opcodes.ASM7);
        Objects.requireNonNull(output);
        Objects.requireNonNull(classFile);
        this.output = output;
        this.classFile = classFile;
    }

    public void generate() {
        ClassReader cls = null;
        try {
            cls = new ClassReader(Files.readAllBytes(classFile));
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
        cls.accept(this, ClassReader.SKIP_CODE | ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);

        String name = Utils.encode(className).replace('/', '_');
        String guard = "_Included_" + className;

        output.println("/* Header for class " + className + " */\n\n");
        output.println("#ifndef " + guard);
        output.println("#define " + guard);
        output.println("#ifdef __cplusplus");
        output.println("extern \"C\" {");
        output.println("#endif");

        for (ConstantField constant : constants) {
            String cm = name + "_" + Utils.encode(constant.name);
            String value;
            if (constant.value instanceof Float) {
                value = constant.value.toString() + "f";
            } else if (constant.value instanceof Long) {
                value = constant.value.toString() + "i64";
            } else if (constant.value instanceof Double) {
                value = constant.value.toString();
            } else {
                value = constant.value.toString() + "L";
            }
            output.println("#undef " + cm);
            output.println("#define " + cm + " " + value);
        }
        for (NativeMethod method : methods) {
            String mm = "Java_" + name + "_" + Utils.encode(method.name);
            if (methodNameCount.getOrDefault(method.name, 1) > 1) {
                mm = mm + "__" + Utils.mangledArgSignature(method.type);
            }
            String r = Utils.mapToNativeType(method.type.getReturnType());
            String args = Stream.concat(Stream.of("JNIEnv *", method.isStatic ? "jclass" : "jobject"),
                    Arrays.stream(method.type.getArgumentTypes()).map(Utils::mapToNativeType))
                    .collect(Collectors.joining(","));

            output.println("/*");
            output.println(" * Class:     " + name);
            output.println(" * Method:    " + mm);
            output.println(" * Signature: " + method.type);
            output.println(" */");
            output.println("JNIEXPORT " + r + " JNICALL " + mm);
            output.println("  (" + args + ");");
            output.println();
        }

        output.println("#ifdef __cplusplus");
        output.println("}");
        output.println("#endif");
        output.println("#endif");
    }

    @Override
    public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
        className = name;
    }

    @Override
    public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
        if ((access & Opcodes.ACC_NATIVE) != 0) {
            methodNameCount.put(name, methodNameCount.getOrDefault(name, 0) + 1);
            methods.add(new NativeMethod(name, Type.getType(descriptor), (access & Opcodes.ACC_STATIC) != 0));
        }
        return null;
    }

    @Override
    public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) {
        if (value != null && !(value instanceof String)) {
            constants.add(new ConstantField(name, Type.getType(descriptor), value));
        }
        return null;
    }

    private static class NativeMethod {
        String name;
        Type type;
        boolean isStatic;

        public NativeMethod(String name, Type type, boolean isStatic) {
            this.name = name;
            this.type = type;
            this.isStatic = isStatic;
        }
    }

    private static class ConstantField {
        String name;
        Type type;
        Object value;

        public ConstantField(String name, Type type, Object value) {
            this.name = name;
            this.type = type;
            this.value = value;
        }
    }
}