diff options
author | Jason Zaugg <jzaugg@gmail.com> | 2014-09-16 14:57:23 +1000 |
---|---|---|
committer | Jason Zaugg <jzaugg@gmail.com> | 2014-09-16 14:57:23 +1000 |
commit | 0e2be38f05a6d3fadd0a6c800604503c72401117 (patch) | |
tree | a9b836b0ec7048d6182448d1157f99711703bf05 | |
parent | 1b9806171940d304b41442b788717d2425764cbf (diff) | |
parent | 9132efa4a8511e267c808c95df4d2e3de68277e6 (diff) | |
download | scala-0e2be38f05a6d3fadd0a6c800604503c72401117.tar.gz scala-0e2be38f05a6d3fadd0a6c800604503c72401117.tar.bz2 scala-0e2be38f05a6d3fadd0a6c800604503c72401117.zip |
Merge pull request #3971 from lrytz/opt/dce
GenBCode: eliminate unreachable code
29 files changed, 1303 insertions, 88 deletions
@@ -970,6 +970,7 @@ TODO: <pathelement location="${test.junit.classes}"/> <path refid="quick.compiler.build.path"/> <path refid="quick.repl.build.path"/> + <path refid="quick.partest-extras.build.path"/> <path refid="junit.classpath"/> </path> diff --git a/src/asm/scala/tools/asm/MethodWriter.java b/src/asm/scala/tools/asm/MethodWriter.java index 0c4130e499..d30e04c625 100644 --- a/src/asm/scala/tools/asm/MethodWriter.java +++ b/src/asm/scala/tools/asm/MethodWriter.java @@ -37,7 +37,7 @@ package scala.tools.asm; * @author Eric Bruneton * @author Eugene Kuleshov */ -class MethodWriter extends MethodVisitor { +public class MethodWriter extends MethodVisitor { /** * Pseudo access flag used to denote constructors. @@ -235,11 +235,19 @@ class MethodWriter extends MethodVisitor { */ private int maxStack; + public int getMaxStack() { + return maxStack; + } + /** * Maximum number of local variables for this method. */ private int maxLocals; + public int getMaxLocals() { + return maxLocals; + } + /** * Number of local variables in the current stack map frame. */ diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala index 397171049f..daf36ce374 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala @@ -9,7 +9,6 @@ package tools.nsc package backend package jvm -import scala.collection.{ mutable, immutable } import scala.annotation.switch import scala.tools.asm @@ -283,9 +282,10 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { val Local(tk, _, idx, isSynth) = locals.getOrMakeLocal(sym) if (rhs == EmptyTree) { emitZeroOf(tk) } else { genLoad(rhs, tk) } + val localVarStart = currProgramPoint() bc.store(idx, tk) if (!isSynth) { // there are case <synthetic> ValDef's emitted by patmat - varsInScope ::= (sym -> currProgramPoint()) + varsInScope ::= (sym -> localVarStart) } generatedType = UNIT @@ -815,7 +815,51 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { case _ => bc.emitT2T(from, to) } } else if (from.isNothingType) { - emit(asm.Opcodes.ATHROW) // ICode enters here into enterIgnoreMode, we'll rely instead on DCE at ClassNode level. + /* There are two possibilities for from.isNothingType: emitting a "throw e" expressions and + * loading a (phantom) value of type Nothing. + * + * The Nothing type in Scala's type system does not exist in the JVM. In bytecode, Nothing + * is mapped to scala.runtime.Nothing$. To the JVM, a call to Predef.??? looks like it would + * return an object of type Nothing$. We need to do something with that phantom object on + * the stack. "Phantom" because it never exists: such methods always throw, but the JVM does + * not know that. + * + * Note: The two verifiers (old: type inference, new: type checking) have different + * requirements. Very briefly: + * + * Old (http://docs.oracle.com/javase/specs/jvms/se8/html/jvms-4.html#jvms-4.10.2.1): at + * each program point, no matter what branches were taken to get there + * - Stack is same size and has same typed values + * - Local and stack values need to have consistent types + * - In practice, the old verifier seems to ignore unreachable code and accept any + * instructions after an ATHROW. For example, there can be another ATHROW (without + * loading another throwable first). + * + * New (http://docs.oracle.com/javase/specs/jvms/se8/html/jvms-4.html#jvms-4.10.1) + * - Requires consistent stack map frames. GenBCode generates stack frames if -target:jvm-1.6 + * or higher. + * - In practice: the ASM library computes stack map frames for us (ClassWriter). Emitting + * correct frames after an ATHROW is probably complex, so ASM uses the following strategy: + * - Every time when generating an ATHROW, a new basic block is started. + * - During classfile writing, such basic blocks are found to be dead: no branches go there + * - Eliminating dead code would probably require complex shifts in the output byte buffer + * - But there's an easy solution: replace all code in the dead block with with + * `nop; nop; ... nop; athrow`, making sure the bytecode size stays the same + * - The corresponding stack frame can be easily generated: on entering a dead the block, + * the frame requires a single Throwable on the stack. + * - Since there are no branches to the dead block, the frame requirements are never violated. + * + * To summarize the above: it does matter what we emit after an ATHROW. + * + * NOW: if we end up here because we emitted a load of a (phantom) value of type Nothing$, + * there was no ATHROW emitted. So, we have to make the verifier happy and do something + * with that value. Since Nothing$ extends Throwable, the easiest is to just emit an ATHROW. + * + * If we ended up here because we generated a "throw e" expression, we know the last + * emitted instruction was an ATHROW. As explained above, it is OK to emit a second ATHROW, + * the verifiers will be happy. + */ + emit(asm.Opcodes.ATHROW) } else if (from.isNullType) { bc drop from emit(asm.Opcodes.ACONST_NULL) diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala index 5670715cd3..14bffd67ab 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeHelpers.scala @@ -782,14 +782,9 @@ abstract class BCodeHelpers extends BCodeIdiomatic with BytecodeWriters { } } // end of trait BCClassGen - /* basic functionality for class file building of plain, mirror, and beaninfo classes. */ - abstract class JBuilder extends BCInnerClassGen { - - } // end of class JBuilder - /* functionality for building plain and mirror classes */ abstract class JCommonBuilder - extends JBuilder + extends BCInnerClassGen with BCAnnotGen with BCForwardersGen with BCPickles { } @@ -851,7 +846,7 @@ abstract class BCodeHelpers extends BCodeIdiomatic with BytecodeWriters { } // end of class JMirrorBuilder /* builder of bean info classes */ - class JBeanInfoBuilder extends JBuilder { + class JBeanInfoBuilder extends BCInnerClassGen { /* * Generate a bean info class that describes the given class. diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala index 4592031a31..03bc32061b 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeSkelBuilder.scala @@ -346,6 +346,13 @@ abstract class BCodeSkelBuilder extends BCodeHelpers { /* * Bookkeeping for method-local vars and method-params. + * + * TODO: use fewer slots. local variable slots are never re-used in separate blocks. + * In the following example, x and y could use the same slot. + * def foo() = { + * { val x = 1 } + * { val y = "a" } + * } */ object locals { diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala b/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala index 7bf61b4f51..53ac5bfdc7 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala @@ -713,6 +713,8 @@ abstract class BTypes { // TODO @lry I don't really understand the reasoning here. // Both this and other are classes. The code takes (transitively) all superclasses and // finds the first common one. + // MOST LIKELY the answer can be found here, see the comments and links by Miguel: + // - https://issues.scala-lang.org/browse/SI-3872 firstCommonSuffix(this :: this.superClassesTransitive, other :: other.superClassesTransitive) } diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BackendStats.scala b/src/compiler/scala/tools/nsc/backend/jvm/BackendStats.scala new file mode 100644 index 0000000000..4b9383c67c --- /dev/null +++ b/src/compiler/scala/tools/nsc/backend/jvm/BackendStats.scala @@ -0,0 +1,24 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2014 LAMP/EPFL + * @author Martin Odersky + */ + +package scala.tools.nsc +package backend.jvm + +import scala.reflect.internal.util.Statistics + +object BackendStats { + import Statistics.{newTimer, newSubTimer} + val bcodeTimer = newTimer("time in backend", "jvm") + + val bcodeInitTimer = newSubTimer("bcode initialization", bcodeTimer) + val bcodeGenStat = newSubTimer("code generation", bcodeTimer) + val bcodeDceTimer = newSubTimer("dead code elimination", bcodeTimer) + val bcodeWriteTimer = newSubTimer("classfile writing", bcodeTimer) + + def timed[T](timer: Statistics.Timer)(body: => T): T = { + val start = Statistics.startTimer(timer) + try body finally Statistics.stopTimer(timer, start) + } +} diff --git a/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala b/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala index 0a7c894a69..ba94a9c44c 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/GenBCode.scala @@ -11,8 +11,10 @@ package jvm import scala.collection.{ mutable, immutable } import scala.annotation.switch +import scala.reflect.internal.util.Statistics import scala.tools.asm +import scala.tools.asm.tree.ClassNode /* * Prepare in-memory representations of classfiles using the ASM Tree API, and serialize them to disk. @@ -213,6 +215,14 @@ abstract class GenBCode extends BCodeSyncAndTry { * - converting the plain ClassNode to byte array and placing it on queue-3 */ class Worker2 { + def localOptimizations(classNode: ClassNode): Unit = { + def dce(): Boolean = BackendStats.timed(BackendStats.bcodeDceTimer) { + if (settings.YoptUnreachableCode) opt.LocalOpt.removeUnreachableCode(classNode) + else false + } + + dce() + } def run() { while (true) { @@ -222,8 +232,10 @@ abstract class GenBCode extends BCodeSyncAndTry { return } else { - try { addToQ3(item) } - catch { + try { + localOptimizations(item.plain) + addToQ3(item) + } catch { case ex: Throwable => ex.printStackTrace() error(s"Error while emitting ${item.plain.name}\n${ex.getMessage}") @@ -272,10 +284,13 @@ abstract class GenBCode extends BCodeSyncAndTry { * */ override def run() { + val bcodeStart = Statistics.startTimer(BackendStats.bcodeTimer) + val initStart = Statistics.startTimer(BackendStats.bcodeInitTimer) arrivalPos = 0 // just in case scalaPrimitives.init() bTypes.intializeCoreBTypes() + Statistics.stopTimer(BackendStats.bcodeInitTimer, initStart) // initBytecodeWriter invokes fullName, thus we have to run it before the typer-dependent thread is activated. bytecodeWriter = initBytecodeWriter(cleanup.getEntryPoints) @@ -287,6 +302,7 @@ abstract class GenBCode extends BCodeSyncAndTry { // closing output files. bytecodeWriter.close() + Statistics.stopTimer(BackendStats.bcodeTimer, bcodeStart) /* TODO Bytecode can be verified (now that all classfiles have been written to disk) * @@ -312,9 +328,15 @@ abstract class GenBCode extends BCodeSyncAndTry { private def buildAndSendToDisk(needsOutFolder: Boolean) { feedPipeline1() + val genStart = Statistics.startTimer(BackendStats.bcodeGenStat) (new Worker1(needsOutFolder)).run() + Statistics.stopTimer(BackendStats.bcodeGenStat, genStart) + (new Worker2).run() + + val writeStart = Statistics.startTimer(BackendStats.bcodeWriteTimer) drainQ3() + Statistics.stopTimer(BackendStats.bcodeWriteTimer, writeStart) } diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/LocalOpt.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/LocalOpt.scala new file mode 100644 index 0000000000..3acd2d6154 --- /dev/null +++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/LocalOpt.scala @@ -0,0 +1,190 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2014 LAMP/EPFL + * @author Martin Odersky + */ + +package scala.tools.nsc +package backend.jvm +package opt + +import scala.tools.asm.{Opcodes, MethodWriter, ClassWriter} +import scala.tools.asm.tree.analysis.{Analyzer, BasicValue, BasicInterpreter} +import scala.tools.asm.tree._ +import scala.collection.convert.decorateAsScala._ +import scala.collection.{ mutable => m } + +/** + * Intra-Method optimizations. + */ +object LocalOpt { + /** + * Remove unreachable instructions from all (non-abstract) methods. + * + * @param clazz The class whose methods are optimized + * @return `true` if unreachable code was elminated in some method, `false` otherwise. + */ + def removeUnreachableCode(clazz: ClassNode): Boolean = { + clazz.methods.asScala.foldLeft(false) { + case (changed, method) => removeUnreachableCode(method, clazz.name) || changed + } + } + + /** + * Remove unreachable code from a method. + * We rely on dead code elimination provided by the ASM framework, as described in the ASM User + * Guide (http://asm.ow2.org/index.html), Section 8.2.1. It runs a data flow analysis, which only + * computes Frame information for reachable instructions. Instructions for which no Frame data is + * available after the analyis are unreachable. + * + * TODO doc: it also removes empty handlers, unused local vars + * + * Returns `true` if dead code in `method` has been eliminated. + */ + private def removeUnreachableCode(method: MethodNode, ownerClassName: String): Boolean = { + if (method.instructions.size == 0) return false // fast path for abstract methods + + val codeRemoved = removeUnreachableCodeImpl(method, ownerClassName) + + // unreachable-code also removes unused local variable nodes and empty exception handlers. + // This is required for correctness: such nodes are not allowed to refer to instruction offsets + // that don't exist (because they have been eliminated). + val localsRemoved = removeUnusedLocalVariableNodes(method) + val handlersRemoved = removeEmptyExceptionHandlers(method) + + // When eliminating a handler, the catch block becomes unreachable. The recursive invocation + // removes these blocks. + // Note that invoking removeUnreachableCode*Impl* a second time is not enough: removing the dead + // catch block can render other handlers empty, which also have to be removed in turn. + if (handlersRemoved) removeUnreachableCode(method, ownerClassName) + + // assert that we can leave local variable annotations as-is + def nullOrEmpty[T](l: java.util.List[T]) = l == null || l.isEmpty + assert(nullOrEmpty(method.visibleLocalVariableAnnotations), method.visibleLocalVariableAnnotations) + assert(nullOrEmpty(method.invisibleLocalVariableAnnotations), method.invisibleLocalVariableAnnotations) + + codeRemoved || localsRemoved || handlersRemoved + } + + private def removeUnreachableCodeImpl(method: MethodNode, ownerClassName: String): Boolean = { + val initialSize = method.instructions.size + if (initialSize == 0) return false + + // The data flow analysis requires the maxLocals / maxStack fields of the method to be computed. + computeMaxLocalsMaxStack(method) + val a = new Analyzer[BasicValue](new BasicInterpreter) + a.analyze(ownerClassName, method) + val frames = a.getFrames + + var i = 0 + val itr = method.instructions.iterator() + while (itr.hasNext) { + val ins = itr.next() + // Don't remove label nodes: they might be referenced for example in a LocalVariableNode + if (frames(i) == null && !ins.isInstanceOf[LabelNode]) { + // Instruction iterators allow removing during iteration. + // Removing is O(1): instructions are doubly linked list elements. + itr.remove() + } + i += 1 + } + + method.instructions.size != initialSize + } + + /** + * Remove exception handlers that cover empty code blocks from all methods of `clazz`. + * Returns `true` if any exception handler was eliminated. + */ + def removeEmptyExceptionHandlers(clazz: ClassNode): Boolean = { + clazz.methods.asScala.foldLeft(false) { + case (changed, method) => removeEmptyExceptionHandlers(method) || changed + } + } + + /** + * Remove exception handlers that cover empty code blocks. A block is considered empty if it + * consist only of labels, frames, line numbers, nops and gotos. + * + * Note that no instructions are eliminated. + * + * @return `true` if some exception handler was eliminated. + */ + def removeEmptyExceptionHandlers(method: MethodNode): Boolean = { + /** True if there exists code between start and end. */ + def containsExecutableCode(start: AbstractInsnNode, end: LabelNode): Boolean = { + start != end && (start.getOpcode match { + // FrameNode, LabelNode and LineNumberNode have opcode == -1. + case -1 | Opcodes.NOP | Opcodes.GOTO => containsExecutableCode(start.getNext, end) + case _ => true + }) + } + + val initialNumberHandlers = method.tryCatchBlocks.size + val handlersIter = method.tryCatchBlocks.iterator() + while(handlersIter.hasNext) { + val handler = handlersIter.next() + if (!containsExecutableCode(handler.start, handler.end)) handlersIter.remove() + } + method.tryCatchBlocks.size != initialNumberHandlers + } + + /** + * Remove all non-parameter entries from the local variable table which denote variables that are + * not actually read or written. + * + * Note that each entry in the local variable table has a start, end and index. Two entries with + * the same index, but distinct start / end ranges are different variables, they may have not the + * same type or name. + * + * TODO: also re-allocate locals to occupy fewer slots after eliminating unused ones + */ + def removeUnusedLocalVariableNodes(method: MethodNode): Boolean = { + def variableIsUsed(start: AbstractInsnNode, end: LabelNode, varIndex: Int): Boolean = { + start != end && (start match { + case v: VarInsnNode => v.`var` == varIndex + case _ => variableIsUsed(start.getNext, end, varIndex) + }) + } + + val initialNumVars = method.localVariables.size + val localsIter = method.localVariables.iterator() + // The parameters and `this` (for instance methods) have the lowest indices in the local variables + // table. Note that double / long fields occupy two slots, so we sum up the sizes. Since getSize + // returns 0 for void, we have to add `max 1`. + val paramsSize = scala.tools.asm.Type.getArgumentTypes(method.desc).map(_.getSize max 1).sum + val thisSize = if ((method.access & Opcodes.ACC_STATIC) == 0) 1 else 0 + val endParamIndex = paramsSize + thisSize + while (localsIter.hasNext) { + val local = localsIter.next() + // parameters and `this` have the lowest indices, starting at 0 + val used = local.index < endParamIndex || variableIsUsed(local.start, local.end, local.index) + if (!used) + localsIter.remove() + } + method.localVariables.size == initialNumVars + } + + + /** + * In order to run an Analyzer, the maxLocals / maxStack fields need to be available. The ASM + * framework only computes these values during bytecode generation. + * + * Sicne there's currently no better way, we run a bytecode generator on the method and extract + * the computed values. This required changes to the ASM codebase: + * - the [[MethodWriter]] class was made public + * - accessors for maxLocals / maxStack were added to the MethodWriter class + * + * We could probably make this faster (and allocate less memory) by hacking the ASM framework + * more: create a subclass of MethodWriter with a /dev/null byteVector. Another option would be + * to create a separate visitor for computing those values, duplicating the functionality from the + * MethodWriter. + */ + private def computeMaxLocalsMaxStack(method: MethodNode) { + val cw = new ClassWriter(ClassWriter.COMPUTE_MAXS) + val excs = method.exceptions.asScala.toArray + val mw = cw.visitMethod(method.access, method.name, method.desc, method.signature, excs).asInstanceOf[MethodWriter] + method.accept(mw) + method.maxLocals = mw.getMaxLocals + method.maxStack = mw.getMaxStack + } +} diff --git a/src/compiler/scala/tools/nsc/settings/ScalaSettings.scala b/src/compiler/scala/tools/nsc/settings/ScalaSettings.scala index 91b03869e5..466e397dd7 100644 --- a/src/compiler/scala/tools/nsc/settings/ScalaSettings.scala +++ b/src/compiler/scala/tools/nsc/settings/ScalaSettings.scala @@ -209,9 +209,35 @@ trait ScalaSettings extends AbsScalaSettings // the current standard is "inline" but we are moving towards "method" val Ydelambdafy = ChoiceSetting ("-Ydelambdafy", "strategy", "Strategy used for translating lambdas into JVM code.", List("inline", "method"), "inline") + object YoptChoices extends MultiChoiceEnumeration { + val unreachableCode = Choice("unreachable-code", "Eliminate unreachable code") + + val lNone = Choice("l:none", "Don't enable any optimizations") + + private val defaultChoices = List(unreachableCode) + val lDefault = Choice("l:default", "Enable default optimizations: "+ defaultChoices.mkString(","), expandsTo = defaultChoices) + + private val methodChoices = List(lDefault) + val lMethod = Choice("l:method", "Intra-method optimizations: "+ methodChoices.mkString(","), expandsTo = methodChoices) + + private val projectChoices = List(lMethod) + val lProject = Choice("l:project", "Cross-method optimizations within the current project: "+ projectChoices.mkString(","), expandsTo = projectChoices) + + private val classpathChoices = List(lProject) + val lClasspath = Choice("l:classpath", "Cross-method optmizations across the entire classpath: "+ classpathChoices.mkString(","), expandsTo = classpathChoices) + } + + val Yopt = MultiChoiceSetting( + name = "-Yopt", + helpArg = "optimization", + descr = "Enable optimizations", + domain = YoptChoices) + + def YoptUnreachableCode: Boolean = !Yopt.isSetByUser || Yopt.contains(YoptChoices.unreachableCode) + private def removalIn212 = "This flag is scheduled for removal in 2.12. If you have a case where you need this flag then please report a bug." - object YstatisticsPhases extends MultiChoiceEnumeration { val parser, typer, patmat, erasure, cleanup = Value } + object YstatisticsPhases extends MultiChoiceEnumeration { val parser, typer, patmat, erasure, cleanup, jvm = Value } val Ystatistics = { val description = "Print compiler statistics for specific phases" MultiChoiceSetting( diff --git a/src/partest-extras/scala/tools/partest/ASMConverters.scala b/src/partest-extras/scala/tools/partest/ASMConverters.scala index d618e086f4..67a4e8ae01 100644 --- a/src/partest-extras/scala/tools/partest/ASMConverters.scala +++ b/src/partest-extras/scala/tools/partest/ASMConverters.scala @@ -2,70 +2,216 @@ package scala.tools.partest import scala.collection.JavaConverters._ import scala.tools.asm -import asm.tree.{ClassNode, MethodNode, InsnList} +import asm.{tree => t} /** Makes using ASM from ByteCodeTests more convenient. * * Wraps ASM instructions in case classes so that equals and toString work * for the purpose of bytecode diffing and pretty printing. */ -trait ASMConverters { - // wrap ASM's instructions so we get case class-style `equals` and `toString` - object instructions { - def fromMethod(meth: MethodNode): List[Instruction] = { - val insns = meth.instructions - val asmToScala = new AsmToScala{ def labelIndex(l: asm.tree.AbstractInsnNode) = insns.indexOf(l) } - - asmToScala.mapOver(insns.iterator.asScala.toList).asInstanceOf[List[Instruction]] +object ASMConverters { + + /** + * Transform the instructions of an ASM Method into a list of [[Instruction]]s. + */ + def instructionsFromMethod(meth: t.MethodNode): List[Instruction] = new AsmToScala(meth).instructions + + def convertMethod(meth: t.MethodNode): Method = new AsmToScala(meth).method + + implicit class RichInstructionLists(val self: List[Instruction]) extends AnyVal { + def === (other: List[Instruction]) = equivalentBytecode(self, other) + + def dropLinesFrames = self.filterNot(i => i.isInstanceOf[LineNumber] || i.isInstanceOf[FrameEntry]) + + private def referencedLabels(instruction: Instruction): Set[Instruction] = instruction match { + case Jump(op, label) => Set(label) + case LookupSwitch(op, dflt, keys, labels) => (dflt :: labels).toSet + case TableSwitch(op, min, max, dflt, labels) => (dflt :: labels).toSet + case LineNumber(line, start) => Set(start) + case _ => Set.empty } - sealed abstract class Instruction { def opcode: String } - case class Field (opcode: String, desc: String, name: String, owner: String) extends Instruction - case class Incr (opcode: String, incr: Int, `var`: Int) extends Instruction - case class Op (opcode: String) extends Instruction - case class IntOp (opcode: String, operand: Int) extends Instruction - case class Jump (opcode: String, label: Label) extends Instruction - case class Ldc (opcode: String, cst: Any) extends Instruction - case class LookupSwitch (opcode: String, dflt: Label, keys: List[Integer], labels: List[Label]) extends Instruction - case class TableSwitch (opcode: String, dflt: Label, max: Int, min: Int, labels: List[Label]) extends Instruction - case class Method (opcode: String, desc: String, name: String, owner: String) extends Instruction - case class NewArray (opcode: String, desc: String, dims: Int) extends Instruction - case class TypeOp (opcode: String, desc: String) extends Instruction - case class VarOp (opcode: String, `var`: Int) extends Instruction - case class Label (offset: Int) extends Instruction { def opcode: String = "" } - case class FrameEntry (local: List[Any], stack: List[Any]) extends Instruction { def opcode: String = "" } - case class LineNumber (line: Int, start: Label) extends Instruction { def opcode: String = "" } + def dropStaleLabels = { + val definedLabels: Set[Instruction] = self.filter(_.isInstanceOf[Label]).toSet + val usedLabels: Set[Instruction] = self.flatMap(referencedLabels)(collection.breakOut) + self.filterNot(definedLabels diff usedLabels) + } + + def dropNonOp = dropLinesFrames.dropStaleLabels + } + + sealed abstract class Instruction extends Product { + def opcode: Int + + // toString such that the first field, "opcode: Int", is printed textually. + final override def toString() = { + import scala.tools.asm.util.Printer.OPCODES + def opString(op: Int) = if (OPCODES.isDefinedAt(op)) OPCODES(op) else "?" + val printOpcode = opcode != -1 + + productPrefix + ( + if (printOpcode) Iterator(opString(opcode)) ++ productIterator.drop(1) + else productIterator + ).mkString("(", ", ", ")") + } } - abstract class AsmToScala { - import instructions._ + case class Method(instructions: List[Instruction], handlers: List[ExceptionHandler], localVars: List[LocalVariable]) + + case class Field (opcode: Int, owner: String, name: String, desc: String) extends Instruction + case class Incr (opcode: Int, `var`: Int, incr: Int) extends Instruction + case class Op (opcode: Int) extends Instruction + case class IntOp (opcode: Int, operand: Int) extends Instruction + case class Jump (opcode: Int, label: Label) extends Instruction + case class Ldc (opcode: Int, cst: Any) extends Instruction + case class LookupSwitch(opcode: Int, dflt: Label, keys: List[Int], labels: List[Label]) extends Instruction + case class TableSwitch (opcode: Int, min: Int, max: Int, dflt: Label, labels: List[Label]) extends Instruction + case class Invoke (opcode: Int, owner: String, name: String, desc: String, itf: Boolean) extends Instruction + case class NewArray (opcode: Int, desc: String, dims: Int) extends Instruction + case class TypeOp (opcode: Int, desc: String) extends Instruction + case class VarOp (opcode: Int, `var`: Int) extends Instruction + case class Label (offset: Int) extends Instruction { def opcode: Int = -1 } + case class FrameEntry (`type`: Int, local: List[Any], stack: List[Any]) extends Instruction { def opcode: Int = -1 } + case class LineNumber (line: Int, start: Label) extends Instruction { def opcode: Int = -1 } + + case class ExceptionHandler(start: Label, end: Label, handler: Label, desc: Option[String]) + case class LocalVariable(name: String, desc: String, signature: Option[String], start: Label, end: Label, index: Int) + + class AsmToScala(asmMethod: t.MethodNode) { + + def instructions: List[Instruction] = asmMethod.instructions.iterator.asScala.toList map apply + + def method: Method = Method(instructions, convertHandlers(asmMethod), convertLocalVars(asmMethod)) - def labelIndex(l: asm.tree.AbstractInsnNode): Int + private def labelIndex(l: t.LabelNode): Int = asmMethod.instructions.indexOf(l) + + private def op(i: t.AbstractInsnNode): Int = i.getOpcode - def mapOver(is: List[Any]): List[Any] = is map { - case i: asm.tree.AbstractInsnNode => apply(i) + private def lst[T](xs: java.util.List[T]): List[T] = if (xs == null) Nil else xs.asScala.toList + + // Heterogenous List[Any] is used in FrameNode: type information about locals / stack values + // are stored in a List[Any] (Integer, String or LabelNode), see Javadoc of MethodNode#visitFrame. + // Opcodes (eg Opcodes.INTEGER) and Reference types (eg "java/lang/Object") are returned unchanged, + // LabelNodes are mapped to their LabelEntry. + private def mapOverFrameTypes(is: List[Any]): List[Any] = is map { + case i: t.LabelNode => applyLabel(i) case x => x } - def op(i: asm.tree.AbstractInsnNode) = if (asm.util.Printer.OPCODES.isDefinedAt(i.getOpcode)) asm.util.Printer.OPCODES(i.getOpcode) else "?" - def lst[T](xs: java.util.List[T]): List[T] = if (xs == null) Nil else xs.asScala.toList - def apply(l: asm.tree.LabelNode): Label = this(l: asm.tree.AbstractInsnNode).asInstanceOf[Label] - def apply(x: asm.tree.AbstractInsnNode): Instruction = x match { - case i: asm.tree.FieldInsnNode => Field (op(i), i.desc: String, i.name: String, i.owner: String) - case i: asm.tree.IincInsnNode => Incr (op(i), i.incr: Int, i.`var`: Int) - case i: asm.tree.InsnNode => Op (op(i)) - case i: asm.tree.IntInsnNode => IntOp (op(i), i.operand: Int) - case i: asm.tree.JumpInsnNode => Jump (op(i), this(i.label)) - case i: asm.tree.LdcInsnNode => Ldc (op(i), i.cst: Any) - case i: asm.tree.LookupSwitchInsnNode => LookupSwitch (op(i), this(i.dflt), lst(i.keys), mapOver(lst(i.labels)).asInstanceOf[List[Label]]) - case i: asm.tree.TableSwitchInsnNode => TableSwitch (op(i), this(i.dflt), i.max: Int, i.min: Int, mapOver(lst(i.labels)).asInstanceOf[List[Label]]) - case i: asm.tree.MethodInsnNode => Method (op(i), i.desc: String, i.name: String, i.owner: String) - case i: asm.tree.MultiANewArrayInsnNode => NewArray (op(i), i.desc: String, i.dims: Int) - case i: asm.tree.TypeInsnNode => TypeOp (op(i), i.desc: String) - case i: asm.tree.VarInsnNode => VarOp (op(i), i.`var`: Int) - case i: asm.tree.LabelNode => Label (labelIndex(x)) - case i: asm.tree.FrameNode => FrameEntry (mapOver(lst(i.local)), mapOver(lst(i.stack))) - case i: asm.tree.LineNumberNode => LineNumber (i.line: Int, this(i.start): Label) + // avoids some casts + private def applyLabel(l: t.LabelNode) = this(l: t.AbstractInsnNode).asInstanceOf[Label] + + private def apply(x: t.AbstractInsnNode): Instruction = x match { + case i: t.FieldInsnNode => Field (op(i), i.owner, i.name, i.desc) + case i: t.IincInsnNode => Incr (op(i), i.`var`, i.incr) + case i: t.InsnNode => Op (op(i)) + case i: t.IntInsnNode => IntOp (op(i), i.operand) + case i: t.JumpInsnNode => Jump (op(i), applyLabel(i.label)) + case i: t.LdcInsnNode => Ldc (op(i), i.cst: Any) + case i: t.LookupSwitchInsnNode => LookupSwitch (op(i), applyLabel(i.dflt), lst(i.keys) map (x => x: Int), lst(i.labels) map applyLabel) + case i: t.TableSwitchInsnNode => TableSwitch (op(i), i.min, i.max, applyLabel(i.dflt), lst(i.labels) map applyLabel) + case i: t.MethodInsnNode => Invoke (op(i), i.owner, i.name, i.desc, i.itf) + case i: t.MultiANewArrayInsnNode => NewArray (op(i), i.desc, i.dims) + case i: t.TypeInsnNode => TypeOp (op(i), i.desc) + case i: t.VarInsnNode => VarOp (op(i), i.`var`) + case i: t.LabelNode => Label (labelIndex(i)) + case i: t.FrameNode => FrameEntry (i.`type`, mapOverFrameTypes(lst(i.local)), mapOverFrameTypes(lst(i.stack))) + case i: t.LineNumberNode => LineNumber (i.line, applyLabel(i.start)) + } + + private def convertHandlers(method: t.MethodNode): List[ExceptionHandler] = { + method.tryCatchBlocks.asScala.map(h => ExceptionHandler(applyLabel(h.start), applyLabel(h.end), applyLabel(h.handler), Option(h.`type`)))(collection.breakOut) + } + + private def convertLocalVars(method: t.MethodNode): List[LocalVariable] = { + method.localVariables.asScala.map(v => LocalVariable(v.name, v.desc, Option(v.signature), applyLabel(v.start), applyLabel(v.end), v.index))(collection.breakOut) + } + } + + import collection.mutable.{Map => MMap} + + /** + * Bytecode is equal modula local variable numbering and label numbering. + */ + def equivalentBytecode(as: List[Instruction], bs: List[Instruction], varMap: MMap[Int, Int] = MMap(), labelMap: MMap[Int, Int] = MMap()): Boolean = { + def same(v1: Int, v2: Int, m: MMap[Int, Int]) = { + if (m contains v1) m(v1) == v2 + else if (m.valuesIterator contains v2) false // v2 is already associated with some different value v1 + else { m(v1) = v2; true } + } + def sameVar(v1: Int, v2: Int) = same(v1, v2, varMap) + def sameLabel(l1: Label, l2: Label) = same(l1.offset, l2.offset, labelMap) + def sameLabels(ls1: List[Label], ls2: List[Label]) = (ls1 corresponds ls2)(sameLabel) + + def sameFrameTypes(ts1: List[Any], ts2: List[Any]) = (ts1 corresponds ts2) { + case (t1: Label, t2: Label) => sameLabel(t1, t2) + case (x, y) => x == y + } + + if (as.isEmpty) bs.isEmpty + else if (bs.isEmpty) false + else ((as.head, bs.head) match { + case (VarOp(op1, v1), VarOp(op2, v2)) => op1 == op2 && sameVar(v1, v2) + case (Incr(op1, v1, inc1), Incr(op2, v2, inc2)) => op1 == op2 && sameVar(v1, v2) && inc1 == inc2 + + case (l1 @ Label(_), l2 @ Label(_)) => sameLabel(l1, l2) + case (Jump(op1, l1), Jump(op2, l2)) => op1 == op2 && sameLabel(l1, l2) + case (LookupSwitch(op1, l1, keys1, ls1), LookupSwitch(op2, l2, keys2, ls2)) => op1 == op2 && sameLabel(l1, l2) && keys1 == keys2 && sameLabels(ls1, ls2) + case (TableSwitch(op1, min1, max1, l1, ls1), TableSwitch(op2, min2, max2, l2, ls2)) => op1 == op2 && min1 == min2 && max1 == max2 && sameLabel(l1, l2) && sameLabels(ls1, ls2) + case (LineNumber(line1, l1), LineNumber(line2, l2)) => line1 == line2 && sameLabel(l1, l2) + case (FrameEntry(tp1, loc1, stk1), FrameEntry(tp2, loc2, stk2)) => tp1 == tp2 && sameFrameTypes(loc1, loc2) && sameFrameTypes(stk1, stk2) + + // this needs to go after the above. For example, Label(1) may not equal Label(1), if before + // the left 1 was associated with another right index. + case (a, b) if a == b => true + + case _ => false + }) && equivalentBytecode(as.tail, bs.tail, varMap, labelMap) + } + + def applyToMethod(method: t.MethodNode, instructions: List[Instruction]): Unit = { + val asmLabel = createLabelNodes(instructions) + instructions.foreach(visitMethod(method, _, asmLabel)) + } + + /** + * Convert back a [[Method]] to ASM land. The code is emitted into the parameter `asmMethod`. + */ + def applyToMethod(asmMethod: t.MethodNode, method: Method): Unit = { + val asmLabel = createLabelNodes(method.instructions) + method.instructions.foreach(visitMethod(asmMethod, _, asmLabel)) + method.handlers.foreach(h => asmMethod.visitTryCatchBlock(asmLabel(h.start), asmLabel(h.end), asmLabel(h.handler), h.desc.orNull)) + method.localVars.foreach(v => asmMethod.visitLocalVariable(v.name, v.desc, v.signature.orNull, asmLabel(v.start), asmLabel(v.end), v.index)) + } + + private def createLabelNodes(instructions: List[Instruction]): Map[Label, asm.Label] = { + val labels = instructions collect { + case l: Label => l } + assert(labels.distinct == labels, s"Duplicate labels in: $labels") + labels.map(l => (l, new asm.Label())).toMap + } + + private def frameTypesToAsm(l: List[Any], asmLabel: Map[Label, asm.Label]): List[Object] = l map { + case l: Label => asmLabel(l) + case x => x.asInstanceOf[Object] + } + + private def visitMethod(method: t.MethodNode, instruction: Instruction, asmLabel: Map[Label, asm.Label]): Unit = instruction match { + case Field(op, owner, name, desc) => method.visitFieldInsn(op, owner, name, desc) + case Incr(op, vr, incr) => method.visitIincInsn(vr, incr) + case Op(op) => method.visitInsn(op) + case IntOp(op, operand) => method.visitIntInsn(op, operand) + case Jump(op, label) => method.visitJumpInsn(op, asmLabel(label)) + case Ldc(op, cst) => method.visitLdcInsn(cst) + case LookupSwitch(op, dflt, keys, labels) => method.visitLookupSwitchInsn(asmLabel(dflt), keys.toArray, (labels map asmLabel).toArray) + case TableSwitch(op, min, max, dflt, labels) => method.visitTableSwitchInsn(min, max, asmLabel(dflt), (labels map asmLabel).toArray: _*) + case Invoke(op, owner, name, desc, itf) => method.visitMethodInsn(op, owner, name, desc, itf) + case NewArray(op, desc, dims) => method.visitMultiANewArrayInsn(desc, dims) + case TypeOp(op, desc) => method.visitTypeInsn(op, desc) + case VarOp(op, vr) => method.visitVarInsn(op, vr) + case l: Label => method.visitLabel(asmLabel(l)) + case FrameEntry(tp, local, stack) => method.visitFrame(tp, local.length, frameTypesToAsm(local, asmLabel).toArray, stack.length, frameTypesToAsm(stack, asmLabel).toArray) + case LineNumber(line, start) => method.visitLineNumber(line, asmLabel(start)) } -}
\ No newline at end of file +} diff --git a/src/partest-extras/scala/tools/partest/BytecodeTest.scala b/src/partest-extras/scala/tools/partest/BytecodeTest.scala index 1e4362fcde..3261cada37 100644 --- a/src/partest-extras/scala/tools/partest/BytecodeTest.scala +++ b/src/partest-extras/scala/tools/partest/BytecodeTest.scala @@ -3,7 +3,7 @@ package scala.tools.partest import scala.tools.nsc.util.JavaClassPath import scala.collection.JavaConverters._ import scala.tools.asm.{ClassWriter, ClassReader} -import scala.tools.asm.tree.{ClassNode, MethodNode, InsnList} +import scala.tools.asm.tree._ import java.io.{FileOutputStream, FileInputStream, File => JFile, InputStream} import AsmNode._ @@ -28,8 +28,8 @@ import AsmNode._ * See test/files/jvm/bytecode-test-example for an example of bytecode test. * */ -abstract class BytecodeTest extends ASMConverters { - import instructions._ +abstract class BytecodeTest { + import ASMConverters._ /** produce the output to be compared against a checkfile */ protected def show(): Unit @@ -38,8 +38,8 @@ abstract class BytecodeTest extends ASMConverters { // asserts def sameBytecode(methA: MethodNode, methB: MethodNode) = { - val isa = instructions.fromMethod(methA) - val isb = instructions.fromMethod(methB) + val isa = instructionsFromMethod(methA) + val isb = instructionsFromMethod(methB) if (isa == isb) println("bytecode identical") else diffInstructions(isa, isb) } @@ -81,18 +81,16 @@ abstract class BytecodeTest extends ASMConverters { } } - // bytecode is equal modulo local variable numbering - def equalsModuloVar(a: Instruction, b: Instruction) = (a, b) match { - case _ if a == b => true - case (VarOp(op1, _), VarOp(op2, _)) if op1 == op2 => true - case _ => false - } - - def similarBytecode(methA: MethodNode, methB: MethodNode, similar: (Instruction, Instruction) => Boolean) = { - val isa = fromMethod(methA) - val isb = fromMethod(methB) + /** + * Compare the bytecodes of two methods. + * + * For the `similar` function, you probably want to pass [[ASMConverters.equivalentBytecode]]. + */ + def similarBytecode(methA: MethodNode, methB: MethodNode, similar: (List[Instruction], List[Instruction]) => Boolean) = { + val isa = instructionsFromMethod(methA) + val isb = instructionsFromMethod(methB) if (isa == isb) println("bytecode identical") - else if ((isa, isb).zipped.forall { case (a, b) => similar(a, b) }) println("bytecode similar") + else if (similar(isa, isb)) println("bytecode similar") else diffInstructions(isa, isb) } diff --git a/test/files/jvm/t6941/test.scala b/test/files/jvm/t6941/test.scala index 248617f71f..fceb54487f 100644 --- a/test/files/jvm/t6941/test.scala +++ b/test/files/jvm/t6941/test.scala @@ -1,4 +1,4 @@ -import scala.tools.partest.BytecodeTest +import scala.tools.partest.{BytecodeTest, ASMConverters} import scala.tools.nsc.util.JavaClassPath import java.io.InputStream @@ -10,6 +10,6 @@ import scala.collection.JavaConverters._ object Test extends BytecodeTest { def show: Unit = { val classNode = loadClassNode("SameBytecode") - similarBytecode(getMethod(classNode, "a"), getMethod(classNode, "b"), equalsModuloVar) + similarBytecode(getMethod(classNode, "a"), getMethod(classNode, "b"), ASMConverters.equivalentBytecode(_, _)) } } diff --git a/test/files/jvm/t7253/test.scala b/test/files/jvm/t7253/test.scala index 7fe08e8813..a3f1e86e65 100644 --- a/test/files/jvm/t7253/test.scala +++ b/test/files/jvm/t7253/test.scala @@ -1,4 +1,4 @@ -import scala.tools.partest.BytecodeTest +import scala.tools.partest.{BytecodeTest, ASMConverters} import scala.tools.nsc.util.JavaClassPath import java.io.InputStream @@ -8,10 +8,10 @@ import asm.tree.{ClassNode, InsnList} import scala.collection.JavaConverters._ object Test extends BytecodeTest { - import instructions._ + import ASMConverters._ def show: Unit = { - val instrBaseSeqs = Seq("ScalaClient_1", "JavaClient_1") map (name => instructions.fromMethod(getMethod(loadClassNode(name), "foo"))) + val instrBaseSeqs = Seq("ScalaClient_1", "JavaClient_1") map (name => instructionsFromMethod(getMethod(loadClassNode(name), "foo"))) val instrSeqs = instrBaseSeqs map (_ filter isInvoke) cmpInstructions(instrSeqs(0), instrSeqs(1)) } diff --git a/test/files/jvm/unreachable.flags b/test/files/jvm/unreachable.flags deleted file mode 100644 index 49f2d2c4c8..0000000000 --- a/test/files/jvm/unreachable.flags +++ /dev/null @@ -1 +0,0 @@ --Ybackend:GenASM diff --git a/test/files/run/nothingTypeDce.flags b/test/files/run/nothingTypeDce.flags new file mode 100644 index 0000000000..d85321ca0e --- /dev/null +++ b/test/files/run/nothingTypeDce.flags @@ -0,0 +1 @@ +-target:jvm-1.6 -Ybackend:GenBCode -Yopt:unreachable-code diff --git a/test/files/run/nothingTypeDce.scala b/test/files/run/nothingTypeDce.scala new file mode 100644 index 0000000000..5f3692fd33 --- /dev/null +++ b/test/files/run/nothingTypeDce.scala @@ -0,0 +1,63 @@ +// See comment in BCodeBodyBuilder + +// -target:jvm-1.6 -Ybackend:GenBCode -Yopt:unreachable-code +// target enables stack map frames generation + +class C { + // can't just emit a call to ???, that returns value of type Nothing$ (not Int). + def f1: Int = ??? + + def f2: Int = throw new Error("") + + def f3(x: Boolean) = { + var y = 0 + // cannot assign an object of type Nothing$ to Int + if (x) y = ??? + else y = 1 + y + } + + def f4(x: Boolean) = { + var y = 0 + // tests that whatever is emitted after the throw is valid (what? depends on opts, presence of stack map frames) + if (x) y = throw new Error("") + else y = 1 + y + } + + def f5(x: Boolean) = { + // stack heights need to be the same. ??? looks to the jvm like returning a value of + // type Nothing$, need to drop or throw it. + println( + if (x) { ???; 10 } + else 20 + ) + } + + def f6(x: Boolean) = { + println( + if (x) { throw new Error(""); 10 } + else 20 + ) + } + + def f7(x: Boolean) = { + println( + if (x) throw new Error("") + else 20 + ) + } + + def f8(x: Boolean) = { + println( + if (x) throw new Error("") + else 20 + ) + } +} + +object Test extends App { + // creating an instance is enough to trigger bytecode verification for all methods, + // no need to invoke the methods. + new C() +} diff --git a/test/files/run/nothingTypeNoFramesNoDce.check b/test/files/run/nothingTypeNoFramesNoDce.check new file mode 100644 index 0000000000..b1d08b45ff --- /dev/null +++ b/test/files/run/nothingTypeNoFramesNoDce.check @@ -0,0 +1 @@ +warning: -target:jvm-1.5 is deprecated: use target for Java 1.6 or above. diff --git a/test/files/run/nothingTypeNoFramesNoDce.flags b/test/files/run/nothingTypeNoFramesNoDce.flags new file mode 100644 index 0000000000..a035c86179 --- /dev/null +++ b/test/files/run/nothingTypeNoFramesNoDce.flags @@ -0,0 +1 @@ +-target:jvm-1.5 -Ybackend:GenBCode -Yopt:l:none -deprecation diff --git a/test/files/run/nothingTypeNoFramesNoDce.scala b/test/files/run/nothingTypeNoFramesNoDce.scala new file mode 100644 index 0000000000..3d1298303a --- /dev/null +++ b/test/files/run/nothingTypeNoFramesNoDce.scala @@ -0,0 +1,61 @@ +// See comment in BCodeBodyBuilder + +// -target:jvm-1.5 -Ybackend:GenBCode -Yopt:l:none +// target disables stack map frame generation. in this mode, the ClssWriter just emits dead code as is. + +class C { + // can't just emit a call to ???, that returns value of type Nothing$ (not Int). + def f1: Int = ??? + + def f2: Int = throw new Error("") + + def f3(x: Boolean) = { + var y = 0 + // cannot assign an object of type Nothing$ to Int + if (x) y = ??? + else y = 1 + y + } + + def f4(x: Boolean) = { + var y = 0 + // tests that whatever is emitted after the throw is valid (what? depends on opts, presence of stack map frames) + if (x) y = throw new Error("") + else y = 1 + y + } + + def f5(x: Boolean) = { + // stack heights need to be the smae. ??? looks to the jvm like returning a value of + // type Nothing$, need to drop or throw it. + println( + if (x) { ???; 10 } + else 20 + ) + } + + def f6(x: Boolean) = { + println( + if (x) { throw new Error(""); 10 } + else 20 + ) + } + + def f7(x: Boolean) = { + println( + if (x) throw new Error("") + else 20 + ) + } + + def f8(x: Boolean) = { + println( + if (x) throw new Error("") + else 20 + ) + } +} + +object Test extends App { + new C() +} diff --git a/test/files/run/nothingTypeNoOpt.flags b/test/files/run/nothingTypeNoOpt.flags new file mode 100644 index 0000000000..b3b518051b --- /dev/null +++ b/test/files/run/nothingTypeNoOpt.flags @@ -0,0 +1 @@ +-target:jvm-1.6 -Ybackend:GenBCode -Yopt:l:none diff --git a/test/files/run/nothingTypeNoOpt.scala b/test/files/run/nothingTypeNoOpt.scala new file mode 100644 index 0000000000..5c5a20fa3b --- /dev/null +++ b/test/files/run/nothingTypeNoOpt.scala @@ -0,0 +1,61 @@ +// See comment in BCodeBodyBuilder + +// -target:jvm-1.6 -Ybackend:GenBCode -Yopt:l:none +// target enables stack map frame generation + +class C { + // can't just emit a call to ???, that returns value of type Nothing$ (not Int). + def f1: Int = ??? + + def f2: Int = throw new Error("") + + def f3(x: Boolean) = { + var y = 0 + // cannot assign an object of type Nothing$ to Int + if (x) y = ??? + else y = 1 + y + } + + def f4(x: Boolean) = { + var y = 0 + // tests that whatever is emitted after the throw is valid (what? depends on opts, presence of stack map frames) + if (x) y = throw new Error("") + else y = 1 + y + } + + def f5(x: Boolean) = { + // stack heights need to be the smae. ??? looks to the jvm like returning a value of + // type Nothing$, need to drop or throw it. + println( + if (x) { ???; 10 } + else 20 + ) + } + + def f6(x: Boolean) = { + println( + if (x) { throw new Error(""); 10 } + else 20 + ) + } + + def f7(x: Boolean) = { + println( + if (x) throw new Error("") + else 20 + ) + } + + def f8(x: Boolean) = { + println( + if (x) throw new Error("") + else 20 + ) + } +} + +object Test extends App { + new C() +} diff --git a/test/junit/scala/tools/nsc/backend/jvm/BTypesTest.scala b/test/junit/scala/tools/nsc/backend/jvm/BTypesTest.scala index cb7e7050b0..221aad6536 100644 --- a/test/junit/scala/tools/nsc/backend/jvm/BTypesTest.scala +++ b/test/junit/scala/tools/nsc/backend/jvm/BTypesTest.scala @@ -1,7 +1,6 @@ package scala.tools.nsc package backend.jvm -import scala.tools.testing.AssertUtil._ import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.junit.Test diff --git a/test/junit/scala/tools/nsc/backend/jvm/CodeGenTools.scala b/test/junit/scala/tools/nsc/backend/jvm/CodeGenTools.scala new file mode 100644 index 0000000000..15bc1f427d --- /dev/null +++ b/test/junit/scala/tools/nsc/backend/jvm/CodeGenTools.scala @@ -0,0 +1,79 @@ +package scala.tools.nsc.backend.jvm + +import org.junit.Assert._ + +import scala.reflect.internal.util.BatchSourceFile +import scala.reflect.io.VirtualDirectory +import scala.tools.asm.Opcodes +import scala.tools.asm.tree.{AbstractInsnNode, LabelNode, ClassNode, MethodNode} +import scala.tools.cmd.CommandLineParser +import scala.tools.nsc.{Settings, Global} +import scala.tools.partest.ASMConverters +import scala.collection.JavaConverters._ + +object CodeGenTools { + import ASMConverters._ + + def genMethod( flags: Int = Opcodes.ACC_PUBLIC, + name: String = "m", + descriptor: String = "()V", + genericSignature: String = null, + throwsExceptions: Array[String] = null, + handlers: List[ExceptionHandler] = Nil, + localVars: List[LocalVariable] = Nil)(body: Instruction*): MethodNode = { + val node = new MethodNode(flags, name, descriptor, genericSignature, throwsExceptions) + applyToMethod(node, Method(body.toList, handlers, localVars)) + node + } + + def wrapInClass(method: MethodNode): ClassNode = { + val cls = new ClassNode() + cls.visit(Opcodes.V1_6, Opcodes.ACC_PUBLIC, "C", null, "java/lang/Object", null) + cls.methods.add(method) + cls + } + + private def resetOutput(compiler: Global): Unit = { + compiler.settings.outputDirs.setSingleOutput(new VirtualDirectory("(memory)", None)) + } + + def newCompiler(defaultArgs: String = "-usejavacp", extraArgs: String = ""): Global = { + val settings = new Settings() + val args = (CommandLineParser tokenize defaultArgs) ++ (CommandLineParser tokenize extraArgs) + settings.processArguments(args, processAll = true) + val compiler = new Global(settings) + resetOutput(compiler) + compiler + } + + def compile(compiler: Global)(code: String): List[(String, Array[Byte])] = { + compiler.reporter.reset() + resetOutput(compiler) + val run = new compiler.Run() + run.compileSources(List(new BatchSourceFile("unitTestSource.scala", code))) + val outDir = compiler.settings.outputDirs.getSingleOutput.get + (for (f <- outDir.iterator if !f.isDirectory) yield (f.name, f.toByteArray)).toList + } + + def compileClasses(compiler: Global)(code: String): List[ClassNode] = { + compile(compiler)(code).map(p => AsmUtils.readClass(p._2)).sortBy(_.name) + } + + def compileMethods(compiler: Global)(code: String): List[MethodNode] = { + compileClasses(compiler)(s"class C { $code }").head.methods.asScala.toList.filterNot(_.name == "<init>") + } + + def singleMethodInstructions(compiler: Global)(code: String): List[Instruction] = { + val List(m) = compileMethods(compiler)(code) + instructionsFromMethod(m) + } + + def singleMethod(compiler: Global)(code: String): Method = { + val List(m) = compileMethods(compiler)(code) + convertMethod(m) + } + + def assertSameCode(actual: List[Instruction], expected: List[Instruction]): Unit = { + assertTrue(s"\nExpected: $expected\nActual : $actual", actual === expected) + } +} diff --git a/test/junit/scala/tools/nsc/backend/jvm/DirectCompileTest.scala b/test/junit/scala/tools/nsc/backend/jvm/DirectCompileTest.scala new file mode 100644 index 0000000000..2fb5bb8052 --- /dev/null +++ b/test/junit/scala/tools/nsc/backend/jvm/DirectCompileTest.scala @@ -0,0 +1,81 @@ +package scala.tools.nsc.backend.jvm + +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.junit.Assert._ +import CodeGenTools._ +import scala.tools.asm.Opcodes._ +import scala.tools.partest.ASMConverters._ + +@RunWith(classOf[JUnit4]) +class DirectCompileTest { + val compiler = newCompiler(extraArgs = "-Ybackend:GenBCode") + + @Test + def testCompile(): Unit = { + val List(("C.class", bytes)) = compile(compiler)( + """ + |class C { + | def f = 1 + |} + """.stripMargin) + def s(i: Int, n: Int) = (bytes(i) & 0xff) << n + assertTrue((s(0, 24) | s(1, 16) | s(2, 8) | s(3, 0)) == 0xcafebabe) // mocha java latte machiatto surpreme dark roasted espresso + } + + @Test + def testCompileClasses(): Unit = { + val List(cClass, cModuleClass) = compileClasses(compiler)( + """ + |class C + |object C + """.stripMargin) + + assertTrue(cClass.name == "C") + assertTrue(cModuleClass.name == "C$") + + val List(dMirror, dModuleClass) = compileClasses(compiler)( + """ + |object D + """.stripMargin) + + assertTrue(dMirror.name == "D") + assertTrue(dModuleClass.name == "D$") + } + + @Test + def testCompileMethods(): Unit = { + val List(f, g) = compileMethods(compiler)( + """ + |def f = 10 + |def g = f + """.stripMargin) + assertTrue(f.name == "f") + assertTrue(g.name == "g") + + assertTrue(instructionsFromMethod(f).dropNonOp === + List(IntOp(BIPUSH, 10), Op(IRETURN))) + + assertTrue(instructionsFromMethod(g).dropNonOp === + List(VarOp(ALOAD, 0), Invoke(INVOKEVIRTUAL, "C", "f", "()I", false), Op(IRETURN))) + } + + @Test + def testDropNonOpAliveLabels(): Unit = { + val List(f) = compileMethods(compiler)("""def f(x: Int) = if (x == 0) "a" else "b"""") + assertTrue(instructionsFromMethod(f).dropNonOp === List( + VarOp(ILOAD, 1), + Op(ICONST_0), + Jump(IF_ICMPEQ, Label(6)), + Jump(GOTO, Label(10)), + Label(6), + Ldc(LDC, "a"), + Jump(GOTO, Label(13)), + Label(10), + Ldc(LDC, "b"), + Label(13), + Op(ARETURN) + )) + } +} diff --git a/test/junit/scala/tools/nsc/backend/jvm/opt/EmptyExceptionHandlersTest.scala b/test/junit/scala/tools/nsc/backend/jvm/opt/EmptyExceptionHandlersTest.scala new file mode 100644 index 0000000000..57fa1a7b66 --- /dev/null +++ b/test/junit/scala/tools/nsc/backend/jvm/opt/EmptyExceptionHandlersTest.scala @@ -0,0 +1,92 @@ +package scala.tools.nsc +package backend.jvm +package opt + +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.junit.Test +import scala.tools.asm.Opcodes._ +import org.junit.Assert._ + +import CodeGenTools._ +import scala.tools.partest.ASMConverters +import ASMConverters._ + +@RunWith(classOf[JUnit4]) +class EmptyExceptionHandlersTest { + + val exceptionDescriptor = "java/lang/Exception" + + @Test + def eliminateEmpty(): Unit = { + val handlers = List(ExceptionHandler(Label(1), Label(2), Label(2), Some(exceptionDescriptor))) + val asmMethod = genMethod(handlers = handlers)( + Label(1), + Label(2), + Op(RETURN) + ) + assertTrue(convertMethod(asmMethod).handlers.length == 1) + LocalOpt.removeEmptyExceptionHandlers(asmMethod) + assertTrue(convertMethod(asmMethod).handlers.isEmpty) + } + + @Test + def eliminateHandlersGuardingNops(): Unit = { + val handlers = List(ExceptionHandler(Label(1), Label(2), Label(2), Some(exceptionDescriptor))) + val asmMethod = genMethod(handlers = handlers)( + Label(1), // nops only + Op(NOP), + Op(NOP), + Jump(GOTO, Label(3)), + Op(NOP), + Label(3), + Op(NOP), + Jump(GOTO, Label(4)), + + Label(2), // handler + Op(ACONST_NULL), + Op(ATHROW), + + Label(4), // return + Op(RETURN) + ) + assertTrue(convertMethod(asmMethod).handlers.length == 1) + LocalOpt.removeEmptyExceptionHandlers(asmMethod) + assertTrue(convertMethod(asmMethod).handlers.isEmpty) + } + + val noOptCompiler = newCompiler(extraArgs = "-Ybackend:GenBCode -Yopt:l:none") + val dceCompiler = newCompiler(extraArgs = "-Ybackend:GenBCode -Yopt:unreachable-code") + + @Test + def eliminateUnreachableHandler(): Unit = { + val code = "def f: Unit = try { } catch { case _: Exception => println(0) }; println(1)" + + assertTrue(singleMethod(noOptCompiler)(code).handlers.length == 1) + val optMethod = singleMethod(dceCompiler)(code) + assertTrue(optMethod.handlers.isEmpty) + + val code2 = + """def f: Unit = { + | println(0) + | return + | try { throw new Exception("") } // removed by dce, so handler will be removed as well + | catch { case _: Exception => println(1) } + | println(2) + |}""".stripMargin + + assertTrue(singleMethod(dceCompiler)(code2).handlers.isEmpty) + } + + @Test + def keepAliveHandlers(): Unit = { + val code = + """def f: Int = { + | println(0) + | try { 1 } + | catch { case _: Exception => 2 } + |}""".stripMargin + + assertTrue(singleMethod(dceCompiler)(code).handlers.length == 1) + } +} diff --git a/test/junit/scala/tools/nsc/backend/jvm/opt/UnreachableCodeTest.scala b/test/junit/scala/tools/nsc/backend/jvm/opt/UnreachableCodeTest.scala new file mode 100644 index 0000000000..a3bd7ae6fe --- /dev/null +++ b/test/junit/scala/tools/nsc/backend/jvm/opt/UnreachableCodeTest.scala @@ -0,0 +1,217 @@ +package scala.tools.nsc +package backend.jvm +package opt + +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.junit.Test +import scala.tools.asm.Opcodes._ +import org.junit.Assert._ + +import scala.tools.testing.AssertUtil._ + +import CodeGenTools._ +import scala.tools.partest.ASMConverters +import ASMConverters._ + +@RunWith(classOf[JUnit4]) +class UnreachableCodeTest { + import UnreachableCodeTest._ + + // jvm-1.6 enables emitting stack map frames, which impacts the code generation wrt dead basic blocks, + // see comment in BCodeBodyBuilder + val dceCompiler = newCompiler(extraArgs = "-target:jvm-1.6 -Ybackend:GenBCode -Yopt:unreachable-code") + val noOptCompiler = newCompiler(extraArgs = "-target:jvm-1.6 -Ybackend:GenBCode -Yopt:l:none") + + // jvm-1.5 disables computing stack map frames, and it emits dead code as-is. + val noOptNoFramesCompiler = newCompiler(extraArgs = "-target:jvm-1.5 -Ybackend:GenBCode -Yopt:l:none") + + @Test + def basicElimination(): Unit = { + assertEliminateDead( + Op(ACONST_NULL), + Op(ATHROW), + Op(RETURN).dead + ) + + assertEliminateDead( + Op(RETURN) + ) + + assertEliminateDead( + Op(RETURN), + Op(ACONST_NULL).dead, + Op(ATHROW).dead + ) + } + + @Test + def eliminateNop(): Unit = { + assertEliminateDead( + // not dead, since visited by data flow analysis. need a different opt to eliminate it. + Op(NOP), + Op(RETURN), + Op(NOP).dead + ) + } + + @Test + def eliminateBranchOver(): Unit = { + assertEliminateDead( + Jump(GOTO, Label(1)), + Op(ACONST_NULL).dead, + Op(ATHROW).dead, + Label(1), + Op(RETURN) + ) + + assertEliminateDead( + Jump(GOTO, Label(1)), + Label(1), + Op(RETURN) + ) + } + + @Test + def deadLabelsRemain(): Unit = { + assertEliminateDead( + Op(RETURN), + Jump(GOTO, Label(1)).dead, + // not dead - labels may be referenced from other places in a classfile (eg exceptions table). + // will need a different opt to get rid of them + Label(1) + ) + } + + @Test + def pushPopNotEliminated(): Unit = { + assertEliminateDead( + // not dead, visited by data flow analysis. + Op(ACONST_NULL), + Op(POP), + Op(RETURN) + ) + } + + @Test + def nullnessNotConsidered(): Unit = { + assertEliminateDead( + Op(ACONST_NULL), + Jump(IFNULL, Label(1)), + Op(RETURN), // not dead + Label(1), + Op(RETURN) + ) + } + + @Test + def basicEliminationCompiler(): Unit = { + val code = "def f: Int = { return 1; 2 }" + val withDce = singleMethodInstructions(dceCompiler)(code) + assertSameCode(withDce.dropNonOp, List(Op(ICONST_1), Op(IRETURN))) + + val noDce = singleMethodInstructions(noOptCompiler)(code) + + // The emitted code is ICONST_1, IRETURN, ICONST_2, IRETURN. The latter two are dead. + // + // GenBCode puts the last IRETURN into a new basic block: it emits a label before the second + // IRETURN. This is an implementation detail, it may change; it affects the outcome of this test. + // + // During classfile writing with COMPUTE_FAMES (-target:jvm-1.6 or larger), the ClassfileWriter + // puts the ICONST_2 into a new basic block, because the preceding operation (IRETURN) ends + // the current block. We get something like + // + // L1: ICONST_1; IRETURN + // L2: ICONST_2 << dead + // L3: IRETURN << dead + // + // Finally, instructions in the dead basic blocks are replaced by ATHROW, as explained in + // a comment in BCodeBodyBuilder. + assertSameCode(noDce.dropNonOp, List(Op(ICONST_1), Op(IRETURN), Op(ATHROW), Op(ATHROW))) + + // when NOT computing stack map frames, ASM's ClassWriter does not replace dead code by NOP/ATHROW + val noDceNoFrames = singleMethodInstructions(noOptNoFramesCompiler)(code) + assertSameCode(noDceNoFrames.dropNonOp, List(Op(ICONST_1), Op(IRETURN), Op(ICONST_2), Op(IRETURN))) + } + + @Test + def eliminateDeadCatchBlocks(): Unit = { + val code = "def f: Int = { return 0; try { 1 } catch { case _: Exception => 2 } }" + assertSameCode(singleMethodInstructions(dceCompiler)(code).dropNonOp, + List(Op(ICONST_0), Op(IRETURN))) + + val code2 = "def f: Unit = { try { } catch { case _: Exception => () }; () }" + // DCE only removes dead basic blocks, but not NOPs, and also not useless jumps + assertSameCode(singleMethodInstructions(dceCompiler)(code2).dropNonOp, + List(Op(NOP), Jump(GOTO, Label(33)), Label(33), Op(RETURN))) + + val code3 = "def f: Unit = { try { } catch { case _: Exception => try { } catch { case _: Exception => () } }; () }" + assertSameCode(singleMethodInstructions(dceCompiler)(code3).dropNonOp, + List(Op(NOP), Jump(GOTO, Label(33)), Label(33), Op(RETURN))) + + val code4 = "def f: Unit = { try { try { } catch { case _: Exception => () } } catch { case _: Exception => () }; () }" + assertSameCode(singleMethodInstructions(dceCompiler)(code4).dropNonOp, + List(Op(NOP), Jump(GOTO, Label(4)), Label(4), Jump(GOTO, Label(7)), Label(7), Op(RETURN))) + } + + @Test // test the dce-testing tools + def metaTest(): Unit = { + assertEliminateDead() // no instructions + + assertThrows[AssertionError]( + assertEliminateDead(Op(RETURN).dead), + _.contains("Expected: List()\nActual : List(Op(RETURN))") + ) + + assertThrows[AssertionError]( + assertEliminateDead(Op(RETURN), Op(RETURN)), + _.contains("Expected: List(Op(RETURN), Op(RETURN))\nActual : List(Op(RETURN))") + ) + } + + @Test + def bytecodeEquivalence: Unit = { + assertTrue(List(VarOp(ILOAD, 1)) === + List(VarOp(ILOAD, 2))) + assertTrue(List(VarOp(ILOAD, 1), VarOp(ISTORE, 1)) === + List(VarOp(ILOAD, 2), VarOp(ISTORE, 2))) + + // the first Op will associate 1->2, then the 2->2 will fail + assertFalse(List(VarOp(ILOAD, 1), VarOp(ISTORE, 2)) === + List(VarOp(ILOAD, 2), VarOp(ISTORE, 2))) + + // will associate 1->2 and 2->1, which is OK + assertTrue(List(VarOp(ILOAD, 1), VarOp(ISTORE, 2)) === + List(VarOp(ILOAD, 2), VarOp(ISTORE, 1))) + + assertTrue(List(Label(1), Label(2), Label(1)) === + List(Label(2), Label(4), Label(2))) + assertTrue(List(LineNumber(1, Label(1)), Label(1)) === + List(LineNumber(1, Label(3)), Label(3))) + assertFalse(List(LineNumber(1, Label(1)), Label(1)) === + List(LineNumber(1, Label(3)), Label(1))) + + assertTrue(List(TableSwitch(TABLESWITCH, 1, 3, Label(4), List(Label(5), Label(6))), Label(4), Label(5), Label(6)) === + List(TableSwitch(TABLESWITCH, 1, 3, Label(9), List(Label(3), Label(4))), Label(9), Label(3), Label(4))) + + assertTrue(List(FrameEntry(F_FULL, List(INTEGER, DOUBLE, Label(3)), List("java/lang/Object", Label(4))), Label(3), Label(4)) === + List(FrameEntry(F_FULL, List(INTEGER, DOUBLE, Label(1)), List("java/lang/Object", Label(3))), Label(1), Label(3))) + } +} + +object UnreachableCodeTest { + import scala.language.implicitConversions + implicit def aliveInstruction(ins: Instruction): (Instruction, Boolean) = (ins, true) + + implicit class MortalInstruction(val ins: Instruction) extends AnyVal { + def dead: (Instruction, Boolean) = (ins, false) + } + + def assertEliminateDead(code: (Instruction, Boolean)*): Unit = { + val cls = wrapInClass(genMethod()(code.map(_._1): _*)) + LocalOpt.removeUnreachableCode(cls) + val nonEliminated = instructionsFromMethod(cls.methods.get(0)) + val expectedLive = code.filter(_._2).map(_._1).toList + assertSameCode(nonEliminated, expectedLive) + } +} diff --git a/test/junit/scala/tools/nsc/backend/jvm/opt/UnusedLocalVariablesTest.scala b/test/junit/scala/tools/nsc/backend/jvm/opt/UnusedLocalVariablesTest.scala new file mode 100644 index 0000000000..24a1f9d1c1 --- /dev/null +++ b/test/junit/scala/tools/nsc/backend/jvm/opt/UnusedLocalVariablesTest.scala @@ -0,0 +1,87 @@ +package scala.tools.nsc +package backend.jvm +package opt + +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.junit.Test +import scala.tools.asm.Opcodes._ +import org.junit.Assert._ +import scala.collection.JavaConverters._ + +import CodeGenTools._ +import scala.tools.partest.ASMConverters +import ASMConverters._ + +@RunWith(classOf[JUnit4]) +class UnusedLocalVariablesTest { + val dceCompiler = newCompiler(extraArgs = "-Ybackend:GenBCode -Yopt:unreachable-code") + + @Test + def removeUnusedVar(): Unit = { + val code = """def f(a: Long, b: String, c: Double): Unit = { return; var x = a; var y = x + 10 }""" + assertLocalVarCount(code, 4) // `this, a, b, c` + + val code2 = """def f(): Unit = { var x = if (true) return else () }""" + assertLocalVarCount(code2, 1) // x is eliminated, constant folding in scalac removes the if + + val code3 = """def f: Unit = return""" // paramless method + assertLocalVarCount(code3, 1) // this + } + + @Test + def keepUsedVar(): Unit = { + val code = """def f(a: Long, b: String, c: Double): Unit = { val x = 10 + a; val y = x + 10 }""" + assertLocalVarCount(code, 6) + + val code2 = """def f(a: Long): Unit = { var x = if (a == 0l) return else () }""" + assertLocalVarCount(code2, 3) // remains + } + + @Test + def constructorLocals(): Unit = { + val code = """class C { + | def this(a: Int) = { + | this() + | throw new Exception("") + | val y = 0 + | } + |} + |""".stripMargin + val cls = compileClasses(dceCompiler)(code).head + val m = convertMethod(cls.methods.asScala.toList.find(_.desc == "(I)V").get) + assertTrue(m.localVars.length == 2) // this, a, but not y + + + val code2 = + """class C { + | { + | throw new Exception("") + | val a = 0 + | } + |} + | + |object C { + | { + | throw new Exception("") + | val b = 1 + | } + |} + """.stripMargin + + val clss2 = compileClasses(dceCompiler)(code2) + val cls2 = clss2.find(_.name == "C").get + val companion2 = clss2.find(_.name == "C$").get + + val clsConstr = convertMethod(cls2.methods.asScala.toList.find(_.name == "<init>").get) + val companionConstr = convertMethod(companion2.methods.asScala.toList.find(_.name == "<init>").get) + + assertTrue(clsConstr.localVars.length == 1) // this + assertTrue(companionConstr.localVars.length == 1) // this + } + + def assertLocalVarCount(code: String, numVars: Int): Unit = { + assertTrue(singleMethod(dceCompiler)(code).localVars.length == numVars) + } + +} diff --git a/test/junit/scala/tools/testing/AssertThrowsTest.scala b/test/junit/scala/tools/testing/AssertThrowsTest.scala index a70519e63c..d91e450bac 100644 --- a/test/junit/scala/tools/testing/AssertThrowsTest.scala +++ b/test/junit/scala/tools/testing/AssertThrowsTest.scala @@ -31,4 +31,13 @@ class AssertThrowsTest { } }) -}
\ No newline at end of file + @Test + def errorIfNoThrow: Unit = { + try { + assertThrows[Foo] { () } + } catch { + case e: AssertionError => return + } + assert(false, "assertThrows should error if the tested expression does not throw anything") + } +} |