From 0a33c421767f6e4587f8adac19169f184d845548 Mon Sep 17 00:00:00 2001 From: Lukas Rytz Date: Fri, 2 Oct 2015 12:27:43 +0200 Subject: Simplify post inlining requests Clean up inliner test --- .../tools/nsc/backend/jvm/opt/CallGraph.scala | 14 +- .../scala/tools/nsc/backend/jvm/opt/Inliner.scala | 36 +++-- .../nsc/backend/jvm/opt/InlinerHeuristics.scala | 6 +- .../scala/tools/nsc/backend/jvm/CodeGenTools.scala | 3 + .../tools/nsc/backend/jvm/opt/CallGraphTest.scala | 13 +- .../tools/nsc/backend/jvm/opt/InlinerTest.scala | 152 +++++---------------- 6 files changed, 74 insertions(+), 150 deletions(-) diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/CallGraph.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/CallGraph.scala index a8f1e43071..801296908f 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/opt/CallGraph.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/CallGraph.scala @@ -67,6 +67,8 @@ class CallGraph[BT <: BTypes](val btypes: BT) { callsites(callsite.callsiteMethod) = methodCallsites + (callsite.callsiteInstruction -> callsite) } + def containsCallsite(callsite: Callsite): Boolean = callsites(callsite.callsiteMethod) contains callsite.callsiteInstruction + def removeClosureInstantiation(indy: InvokeDynamicInsnNode, methodNode: MethodNode): Option[ClosureInstantiation] = { val methodClosureInits = closureInstantiations(methodNode) val newClosureInits = methodClosureInits - indy @@ -130,8 +132,8 @@ class CallGraph[BT <: BTypes](val btypes: BT) { val callee: Either[OptimizerWarning, Callee] = for { (method, declarationClass) <- byteCodeRepository.methodNode(call.owner, call.name, call.desc): Either[OptimizerWarning, (MethodNode, InternalName)] (declarationClassNode, source) <- byteCodeRepository.classNodeAndSource(declarationClass): Either[OptimizerWarning, (ClassNode, Source)] - declarationClassBType = classBTypeFromClassNode(declarationClassNode) } yield { + val declarationClassBType = classBTypeFromClassNode(declarationClassNode) val CallsiteInfo(safeToInline, safeToRewrite, annotatedInline, annotatedNoInline, samParamTypes, warning) = analyzeCallsite(method, declarationClassBType, call.owner, source) Callee( callee = method, @@ -347,6 +349,12 @@ class CallGraph[BT <: BTypes](val btypes: BT) { final case class Callsite(callsiteInstruction: MethodInsnNode, callsiteMethod: MethodNode, callsiteClass: ClassBType, callee: Either[OptimizerWarning, Callee], argInfos: IntMap[ArgInfo], callsiteStackHeight: Int, receiverKnownNotNull: Boolean, callsitePosition: Position) { + /** + * Contains callsites that were created during inlining by cloning this callsite. Used to find + * corresponding callsites when inlining post-inline requests. + */ + val inlinedClones = mutable.Set.empty[Callsite] + override def toString = "Invocation of" + s" ${callee.map(_.calleeDeclarationClass.internalName).getOrElse("?")}.${callsiteInstruction.name + callsiteInstruction.desc}" + @@ -399,6 +407,10 @@ class CallGraph[BT <: BTypes](val btypes: BT) { * graph when re-writing a closure invocation to the body method. */ final case class ClosureInstantiation(lambdaMetaFactoryCall: LambdaMetaFactoryCall, ownerMethod: MethodNode, ownerClass: ClassBType, capturedArgInfos: IntMap[ArgInfo]) { + /** + * Contains closure instantiations that were created during inlining by cloning this instantiation. + */ + val inlinedClones = mutable.Set.empty[ClosureInstantiation] override def toString = s"ClosureInstantiation($lambdaMetaFactoryCall, ${ownerMethod.name + ownerMethod.desc}, $ownerClass)" } final case class LambdaMetaFactoryCall(indy: InvokeDynamicInsnNode, samMethodType: Type, implMethod: Handle, instantiatedMethodType: Type) diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala index 9a5341e131..4e1ecea217 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala @@ -272,14 +272,12 @@ class Inliner[BT <: BTypes](val btypes: BT) { def inline(request: InlineRequest): List[CannotInlineWarning] = canInlineBody(request.callsite) match { case Some(w) => List(w) case None => - val instructionsMap = inlineCallsite(request.callsite) + inlineCallsite(request.callsite) val postRequests = request.post.flatMap(post => { - // the post-request invocation instruction might not exist anymore: it might have been - // inlined itself, or eliminated by DCE. - for { - inlinedInvocationInstr <- instructionsMap.get(post.callsiteInstruction).map(_.asInstanceOf[MethodInsnNode]) - inlinedCallsite <- callGraph.callsites(request.callsite.callsiteMethod).get(inlinedInvocationInstr) - } yield InlineRequest(inlinedCallsite, post.post) + post.callsite.inlinedClones.find(cs => cs.callsiteMethod == request.callsite.callsiteMethod) match { + case Some(inlinedPostCallsite) if callGraph.containsCallsite(inlinedPostCallsite) => Some(InlineRequest(inlinedPostCallsite, post.post)) + case _ => None + } }) postRequests flatMap inline } @@ -296,7 +294,7 @@ class Inliner[BT <: BTypes](val btypes: BT) { * @return A map associating instruction nodes of the callee with the corresponding cloned * instruction in the callsite method. */ - def inlineCallsite(callsite: Callsite): Map[AbstractInsnNode, AbstractInsnNode] = { + def inlineCallsite(callsite: Callsite): Unit = { import callsite.{callsiteClass, callsiteMethod, callsiteInstruction, receiverKnownNotNull, callsiteStackHeight} val Right(callsiteCallee) = callsite.callee import callsiteCallee.{callee, calleeDeclarationClass} @@ -451,7 +449,7 @@ class Inliner[BT <: BTypes](val btypes: BT) { callGraph.callsites(callee).valuesIterator foreach { originalCallsite => val newCallsiteIns = instructionMap(originalCallsite.callsiteInstruction).asInstanceOf[MethodInsnNode] val argInfos = originalCallsite.argInfos flatMap mapArgInfo - callGraph.addCallsite(Callsite( + val newCallsite = Callsite( callsiteInstruction = newCallsiteIns, callsiteMethod = callsiteMethod, callsiteClass = callsiteClass, @@ -460,19 +458,21 @@ class Inliner[BT <: BTypes](val btypes: BT) { callsiteStackHeight = callsiteStackHeight + originalCallsite.callsiteStackHeight, receiverKnownNotNull = originalCallsite.receiverKnownNotNull, callsitePosition = originalCallsite.callsitePosition - )) + ) + originalCallsite.inlinedClones += newCallsite + callGraph.addCallsite(newCallsite) } callGraph.closureInstantiations(callee).valuesIterator foreach { originalClosureInit => val newIndy = instructionMap(originalClosureInit.lambdaMetaFactoryCall.indy).asInstanceOf[InvokeDynamicInsnNode] val capturedArgInfos = originalClosureInit.capturedArgInfos flatMap mapArgInfo - callGraph.addClosureInstantiation( - ClosureInstantiation( - originalClosureInit.lambdaMetaFactoryCall.copy(indy = newIndy), - callsiteMethod, - callsiteClass, - capturedArgInfos) - ) + val newClosureInit = ClosureInstantiation( + originalClosureInit.lambdaMetaFactoryCall.copy(indy = newIndy), + callsiteMethod, + callsiteClass, + capturedArgInfos) + originalClosureInit.inlinedClones += newClosureInit + callGraph.addClosureInstantiation(newClosureInit) } // Remove the elided invocation from the call graph @@ -480,8 +480,6 @@ class Inliner[BT <: BTypes](val btypes: BT) { // Inlining a method body can render some code unreachable, see example above (in runInliner). unreachableCodeEliminated -= callsiteMethod - - instructionMap } /** diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/InlinerHeuristics.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/InlinerHeuristics.scala index d8f12ffb11..52627e77e6 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/opt/InlinerHeuristics.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/InlinerHeuristics.scala @@ -19,8 +19,10 @@ class InlinerHeuristics[BT <: BTypes](val bTypes: BT) { import inliner._ import callGraph._ - case class InlineRequest(callsite: Callsite, post: List[PostInlineRequest]) - case class PostInlineRequest(callsiteInstruction: MethodInsnNode, post: List[PostInlineRequest]) + case class InlineRequest(callsite: Callsite, post: List[InlineRequest]) { + // invariant: all post inline requests denote callsites in the callee of the main callsite + for (pr <- post) assert(pr.callsite.callsiteMethod == callsite.callee.get.callee, s"Callsite method mismatch: main $callsite - post ${pr.callsite}") + } /** * Select callsites from the call graph that should be inlined, grouped by the containing method. diff --git a/test/junit/scala/tools/nsc/backend/jvm/CodeGenTools.scala b/test/junit/scala/tools/nsc/backend/jvm/CodeGenTools.scala index 769236ae49..1f2ec274d3 100644 --- a/test/junit/scala/tools/nsc/backend/jvm/CodeGenTools.scala +++ b/test/junit/scala/tools/nsc/backend/jvm/CodeGenTools.scala @@ -169,6 +169,9 @@ object CodeGenTools { def getSingleMethod(classNode: ClassNode, name: String): Method = convertMethod(classNode.methods.asScala.toList.find(_.name == name).get) + def findAsmMethods(c: ClassNode, p: String => Boolean) = c.methods.iterator.asScala.filter(m => p(m.name)).toList.sortBy(_.name) + def findAsmMethod(c: ClassNode, name: String) = findAsmMethods(c, _ == name).head + /** * Instructions that match `query` when textified. * If `query` starts with a `+`, the next instruction is returned. diff --git a/test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala b/test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala index efd88f10c3..f329a43b30 100644 --- a/test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala +++ b/test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala @@ -50,9 +50,6 @@ class CallGraphTest extends ClearAfterClass { compileClasses(compiler)(code, allowMessage = allowMessage).map(c => byteCodeRepository.classNode(c.name).get) } - def getMethods(c: ClassNode, p: String => Boolean) = c.methods.iterator.asScala.filter(m => p(m.name)).toList.sortBy(_.name) - def getMethod(c: ClassNode, name: String) = getMethods(c, _ == name).head - def callsInMethod(methodNode: MethodNode): List[MethodInsnNode] = methodNode.instructions.iterator.asScala.collect({ case call: MethodInsnNode => call }).toList @@ -121,10 +118,10 @@ class CallGraphTest extends ClearAfterClass { val List(cCls, cMod, dCls, testCls) = compile(code, checkMsg) assert(msgCount == 6, msgCount) - val List(cf1, cf2, cf3, cf4, cf5, cf6, cf7) = getMethods(cCls, _.startsWith("f")) - val List(df1, df3) = getMethods(dCls, _.startsWith("f")) - val g1 = getMethod(cMod, "g1") - val List(t1, t2) = getMethods(testCls, _.startsWith("t")) + val List(cf1, cf2, cf3, cf4, cf5, cf6, cf7) = findAsmMethods(cCls, _.startsWith("f")) + val List(df1, df3) = findAsmMethods(dCls, _.startsWith("f")) + val g1 = findAsmMethod(cMod, "g1") + val List(t1, t2) = findAsmMethods(testCls, _.startsWith("t")) val List(cf1Call, cf2Call, cf3Call, cf4Call, cf5Call, cf6Call, cf7Call, cg1Call) = callsInMethod(t1) val List(df1Call, df2Call, df3Call, df4Call, df5Call, df6Call, df7Call, dg1Call) = callsInMethod(t2) @@ -160,7 +157,7 @@ class CallGraphTest extends ClearAfterClass { |} """.stripMargin val List(c) = compile(code) - val m = getMethod(c, "m") + val m = findAsmMethod(c, "m") val List(fn) = callsInMethod(m) val forNameMeth = byteCodeRepository.methodNode("java/lang/Class", "forName", "(Ljava/lang/String;)Ljava/lang/Class;").get._1 val classTp = classBTypeFromInternalName("java/lang/Class") diff --git a/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala b/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala index 81c6dd2ce2..cdba4073f2 100644 --- a/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala +++ b/test/junit/scala/tools/nsc/backend/jvm/opt/InlinerTest.scala @@ -6,17 +6,11 @@ import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.junit.Test import scala.collection.generic.Clearable -import scala.collection.immutable.IntMap -import scala.collection.mutable.ListBuffer -import scala.reflect.internal.util.{NoPosition, BatchSourceFile} import scala.tools.asm.Opcodes._ import org.junit.Assert._ import scala.tools.asm.tree._ -import scala.tools.asm.tree.analysis._ -import scala.tools.nsc.io._ import scala.tools.nsc.reporters.StoreReporter -import scala.tools.testing.AssertUtil._ import CodeGenTools._ import scala.tools.partest.ASMConverters @@ -69,6 +63,7 @@ class InlinerTest extends ClearAfterClass { val compiler = InlinerTest.compiler import compiler.genBCode.bTypes._ import compiler.genBCode.bTypes.backendUtils._ + import inlinerHeuristics._ def compile(scalaCode: String, javaCode: List[(String, String)] = Nil, allowMessage: StoreReporter#Info => Boolean = _ => false): List[ClassNode] = { InlinerTest.notPerRun.foreach(_.clear()) @@ -88,55 +83,27 @@ class InlinerTest extends ClearAfterClass { assert(callsite.callee.get.callee == callee, callsite.callee.get.callee.name) } - def makeInlineRequest( callsiteInstruction: MethodInsnNode, callsiteMethod: MethodNode, callsiteClass: ClassBType, - callee: MethodNode, calleeDeclarationClass: ClassBType, - callsiteStackHeight: Int, receiverKnownNotNull: Boolean, - post: List[inlinerHeuristics.PostInlineRequest] = Nil) = inlinerHeuristics.InlineRequest( - callsite = callGraph.Callsite( - callsiteInstruction = callsiteInstruction, - callsiteMethod = callsiteMethod, - callsiteClass = callsiteClass, - callee = Right(callGraph.Callee(callee = callee, calleeDeclarationClass = calleeDeclarationClass, safeToInline = true, safeToRewrite = false, annotatedInline = false, annotatedNoInline = false, samParamTypes = IntMap.empty, calleeInfoWarning = None)), - argInfos = IntMap.empty, - callsiteStackHeight = callsiteStackHeight, - receiverKnownNotNull = receiverKnownNotNull, - callsitePosition = NoPosition), - post = post) - - def inlineRequest(code: String, mod: ClassNode => Unit = _ => ()): (inlinerHeuristics.InlineRequest, MethodNode) = { - val List(cls) = compile(code) - mod(cls) - val clsBType = classBTypeFromParsedClassfile(cls.name) - - val List(f, g) = cls.methods.asScala.filter(m => Set("f", "g")(m.name)).toList.sortBy(_.name) - val fCall = g.instructions.iterator.asScala.collect({ case i: MethodInsnNode if i.name == "f" => i }).next() - - val analyzer = new AsmAnalyzer(g, clsBType.internalName) - - val request = makeInlineRequest( - callsiteInstruction = fCall, - callsiteMethod = g, - callsiteClass = clsBType, - callee = f, - calleeDeclarationClass = clsBType, - callsiteStackHeight = analyzer.frameAt(fCall).getStackSize, - receiverKnownNotNull = true - ) - (request, g) - } + def getCallsite(method: MethodNode, calleeName: String) = callGraph.callsites(method).valuesIterator.find(_.callee.get.callee.name == calleeName).get - // inline first invocation of f into g in class C - def inlineTest(code: String, mod: ClassNode => Unit = _ => ()): MethodNode = { - val (request, g) = inlineRequest(code, mod) - inliner.inline(request) - g + def gMethAndFCallsite(code: String, mod: ClassNode => Unit = _ => ()) = { + val List(c) = compile(code) + mod(c) + val gMethod = findAsmMethod(c, "g") + val fCall = getCallsite(gMethod, "f") + (gMethod, fCall) } def canInlineTest(code: String, mod: ClassNode => Unit = _ => ()): Option[OptimizerWarning] = { - val cs = inlineRequest(code, mod)._1.callsite + val cs = gMethAndFCallsite(code, mod)._2 inliner.earlyCanInlineCheck(cs) orElse inliner.canInlineBody(cs) } + def inlineTest(code: String, mod: ClassNode => Unit = _ => ()): MethodNode = { + val (gMethod, fCall) = gMethAndFCallsite(code, mod) + inliner.inline(InlineRequest(fCall, Nil)) + gMethod + } + @Test def simpleInlineOK(): Unit = { val code = @@ -251,29 +218,9 @@ class InlinerTest extends ClearAfterClass { """.stripMargin val List(c, d) = compile(code) - - val cTp = classBTypeFromParsedClassfile(c.name) - val dTp = classBTypeFromParsedClassfile(d.name) - - val g = c.methods.asScala.find(_.name == "g").get - val h = d.methods.asScala.find(_.name == "h").get - val gCall = h.instructions.iterator.asScala.collect({ - case m: MethodInsnNode if m.name == "g" => m - }).next() - - val analyzer = new AsmAnalyzer(h, dTp.internalName) - - val request = makeInlineRequest( - callsiteInstruction = gCall, - callsiteMethod = h, - callsiteClass = dTp, - callee = g, - calleeDeclarationClass = cTp, - callsiteStackHeight = analyzer.frameAt(gCall).getStackSize, - receiverKnownNotNull = true - ) - - val r = inliner.canInlineBody(request.callsite) + val hMeth = findAsmMethod(d, "h") + val gCall = getCallsite(hMeth, "g") + val r = inliner.canInlineBody(gCall) assert(r.nonEmpty && r.get.isInstanceOf[IllegalAccessInstruction], r) } @@ -411,28 +358,14 @@ class InlinerTest extends ClearAfterClass { """.stripMargin val List(c) = compile(code) - val f = c.methods.asScala.find(_.name == "f").get - val callsiteIns = f.instructions.iterator().asScala.collect({ case c: MethodInsnNode => c }).next() - val clsBType = classBTypeFromParsedClassfile(c.name) - val analyzer = new AsmAnalyzer(f, clsBType.internalName) - - val integerClassBType = classBTypeFromInternalName("java/lang/Integer") - val lowestOneBitMethod = byteCodeRepository.methodNode(integerClassBType.internalName, "lowestOneBit", "(I)I").get._1 - - val request = makeInlineRequest( - callsiteInstruction = callsiteIns, - callsiteMethod = f, - callsiteClass = clsBType, - callee = lowestOneBitMethod, - calleeDeclarationClass = integerClassBType, - callsiteStackHeight = analyzer.frameAt(callsiteIns).getStackSize, - receiverKnownNotNull = false - ) - - val warning = inliner.canInlineBody(request.callsite) + val fMeth = findAsmMethod(c, "f") + val call = getCallsite(fMeth, "lowestOneBit") + + val warning = inliner.canInlineBody(call) assert(warning.isEmpty, warning) - inliner.inline(request) - val ins = instructionsFromMethod(f) + + inliner.inline(InlineRequest(call, Nil)) + val ins = instructionsFromMethod(fMeth) // no invocations, lowestOneBit is inlined assertNoInvoke(ins) @@ -1078,36 +1011,15 @@ class InlinerTest extends ClearAfterClass { """.stripMargin val List(c) = compile(code) + val hMeth = findAsmMethod(c, "h") + val gMeth = findAsmMethod(c, "g") + val gCall = getCallsite(hMeth, "g") + val fCall = getCallsite(gMeth, "f") - val cTp = classBTypeFromParsedClassfile(c.name) - - val f = c.methods.asScala.find(_.name == "f").get - val g = c.methods.asScala.find(_.name == "g").get - val h = c.methods.asScala.find(_.name == "h").get - - val gCall = h.instructions.iterator.asScala.collect({ - case m: MethodInsnNode if m.name == "g" => m - }).next() - val fCall = g.instructions.iterator.asScala.collect({ - case m: MethodInsnNode if m.name == "f" => m - }).next() - - val analyzer = new AsmAnalyzer(h, cTp.internalName) - - val request = makeInlineRequest( - callsiteInstruction = gCall, - callsiteMethod = h, - callsiteClass = cTp, - callee = g, - calleeDeclarationClass = cTp, - callsiteStackHeight = analyzer.frameAt(gCall).getStackSize, - receiverKnownNotNull = false, - post = List(inlinerHeuristics.PostInlineRequest(fCall, Nil)) - ) - - val warning = inliner.canInlineBody(request.callsite) + val warning = inliner.canInlineBody(gCall) assert(warning.isEmpty, warning) - inliner.inline(request) + + inliner.inline(InlineRequest(gCall, List(InlineRequest(fCall, Nil)))) assertNoInvoke(getSingleMethod(c, "h")) // no invoke in h: first g is inlined, then the inlined call to f is also inlined assertInvoke(getSingleMethod(c, "g"), "C", "f") // g itself still has the call to f } -- cgit v1.2.3