diff options
4 files changed, 98 insertions, 49 deletions
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/CallGraph.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/CallGraph.scala index 6442b81721..5eb4d033e6 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/opt/CallGraph.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/CallGraph.scala @@ -136,7 +136,7 @@ class CallGraph[BT <: BTypes](val btypes: BT) { calleeInfoWarning = warning) } - val argInfos = computeArgInfos(callee, call, methodNode, definingClass, Some(prodCons)) + val argInfos = computeArgInfos(callee, call, prodCons) val receiverNotNull = call.getOpcode == Opcodes.INVOKESTATIC || { val numArgs = Type.getArgumentTypes(call.desc).length @@ -155,10 +155,13 @@ class CallGraph[BT <: BTypes](val btypes: BT) { ) case LambdaMetaFactoryCall(indy, samMethodType, implMethod, instantiatedMethodType) => + val lmf = LambdaMetaFactoryCall(indy, samMethodType, implMethod, instantiatedMethodType) + val capturedArgInfos = computeCapturedArgInfos(lmf, prodCons) methodClosureInstantiations += indy -> ClosureInstantiation( - LambdaMetaFactoryCall(indy, samMethodType, implMethod, instantiatedMethodType), + lmf, methodNode, - definingClass) + definingClass, + capturedArgInfos) case _ => } @@ -166,48 +169,47 @@ class CallGraph[BT <: BTypes](val btypes: BT) { callsites(methodNode) = methodCallsites closureInstantiations(methodNode) = methodClosureInstantiations } - - def computeArgInfos( - callee: Either[OptimizerWarning, Callee], - callsiteInsn: MethodInsnNode, callsiteMethod: MethodNode, callsiteClass: ClassBType, - methodProdCons: => Option[ProdConsAnalyzer] = None): IntMap[ArgInfo] = { + + def computeArgInfos(callee: Either[OptimizerWarning, Callee], callsiteInsn: MethodInsnNode, prodCons: => ProdConsAnalyzer): IntMap[ArgInfo] = { if (callee.isLeft) IntMap.empty else { - if (callee.get.samParamTypes.nonEmpty) { - - val prodCons = methodProdCons.getOrElse({ - localOpt.minimalRemoveUnreachableCode(callsiteMethod, callsiteClass.internalName) - new ProdConsAnalyzer(callsiteMethod, callsiteClass.internalName) - }) - - // TODO: use type analysis instead - should be more efficient than prodCons - // some random thoughts: - // - assign special types to parameters and indy-lambda-functions to track them - // - upcast should not change type flow analysis: don't lose information. - // - can we do something about factory calls? Foo(x) for case class foo gives a Foo. - // inline the factory? analysis across method boundry? - - lazy val callFrame = prodCons.frameAt(callsiteInsn) - val receiverOrFirstArgSlot = { - val numArgs = Type.getArgumentTypes(callsiteInsn.desc).length + (if (callsiteInsn.getOpcode == Opcodes.INVOKESTATIC) 0 else 1) - callFrame.stackTop - numArgs + 1 - } - callee.get.samParamTypes flatMap { - case (index, paramType) => - val prods = prodCons.initialProducersForValueAt(callsiteInsn, receiverOrFirstArgSlot + index) - if (prods.size != 1) None - else { - val argInfo = prods.head match { - case LambdaMetaFactoryCall(_, _, _, _) => Some(FunctionLiteral) - case ParameterProducer(local) => Some(ForwardedParam(local)) - case _ => None - } - argInfo.map((index, _)) - } + lazy val numArgs = Type.getArgumentTypes(callsiteInsn.desc).length + (if (callsiteInsn.getOpcode == Opcodes.INVOKESTATIC) 0 else 1) + argInfosForSams(callee.get.samParamTypes, callsiteInsn, numArgs, prodCons) + } + } + + def computeCapturedArgInfos(lmf: LambdaMetaFactoryCall, prodCons: => ProdConsAnalyzer): IntMap[ArgInfo] = { + val capturedSams = capturedSamTypes(lmf) + val numCaptures = Type.getArgumentTypes(lmf.indy.desc).length + argInfosForSams(capturedSams, lmf.indy, numCaptures, prodCons) + } + + private def argInfosForSams(sams: IntMap[ClassBType], consumerInsn: AbstractInsnNode, numConsumed: => Int, prodCons: => ProdConsAnalyzer): IntMap[ArgInfo] = { + // TODO: use type analysis instead of ProdCons - should be more efficient + // some random thoughts: + // - assign special types to parameters and indy-lambda-functions to track them + // - upcast should not change type flow analysis: don't lose information. + // - can we do something about factory calls? Foo(x) for case class foo gives a Foo. + // inline the factory? analysis across method boundary? + + // assign to a lazy val to prevent repeated evaluation of the by-name arg + lazy val prodConsI = prodCons + lazy val firstConsumedSlot = { + val consumerFrame = prodConsI.frameAt(consumerInsn) + consumerFrame.stackTop - numConsumed + 1 + } + sams flatMap { + case (index, _) => + val prods = prodConsI.initialProducersForValueAt(consumerInsn, firstConsumedSlot + index) + if (prods.size != 1) None + else { + val argInfo = prods.head match { + case LambdaMetaFactoryCall(_, _, _, _) => Some(FunctionLiteral) + case ParameterProducer(local) => Some(ForwardedParam(local)) + case _ => None + } + argInfo.map((index, _)) } - } else { - IntMap.empty - } } } @@ -377,7 +379,16 @@ class CallGraph[BT <: BTypes](val btypes: BT) { override def toString = s"Callee($calleeDeclarationClass.${callee.name})" } - final case class ClosureInstantiation(lambdaMetaFactoryCall: LambdaMetaFactoryCall, ownerMethod: MethodNode, ownerClass: ClassBType) { + /** + * Metadata about a closure instantiation, stored in the call graph + * + * @param lambdaMetaFactoryCall the InvokeDynamic instruction + * @param ownerMethod the method where the closure is allocated + * @param ownerClass the class containing the above method + * @param capturedArgInfos information about captured arguments. Used for updating the call + * graph when re-writing a closure invocation to the body method. + */ + final case class ClosureInstantiation(lambdaMetaFactoryCall: LambdaMetaFactoryCall, ownerMethod: MethodNode, ownerClass: ClassBType, capturedArgInfos: IntMap[ArgInfo]) { override def toString = s"ClosureInstantiation($lambdaMetaFactoryCall, ${ownerMethod.name + ownerMethod.desc}, $ownerClass)" } final case class LambdaMetaFactoryCall(indy: InvokeDynamicInsnNode, samMethodType: Type, implMethod: Handle, instantiatedMethodType: Type) diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/ClosureOptimizer.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/ClosureOptimizer.scala index f211c00c80..30b7f2edad 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/opt/ClosureOptimizer.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/ClosureOptimizer.scala @@ -204,8 +204,8 @@ class ClosureOptimizer[BT <: BTypes](val btypes: BT) { insertLoadOps(invocation, ownerMethod, argumentLocalsList) // update maxStack - val capturesStackSize = localsForCapturedValues.size - val invocationStackHeight = stackHeight + capturesStackSize - 1 // -1 because the closure is gone + val numCapturedValues = localsForCapturedValues.locals.length // not `localsForCapturedValues.size`: every value takes 1 slot on the stack (also long / double), JVMS 2.6.2 + val invocationStackHeight = stackHeight + numCapturedValues - 1 // -1 because the closure is gone if (invocationStackHeight > ownerMethod.maxStack) ownerMethod.maxStack = invocationStackHeight @@ -249,12 +249,15 @@ class ClosureOptimizer[BT <: BTypes](val btypes: BT) { samParamTypes = callGraph.samParamTypes(bodyMethodNode, bodyDeclClassType), calleeInfoWarning = None) }) + val argInfos = closureInit.capturedArgInfos ++ originalCallsite.map(cs => cs.argInfos map { + case (index, info) => (index + numCapturedValues, info) + }).getOrElse(IntMap.empty) val bodyMethodCallsite = Callsite( callsiteInstruction = bodyInvocation, callsiteMethod = ownerMethod, callsiteClass = closureInit.ownerClass, callee = callee, - argInfos = computeArgInfos(callee, bodyInvocation, ownerMethod, closureInit.ownerClass), + argInfos = argInfos, callsiteStackHeight = invocationStackHeight, receiverKnownNotNull = true, // see below (*) callsitePosition = originalCallsite.map(_.callsitePosition).getOrElse(NoPosition) diff --git a/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala b/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala index 1b017bb0da..1550f942c7 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/opt/Inliner.scala @@ -410,20 +410,26 @@ class Inliner[BT <: BTypes](val btypes: BT) { callsiteMethod.tryCatchBlocks.addAll(cloneTryCatchBlockNodes(callee, labelsMap).asJava) callsiteMethod.maxLocals += returnType.getSize + callee.maxLocals - val numStoredArgs = calleeParamTypes.length + (if (isStaticMethod(callee)) 0 else 1) + val numStoredArgs = calleeParamTypes.length + (if (isStaticMethod(callee)) 0 else 1) // every value takes 1 slot on the stack (also long / double), JVMS 2.6.2 callsiteMethod.maxStack = math.max(callsiteMethod.maxStack, callee.maxStack + callsiteStackHeight - numStoredArgs) callGraph.addIfMissing(callee, calleeDeclarationClass) + def mapArgInfo(argInfo: (Int, ArgInfo)): Option[(Int, ArgInfo)] = argInfo match { + case lit @ (_, FunctionLiteral) => Some(lit) + case (argIndex, ForwardedParam(paramIndex)) => callsite.argInfos.get(paramIndex).map((argIndex, _)) + } + // Add all invocation instructions and closure instantiations that were inlined to the call graph callGraph.callsites(callee).valuesIterator foreach { originalCallsite => val newCallsiteIns = instructionMap(originalCallsite.callsiteInstruction).asInstanceOf[MethodInsnNode] + val argInfos = originalCallsite.argInfos flatMap mapArgInfo callGraph.addCallsite(Callsite( callsiteInstruction = newCallsiteIns, callsiteMethod = callsiteMethod, callsiteClass = callsiteClass, callee = originalCallsite.callee, - argInfos = computeArgInfos(originalCallsite.callee, newCallsiteIns, callsiteMethod, callsiteClass), // TODO: try to re-build argInfos from the original callsite's + argInfos = argInfos, callsiteStackHeight = callsiteStackHeight + originalCallsite.callsiteStackHeight, receiverKnownNotNull = originalCallsite.receiverKnownNotNull, callsitePosition = originalCallsite.callsitePosition @@ -432,8 +438,13 @@ class Inliner[BT <: BTypes](val btypes: BT) { callGraph.closureInstantiations(callee).valuesIterator foreach { originalClosureInit => val newIndy = instructionMap(originalClosureInit.lambdaMetaFactoryCall.indy).asInstanceOf[InvokeDynamicInsnNode] + val capturedArgInfos = originalClosureInit.capturedArgInfos flatMap mapArgInfo callGraph.addClosureInstantiation( - ClosureInstantiation(originalClosureInit.lambdaMetaFactoryCall.copy(indy = newIndy), callsiteMethod, callsiteClass) + ClosureInstantiation( + originalClosureInit.lambdaMetaFactoryCall.copy(indy = newIndy), + callsiteMethod, + callsiteClass, + capturedArgInfos) ) } diff --git a/test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala b/test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala index b518cbdc50..995e008912 100644 --- a/test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala +++ b/test/junit/scala/tools/nsc/backend/jvm/opt/CallGraphTest.scala @@ -185,4 +185,28 @@ class CallGraphTest extends ClearAfterClass { val selfSamCall = callIn("selfSamCall") assertEquals(selfSamCall.argInfos.toList, List((0,ForwardedParam(0)))) } + + @Test + def argInfoAfterInlining(): Unit = { + val code = + """class C { + | def foo(f: Int => Int) = f(1) // not inlined + | @inline final def bar(g: Int => Int) = foo(g) // forwarded param 1 + | @inline final def baz = foo(x => x + 1) // literal + | + | def t1 = bar(x => x + 1) // call to foo should have argInfo literal + | def t2(x: Int, f: Int => Int) = x + bar(f) // call to foo should have argInfo forwarded param 2 + | def t3 = baz // call to foo should have argInfo literal + | def someFun: Int => Int = null + | def t4(x: Int) = x + bar(someFun) // call to foo has empty argInfo + |} + """.stripMargin + + compile(code) + def callIn(m: String) = callGraph.callsites.find(_._1.name == m).get._2.values.head + assertEquals(callIn("t1").argInfos.toList, List((1, FunctionLiteral))) + assertEquals(callIn("t2").argInfos.toList, List((1, ForwardedParam(2)))) + assertEquals(callIn("t3").argInfos.toList, List((1, FunctionLiteral))) + assertEquals(callIn("t4").argInfos.toList, Nil) + } } |