diff options
7 files changed, 74 insertions, 21 deletions
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/AsmUtils.scala b/src/compiler/scala/tools/nsc/backend/jvm/AsmUtils.scala index d3f09217cd..0df1b2029d 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/AsmUtils.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/AsmUtils.scala @@ -7,8 +7,8 @@ package scala.tools.nsc.backend.jvm import scala.tools.asm.tree.{InsnList, AbstractInsnNode, ClassNode, MethodNode} import java.io.{StringWriter, PrintWriter} -import scala.tools.asm.util.{TraceClassVisitor, TraceMethodVisitor, Textifier} -import scala.tools.asm.{Attribute, ClassReader} +import scala.tools.asm.util.{CheckClassAdapter, TraceClassVisitor, TraceMethodVisitor, Textifier} +import scala.tools.asm.{ClassWriter, Attribute, ClassReader} import scala.collection.convert.decorateAsScala._ import scala.tools.nsc.backend.jvm.opt.InlineInfoAttributePrototype @@ -106,4 +106,18 @@ object AsmUtils { * Returns a human-readable representation of the given instruction sequence. */ def textify(insns: InsnList): String = textify(insns.iterator().asScala) + + /** + * Run ASM's CheckClassAdapter over a class. Returns None if no problem is found, otherwise + * Some(msg) with the verifier's error message. + */ + def checkClass(classNode: ClassNode): Option[String] = { + val cw = new ClassWriter(ClassWriter.COMPUTE_MAXS) + classNode.accept(cw) + val sw = new StringWriter() + val pw = new PrintWriter(sw) + CheckClassAdapter.verify(new ClassReader(cw.toByteArray), false, pw) + val res = sw.toString + if (res.isEmpty) None else Some(res) + } } diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala b/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala index 51a17b7fe4..872d1cc522 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BTypes.scala @@ -16,11 +16,11 @@ import opt.OptimizerReporting._ import scala.collection.convert.decorateAsScala._ /** - * The BTypes component defines The BType class hierarchy. BTypes encapsulate all type information + * The BTypes component defines The BType class hierarchy. A BType stores all type information * that is required after building the ASM nodes. This includes optimizations, generation of * InnerClass attributes and generation of stack map frames. * - * This representation is immutable and independent of the compiler data structures, hence it can + * The representation is immutable and independent of the compiler data structures, hence it can * be queried by concurrent threads. */ abstract class BTypes { diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/BytecodeUtils.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/BytecodeUtils.scala index e221eef636..d2658bcd2a 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/opt/BytecodeUtils.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/BytecodeUtils.scala @@ -296,10 +296,13 @@ object BytecodeUtils { )).toList } - class BasicAnalyzer(methodNode: MethodNode, classInternalName: InternalName) { - val analyzer = new Analyzer(new BasicInterpreter) + /** + * A wrapper to make ASM's Analyzer a bit easier to use. + */ + class AsmAnalyzer[V <: Value](methodNode: MethodNode, classInternalName: InternalName, interpreter: Interpreter[V] = new BasicInterpreter) { + val analyzer = new Analyzer(interpreter) analyzer.analyze(classInternalName, methodNode) - def frameAt(instruction: AbstractInsnNode): Frame[BasicValue] = analyzer.getFrames()(methodNode.instructions.indexOf(instruction)) + def frameAt(instruction: AbstractInsnNode): Frame[V] = analyzer.getFrames()(methodNode.instructions.indexOf(instruction)) } implicit class `frame extensions`[V <: Value](val frame: Frame[V]) extends AnyVal { diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/CallGraph.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/CallGraph.scala index 020db738e8..18b95184e5 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/opt/CallGraph.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/CallGraph.scala @@ -10,7 +10,7 @@ package opt import scala.tools.asm.tree._ import scala.collection.convert.decorateAsScala._ import scala.tools.nsc.backend.jvm.BTypes.InternalName -import scala.tools.nsc.backend.jvm.opt.BytecodeUtils.BasicAnalyzer +import scala.tools.nsc.backend.jvm.opt.BytecodeUtils.AsmAnalyzer import ByteCodeRepository.{Source, CompilationUnit} class CallGraph[BT <: BTypes](val btypes: BT) { @@ -68,7 +68,7 @@ class CallGraph[BT <: BTypes](val btypes: BT) { // TODO: for now we run a basic analyzer to get the stack height at the call site. // once we run a more elaborate analyzer (types, nullness), we can get the stack height out of there. - val analyzer = new BasicAnalyzer(methodNode, definingClass.internalName) + val analyzer = new AsmAnalyzer(methodNode, definingClass.internalName) methodNode.instructions.iterator.asScala.collect({ case call: MethodInsnNode => diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala index 970cc6803a..b2459862ea 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala @@ -17,6 +17,7 @@ import AsmUtils._ import BytecodeUtils._ import OptimizerReporting._ import collection.mutable +import scala.tools.asm.tree.analysis.{SourceInterpreter, Analyzer} class Inliner[BT <: BTypes](val btypes: BT) { import btypes._ @@ -92,19 +93,38 @@ class Inliner[BT <: BTypes](val btypes: BT) { val traitMethodArgumentTypes = asm.Type.getArgumentTypes(callee.desc) - val selfParamTypeName = calleeDeclarationClass.info.inlineInfo.traitImplClassSelfType.getOrElse(calleeDeclarationClass.internalName) - val selfParamType = asm.Type.getObjectType(selfParamTypeName) + val selfParamType = calleeDeclarationClass.info.inlineInfo.traitImplClassSelfType match { + case Some(internalName) => classBTypeFromParsedClassfile(internalName) + case None => Some(calleeDeclarationClass) + } - val implClassMethodDescriptor = asm.Type.getMethodDescriptor(asm.Type.getReturnType(callee.desc), selfParamType +: traitMethodArgumentTypes: _*) val implClassInternalName = calleeDeclarationClass.internalName + "$class" // The rewrite reading the implementation class and the implementation method from the bytecode // repository. If either of the two fails, the rewrite is not performed. for { - // TODO: inline warnings if impl class or method cannot be found - (implClassMethod, _) <- byteCodeRepository.methodNode(implClassInternalName, callee.name, implClassMethodDescriptor) - implClassBType <- classBTypeFromParsedClassfile(implClassInternalName) + // TODO: inline warnings if selfClassType, impl class or impl method cannot be found + selfType <- selfParamType + implClassMethodDescriptor = asm.Type.getMethodDescriptor(asm.Type.getReturnType(callee.desc), selfType.toASMType +: traitMethodArgumentTypes: _*) + (implClassMethod, _) <- byteCodeRepository.methodNode(implClassInternalName, callee.name, implClassMethodDescriptor) + implClassBType <- classBTypeFromParsedClassfile(implClassInternalName) } yield { + + // The self parameter type may be incompatible with the trait type. + // trait T { self: S => def foo = 1 } + // The $self parameter type of T$class.foo is S, which may be unrelated to T. If we re-write + // a call to T.foo to T$class.foo, we need to cast the receiver to S, otherwise we get a + // VerifyError. We run a `SourceInterpreter` to find all producer instructions of the + // receiver value and add a cast to the self type after each. + if (!calleeDeclarationClass.isSubtypeOf(selfType)) { + val analyzer = new AsmAnalyzer(callsite.callsiteMethod, callsite.callsiteClass.internalName, new SourceInterpreter) + val receiverValue = analyzer.frameAt(callsite.callsiteInstruction).peekDown(traitMethodArgumentTypes.length) + for (i <- receiverValue.insns.asScala) { + val cast = new TypeInsnNode(CHECKCAST, selfType.internalName) + callsite.callsiteMethod.instructions.insert(i, cast) + } + } + val newCallsiteInstruction = new MethodInsnNode(INVOKESTATIC, implClassInternalName, callee.name, implClassMethodDescriptor, false) callsite.callsiteMethod.instructions.insert(callsite.callsiteInstruction, newCallsiteInstruction) callsite.callsiteMethod.instructions.remove(callsite.callsiteInstruction) @@ -291,7 +311,7 @@ class Inliner[BT <: BTypes](val btypes: BT) { // We run an interpreter to know the stack height at each xRETURN instruction and the sizes // of the values on the stack. - val analyzer = new BasicAnalyzer(callee, calleeDeclarationClass.internalName) + val analyzer = new AsmAnalyzer(callee, calleeDeclarationClass.internalName) for (originalReturn <- callee.instructions.iterator().asScala if isReturn(originalReturn)) { val frame = analyzer.frameAt(originalReturn) diff --git a/test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala b/test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala index 16f09db189..d7344ae61f 100644 --- a/test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala +++ b/test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala @@ -11,7 +11,6 @@ import org.junit.Assert._ import scala.tools.asm.tree._ import scala.tools.asm.tree.analysis._ -import scala.tools.nsc.backend.jvm.opt.BytecodeUtils.BasicAnalyzer import scala.tools.testing.AssertUtil._ import CodeGenTools._ diff --git a/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala b/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala index 694dff8dee..7f58f77b15 100644 --- a/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala +++ b/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala @@ -13,7 +13,7 @@ import org.junit.Assert._ import scala.tools.asm.tree._ import scala.tools.asm.tree.analysis._ -import scala.tools.nsc.backend.jvm.opt.BytecodeUtils.BasicAnalyzer +import scala.tools.nsc.backend.jvm.opt.BytecodeUtils.AsmAnalyzer import scala.tools.nsc.io._ import scala.tools.testing.AssertUtil._ @@ -84,7 +84,7 @@ class InlinerTest extends ClearAfterClass { val List(f, g) = cls.methods.asScala.filter(m => Set("f", "g")(m.name)).toList.sortBy(_.name) val fCall = g.instructions.iterator.asScala.collect({ case i: MethodInsnNode if i.name == "f" => i }).next() - val analyzer = new BasicAnalyzer(g, clsBType.internalName) + val analyzer = new AsmAnalyzer(g, clsBType.internalName) val r = inliner.inline( fCall, @@ -222,7 +222,7 @@ class InlinerTest extends ClearAfterClass { case m: MethodInsnNode if m.name == "g" => m }).next() - val analyzer = new BasicAnalyzer(h, dTp.internalName) + val analyzer = new AsmAnalyzer(h, dTp.internalName) val r = inliner.inline( gCall, @@ -374,7 +374,7 @@ class InlinerTest extends ClearAfterClass { val f = c.methods.asScala.find(_.name == "f").get val callsiteIns = f.instructions.iterator().asScala.collect({ case c: MethodInsnNode => c }).next() val clsBType = classBTypeFromParsedClassfile(c.name).get - val analyzer = new BasicAnalyzer(f, clsBType.internalName) + val analyzer = new AsmAnalyzer(f, clsBType.internalName) val integerClassBType = classBTypeFromInternalName("java/lang/Integer") val lowestOneBitMethod = byteCodeRepository.methodNode(integerClassBType.internalName, "lowestOneBit", "(I)I").get._1 @@ -720,4 +720,21 @@ class InlinerTest extends ClearAfterClass { assertNoInvoke(getSingleMethod(d, "m")) assertNoInvoke(getSingleMethod(c, "m")) } + + @Test + def inlineTraitCastReceiverToSelf(): Unit = { + val code = + """class C { def foo(x: Int) = x } + |trait T { self: C => + | @inline final def f(x: Int) = foo(x) + | def t1 = f(1) + | def t2(t: T) = t.f(2) + |} + """.stripMargin + val List(c, t, tc) = compile(code) + val t1 = getSingleMethod(tc, "t1") + val t2 = getSingleMethod(tc, "t2") + val cast = TypeOp(CHECKCAST, "C") + Set(t1, t2).foreach(m => assert(m.instructions.contains(cast), m.instructions)) + } } |