summaryrefslogtreecommitdiff
path: root/src/compiler/scala/tools/nsc/backend/jvm/opt/ClosureOptimizer.scala
blob: b0dc6ead1bafb6cb0d155c3d165b1b85f211693c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
/* NSC -- new Scala compiler
 * Copyright 2005-2015 LAMP/EPFL
 * @author  Martin Odersky
 */

package scala.tools.nsc
package backend.jvm
package opt

import scala.annotation.switch
import scala.collection.immutable
import scala.reflect.internal.util.NoPosition
import scala.tools.asm.{Type, Opcodes}
import scala.tools.asm.tree._
import scala.tools.nsc.backend.jvm.BTypes.InternalName
import scala.tools.nsc.backend.jvm.analysis.ProdConsAnalyzer
import BytecodeUtils._
import BackendReporting._
import Opcodes._
import scala.tools.nsc.backend.jvm.opt.ByteCodeRepository.CompilationUnit
import scala.collection.convert.decorateAsScala._

class ClosureOptimizer[BT <: BTypes](val btypes: BT) {
  import btypes._
  import callGraph._

  /**
   * If a closure is allocated and invoked within the same method, re-write the invocation to the
   * closure body method.
   *
   * Note that the closure body method (generated by delambdafy:method) takes additional parameters
   * for the values captured by the closure. The bytecode is transformed from
   *
   *   [generate captured values]
   *   [closure init, capturing values]
   *   [...]
   *   [load closure object]
   *   [generate closure invocation arguments]
   *   [invoke closure.apply]
   *
   * to
   *
   *   [generate captured values]
   *   [store captured values into new locals]
   *   [load the captured values from locals]    // a future optimization will eliminate the closure
   *   [closure init, capturing values]          // instantiation if the closure object becomes unused
   *   [...]
   *   [load closure object]
   *   [generate closure invocation arguments]
   *   [store argument values into new locals]
   *   [drop the closure object]
   *   [load captured values from locals]
   *   [load argument values from locals]
   *   [invoke the closure body method]
   */
  def rewriteClosureApplyInvocations(): Unit = {
    implicit object closureInitOrdering extends Ordering[ClosureInstantiation] {
      override def compare(x: ClosureInstantiation, y: ClosureInstantiation): Int = {
        val cls = x.ownerClass.internalName compareTo y.ownerClass.internalName
        if (cls != 0) return cls

        val mName = x.ownerMethod.name compareTo y.ownerMethod.name
        if (mName != 0) return mName

        val mDesc = x.ownerMethod.desc compareTo y.ownerMethod.desc
        if (mDesc != 0) return mDesc

        def pos(inst: ClosureInstantiation) = inst.ownerMethod.instructions.indexOf(inst.lambdaMetaFactoryCall.indy)
        pos(x) - pos(y)
      }
    }

    // Grouping the closure instantiations by method allows running the ProdConsAnalyzer only once per
    // method. Also sort the instantiations: If there are multiple closure instantiations in a method,
    // closure invocations need to be re-written in a consistent order for bytecode stability. The local
    // variable slots for storing captured values depends on the order of rewriting.
    val closureInstantiationsByMethod: Map[MethodNode, immutable.TreeSet[ClosureInstantiation]] = {
      closureInstantiations.values.groupBy(_.ownerMethod).mapValues(immutable.TreeSet.empty ++ _)
    }

    // For each closure instantiation, a list of callsites of the closure that can be re-written
    // If a callsite cannot be rewritten, for example because the lambda body method is not accessible,
    // a warning is returned instead.
    val callsitesToRewrite: List[(ClosureInstantiation, List[Either[RewriteClosureApplyToClosureBodyFailed, (MethodInsnNode, Int)]])] = {
      closureInstantiationsByMethod.iterator.flatMap({
        case (methodNode, closureInits) =>
          // A lazy val to ensure the analysis only runs if necessary (the value is passed by name to `closureCallsites`)
          lazy val prodCons = new ProdConsAnalyzer(methodNode, closureInits.head.ownerClass.internalName)
          closureInits.iterator.map(init => (init, closureCallsites(init, prodCons)))
      }).toList // mapping to a list (not a map) to keep the sorting of closureInstantiationsByMethod
    }

    // Rewrite all closure callsites (or issue inliner warnings for those that cannot be rewritten)
    for ((closureInit, callsites) <- callsitesToRewrite) {
      // Local variables that hold the captured values and the closure invocation arguments.
      // They are lazy vals to ensure that locals for captured values are only allocated if there's
      // actually a callsite to rewrite (an not only warnings to be issued).
      lazy val (localsForCapturedValues, argumentLocalsList) = localsForClosureRewrite(closureInit)
      for (callsite <- callsites) callsite match {
        case Left(warning) =>
          backendReporting.inlinerWarning(warning.pos, warning.toString)

        case Right((invocation, stackHeight)) =>
          rewriteClosureApplyInvocation(closureInit, invocation, stackHeight, localsForCapturedValues, argumentLocalsList)
      }
    }
  }

  /**
   * Insert instructions to store the values captured by a closure instantiation into local variables,
   * and load the values back to the stack.
   *
   * Returns the list of locals holding those captured values, and a list of locals that should be
   * used at the closure invocation callsite to store the arguments passed to the closure invocation.
   */
  private def localsForClosureRewrite(closureInit: ClosureInstantiation): (LocalsList, LocalsList) = {
    val ownerMethod = closureInit.ownerMethod
    val captureLocals = storeCaptures(closureInit)

    // allocate locals for storing the arguments of the closure apply callsites.
    // if there are multiple callsites, the same locals are re-used.
    val argTypes = closureInit.lambdaMetaFactoryCall.samMethodType.getArgumentTypes
    val firstArgLocal = ownerMethod.maxLocals

    // The comment in the unapply method of `LambdaMetaFactoryCall` explains why we have to introduce
    // casts for arguments that have different types in samMethodType and instantiatedMethodType.
    val castLoadTypes = {
      val instantiatedMethodType = closureInit.lambdaMetaFactoryCall.instantiatedMethodType
      (argTypes, instantiatedMethodType.getArgumentTypes).zipped map {
        case (samArgType, instantiatedArgType) if samArgType != instantiatedArgType =>
          // the LambdaMetaFactoryCall extractor ensures that the two types are reference types,
          // so we don't end up casting primitive values.
          Some(instantiatedArgType)
        case _ =>
          None
      }
    }
    val argLocals = LocalsList.fromTypes(firstArgLocal, argTypes, castLoadTypes)
    ownerMethod.maxLocals = firstArgLocal + argLocals.size

    (captureLocals, argLocals)
  }

  /**
   * Find all callsites of a closure within the method where the closure is allocated.
   */
  private def closureCallsites(closureInit: ClosureInstantiation, prodCons: => ProdConsAnalyzer): List[Either[RewriteClosureApplyToClosureBodyFailed, (MethodInsnNode, Int)]] = {
    val ownerMethod = closureInit.ownerMethod
    val ownerClass = closureInit.ownerClass
    val lambdaBodyHandle = closureInit.lambdaMetaFactoryCall.implMethod

    ownerMethod.instructions.iterator.asScala.collect({
      case invocation: MethodInsnNode if isSamInvocation(invocation, closureInit, prodCons) =>
        // TODO: This is maybe over-cautious.
        // We are checking if the closure body method is accessible at the closure callsite.
        // If the closure allocation has access to the body method, then the callsite (in the same
        // method as the alloction) should have access too.
        val bodyAccessible: Either[OptimizerWarning, Boolean] = for {
          (bodyMethodNode, declClass) <- byteCodeRepository.methodNode(lambdaBodyHandle.getOwner, lambdaBodyHandle.getName, lambdaBodyHandle.getDesc): Either[OptimizerWarning, (MethodNode, InternalName)]
          isAccessible                <- inliner.memberIsAccessible(bodyMethodNode.access, classBTypeFromParsedClassfile(declClass), classBTypeFromParsedClassfile(lambdaBodyHandle.getOwner), ownerClass)
        } yield {
          isAccessible
        }

        def pos = callGraph.callsites.get(invocation).map(_.callsitePosition).getOrElse(NoPosition)
        val stackSize: Either[RewriteClosureApplyToClosureBodyFailed, Int] = bodyAccessible match {
          case Left(w)      => Left(RewriteClosureAccessCheckFailed(pos, w))
          case Right(false) => Left(RewriteClosureIllegalAccess(pos, ownerClass.internalName))
          case _            => Right(prodCons.frameAt(invocation).getStackSize)
        }

        stackSize.right.map((invocation, _))
    }).toList
  }

  private def isSamInvocation(invocation: MethodInsnNode, closureInit: ClosureInstantiation, prodCons: => ProdConsAnalyzer): Boolean = {
    val indy = closureInit.lambdaMetaFactoryCall.indy
    if (invocation.getOpcode == INVOKESTATIC) false
    else {
      def closureIsReceiver = {
        val invocationFrame = prodCons.frameAt(invocation)
        val receiverSlot = {
          val numArgs = Type.getArgumentTypes(invocation.desc).length
          invocationFrame.stackTop - numArgs
        }
        val receiverProducers = prodCons.initialProducersForValueAt(invocation, receiverSlot)
        receiverProducers.size == 1 && receiverProducers.head == indy
      }

      invocation.name == indy.name && {
        val indySamMethodDesc = closureInit.lambdaMetaFactoryCall.samMethodType.getDescriptor
        indySamMethodDesc == invocation.desc
      } &&
        closureIsReceiver // most expensive check last
    }
  }

  private def rewriteClosureApplyInvocation(closureInit: ClosureInstantiation, invocation: MethodInsnNode, stackHeight: Int, localsForCapturedValues: LocalsList, argumentLocalsList: LocalsList): Unit = {
    val ownerMethod = closureInit.ownerMethod
    val lambdaBodyHandle = closureInit.lambdaMetaFactoryCall.implMethod

    // store arguments
    insertStoreOps(invocation, ownerMethod, argumentLocalsList)

    // drop the closure from the stack
    ownerMethod.instructions.insertBefore(invocation, new InsnNode(POP))

    // load captured values and arguments
    insertLoadOps(invocation, ownerMethod, localsForCapturedValues)
    insertLoadOps(invocation, ownerMethod, argumentLocalsList)

    // update maxStack
    val capturesStackSize = localsForCapturedValues.size
    val invocationStackHeight = stackHeight + capturesStackSize - 1 // -1 because the closure is gone
    if (invocationStackHeight > ownerMethod.maxStack)
      ownerMethod.maxStack = invocationStackHeight

    // replace the callsite with a new call to the body method
    val bodyOpcode = (lambdaBodyHandle.getTag: @switch) match {
      case H_INVOKEVIRTUAL    => INVOKEVIRTUAL
      case H_INVOKESTATIC     => INVOKESTATIC
      case H_INVOKESPECIAL    => INVOKESPECIAL
      case H_INVOKEINTERFACE  => INVOKEINTERFACE
      case H_NEWINVOKESPECIAL =>
        val insns = ownerMethod.instructions
        insns.insertBefore(invocation, new TypeInsnNode(NEW, lambdaBodyHandle.getOwner))
        insns.insertBefore(invocation, new InsnNode(DUP))
        INVOKESPECIAL
    }
    val isInterface = bodyOpcode == INVOKEINTERFACE
    val bodyInvocation = new MethodInsnNode(bodyOpcode, lambdaBodyHandle.getOwner, lambdaBodyHandle.getName, lambdaBodyHandle.getDesc, isInterface)
    ownerMethod.instructions.insertBefore(invocation, bodyInvocation)

    val returnType = Type.getReturnType(lambdaBodyHandle.getDesc)
    fixLoadedNothingOrNullValue(returnType, bodyInvocation, ownerMethod, btypes) // see comment of that method

    ownerMethod.instructions.remove(invocation)

    // update the call graph
    val originalCallsite = callGraph.callsites.remove(invocation)

    // the method node is needed for building the call graph entry
    val bodyMethod = byteCodeRepository.methodNode(lambdaBodyHandle.getOwner, lambdaBodyHandle.getName, lambdaBodyHandle.getDesc)
    def bodyMethodIsBeingCompiled = byteCodeRepository.classNodeAndSource(lambdaBodyHandle.getOwner).map(_._2 == CompilationUnit).getOrElse(false)
    val bodyMethodCallsite = Callsite(
      callsiteInstruction = bodyInvocation,
      callsiteMethod = ownerMethod,
      callsiteClass = closureInit.ownerClass,
      callee = bodyMethod.map({
        case (bodyMethodNode, bodyMethodDeclClass) => Callee(
          callee = bodyMethodNode,
          calleeDeclarationClass = classBTypeFromParsedClassfile(bodyMethodDeclClass),
          safeToInline = compilerSettings.YoptInlineGlobal || bodyMethodIsBeingCompiled,
          safeToRewrite = false, // the lambda body method is not a trait interface method
          annotatedInline = false,
          annotatedNoInline = false,
          calleeInfoWarning = None)
      }),
      argInfos = Nil,
      callsiteStackHeight = invocationStackHeight,
      receiverKnownNotNull = true, // see below (*)
      callsitePosition = originalCallsite.map(_.callsitePosition).getOrElse(NoPosition)
    )
    // (*) The documentation in class LambdaMetafactory says:
    //     "if implMethod corresponds to an instance method, the first capture argument
    //     (corresponding to the receiver) must be non-null"
    // Explanation: If the lambda body method is non-static, the receiver is a captured
    // value. It can only be captured within some instance method, so we know it's non-null.
    callGraph.callsites(bodyInvocation) = bodyMethodCallsite
  }

  /**
   * Stores the values captured by a closure creation into fresh local variables, and loads the
   * values back onto the stack. Returns the list of locals holding the captured values.
   */
  private def storeCaptures(closureInit: ClosureInstantiation): LocalsList = {
    val indy = closureInit.lambdaMetaFactoryCall.indy
    val capturedTypes = Type.getArgumentTypes(indy.desc)
    val firstCaptureLocal = closureInit.ownerMethod.maxLocals

    // This could be optimized: in many cases the captured values are produced by LOAD instructions.
    // If the variable is not modified within the method, we could avoid introducing yet another
    // local. On the other hand, further optimizations (copy propagation, remove unused locals) will
    // clean it up.

    // Captured variables don't need to be cast when loaded at the callsite (castLoadTypes are None).
    // This is checked in `isClosureInstantiation`: the types of the captured variables in the indy
    // instruction match exactly the corresponding parameter types in the body method.
    val localsForCaptures = LocalsList.fromTypes(firstCaptureLocal, capturedTypes, castLoadTypes = _ => None)
    closureInit.ownerMethod.maxLocals = firstCaptureLocal + localsForCaptures.size

    insertStoreOps(indy, closureInit.ownerMethod, localsForCaptures)
    insertLoadOps(indy, closureInit.ownerMethod, localsForCaptures)

    localsForCaptures
  }

  /**
   * Insert store operations in front of the `before` instruction to copy stack values into the
   * locals denoted by `localsList`.
   *
   * The lowest stack value is stored in the head of the locals list, so the last local is stored first.
   */
  private def insertStoreOps(before: AbstractInsnNode, methodNode: MethodNode, localsList: LocalsList) =
    insertLocalValueOps(before, methodNode, localsList, store = true)

  /**
   * Insert load operations in front of the `before` instruction to copy the local values denoted
   * by `localsList` onto the stack.
   *
   * The head of the locals list will be the lowest value on the stack, so the first local is loaded first.
   */
  private def insertLoadOps(before: AbstractInsnNode, methodNode: MethodNode, localsList: LocalsList) =
    insertLocalValueOps(before, methodNode, localsList, store = false)

  private def insertLocalValueOps(before: AbstractInsnNode, methodNode: MethodNode, localsList: LocalsList, store: Boolean): Unit = {
    // If `store` is true, the first instruction needs to store into the last local of the `localsList`.
    // Load instructions on the other hand are emitted in the order of the list.
    // To avoid reversing the list, we use `insert(previousInstr)` for stores and `insertBefore(before)` for loads.
    lazy val previous = before.getPrevious
    for (l <- localsList.locals) {
      val varOp = new VarInsnNode(if (store) l.storeOpcode else l.loadOpcode, l.local)
      if (store) methodNode.instructions.insert(previous, varOp)
      else methodNode.instructions.insertBefore(before, varOp)
      if (!store) for (castType <- l.castLoadedValue)
        methodNode.instructions.insert(varOp, new TypeInsnNode(CHECKCAST, castType.getInternalName))
    }
  }

  /**
   * A list of local variables. Each local stores information about its type, see class [[Local]].
   */
  case class LocalsList(locals: List[Local]) {
    val size = locals.iterator.map(_.size).sum
  }

  object LocalsList {
    /**
     * A list of local variables starting at `firstLocal` that can hold values of the types in the
     * `types` parameter.
     *
     * For example, `fromTypes(3, Array(Int, Long, String))` returns
     *   Local(3, intOpOffset)  ::
     *   Local(4, longOpOffset) ::  // note that this local occupies two slots, the next is at 6
     *   Local(6, refOpOffset)  ::
     *   Nil
     */
    def fromTypes(firstLocal: Int, types: Array[Type], castLoadTypes: Int => Option[Type]): LocalsList = {
      var sizeTwoOffset = 0
      val locals: List[Local] = types.indices.map(i => {
        // The ASM method `type.getOpcode` returns the opcode for operating on a value of `type`.
        val offset = types(i).getOpcode(ILOAD) - ILOAD
        val local = Local(firstLocal + i + sizeTwoOffset, offset, castLoadTypes(i))
        if (local.size == 2) sizeTwoOffset += 1
        local
      })(collection.breakOut)
      LocalsList(locals)
    }
  }

  /**
   * Stores a local variable index the opcode offset required for operating on that variable.
   *
   * The xLOAD / xSTORE opcodes are in the following sequence: I, L, F, D, A, so the offset for
   * a local variable holding a reference (`A`) is 4. See also method `getOpcode` in [[scala.tools.asm.Type]].
   */
  case class Local(local: Int, opcodeOffset: Int, castLoadedValue: Option[Type]) {
    def size = if (loadOpcode == LLOAD || loadOpcode == DLOAD) 2  else 1

    def loadOpcode = ILOAD + opcodeOffset
    def storeOpcode = ISTORE + opcodeOffset
  }
}