/* NSC -- new Scala compiler
* Copyright 2005-2015 LAMP/EPFL
* @author Martin Odersky
*/
package scala.tools.nsc
package backend.jvm
package opt
import scala.annotation.switch
import scala.collection.mutable
import scala.reflect.internal.util.NoPosition
import scala.tools.asm.{Handle, Type, Opcodes}
import scala.tools.asm.tree._
import scala.tools.nsc.backend.jvm.BTypes.InternalName
import scala.tools.nsc.backend.jvm.analysis.ProdConsAnalyzer
import BytecodeUtils._
import BackendReporting._
import Opcodes._
import scala.tools.nsc.backend.jvm.opt.ByteCodeRepository.CompilationUnit
import scala.collection.convert.decorateAsScala._
class ClosureOptimizer[BT <: BTypes](val btypes: BT) {
import btypes._
import callGraph._
def rewriteClosureApplyInvocations(): Unit = {
implicit object closureInitOrdering extends Ordering[ClosureInstantiation] {
override def compare(x: ClosureInstantiation, y: ClosureInstantiation): Int = {
val cls = x.ownerClass.internalName compareTo y.ownerClass.internalName
if (cls != 0) return cls
val mName = x.ownerMethod.name compareTo y.ownerMethod.name
if (mName != 0) return mName
val mDesc = x.ownerMethod.desc compareTo y.ownerMethod.desc
if (mDesc != 0) return mDesc
def pos(inst: ClosureInstantiation) = inst.ownerMethod.instructions.indexOf(inst.lambdaMetaFactoryCall.indy)
pos(x) - pos(y)
}
}
val sorted = mutable.TreeSet.empty[ClosureInstantiation]
sorted ++= closureInstantiations.values
for (closureInst <- sorted) {
val warnings = rewriteClosureApplyInvocations(closureInst)
warnings.foreach(w => backendReporting.inlinerWarning(w.pos, w.toString))
}
}
def isSamInvocation(invocation: MethodInsnNode, indy: InvokeDynamicInsnNode, prodCons: => ProdConsAnalyzer): Boolean = {
if (invocation.getOpcode == INVOKESTATIC) false
else {
def closureIsReceiver = {
val invocationFrame = prodCons.frameAt(invocation)
val receiverSlot = {
val numArgs = Type.getArgumentTypes(invocation.desc).length
invocationFrame.stackTop - numArgs
}
val receiverProducers = prodCons.initialProducersForValueAt(invocation, receiverSlot)
receiverProducers.size == 1 && receiverProducers.head == indy
}
invocation.name == indy.name && {
val indySamMethodDesc = indy.bsmArgs(0).asInstanceOf[Type].getDescriptor // safe, checked in isClosureInstantiation
indySamMethodDesc == invocation.desc
} &&
closureIsReceiver // most expensive check last
}
}
/**
* Stores the values captured by a closure creation into fresh local variables.
* Returns the list of locals holding the captured values.
*/
private def storeCaptures(closureInit: ClosureInstantiation): LocalsList = {
val indy = closureInit.lambdaMetaFactoryCall.indy
val capturedTypes = Type.getArgumentTypes(indy.desc)
val firstCaptureLocal = closureInit.ownerMethod.maxLocals
// This could be optimized: in many cases the captured values are produced by LOAD instructions.
// If the variable is not modified within the method, we could avoid introducing yet another
// local. On the other hand, further optimizations (copy propagation, remove unused locals) will
// clean it up.
// Captured variables don't need to be cast when loaded at the callsite (castLoadTypes are None).
// This is checked in `isClosureInstantiation`: the types of the captured variables in the indy
// instruction match exactly the corresponding parameter types in the body method.
val localsForCaptures = LocalsList.fromTypes(firstCaptureLocal, capturedTypes, castLoadTypes = _ => None)
closureInit.ownerMethod.maxLocals = firstCaptureLocal + localsForCaptures.size
insertStoreOps(indy, closureInit.ownerMethod, localsForCaptures)
insertLoadOps(indy, closureInit.ownerMethod, localsForCaptures)
localsForCaptures
}
/**
* Insert store operations in front of the `before` instruction to copy stack values into the
* locals denoted by `localsList`.
*
* The lowest stack value is stored in the head of the locals list, so the last local is stored first.
*/
private def insertStoreOps(before: AbstractInsnNode, methodNode: MethodNode, localsList: LocalsList) =
insertLocalValueOps(before, methodNode, localsList, store = true)
/**
* Insert load operations in front of the `before` instruction to copy the local values denoted
* by `localsList` onto the stack.
*
* The head of the locals list will be the lowest value on the stack, so the first local is loaded first.
*/
private def insertLoadOps(before: AbstractInsnNode, methodNode: MethodNode, localsList: LocalsList) =
insertLocalValueOps(before, methodNode, localsList, store = false)
private def insertLocalValueOps(before: AbstractInsnNode, methodNode: MethodNode, localsList: LocalsList, store: Boolean): Unit = {
// If `store` is true, the first instruction needs to store into the last local of the `localsList`.
// Load instructions on the other hand are emitted in the order of the list.
// To avoid reversing the list, we use `insert(previousInstr)` for stores and `insertBefore(before)` for loads.
lazy val previous = before.getPrevious
for (l <- localsList.locals) {
val varOp = new VarInsnNode(if (store) l.storeOpcode else l.loadOpcode, l.local)
if (store) methodNode.instructions.insert(previous, varOp)
else methodNode.instructions.insertBefore(before, varOp)
if (!store) for (castType <- l.castLoadedValue)
methodNode.instructions.insert(varOp, new TypeInsnNode(CHECKCAST, castType.getInternalName))
}
}
def rewriteClosureApplyInvocations(closureInit: ClosureInstantiation): List[RewriteClosureApplyToClosureBodyFailed] = {
val lambdaBodyHandle = closureInit.lambdaMetaFactoryCall.implMethod
val ownerMethod = closureInit.ownerMethod
val ownerClass = closureInit.ownerClass
// Kept as a lazy val to make sure the analysis is only computed if it's actually needed.
// ProdCons is used to identify closure body invocations (see isSamInvocation), but only if the
// callsite has the right name and signature. If the method has no invcation instruction with
// the right name and signature, the analysis is not executed.
lazy val prodCons = new ProdConsAnalyzer(ownerMethod, ownerClass.internalName)
// First collect all callsites without modifying the instructions list yet.
// Once we start modifying the instruction list, prodCons becomes unusable.
// A list of callsites and stack heights. If the invocation cannot be rewritten, a warning
// message is stored in the stack height value.
val invocationsToRewrite: List[(MethodInsnNode, Either[RewriteClosureApplyToClosureBodyFailed, Int])] = ownerMethod.instructions.iterator.asScala.collect({
case invocation: MethodInsnNode if isSamInvocation(invocation, closureInit.lambdaMetaFactoryCall.indy, prodCons) =>
val bodyAccessible: Either[OptimizerWarning, Boolean] = for {
(bodyMethodNode, declClass) <- byteCodeRepository.methodNode(lambdaBodyHandle.getOwner, lambdaBodyHandle.getName, lambdaBodyHandle.getDesc): Either[OptimizerWarning, (MethodNode, InternalName)]
isAccessible <- inliner.memberIsAccessible(bodyMethodNode.access, classBTypeFromParsedClassfile(declClass), classBTypeFromParsedClassfile(lambdaBodyHandle.getOwner), ownerClass)
} yield {
isAccessible
}
def pos = callGraph.callsites.get(invocation).map(_.callsitePosition).getOrElse(NoPosition)
val stackSize: Either[RewriteClosureApplyToClosureBodyFailed, Int] = bodyAccessible match {
case Left(w) => Left(RewriteClosureAccessCheckFailed(pos, w))
case Right(false) => Left(RewriteClosureIllegalAccess(pos, ownerClass.internalName))
case _ => Right(prodCons.frameAt(invocation).getStackSize)
}
(invocation, stackSize)
}).toList
if (invocationsToRewrite.isEmpty) Nil
else {
// lazy val to make sure locals for captures and arguments are only allocated if there's
// effectively a callsite to rewrite.
lazy val (localsForCapturedValues, argumentLocalsList) = {
val captureLocals = storeCaptures(closureInit)
// allocate locals for storing the arguments of the closure apply callsites.
// if there are multiple callsites, the same locals are re-used.
val argTypes = closureInit.lambdaMetaFactoryCall.samMethodType.getArgumentTypes
val firstArgLocal = ownerMethod.maxLocals
// The comment in `isClosureInstantiation` explains why we have to introduce casts for
// arguments that have different types in samMethodType and instantiatedMethodType.
val castLoadTypes = {
val instantiatedMethodType = closureInit.lambdaMetaFactoryCall.instantiatedMethodType
(argTypes, instantiatedMethodType.getArgumentTypes).zipped map {
case (samArgType, instantiatedArgType) if samArgType != instantiatedArgType =>
// isClosureInstantiation ensures that the two types are reference types, so we don't
// end up casting primitive values.
Some(instantiatedArgType)
case _ =>
None
}
}
val argLocals = LocalsList.fromTypes(firstArgLocal, argTypes, castLoadTypes)
ownerMethod.maxLocals = firstArgLocal + argLocals.size
(captureLocals, argLocals)
}
val warnings = invocationsToRewrite flatMap {
case (invocation, Left(warning)) => Some(warning)
case (invocation, Right(stackHeight)) =>
// store arguments
insertStoreOps(invocation, ownerMethod, argumentLocalsList)
// drop the closure from the stack
ownerMethod.instructions.insertBefore(invocation, new InsnNode(POP))
// load captured values and arguments
insertLoadOps(invocation, ownerMethod, localsForCapturedValues)
insertLoadOps(invocation, ownerMethod, argumentLocalsList)
// update maxStack
val capturesStackSize = localsForCapturedValues.size
val invocationStackHeight = stackHeight + capturesStackSize - 1 // -1 because the closure is gone
if (invocationStackHeight > ownerMethod.maxStack)
ownerMethod.maxStack = invocationStackHeight
// replace the callsite with a new call to the body method
val bodyOpcode = (lambdaBodyHandle.getTag: @switch) match {
case H_INVOKEVIRTUAL => INVOKEVIRTUAL
case H_INVOKESTATIC => INVOKESTATIC
case H_INVOKESPECIAL => INVOKESPECIAL
case H_INVOKEINTERFACE => INVOKEINTERFACE
case H_NEWINVOKESPECIAL =>
val insns = ownerMethod.instructions
insns.insertBefore(invocation, new TypeInsnNode(NEW, lambdaBodyHandle.getOwner))
insns.insertBefore(invocation, new InsnNode(DUP))
INVOKESPECIAL
}
val isInterface = bodyOpcode == INVOKEINTERFACE
val bodyInvocation = new MethodInsnNode(bodyOpcode, lambdaBodyHandle.getOwner, lambdaBodyHandle.getName, lambdaBodyHandle.getDesc, isInterface)
ownerMethod.instructions.insertBefore(invocation, bodyInvocation)
val returnType = Type.getReturnType(lambdaBodyHandle.getDesc)
fixLoadedNothingOrNullValue(returnType, bodyInvocation, ownerMethod, btypes) // see comment of that method
ownerMethod.instructions.remove(invocation)
// update the call graph
val originalCallsite = callGraph.callsites.remove(invocation)
// the method node is needed for building the call graph entry
val bodyMethod = byteCodeRepository.methodNode(lambdaBodyHandle.getOwner, lambdaBodyHandle.getName, lambdaBodyHandle.getDesc)
def bodyMethodIsBeingCompiled = byteCodeRepository.classNodeAndSource(lambdaBodyHandle.getOwner).map(_._2 == CompilationUnit).getOrElse(false)
val bodyMethodCallsite = Callsite(
callsiteInstruction = bodyInvocation,
callsiteMethod = ownerMethod,
callsiteClass = ownerClass,
callee = bodyMethod.map({
case (bodyMethodNode, bodyMethodDeclClass) => Callee(
callee = bodyMethodNode,
calleeDeclarationClass = classBTypeFromParsedClassfile(bodyMethodDeclClass),
safeToInline = compilerSettings.YoptInlineGlobal || bodyMethodIsBeingCompiled,
safeToRewrite = false, // the lambda body method is not a trait interface method
annotatedInline = false,
annotatedNoInline = false,
calleeInfoWarning = None)
}),
argInfos = Nil,
callsiteStackHeight = invocationStackHeight,
receiverKnownNotNull = true, // see below (*)
callsitePosition = originalCallsite.map(_.callsitePosition).getOrElse(NoPosition)
)
// (*) The documentation in class LambdaMetafactory says:
// "if implMethod corresponds to an instance method, the first capture argument
// (corresponding to the receiver) must be non-null"
// Explanation: If the lambda body method is non-static, the receiver is a captured
// value. It can only be captured within some instance method, so we know it's non-null.
callGraph.callsites(bodyInvocation) = bodyMethodCallsite
None
}
warnings.toList
}
}
/**
* A list of local variables. Each local stores information about its type, see class [[Local]].
*/
case class LocalsList(locals: List[Local]) {
val size = locals.iterator.map(_.size).sum
}
object LocalsList {
/**
* A list of local variables starting at `firstLocal` that can hold values of the types in the
* `types` parameter.
*
* For example, `fromTypes(3, Array(Int, Long, String))` returns
* Local(3, intOpOffset) ::
* Local(4, longOpOffset) :: // note that this local occupies two slots, the next is at 6
* Local(6, refOpOffset) ::
* Nil
*/
def fromTypes(firstLocal: Int, types: Array[Type], castLoadTypes: Int => Option[Type]): LocalsList = {
var sizeTwoOffset = 0
val locals: List[Local] = types.indices.map(i => {
// The ASM method `type.getOpcode` returns the opcode for operating on a value of `type`.
val offset = types(i).getOpcode(ILOAD) - ILOAD
val local = Local(firstLocal + i + sizeTwoOffset, offset, castLoadTypes(i))
if (local.size == 2) sizeTwoOffset += 1
local
})(collection.breakOut)
LocalsList(locals)
}
}
/**
* Stores a local varaible index the opcode offset required for operating on that variable.
*
* The xLOAD / xSTORE opcodes are in the following sequence: I, L, F, D, A, so the offset for
* a local variable holding a reference (`A`) is 4. See also method `getOpcode` in [[scala.tools.asm.Type]].
*/
case class Local(local: Int, opcodeOffset: Int, castLoadedValue: Option[Type]) {
def size = if (loadOpcode == LLOAD || loadOpcode == DLOAD) 2 else 1
def loadOpcode = ILOAD + opcodeOffset
def storeOpcode = ISTORE + opcodeOffset
}
}