summaryrefslogtreecommitdiff
path: root/src/compiler/scala/tools/nsc/backend/jvm/opt/ClosureOptimizer.scala
blob: 8da209b269dbf61f23e5dd2305b9dbb335d71afd (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
/* NSC -- new Scala compiler
 * Copyright 2005-2015 LAMP/EPFL
 * @author  Martin Odersky
 */

package scala.tools.nsc
package backend.jvm
package opt

import scala.annotation.switch
import scala.reflect.internal.util.NoPosition
import scala.tools.asm.{Handle, Type, Opcodes}
import scala.tools.asm.tree._
import scala.tools.nsc.backend.jvm.BTypes.InternalName
import scala.tools.nsc.backend.jvm.analysis.ProdConsAnalyzer
import BytecodeUtils._
import BackendReporting._
import Opcodes._
import scala.tools.nsc.backend.jvm.opt.ByteCodeRepository.CompilationUnit
import scala.collection.convert.decorateAsScala._

class ClosureOptimizer[BT <: BTypes](val btypes: BT) {
  import btypes._
  import callGraph._

  def rewriteClosureApplyInvocations(): Unit = {
    implicit object closureInitOrdering extends Ordering[(InvokeDynamicInsnNode, MethodNode, ClassBType)] {
      // Note: this code is cleaned up in a future commit, no more tuples.
      override def compare(x: (InvokeDynamicInsnNode, MethodNode, ClassBType), y: (InvokeDynamicInsnNode, MethodNode, ClassBType)): Int = {
        val cls = x._3.internalName compareTo y._3.internalName
        if (cls != 0) return cls

        val mName = x._2.name compareTo y._2.name
        if (mName != 0) return mName

        val mDesc = x._2.desc compareTo y._2.desc
        if (mDesc != 0) return mDesc

        def pos(indy: InvokeDynamicInsnNode) = x._2.instructions.indexOf(indy)
        pos(x._1) - pos(y._1)
      }
    }

    val sorted = closureInstantiations.iterator.map({
      case (indy, (methodNode, ownerClass)) => (indy, methodNode, ownerClass)
    }).to[collection.immutable.TreeSet]

    sorted foreach {
      case (indy, methodNode, ownerClass) =>
        val warnings = rewriteClosureApplyInvocations(indy, methodNode, ownerClass)
        warnings.foreach(w => backendReporting.inlinerWarning(w.pos, w.toString))
    }
  }

  private val lambdaMetaFactoryInternalName: InternalName = "java/lang/invoke/LambdaMetafactory"

  private val metafactoryHandle = {
    val metafactoryMethodName: String = "metafactory"
    val metafactoryDesc: String       = "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodType;Ljava/lang/invoke/MethodHandle;Ljava/lang/invoke/MethodType;)Ljava/lang/invoke/CallSite;"
    new Handle(H_INVOKESTATIC, lambdaMetaFactoryInternalName, metafactoryMethodName, metafactoryDesc)
  }

  private val altMetafactoryHandle = {
    val altMetafactoryMethodName: String = "altMetafactory"
    val altMetafactoryDesc: String       = "(Ljava/lang/invoke/MethodHandles$Lookup;Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)Ljava/lang/invoke/CallSite;"
    new Handle(H_INVOKESTATIC, lambdaMetaFactoryInternalName, altMetafactoryMethodName, altMetafactoryDesc)
  }

  def isClosureInstantiation(indy: InvokeDynamicInsnNode): Boolean = {
    (indy.bsm == metafactoryHandle || indy.bsm == altMetafactoryHandle) &&
    {
      indy.bsmArgs match {
        case Array(samMethodType: Type, implMethod: Handle, instantiatedMethodType: Type, xs @ _*) =>
          // LambdaMetaFactory performs a number of automatic adaptations when invoking the lambda
          // implementation method (casting, boxing, unboxing, and primitive widening, see Javadoc).
          //
          // The closure optimizer supports only one of those adaptations: it will cast arguments
          // to the correct type when re-writing a closure call to the body method. Example:
          //
          //   val fun: String => String = l => l
          //   val l = List("")
          //   fun(l.head)
          //
          // The samMethodType of Function1 is `(Object)Object`, while the instantiatedMethodType
          // is `(String)String`. The return type of `List.head` is `Object`.
          //
          // The implMethod has the signature `C$anonfun(String)String`.
          //
          // At the closure callsite, we have an `INVOKEINTERFACE Function1.apply (Object)Object`,
          // so the object returned by `List.head` can be directly passed into the call (no cast).
          //
          // The closure object will cast the object to String before passing it to the implMethod.
          //
          // When re-writing the closure callsite to the implMethod, we have to insert a cast.
          //
          // The check below ensures that
          //   (1) the implMethod type has the expected singature (captured types plus argument types
          //       from instantiatedMethodType)
          //   (2) the receiver of the implMethod matches the first captured type
          //   (3) all parameters that are not the same in samMethodType and instantiatedMethodType
          //       are reference types, so that we can insert casts to perform the same adaptation
          //       that the closure object would.

          val isStatic = implMethod.getTag == H_INVOKESTATIC
          val indyParamTypes = Type.getArgumentTypes(indy.desc)
          val instantiatedMethodArgTypes = instantiatedMethodType.getArgumentTypes
          val expectedImplMethodType = {
            val paramTypes = (if (isStatic) indyParamTypes else indyParamTypes.tail) ++ instantiatedMethodArgTypes
            Type.getMethodType(instantiatedMethodType.getReturnType, paramTypes: _*)
          }

          {
            Type.getType(implMethod.getDesc) == expectedImplMethodType // (1)
          } && {
            isStatic || implMethod.getOwner == indyParamTypes(0).getInternalName // (2)
          } && {
            def isReference(t: Type) = t.getSort == Type.OBJECT || t.getSort == Type.ARRAY
            (samMethodType.getArgumentTypes, instantiatedMethodArgTypes).zipped forall {
              case (samArgType, instArgType) =>
                samArgType == instArgType || isReference(samArgType) && isReference(instArgType) // (3)
            }
          }

        case _ =>
          false
      }
    }
  }

  def isSamInvocation(invocation: MethodInsnNode, indy: InvokeDynamicInsnNode, prodCons: => ProdConsAnalyzer): Boolean = {
    if (invocation.getOpcode == INVOKESTATIC) false
    else {
      def closureIsReceiver = {
        val invocationFrame = prodCons.frameAt(invocation)
        val receiverSlot = {
          val numArgs = Type.getArgumentTypes(invocation.desc).length
          invocationFrame.stackTop - numArgs
        }
        val receiverProducers = prodCons.initialProducersForValueAt(invocation, receiverSlot)
        receiverProducers.size == 1 && receiverProducers.head == indy
      }

      invocation.name == indy.name && {
        val indySamMethodDesc = indy.bsmArgs(0).asInstanceOf[Type].getDescriptor // safe, checked in isClosureInstantiation
        indySamMethodDesc == invocation.desc
      } &&
      closureIsReceiver // most expensive check last
    }
  }

  /**
   * Stores the values captured by a closure creation into fresh local variables.
   * Returns the list of locals holding the captured values.
   */
  private def storeCaptures(indy: InvokeDynamicInsnNode, methodNode: MethodNode): LocalsList = {
    val capturedTypes = Type.getArgumentTypes(indy.desc)
    val firstCaptureLocal = methodNode.maxLocals

    // This could be optimized: in many cases the captured values are produced by LOAD instructions.
    // If the variable is not modified within the method, we could avoid introducing yet another
    // local. On the other hand, further optimizations (copy propagation, remove unused locals) will
    // clean it up.

    // Captured variables don't need to be cast when loaded at the callsite (castLoadTypes are None).
    // This is checked in `isClosureInstantiation`: the types of the captured variables in the indy
    // instruction match exactly the corresponding parameter types in the body method.
    val localsForCaptures = LocalsList.fromTypes(firstCaptureLocal, capturedTypes, castLoadTypes = _ => None)
    methodNode.maxLocals = firstCaptureLocal + localsForCaptures.size

    insertStoreOps(indy, methodNode, localsForCaptures)
    insertLoadOps(indy, methodNode, localsForCaptures)

    localsForCaptures
  }

  /**
   * Insert store operations in front of the `before` instruction to copy stack values into the
   * locals denoted by `localsList`.
   *
   * The lowest stack value is stored in the head of the locals list, so the last local is stored first.
   */
  private def insertStoreOps(before: AbstractInsnNode, methodNode: MethodNode, localsList: LocalsList) =
    insertLocalValueOps(before, methodNode, localsList, store = true)

  /**
   * Insert load operations in front of the `before` instruction to copy the local values denoted
   * by `localsList` onto the stack.
   *
   * The head of the locals list will be the lowest value on the stack, so the first local is loaded first.
   */
  private def insertLoadOps(before: AbstractInsnNode, methodNode: MethodNode, localsList: LocalsList) =
    insertLocalValueOps(before, methodNode, localsList, store = false)

  private def insertLocalValueOps(before: AbstractInsnNode, methodNode: MethodNode, localsList: LocalsList, store: Boolean): Unit = {
    // If `store` is true, the first instruction needs to store into the last local of the `localsList`.
    // Load instructions on the other hand are emitted in the order of the list.
    // To avoid reversing the list, we use `insert(previousInstr)` for stores and `insertBefore(before)` for loads.
    lazy val previous = before.getPrevious
    for (l <- localsList.locals) {
      val varOp = new VarInsnNode(if (store) l.storeOpcode else l.loadOpcode, l.local)
      if (store) methodNode.instructions.insert(previous, varOp)
      else methodNode.instructions.insertBefore(before, varOp)
      if (!store) for (castType <- l.castLoadedValue)
        methodNode.instructions.insert(varOp, new TypeInsnNode(CHECKCAST, castType.getInternalName))
    }
  }

  def rewriteClosureApplyInvocations(indy: InvokeDynamicInsnNode, methodNode: MethodNode, ownerClass: ClassBType): List[RewriteClosureApplyToClosureBodyFailed] = {
    val lambdaBodyHandle = indy.bsmArgs(1).asInstanceOf[Handle] // safe, checked in isClosureInstantiation

    // Kept as a lazy val to make sure the analysis is only computed if it's actually needed.
    // ProdCons is used to identify closure body invocations (see isSamInvocation), but only if the
    // callsite has the right name and signature. If the method has no invcation instruction with
    // the right name and signature, the analysis is not executed.
    lazy val prodCons = new ProdConsAnalyzer(methodNode, ownerClass.internalName)

    // First collect all callsites without modifying the instructions list yet.
    // Once we start modifying the instruction list, prodCons becomes unusable.

    // A list of callsites and stack heights. If the invocation cannot be rewritten, a warning
    // message is stored in the stack height value.
    val invocationsToRewrite: List[(MethodInsnNode, Either[RewriteClosureApplyToClosureBodyFailed, Int])] = methodNode.instructions.iterator.asScala.collect({
      case invocation: MethodInsnNode if isSamInvocation(invocation, indy, prodCons) =>
        val bodyAccessible: Either[OptimizerWarning, Boolean] = for {
          (bodyMethodNode, declClass) <- byteCodeRepository.methodNode(lambdaBodyHandle.getOwner, lambdaBodyHandle.getName, lambdaBodyHandle.getDesc): Either[OptimizerWarning, (MethodNode, InternalName)]
          isAccessible                <- inliner.memberIsAccessible(bodyMethodNode.access, classBTypeFromParsedClassfile(declClass), classBTypeFromParsedClassfile(lambdaBodyHandle.getOwner), ownerClass)
        } yield {
          isAccessible
        }

        def pos = callGraph.callsites.get(invocation).map(_.callsitePosition).getOrElse(NoPosition)
        val stackSize: Either[RewriteClosureApplyToClosureBodyFailed, Int] = bodyAccessible match {
          case Left(w)      => Left(RewriteClosureAccessCheckFailed(pos, w))
          case Right(false) => Left(RewriteClosureIllegalAccess(pos, ownerClass.internalName))
          case _            => Right(prodCons.frameAt(invocation).getStackSize)
        }

        (invocation, stackSize)
    }).toList

    if (invocationsToRewrite.isEmpty) Nil
    else {
      // lazy val to make sure locals for captures and arguments are only allocated if there's
      // effectively a callsite to rewrite.
      lazy val (localsForCapturedValues, argumentLocalsList) = {
        val captureLocals = storeCaptures(indy, methodNode)

        // allocate locals for storing the arguments of the closure apply callsites.
        // if there are multiple callsites, the same locals are re-used.
        val argTypes = indy.bsmArgs(0).asInstanceOf[Type].getArgumentTypes // safe, checked in isClosureInstantiation
        val firstArgLocal = methodNode.maxLocals

        // The comment in `isClosureInstantiation` explains why we have to introduce casts for
        // arguments that have different types in samMethodType and instantiatedMethodType.
        val castLoadTypes = {
          val instantiatedMethodType = indy.bsmArgs(2).asInstanceOf[Type]
          (argTypes, instantiatedMethodType.getArgumentTypes).zipped map {
            case (samArgType, instantiatedArgType) if samArgType != instantiatedArgType =>
              // isClosureInstantiation ensures that the two types are reference types, so we don't
              // end up casting primitive values.
              Some(instantiatedArgType)
            case _ =>
              None
          }
        }
        val argLocals = LocalsList.fromTypes(firstArgLocal, argTypes, castLoadTypes)
        methodNode.maxLocals = firstArgLocal + argLocals.size

        (captureLocals, argLocals)
      }

      val warnings = invocationsToRewrite flatMap {
        case (invocation, Left(warning)) => Some(warning)

        case (invocation, Right(stackHeight)) =>
          // store arguments
          insertStoreOps(invocation, methodNode, argumentLocalsList)

          // drop the closure from the stack
          methodNode.instructions.insertBefore(invocation, new InsnNode(POP))

          // load captured values and arguments
          insertLoadOps(invocation, methodNode, localsForCapturedValues)
          insertLoadOps(invocation, methodNode, argumentLocalsList)

          // update maxStack
          val capturesStackSize = localsForCapturedValues.size
          val invocationStackHeight = stackHeight + capturesStackSize - 1 // -1 because the closure is gone
          if (invocationStackHeight > methodNode.maxStack)
            methodNode.maxStack = invocationStackHeight

          // replace the callsite with a new call to the body method
          val bodyOpcode = (lambdaBodyHandle.getTag: @switch) match {
            case H_INVOKEVIRTUAL    => INVOKEVIRTUAL
            case H_INVOKESTATIC     => INVOKESTATIC
            case H_INVOKESPECIAL    => INVOKESPECIAL
            case H_INVOKEINTERFACE  => INVOKEINTERFACE
            case H_NEWINVOKESPECIAL =>
              val insns = methodNode.instructions
              insns.insertBefore(invocation, new TypeInsnNode(NEW, lambdaBodyHandle.getOwner))
              insns.insertBefore(invocation, new InsnNode(DUP))
              INVOKESPECIAL
          }
          val isInterface = bodyOpcode == INVOKEINTERFACE
          val bodyInvocation = new MethodInsnNode(bodyOpcode, lambdaBodyHandle.getOwner, lambdaBodyHandle.getName, lambdaBodyHandle.getDesc, isInterface)
          methodNode.instructions.insertBefore(invocation, bodyInvocation)

          val returnType = Type.getReturnType(lambdaBodyHandle.getDesc)
          fixLoadedNothingOrNullValue(returnType, bodyInvocation, methodNode, btypes) // see comment of that method

          methodNode.instructions.remove(invocation)

          // update the call graph
          val originalCallsite = callGraph.callsites.remove(invocation)

          // the method node is needed for building the call graph entry
          val bodyMethod = byteCodeRepository.methodNode(lambdaBodyHandle.getOwner, lambdaBodyHandle.getName, lambdaBodyHandle.getDesc)
          def bodyMethodIsBeingCompiled = byteCodeRepository.classNodeAndSource(lambdaBodyHandle.getOwner).map(_._2 == CompilationUnit).getOrElse(false)
          val bodyMethodCallsite = Callsite(
            callsiteInstruction = bodyInvocation,
            callsiteMethod = methodNode,
            callsiteClass = ownerClass,
            callee = bodyMethod.map({
              case (bodyMethodNode, bodyMethodDeclClass) => Callee(
                callee = bodyMethodNode,
                calleeDeclarationClass = classBTypeFromParsedClassfile(bodyMethodDeclClass),
                safeToInline = compilerSettings.YoptInlineGlobal || bodyMethodIsBeingCompiled,
                safeToRewrite = false, // the lambda body method is not a trait interface method
                annotatedInline = false,
                annotatedNoInline = false,
                calleeInfoWarning = None)
            }),
            argInfos = Nil,
            callsiteStackHeight = invocationStackHeight,
            receiverKnownNotNull = true, // see below (*)
            callsitePosition = originalCallsite.map(_.callsitePosition).getOrElse(NoPosition)
          )
          // (*) The documentation in class LambdaMetafactory says:
          //     "if implMethod corresponds to an instance method, the first capture argument
          //     (corresponding to the receiver) must be non-null"
          // Explanation: If the lambda body method is non-static, the receiver is a captured
          // value. It can only be captured within some instance method, so we know it's non-null.
          callGraph.callsites(bodyInvocation) = bodyMethodCallsite
          None
      }

      warnings.toList
    }
  }

  /**
   * A list of local variables. Each local stores information about its type, see class [[Local]].
   */
  case class LocalsList(locals: List[Local]) {
    val size = locals.iterator.map(_.size).sum
  }

  object LocalsList {
    /**
     * A list of local variables starting at `firstLocal` that can hold values of the types in the
     * `types` parameter.
     *
     * For example, `fromTypes(3, Array(Int, Long, String))` returns
     *   Local(3, intOpOffset)  ::
     *   Local(4, longOpOffset) ::  // note that this local occupies two slots, the next is at 6
     *   Local(6, refOpOffset)  ::
     *   Nil
     */
    def fromTypes(firstLocal: Int, types: Array[Type], castLoadTypes: Int => Option[Type]): LocalsList = {
      var sizeTwoOffset = 0
      val locals: List[Local] = types.indices.map(i => {
        // The ASM method `type.getOpcode` returns the opcode for operating on a value of `type`.
        val offset = types(i).getOpcode(ILOAD) - ILOAD
        val local = Local(firstLocal + i + sizeTwoOffset, offset, castLoadTypes(i))
        if (local.size == 2) sizeTwoOffset += 1
        local
      })(collection.breakOut)
      LocalsList(locals)
    }
  }

  /**
   * Stores a local varaible index the opcode offset required for operating on that variable.
   *
   * The xLOAD / xSTORE opcodes are in the following sequence: I, L, F, D, A, so the offset for
   * a local variable holding a reference (`A`) is 4. See also method `getOpcode` in [[scala.tools.asm.Type]].
   */
  case class Local(local: Int, opcodeOffset: Int, castLoadedValue: Option[Type]) {
    def size = if (loadOpcode == LLOAD || loadOpcode == DLOAD) 2  else 1

    def loadOpcode = ILOAD + opcodeOffset
    def storeOpcode = ISTORE + opcodeOffset
  }
}