/* NSC -- new Scala compiler * Copyright 2005-2015 LAMP/EPFL * @author Martin Odersky */ package scala.tools.nsc package backend.jvm package opt import scala.annotation.switch import scala.collection.immutable import scala.reflect.internal.util.NoPosition import scala.tools.asm.{Type, Opcodes} import scala.tools.asm.tree._ import scala.tools.nsc.backend.jvm.BTypes.InternalName import scala.tools.nsc.backend.jvm.analysis.ProdConsAnalyzer import BytecodeUtils._ import BackendReporting._ import Opcodes._ import scala.tools.nsc.backend.jvm.opt.ByteCodeRepository.CompilationUnit import scala.collection.convert.decorateAsScala._ class ClosureOptimizer[BT <: BTypes](val btypes: BT) { import btypes._ import callGraph._ /** * If a closure is allocated and invoked within the same method, re-write the invocation to the * closure body method. * * Note that the closure body method (generated by delambdafy:method) takes additional parameters * for the values captured by the closure. The bytecode is transformed from * * [generate captured values] * [closure init, capturing values] * [...] * [load closure object] * [generate closure invocation arguments] * [invoke closure.apply] * * to * * [generate captured values] * [store captured values into new locals] * [load the captured values from locals] // a future optimization will eliminate the closure * [closure init, capturing values] // instantiation if the closure object becomes unused * [...] * [load closure object] * [generate closure invocation arguments] * [store argument values into new locals] * [drop the closure object] * [load captured values from locals] * [load argument values from locals] * [invoke the closure body method] */ def rewriteClosureApplyInvocations(): Unit = { implicit object closureInitOrdering extends Ordering[ClosureInstantiation] { override def compare(x: ClosureInstantiation, y: ClosureInstantiation): Int = { val cls = x.ownerClass.internalName compareTo y.ownerClass.internalName if (cls != 0) return cls val mName = x.ownerMethod.name compareTo y.ownerMethod.name if (mName != 0) return mName val mDesc = x.ownerMethod.desc compareTo y.ownerMethod.desc if (mDesc != 0) return mDesc def pos(inst: ClosureInstantiation) = inst.ownerMethod.instructions.indexOf(inst.lambdaMetaFactoryCall.indy) pos(x) - pos(y) } } // Grouping the closure instantiations by method allows running the ProdConsAnalyzer only once per // method. Also sort the instantiations: If there are multiple closure instantiations in a method, // closure invocations need to be re-written in a consistent order for bytecode stability. The local // variable slots for storing captured values depends on the order of rewriting. val closureInstantiationsByMethod: Map[MethodNode, immutable.TreeSet[ClosureInstantiation]] = { closureInstantiations.values.groupBy(_.ownerMethod).mapValues(immutable.TreeSet.empty ++ _) } // For each closure instantiation, a list of callsites of the closure that can be re-written // If a callsite cannot be rewritten, for example because the lambda body method is not accessible, // a warning is returned instead. val callsitesToRewrite: List[(ClosureInstantiation, List[Either[RewriteClosureApplyToClosureBodyFailed, (MethodInsnNode, Int)]])] = { closureInstantiationsByMethod.iterator.flatMap({ case (methodNode, closureInits) => // A lazy val to ensure the analysis only runs if necessary (the value is passed by name to `closureCallsites`) lazy val prodCons = new ProdConsAnalyzer(methodNode, closureInits.head.ownerClass.internalName) closureInits.iterator.map(init => (init, closureCallsites(init, prodCons))) }).toList // mapping to a list (not a map) to keep the sorting of closureInstantiationsByMethod } // Rewrite all closure callsites (or issue inliner warnings for those that cannot be rewritten) for ((closureInit, callsites) <- callsitesToRewrite) { // Local variables that hold the captured values and the closure invocation arguments. // They are lazy vals to ensure that locals for captured values are only allocated if there's // actually a callsite to rewrite (an not only warnings to be issued). lazy val (localsForCapturedValues, argumentLocalsList) = localsForClosureRewrite(closureInit) for (callsite <- callsites) callsite match { case Left(warning) => backendReporting.inlinerWarning(warning.pos, warning.toString) case Right((invocation, stackHeight)) => rewriteClosureApplyInvocation(closureInit, invocation, stackHeight, localsForCapturedValues, argumentLocalsList) } } } /** * Insert instructions to store the values captured by a closure instantiation into local variables, * and load the values back to the stack. * * Returns the list of locals holding those captured values, and a list of locals that should be * used at the closure invocation callsite to store the arguments passed to the closure invocation. */ private def localsForClosureRewrite(closureInit: ClosureInstantiation): (LocalsList, LocalsList) = { val ownerMethod = closureInit.ownerMethod val captureLocals = storeCaptures(closureInit) // allocate locals for storing the arguments of the closure apply callsites. // if there are multiple callsites, the same locals are re-used. val argTypes = closureInit.lambdaMetaFactoryCall.samMethodType.getArgumentTypes val firstArgLocal = ownerMethod.maxLocals // The comment in the unapply method of `LambdaMetaFactoryCall` explains why we have to introduce // casts for arguments that have different types in samMethodType and instantiatedMethodType. val castLoadTypes = { val instantiatedMethodType = closureInit.lambdaMetaFactoryCall.instantiatedMethodType (argTypes, instantiatedMethodType.getArgumentTypes).zipped map { case (samArgType, instantiatedArgType) if samArgType != instantiatedArgType => // the LambdaMetaFactoryCall extractor ensures that the two types are reference types, // so we don't end up casting primitive values. Some(instantiatedArgType) case _ => None } } val argLocals = LocalsList.fromTypes(firstArgLocal, argTypes, castLoadTypes) ownerMethod.maxLocals = firstArgLocal + argLocals.size (captureLocals, argLocals) } /** * Find all callsites of a closure within the method where the closure is allocated. */ private def closureCallsites(closureInit: ClosureInstantiation, prodCons: => ProdConsAnalyzer): List[Either[RewriteClosureApplyToClosureBodyFailed, (MethodInsnNode, Int)]] = { val ownerMethod = closureInit.ownerMethod val ownerClass = closureInit.ownerClass val lambdaBodyHandle = closureInit.lambdaMetaFactoryCall.implMethod ownerMethod.instructions.iterator.asScala.collect({ case invocation: MethodInsnNode if isSamInvocation(invocation, closureInit, prodCons) => // TODO: This is maybe over-cautious. // We are checking if the closure body method is accessible at the closure callsite. // If the closure allocation has access to the body method, then the callsite (in the same // method as the alloction) should have access too. val bodyAccessible: Either[OptimizerWarning, Boolean] = for { (bodyMethodNode, declClass) <- byteCodeRepository.methodNode(lambdaBodyHandle.getOwner, lambdaBodyHandle.getName, lambdaBodyHandle.getDesc): Either[OptimizerWarning, (MethodNode, InternalName)] isAccessible <- inliner.memberIsAccessible(bodyMethodNode.access, classBTypeFromParsedClassfile(declClass), classBTypeFromParsedClassfile(lambdaBodyHandle.getOwner), ownerClass) } yield { isAccessible } def pos = callGraph.callsites.get(invocation).map(_.callsitePosition).getOrElse(NoPosition) val stackSize: Either[RewriteClosureApplyToClosureBodyFailed, Int] = bodyAccessible match { case Left(w) => Left(RewriteClosureAccessCheckFailed(pos, w)) case Right(false) => Left(RewriteClosureIllegalAccess(pos, ownerClass.internalName)) case _ => Right(prodCons.frameAt(invocation).getStackSize) } stackSize.right.map((invocation, _)) }).toList } private def isSamInvocation(invocation: MethodInsnNode, closureInit: ClosureInstantiation, prodCons: => ProdConsAnalyzer): Boolean = { val indy = closureInit.lambdaMetaFactoryCall.indy if (invocation.getOpcode == INVOKESTATIC) false else { def closureIsReceiver = { val invocationFrame = prodCons.frameAt(invocation) val receiverSlot = { val numArgs = Type.getArgumentTypes(invocation.desc).length invocationFrame.stackTop - numArgs } val receiverProducers = prodCons.initialProducersForValueAt(invocation, receiverSlot) receiverProducers.size == 1 && receiverProducers.head == indy } invocation.name == indy.name && { val indySamMethodDesc = closureInit.lambdaMetaFactoryCall.samMethodType.getDescriptor indySamMethodDesc == invocation.desc } && closureIsReceiver // most expensive check last } } private def rewriteClosureApplyInvocation(closureInit: ClosureInstantiation, invocation: MethodInsnNode, stackHeight: Int, localsForCapturedValues: LocalsList, argumentLocalsList: LocalsList): Unit = { val ownerMethod = closureInit.ownerMethod val lambdaBodyHandle = closureInit.lambdaMetaFactoryCall.implMethod // store arguments insertStoreOps(invocation, ownerMethod, argumentLocalsList) // drop the closure from the stack ownerMethod.instructions.insertBefore(invocation, new InsnNode(POP)) // load captured values and arguments insertLoadOps(invocation, ownerMethod, localsForCapturedValues) insertLoadOps(invocation, ownerMethod, argumentLocalsList) // update maxStack val capturesStackSize = localsForCapturedValues.size val invocationStackHeight = stackHeight + capturesStackSize - 1 // -1 because the closure is gone if (invocationStackHeight > ownerMethod.maxStack) ownerMethod.maxStack = invocationStackHeight // replace the callsite with a new call to the body method val bodyOpcode = (lambdaBodyHandle.getTag: @switch) match { case H_INVOKEVIRTUAL => INVOKEVIRTUAL case H_INVOKESTATIC => INVOKESTATIC case H_INVOKESPECIAL => INVOKESPECIAL case H_INVOKEINTERFACE => INVOKEINTERFACE case H_NEWINVOKESPECIAL => val insns = ownerMethod.instructions insns.insertBefore(invocation, new TypeInsnNode(NEW, lambdaBodyHandle.getOwner)) insns.insertBefore(invocation, new InsnNode(DUP)) INVOKESPECIAL } val isInterface = bodyOpcode == INVOKEINTERFACE val bodyInvocation = new MethodInsnNode(bodyOpcode, lambdaBodyHandle.getOwner, lambdaBodyHandle.getName, lambdaBodyHandle.getDesc, isInterface) ownerMethod.instructions.insertBefore(invocation, bodyInvocation) val returnType = Type.getReturnType(lambdaBodyHandle.getDesc) fixLoadedNothingOrNullValue(returnType, bodyInvocation, ownerMethod, btypes) // see comment of that method ownerMethod.instructions.remove(invocation) // update the call graph val originalCallsite = callGraph.callsites.remove(invocation) // the method node is needed for building the call graph entry val bodyMethod = byteCodeRepository.methodNode(lambdaBodyHandle.getOwner, lambdaBodyHandle.getName, lambdaBodyHandle.getDesc) def bodyMethodIsBeingCompiled = byteCodeRepository.classNodeAndSource(lambdaBodyHandle.getOwner).map(_._2 == CompilationUnit).getOrElse(false) val bodyMethodCallsite = Callsite( callsiteInstruction = bodyInvocation, callsiteMethod = ownerMethod, callsiteClass = closureInit.ownerClass, callee = bodyMethod.map({ case (bodyMethodNode, bodyMethodDeclClass) => Callee( callee = bodyMethodNode, calleeDeclarationClass = classBTypeFromParsedClassfile(bodyMethodDeclClass), safeToInline = compilerSettings.YoptInlineGlobal || bodyMethodIsBeingCompiled, safeToRewrite = false, // the lambda body method is not a trait interface method annotatedInline = false, annotatedNoInline = false, calleeInfoWarning = None) }), argInfos = Nil, callsiteStackHeight = invocationStackHeight, receiverKnownNotNull = true, // see below (*) callsitePosition = originalCallsite.map(_.callsitePosition).getOrElse(NoPosition) ) // (*) The documentation in class LambdaMetafactory says: // "if implMethod corresponds to an instance method, the first capture argument // (corresponding to the receiver) must be non-null" // Explanation: If the lambda body method is non-static, the receiver is a captured // value. It can only be captured within some instance method, so we know it's non-null. callGraph.callsites(bodyInvocation) = bodyMethodCallsite } /** * Stores the values captured by a closure creation into fresh local variables, and loads the * values back onto the stack. Returns the list of locals holding the captured values. */ private def storeCaptures(closureInit: ClosureInstantiation): LocalsList = { val indy = closureInit.lambdaMetaFactoryCall.indy val capturedTypes = Type.getArgumentTypes(indy.desc) val firstCaptureLocal = closureInit.ownerMethod.maxLocals // This could be optimized: in many cases the captured values are produced by LOAD instructions. // If the variable is not modified within the method, we could avoid introducing yet another // local. On the other hand, further optimizations (copy propagation, remove unused locals) will // clean it up. // Captured variables don't need to be cast when loaded at the callsite (castLoadTypes are None). // This is checked in `isClosureInstantiation`: the types of the captured variables in the indy // instruction match exactly the corresponding parameter types in the body method. val localsForCaptures = LocalsList.fromTypes(firstCaptureLocal, capturedTypes, castLoadTypes = _ => None) closureInit.ownerMethod.maxLocals = firstCaptureLocal + localsForCaptures.size insertStoreOps(indy, closureInit.ownerMethod, localsForCaptures) insertLoadOps(indy, closureInit.ownerMethod, localsForCaptures) localsForCaptures } /** * Insert store operations in front of the `before` instruction to copy stack values into the * locals denoted by `localsList`. * * The lowest stack value is stored in the head of the locals list, so the last local is stored first. */ private def insertStoreOps(before: AbstractInsnNode, methodNode: MethodNode, localsList: LocalsList) = insertLocalValueOps(before, methodNode, localsList, store = true) /** * Insert load operations in front of the `before` instruction to copy the local values denoted * by `localsList` onto the stack. * * The head of the locals list will be the lowest value on the stack, so the first local is loaded first. */ private def insertLoadOps(before: AbstractInsnNode, methodNode: MethodNode, localsList: LocalsList) = insertLocalValueOps(before, methodNode, localsList, store = false) private def insertLocalValueOps(before: AbstractInsnNode, methodNode: MethodNode, localsList: LocalsList, store: Boolean): Unit = { // If `store` is true, the first instruction needs to store into the last local of the `localsList`. // Load instructions on the other hand are emitted in the order of the list. // To avoid reversing the list, we use `insert(previousInstr)` for stores and `insertBefore(before)` for loads. lazy val previous = before.getPrevious for (l <- localsList.locals) { val varOp = new VarInsnNode(if (store) l.storeOpcode else l.loadOpcode, l.local) if (store) methodNode.instructions.insert(previous, varOp) else methodNode.instructions.insertBefore(before, varOp) if (!store) for (castType <- l.castLoadedValue) methodNode.instructions.insert(varOp, new TypeInsnNode(CHECKCAST, castType.getInternalName)) } } /** * A list of local variables. Each local stores information about its type, see class [[Local]]. */ case class LocalsList(locals: List[Local]) { val size = locals.iterator.map(_.size).sum } object LocalsList { /** * A list of local variables starting at `firstLocal` that can hold values of the types in the * `types` parameter. * * For example, `fromTypes(3, Array(Int, Long, String))` returns * Local(3, intOpOffset) :: * Local(4, longOpOffset) :: // note that this local occupies two slots, the next is at 6 * Local(6, refOpOffset) :: * Nil */ def fromTypes(firstLocal: Int, types: Array[Type], castLoadTypes: Int => Option[Type]): LocalsList = { var sizeTwoOffset = 0 val locals: List[Local] = types.indices.map(i => { // The ASM method `type.getOpcode` returns the opcode for operating on a value of `type`. val offset = types(i).getOpcode(ILOAD) - ILOAD val local = Local(firstLocal + i + sizeTwoOffset, offset, castLoadTypes(i)) if (local.size == 2) sizeTwoOffset += 1 local })(collection.breakOut) LocalsList(locals) } } /** * Stores a local variable index the opcode offset required for operating on that variable. * * The xLOAD / xSTORE opcodes are in the following sequence: I, L, F, D, A, so the offset for * a local variable holding a reference (`A`) is 4. See also method `getOpcode` in [[scala.tools.asm.Type]]. */ case class Local(local: Int, opcodeOffset: Int, castLoadedValue: Option[Type]) { def size = if (loadOpcode == LLOAD || loadOpcode == DLOAD) 2 else 1 def loadOpcode = ILOAD + opcodeOffset def storeOpcode = ISTORE + opcodeOffset } }