summaryrefslogtreecommitdiff
path: root/src/scalap/scala/tools/scalap/scalasig/ClassFileParser.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/scalap/scala/tools/scalap/scalasig/ClassFileParser.scala')
-rw-r--r--src/scalap/scala/tools/scalap/scalasig/ClassFileParser.scala240
1 files changed, 240 insertions, 0 deletions
diff --git a/src/scalap/scala/tools/scalap/scalasig/ClassFileParser.scala b/src/scalap/scala/tools/scalap/scalasig/ClassFileParser.scala
new file mode 100644
index 0000000000..ed438be7f2
--- /dev/null
+++ b/src/scalap/scala/tools/scalap/scalasig/ClassFileParser.scala
@@ -0,0 +1,240 @@
+package scala.tools.scalap.scalasig
+
+import scala.tools.scalap.rules.{ Success, Failure, ~, RulesWithState }
+
+object ByteCode {
+ def apply(bytes : Array[Byte]) = new ByteCode(bytes, 0, bytes.length)
+
+ def forClass(clazz : Class[_]) = {
+ val name = clazz.getName
+ val subPath = name.substring(name.lastIndexOf('.') + 1) + ".class"
+ val in = clazz.getResourceAsStream(subPath)
+
+ try {
+ var rest = in.available()
+ val bytes = new Array[Byte](rest)
+ while (rest > 0) {
+ val res = in.read(bytes, bytes.length - rest, rest)
+ if (res == -1) throw new java.io.IOException("read error")
+ rest -= res
+ }
+ ByteCode(bytes)
+
+ } finally {
+ in.close()
+ }
+ }
+}
+
+/** Represents a chunk of raw bytecode. Used as input for the parsers. */
+class ByteCode(val bytes : Array[Byte], val pos : Int, val length : Int) {
+
+ assert(pos >= 0 && length >= 0 && pos + length <= bytes.length)
+
+ def nextByte = if (length == 0) Failure else Success(drop(1), bytes(pos))
+ def next(n : Int) = if (length >= n) Success(drop(n), take(n)) else Failure
+
+ def take(n : Int) = new ByteCode(bytes, pos, n)
+ def drop(n : Int) = new ByteCode(bytes, pos + n, length - n)
+
+ def fold[X](x : X)(f : (X, Byte) => X) : X = {
+ var result = x
+ var i = pos
+ while (i < pos + length) {
+ result = f(result, bytes(i))
+ i += 1
+ }
+ result
+ }
+
+ override def toString = length + " bytes"
+
+ def toInt = fold(0) { (x, b) => (x << 8) + (b & 0xFF)}
+ def toLong = fold(0L) { (x, b) => (x << 8) + (b & 0xFF)}
+
+ /**
+ * Transforms array subsequence of the current buffer into the UTF8 String and
+ * stores and array of bytes for the decompiler
+ */
+ def fromUTF8StringAndBytes = {
+ val chunk: Array[Byte] = new Array[Byte](length)
+ System.arraycopy(bytes, pos, chunk, 0, length)
+ val str = new String(io.Codec.fromUTF8(bytes, pos, length))
+
+ StringBytesPair(str, chunk)
+ }
+
+ def byte(i : Int) = bytes(pos) & 0xFF
+}
+
+/**
+ * The wrapper for decode UTF-8 string
+ */
+case class StringBytesPair(string: String, bytes: Array[Byte])
+
+/** Provides rules for parsing byte-code.
+*/
+trait ByteCodeReader extends RulesWithState {
+ type S = ByteCode
+ type Parser[A] = Rule[A, String]
+
+ val byte = apply(_.nextByte)
+
+ val u1 = byte ^^ (_ & 0xFF)
+ val u2 = bytes(2) ^^ (_.toInt)
+ val u4 = bytes(4) ^^ (_.toInt) // should map to Long??
+
+ def bytes(n : Int) = apply(_ next n)
+}
+
+object ClassFileParser extends ByteCodeReader {
+ def parse(byteCode : ByteCode) = expect(classFile)(byteCode)
+ def parseAnnotations(byteCode: ByteCode) = expect(annotations)(byteCode)
+
+ val magicNumber = (u4 filter (_ == 0xCAFEBABE)) | error("Not a valid class file")
+ val version = u2 ~ u2 ^^ { case minor ~ major => (major, minor) }
+ val constantPool = (u2 ^^ ConstantPool) >> repeatUntil(constantPoolEntry)(_.isFull)
+
+ // NOTE currently most constants just evaluate to a string description
+ // TODO evaluate to useful values
+ val utf8String = (u2 >> bytes) ^^ add1 { raw => pool => raw.fromUTF8StringAndBytes }
+ val intConstant = u4 ^^ add1 { x => pool => x }
+ val floatConstant = bytes(4) ^^ add1 { raw => pool => "Float: TODO" }
+ val longConstant = bytes(8) ^^ add2 { raw => pool => raw.toLong }
+ val doubleConstant = bytes(8) ^^ add2 { raw => pool => "Double: TODO" }
+ val classRef = u2 ^^ add1 { x => pool => "Class: " + pool(x) }
+ val stringRef = u2 ^^ add1 { x => pool => "String: " + pool(x) }
+ val fieldRef = memberRef("Field")
+ val methodRef = memberRef("Method")
+ val interfaceMethodRef = memberRef("InterfaceMethod")
+ val nameAndType = u2 ~ u2 ^^ add1 { case name ~ descriptor => pool => "NameAndType: " + pool(name) + ", " + pool(descriptor) }
+
+ val constantPoolEntry = u1 >> {
+ case 1 => utf8String
+ case 3 => intConstant
+ case 4 => floatConstant
+ case 5 => longConstant
+ case 6 => doubleConstant
+ case 7 => classRef
+ case 8 => stringRef
+ case 9 => fieldRef
+ case 10 => methodRef
+ case 11 => interfaceMethodRef
+ case 12 => nameAndType
+ }
+
+ val interfaces = u2 >> u2.times
+
+ // bytes are parametrizes by the length, declared in u4 section
+ val attribute = u2 ~ (u4 >> bytes) ^~^ Attribute
+ // parse attributes u2 times
+ val attributes = u2 >> attribute.times
+
+ // parse runtime-visible annotations
+ abstract class ElementValue
+ case class AnnotationElement(elementNameIndex: Int, elementValue: ElementValue)
+ case class ConstValueIndex(index: Int) extends ElementValue
+ case class EnumConstValue(typeNameIndex: Int, constNameIndex: Int) extends ElementValue
+ case class ClassInfoIndex(index: Int) extends ElementValue
+ case class Annotation(typeIndex: Int, elementValuePairs: Seq[AnnotationElement]) extends ElementValue
+ case class ArrayValue(values: Seq[ElementValue]) extends ElementValue
+
+ def element_value: Parser[ElementValue] = u1 >> {
+ case 'B'|'C'|'D'|'F'|'I'|'J'|'S'|'Z'|'s' => u2 ^^ ConstValueIndex
+ case 'e' => u2 ~ u2 ^~^ EnumConstValue
+ case 'c' => u2 ^^ ClassInfoIndex
+ case '@' => annotation //nested annotation
+ case '[' => u2 >> element_value.times ^^ ArrayValue
+ }
+
+ val element_value_pair = u2 ~ element_value ^~^ AnnotationElement
+ val annotation: Parser[Annotation] = u2 ~ (u2 >> element_value_pair.times) ^~^ Annotation
+ val annotations = u2 >> annotation.times
+
+ val field = u2 ~ u2 ~ u2 ~ attributes ^~~~^ Field
+ val fields = u2 >> field.times
+
+ val method = u2 ~ u2 ~ u2 ~ attributes ^~~~^ Method
+ val methods = u2 >> method.times
+
+ val header = magicNumber -~ u2 ~ u2 ~ constantPool ~ u2 ~ u2 ~ u2 ~ interfaces ^~~~~~~^ ClassFileHeader
+ val classFile = header ~ fields ~ methods ~ attributes ~- !u1 ^~~~^ ClassFile
+
+ // TODO create a useful object, not just a string
+ def memberRef(description : String) = u2 ~ u2 ^^ add1 {
+ case classRef ~ nameAndTypeRef => pool => description + ": " + pool(classRef) + ", " + pool(nameAndTypeRef)
+ }
+
+ def add1[T](f : T => ConstantPool => Any)(raw : T)(pool : ConstantPool) = pool add f(raw)
+ def add2[T](f : T => ConstantPool => Any)(raw : T)(pool : ConstantPool) = pool add f(raw) add { pool => "<empty>" }
+}
+
+case class ClassFile(
+ header : ClassFileHeader,
+ fields : Seq[Field],
+ methods : Seq[Method],
+ attributes : Seq[Attribute]) {
+
+ def majorVersion = header.major
+ def minorVersion = header.minor
+
+ def className = constant(header.classIndex)
+ def superClass = constant(header.superClassIndex)
+ def interfaces = header.interfaces.map(constant)
+
+ def constant(index : Int) = header.constants(index) match {
+ case StringBytesPair(str, _) => str
+ case z => z
+ }
+
+ def constantWrapped(index: Int) = header.constants(index)
+
+ def attribute(name : String) = attributes.find {attrib => constant(attrib.nameIndex) == name }
+
+ val RUNTIME_VISIBLE_ANNOTATIONS = "RuntimeVisibleAnnotations"
+ def annotations = (attributes.find(attr => constant(attr.nameIndex) == RUNTIME_VISIBLE_ANNOTATIONS)
+ .map(attr => ClassFileParser.parseAnnotations(attr.byteCode)))
+
+ def annotation(name: String) = annotations.flatMap(seq => seq.find(annot => constant(annot.typeIndex) == name))
+}
+
+case class Attribute(nameIndex : Int, byteCode : ByteCode)
+case class Field(flags : Int, nameIndex : Int, descriptorIndex : Int, attributes : Seq[Attribute])
+case class Method(flags : Int, nameIndex : Int, descriptorIndex : Int, attributes : Seq[Attribute])
+
+case class ClassFileHeader(
+ minor : Int,
+ major : Int,
+ constants : ConstantPool,
+ flags : Int,
+ classIndex : Int,
+ superClassIndex : Int,
+ interfaces : Seq[Int]) {
+
+ def constant(index : Int) = constants(index)
+}
+
+case class ConstantPool(len : Int) {
+ val size = len - 1
+
+ private val buffer = new scala.collection.mutable.ArrayBuffer[ConstantPool => Any]
+ private val values = Array.fill[Option[Any]](size)(None)
+
+ def isFull = buffer.length >= size
+
+ def apply(index : Int) = {
+ // Note constant pool indices are 1-based
+ val i = index - 1
+ values(i) getOrElse {
+ val value = buffer(i)(this)
+ buffer(i) = null
+ values(i) = Some(value)
+ value
+ }
+ }
+
+ def add(f : ConstantPool => Any) = {
+ buffer += f
+ this
+ }
+}