diff options
Diffstat (limited to 'test/junit/scala/tools/testing/BytecodeTesting.scala')
-rw-r--r-- | test/junit/scala/tools/testing/BytecodeTesting.scala | 312 |
1 files changed, 312 insertions, 0 deletions
diff --git a/test/junit/scala/tools/testing/BytecodeTesting.scala b/test/junit/scala/tools/testing/BytecodeTesting.scala new file mode 100644 index 0000000000..c0fdb8010f --- /dev/null +++ b/test/junit/scala/tools/testing/BytecodeTesting.scala @@ -0,0 +1,312 @@ +package scala.tools.testing + +import junit.framework.AssertionFailedError +import org.junit.Assert._ + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer +import scala.reflect.internal.util.BatchSourceFile +import scala.reflect.io.VirtualDirectory +import scala.tools.asm.Opcodes +import scala.tools.asm.tree.{AbstractInsnNode, ClassNode, MethodNode} +import scala.tools.cmd.CommandLineParser +import scala.tools.nsc.backend.jvm.AsmUtils +import scala.tools.nsc.backend.jvm.AsmUtils._ +import scala.tools.nsc.backend.jvm.opt.BytecodeUtils +import scala.tools.nsc.io.AbstractFile +import scala.tools.nsc.reporters.StoreReporter +import scala.tools.nsc.{Global, Settings} +import scala.tools.partest.ASMConverters._ + +trait BytecodeTesting extends ClearAfterClass { + def compilerArgs = "" // to be overridden + val compiler = cached("compiler", () => BytecodeTesting.newCompiler(extraArgs = compilerArgs)) +} + +class Compiler(val global: Global) { + import BytecodeTesting._ + + def resetOutput(): Unit = { + global.settings.outputDirs.setSingleOutput(new VirtualDirectory("(memory)", None)) + } + + def newRun: global.Run = { + global.reporter.reset() + resetOutput() + new global.Run() + } + + private def reporter = global.reporter.asInstanceOf[StoreReporter] + + def checkReport(allowMessage: StoreReporter#Info => Boolean = _ => false): Unit = { + val disallowed = reporter.infos.toList.filter(!allowMessage(_)) // toList prevents an infer-non-wildcard-existential warning. + if (disallowed.nonEmpty) { + val msg = disallowed.mkString("\n") + assert(false, "The compiler issued non-allowed warnings or errors:\n" + msg) + } + } + + def compileToBytes(scalaCode: String, javaCode: List[(String, String)] = Nil, allowMessage: StoreReporter#Info => Boolean = _ => false): List[(String, Array[Byte])] = { + val run = newRun + run.compileSources(makeSourceFile(scalaCode, "unitTestSource.scala") :: javaCode.map(p => makeSourceFile(p._1, p._2))) + checkReport(allowMessage) + getGeneratedClassfiles(global.settings.outputDirs.getSingleOutput.get) + } + + def compileClasses(code: String, javaCode: List[(String, String)] = Nil, allowMessage: StoreReporter#Info => Boolean = _ => false): List[ClassNode] = { + readAsmClasses(compileToBytes(code, javaCode, allowMessage)) + } + + def compileClass(code: String, javaCode: List[(String, String)] = Nil, allowMessage: StoreReporter#Info => Boolean = _ => false): ClassNode = { + val List(c) = compileClasses(code, javaCode, allowMessage) + c + } + + def compileToBytesTransformed(scalaCode: String, javaCode: List[(String, String)] = Nil, beforeBackend: global.Tree => global.Tree): List[(String, Array[Byte])] = { + import global._ + settings.stopBefore.value = "jvm" :: Nil + val run = newRun + val scalaUnit = newCompilationUnit(scalaCode, "unitTestSource.scala") + val javaUnits = javaCode.map(p => newCompilationUnit(p._1, p._2)) + val units = scalaUnit :: javaUnits + run.compileUnits(units, run.parserPhase) + settings.stopBefore.value = Nil + scalaUnit.body = beforeBackend(scalaUnit.body) + checkReport(_ => false) + val run1 = newRun + run1.compileUnits(units, run1.phaseNamed("jvm")) + checkReport(_ => false) + getGeneratedClassfiles(settings.outputDirs.getSingleOutput.get) + } + + def compileClassesTransformed(scalaCode: String, javaCode: List[(String, String)] = Nil, beforeBackend: global.Tree => global.Tree): List[ClassNode] = + readAsmClasses(compileToBytesTransformed(scalaCode, javaCode, beforeBackend)) + + def compileAsmMethods(code: String, allowMessage: StoreReporter#Info => Boolean = _ => false): List[MethodNode] = { + val c = compileClass(s"class C { $code }", allowMessage = allowMessage) + getAsmMethods(c, _ != "<init>") + } + + def compileAsmMethod(code: String, allowMessage: StoreReporter#Info => Boolean = _ => false): MethodNode = { + val List(m) = compileAsmMethods(code, allowMessage) + m + } + + def compileMethods(code: String, allowMessage: StoreReporter#Info => Boolean = _ => false): List[Method] = + compileAsmMethods(code, allowMessage).map(convertMethod) + + def compileMethod(code: String, allowMessage: StoreReporter#Info => Boolean = _ => false): Method = { + val List(m) = compileMethods(code, allowMessage = allowMessage) + m + } + + def compileInstructions(code: String, allowMessage: StoreReporter#Info => Boolean = _ => false): List[Instruction] = { + val List(m) = compileMethods(code, allowMessage = allowMessage) + m.instructions + } +} + +object BytecodeTesting { + def genMethod(flags: Int = Opcodes.ACC_PUBLIC, + name: String = "m", + descriptor: String = "()V", + genericSignature: String = null, + throwsExceptions: Array[String] = null, + handlers: List[ExceptionHandler] = Nil, + localVars: List[LocalVariable] = Nil)(body: Instruction*): MethodNode = { + val node = new MethodNode(flags, name, descriptor, genericSignature, throwsExceptions) + applyToMethod(node, Method(body.toList, handlers, localVars)) + node + } + + def wrapInClass(method: MethodNode): ClassNode = { + val cls = new ClassNode() + cls.visit(Opcodes.V1_6, Opcodes.ACC_PUBLIC, "C", null, "java/lang/Object", null) + cls.methods.add(method) + cls + } + + def newCompiler(defaultArgs: String = "-usejavacp", extraArgs: String = ""): Compiler = { + val compiler = newCompilerWithoutVirtualOutdir(defaultArgs, extraArgs) + compiler.resetOutput() + compiler + } + + def newCompilerWithoutVirtualOutdir(defaultArgs: String = "-usejavacp", extraArgs: String = ""): Compiler = { + def showError(s: String) = throw new Exception(s) + val settings = new Settings(showError) + val args = (CommandLineParser tokenize defaultArgs) ++ (CommandLineParser tokenize extraArgs) + val (_, nonSettingsArgs) = settings.processArguments(args, processAll = true) + if (nonSettingsArgs.nonEmpty) showError("invalid compiler flags: " + nonSettingsArgs.mkString(" ")) + new Compiler(new Global(settings, new StoreReporter)) + } + + def makeSourceFile(code: String, filename: String): BatchSourceFile = new BatchSourceFile(filename, code) + + def getGeneratedClassfiles(outDir: AbstractFile): List[(String, Array[Byte])] = { + def files(dir: AbstractFile): List[(String, Array[Byte])] = { + val res = ListBuffer.empty[(String, Array[Byte])] + for (f <- dir.iterator) { + if (!f.isDirectory) res += ((f.name, f.toByteArray)) + else if (f.name != "." && f.name != "..") res ++= files(f) + } + res.toList + } + files(outDir) + } + + /** + * Compile multiple Scala files separately into a single output directory. + * + * Note that a new compiler instance is created for compiling each file because symbols survive + * across runs. This makes separate compilation slower. + * + * The output directory is a physical directory, I have not figured out if / how it's possible to + * add a VirtualDirectory to the classpath of a compiler. + */ + def compileToBytesSeparately(codes: List[String], extraArgs: String = "", allowMessage: StoreReporter#Info => Boolean = _ => false, afterEach: AbstractFile => Unit = _ => ()): List[(String, Array[Byte])] = { + val outDir = AbstractFile.getDirectory(TempDir.createTempDir()) + val outDirPath = outDir.canonicalPath + val argsWithOutDir = extraArgs + s" -d $outDirPath -cp $outDirPath" + + for (code <- codes) { + val compiler = newCompilerWithoutVirtualOutdir(extraArgs = argsWithOutDir) + new compiler.global.Run().compileSources(List(makeSourceFile(code, "unitTestSource.scala"))) + compiler.checkReport(allowMessage) + afterEach(outDir) + } + + val classfiles = getGeneratedClassfiles(outDir) + outDir.delete() + classfiles + } + + def compileClassesSeparately(codes: List[String], extraArgs: String = "", allowMessage: StoreReporter#Info => Boolean = _ => false, afterEach: AbstractFile => Unit = _ => ()): List[ClassNode] = { + readAsmClasses(compileToBytesSeparately(codes, extraArgs, allowMessage, afterEach)) + } + + def readAsmClasses(classfiles: List[(String, Array[Byte])]) = classfiles.map(p => AsmUtils.readClass(p._2)).sortBy(_.name) + + def assertSameCode(method: Method, expected: List[Instruction]): Unit = assertSameCode(method.instructions.dropNonOp, expected) + def assertSameCode(actual: List[Instruction], expected: List[Instruction]): Unit = { + assert(actual === expected, s"\nExpected: $expected\nActual : $actual") + } + + def assertSameSummary(method: Method, expected: List[Any]): Unit = assertSameSummary(method.instructions, expected) + def assertSameSummary(actual: List[Instruction], expected: List[Any]): Unit = { + def expectedString = expected.map({ + case s: String => s""""$s"""" + case i: Int => opcodeToString(i, i) + }).mkString("List(", ", ", ")") + assert(actual.summary == expected, s"\nFound : ${actual.summaryText}\nExpected: $expectedString") + } + + def assertNoInvoke(m: Method): Unit = assertNoInvoke(m.instructions) + def assertNoInvoke(ins: List[Instruction]): Unit = { + assert(!ins.exists(_.isInstanceOf[Invoke]), ins.stringLines) + } + + def assertInvoke(m: Method, receiver: String, method: String): Unit = assertInvoke(m.instructions, receiver, method) + def assertInvoke(l: List[Instruction], receiver: String, method: String): Unit = { + assert(l.exists { + case Invoke(_, `receiver`, `method`, _, _) => true + case _ => false + }, l.stringLines) + } + + def assertDoesNotInvoke(m: Method, method: String): Unit = assertDoesNotInvoke(m.instructions, method) + def assertDoesNotInvoke(l: List[Instruction], method: String): Unit = { + assert(!l.exists { + case i: Invoke => i.name == method + case _ => false + }, l.stringLines) + } + + def assertInvokedMethods(m: Method, expected: List[String]): Unit = assertInvokedMethods(m.instructions, expected) + def assertInvokedMethods(l: List[Instruction], expected: List[String]): Unit = { + def quote(l: List[String]) = l.map(s => s""""$s"""").mkString("List(", ", ", ")") + val actual = l collect { case i: Invoke => i.owner + "." + i.name } + assert(actual == expected, s"\nFound : ${quote(actual)}\nExpected: ${quote(expected)}") + } + + def assertNoIndy(m: Method): Unit = assertNoIndy(m.instructions) + def assertNoIndy(l: List[Instruction]) = { + val indy = l collect { case i: InvokeDynamic => i } + assert(indy.isEmpty, indy) + } + + def findClass(cs: List[ClassNode], name: String): ClassNode = { + val List(c) = cs.filter(_.name == name) + c + } + + def getAsmMethods(c: ClassNode, p: String => Boolean): List[MethodNode] = + c.methods.iterator.asScala.filter(m => p(m.name)).toList.sortBy(_.name) + + def getAsmMethods(c: ClassNode, name: String): List[MethodNode] = + getAsmMethods(c, _ == name) + + def getAsmMethod(c: ClassNode, name: String): MethodNode = { + val methods = getAsmMethods(c, name) + def fail() = { + val allNames = getAsmMethods(c, _ => true).map(_.name) + throw new AssertionFailedError(s"Could not find method named $name among ${allNames}") + } + methods match { + case List(m) => m + case ms @ List(m1, m2) if BytecodeUtils.isInterface(c) => + val (statics, nonStatics) = ms.partition(BytecodeUtils.isStaticMethod) + (statics, nonStatics) match { + case (List(staticMethod), List(_)) => m1 // prefer the static method of the pair if methods in traits + case _ => fail() + } + case ms => fail() + } + } + + def getMethods(c: ClassNode, name: String): List[Method] = + getAsmMethods(c, name).map(convertMethod) + + def getMethod(c: ClassNode, name: String): Method = + convertMethod(getAsmMethod(c, name)) + + def getInstructions(c: ClassNode, name: String): List[Instruction] = + getMethod(c, name).instructions + + /** + * Instructions that match `query` when textified. + * If `query` starts with a `+`, the next instruction is returned. + */ + def findInstrs(method: MethodNode, query: String): List[AbstractInsnNode] = { + val useNext = query(0) == '+' + val instrPart = if (useNext) query.drop(1) else query + val insns = method.instructions.iterator.asScala.filter(i => textify(i) contains instrPart).toList + if (useNext) insns.map(_.getNext) else insns + } + + /** + * Instruction that matches `query` when textified. + * If `query` starts with a `+`, the next instruction is returned. + */ + def findInstr(method: MethodNode, query: String): AbstractInsnNode = { + val List(i) = findInstrs(method, query) + i + } + + def assertHandlerLabelPostions(h: ExceptionHandler, instructions: List[Instruction], startIndex: Int, endIndex: Int, handlerIndex: Int): Unit = { + val insVec = instructions.toVector + assertTrue(h.start == insVec(startIndex) && h.end == insVec(endIndex) && h.handler == insVec(handlerIndex)) + } + + import scala.language.implicitConversions + + implicit def aliveInstruction(ins: Instruction): (Instruction, Boolean) = (ins, true) + + implicit class MortalInstruction(val ins: Instruction) extends AnyVal { + def dead: (Instruction, Boolean) = (ins, false) + } + + implicit class listStringLines[T](val l: List[T]) extends AnyVal { + def stringLines = l.mkString("\n") + } +} |