summaryrefslogtreecommitdiff
path: root/src/compiler/scala/tools/nsc/transform/Delambdafy.scala
blob: 034cf118d7aedd8ae69d166fee7d5a0c263e7466 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
package scala.tools.nsc
package transform

import symtab._
import Flags._
import scala.collection._

/**
  * This transformer is responsible for preparing Function nodes for runtime,
  * by translating to a tree that will be converted to an invokedynamic by the backend.
  *
  * The main assumption it makes is that a Function {args => body} has been turned into
  * {args => liftedBody()} where lifted body is a top level method that implements the body of the function.
  * Currently Uncurry is responsible for that transformation.
  *
  * From this shape of Function, Delambdafy will create:
  *
  * An application of the captured arguments to a fictional symbol representing the lambda factory.
  * This will be translated by the backed into an invokedynamic using a bootstrap method in JDK8's `LambdaMetaFactory`.
  * The captured arguments include `this` if `liftedBody` is unable to be made STATIC.
  */
abstract class Delambdafy extends Transform with TypingTransformers with ast.TreeDSL with TypeAdaptingTransformer {
  import global._
  import definitions._

  val analyzer: global.analyzer.type = global.analyzer

  /** the following two members override abstract members in Transform */
  val phaseName: String = "delambdafy"

  final case class LambdaMetaFactoryCapable(target: Symbol, arity: Int, functionalInterface: Symbol, sam: Symbol, isSerializable: Boolean, addScalaSerializableMarker: Boolean)

  /**
    * Get the symbol of the target lifted lambda body method from a function. I.e. if
    * the function is {args => anonfun(args)} then this method returns anonfun's symbol
    */
  private def targetMethod(fun: Function): Symbol = fun match {
    case Function(_, Apply(target, _)) => target.symbol
    case _ =>
      // any other shape of Function is unexpected at this point
      abort(s"could not understand function with tree $fun")
  }

  override def newPhase(prev: scala.tools.nsc.Phase): StdPhase = {
    if (settings.Ydelambdafy.value == "method") new Phase(prev)
    else new SkipPhase(prev)
  }

  class SkipPhase(prev: scala.tools.nsc.Phase) extends StdPhase(prev) {
    def apply(unit: global.CompilationUnit): Unit = ()
  }

  protected def newTransformer(unit: CompilationUnit): Transformer =
    new DelambdafyTransformer(unit)

  class DelambdafyTransformer(unit: CompilationUnit) extends TypingTransformer(unit) {
    // we need to know which methods refer to the 'this' reference so that we can determine which lambdas need access to it
    // TODO: this looks expensive, so I made it a lazy val. Can we make it more pay-as-you-go / optimize for common shapes?
    private[this] lazy val methodReferencesThis: Set[Symbol] =
      (new ThisReferringMethodsTraverser).methodReferencesThisIn(unit.body)

    private def mkLambdaMetaFactoryCall(fun: Function, target: Symbol, functionalInterface: Symbol, samUserDefined: Symbol, isSpecialized: Boolean): Tree = {
      val pos = fun.pos
      def isSelfParam(p: Symbol) = p.isSynthetic && p.name == nme.SELF
      val hasSelfParam = isSelfParam(target.firstParam)

      val allCapturedArgRefs = {
        // find which variables are free in the lambda because those are captures that need to be
        // passed into the constructor of the anonymous function class
        val captureArgs = FreeVarTraverser.freeVarsOf(fun).iterator.map(capture =>
          gen.mkAttributedRef(capture) setPos pos
        ).toList

        if (!hasSelfParam) captureArgs.filterNot(arg => isSelfParam(arg.symbol))
        else if (currentMethod.hasFlag(Flags.STATIC)) captureArgs
        else (gen.mkAttributedThis(fun.symbol.enclClass) setPos pos) :: captureArgs
      }

      // Create a symbol representing a fictional lambda factory method that accepts the captured
      // arguments and returns the SAM type.
      val msym = {
        val meth = currentOwner.newMethod(nme.ANON_FUN_NAME, pos, ARTIFACT)
        val capturedParams = meth.newSyntheticValueParams(allCapturedArgRefs.map(_.tpe))
        meth.setInfo(MethodType(capturedParams, fun.tpe))
      }

      // We then apply this symbol to the captures.
      val apply = localTyper.typedPos(pos)(Apply(Ident(msym), allCapturedArgRefs))

      // TODO: this is a bit gross
      val sam = samUserDefined orElse {
        if (isSpecialized) functionalInterface.info.decls.find(_.isDeferred).get
        else functionalInterface.info.member(nme.apply)
      }

      // no need for adaptation when the implemented sam is of a specialized built-in function type
      val lambdaTarget = if (isSpecialized) target else createBoxingBridgeMethodIfNeeded(fun, target, functionalInterface, sam)
      val isSerializable = samUserDefined == NoSymbol || samUserDefined.owner.isNonBottomSubClass(definitions.JavaSerializableClass)
      val addScalaSerializableMarker = samUserDefined == NoSymbol

      // The backend needs to know the target of the lambda and the functional interface in order
      // to emit the invokedynamic instruction. We pass this information as tree attachment.
      //
      // see https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/LambdaMetafactory.html
      //   instantiatedMethodType is derived from lambdaTarget's signature
      //   samMethodType is derived from samOf(functionalInterface)'s signature
      apply.updateAttachment(LambdaMetaFactoryCapable(lambdaTarget, fun.vparams.length, functionalInterface, sam, isSerializable, addScalaSerializableMarker))

      apply
    }


    private val boxingBridgeMethods = mutable.ArrayBuffer[Tree]()

    private def reboxValueClass(tp: Type) = tp match {
      case ErasedValueType(valueClazz, _) => TypeRef(NoPrefix, valueClazz, Nil)
      case _ => tp
    }

    // exclude primitives and value classes, which need special boxing
    private def isReferenceType(tp: Type) = !tp.isInstanceOf[ErasedValueType] && {
      val sym = tp.typeSymbol
      !(isPrimitiveValueClass(sym) || sym.isDerivedValueClass)
    }

    // determine which lambda target to use with java's LMF -- create a new one if scala-specific boxing is required
    def createBoxingBridgeMethodIfNeeded(fun: Function, target: Symbol, functionalInterface: Symbol, sam: Symbol): Symbol = {
      val oldClass = fun.symbol.enclClass
      val pos = fun.pos

      // At erasure, there won't be any captured arguments (they are added in constructors)
      val functionParamTypes = exitingErasure(target.info.paramTypes)
      val functionResultType = exitingErasure(target.info.resultType)

      val samParamTypes = exitingErasure(sam.info.paramTypes)
      val samResultType = exitingErasure(sam.info.resultType)

      /** How to satisfy the linking invariants of https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/LambdaMetafactory.html
        *
        * Given samMethodType: (U1..Un)Ru and function type T1,..., Tn => Rt (the target method created by uncurry)
        *
        * Do we need a bridge, or can we use the original lambda target for implMethod: (<captured args> A1..An)Ra
        * (We can ignore capture here.)
        *
        * If, for i=1..N:
        *  Ai =:= Ui || (Ai <:< Ui <:< AnyRef)
        *  Ru =:= void || (Ra =:= Ru || (Ra <:< AnyRef, Ru <:< AnyRef))
        *
        * We can use the target method as-is -- if not, we create a bridging one that uses the types closest
        * to the target method that still meet the above requirements.
        */
      val resTpOk = (
           samResultType =:= UnitTpe
        || functionResultType =:= samResultType
        || (isReferenceType(samResultType) && isReferenceType(functionResultType))) // yes, this is what the spec says -- no further correspondence required
      if (resTpOk && (samParamTypes corresponds functionParamTypes){ (samParamTp, funParamTp) =>
          funParamTp =:= samParamTp || (isReferenceType(funParamTp) && isReferenceType(samParamTp) && funParamTp <:< samParamTp) }) target
      else {
        // We have to construct a new lambda target that bridges to the one created by uncurry.
        // The bridge must satisfy the above invariants, while also minimizing adaptation on our end.
        // LMF will insert runtime casts according to the spec at the above link.

        // we use the more precise type between samParamTp and funParamTp to minimize boxing in the bridge method
        // we are constructing a method whose signature matches the sam's signature (because the original target did not)
        // whenever a type in the sam's signature is (erases to) a primitive type, we must pick the sam's version,
        // as we don't implement the logic regarding widening that's performed by LMF -- we require =:= for primitives
        //
        // We use the sam's type for the check whether we're dealing with a reference type, as it could be a generic type,
        // which means the function's parameter -- even if it expects a value class -- will need to be
        // boxed on the generic call to the sam method.

        val bridgeParamTypes = map2(samParamTypes, functionParamTypes){ (samParamTp, funParamTp) =>
          if (isReferenceType(samParamTp) && funParamTp <:< samParamTp) funParamTp
          else samParamTp
        }

        val bridgeResultType =
          if (resTpOk && isReferenceType(samResultType) && functionResultType <:< samResultType) functionResultType
          else samResultType

        val typeAdapter = new TypeAdapter { def typedPos(pos: Position)(tree: Tree): Tree = localTyper.typedPos(pos)(tree) }
        import typeAdapter.{adaptToType, unboxValueClass}

        val targetParams = target.paramss.head
        val numCaptures  = targetParams.length - functionParamTypes.length
        val (targetCapturedParams, targetFunctionParams) = targetParams.splitAt(numCaptures)

        val methSym = oldClass.newMethod(target.name.append("$adapted").toTermName, target.pos, target.flags | FINAL | ARTIFACT | STATIC)
        val bridgeCapturedParams = targetCapturedParams.map(param => methSym.newSyntheticValueParam(param.tpe, param.name.toTermName))
        val bridgeFunctionParams =
          map2(targetFunctionParams, bridgeParamTypes)((param, tp) => methSym.newSyntheticValueParam(tp, param.name.toTermName))

        val bridgeParams = bridgeCapturedParams ::: bridgeFunctionParams

        methSym setInfo MethodType(bridgeParams, bridgeResultType)
        oldClass.info.decls enter methSym

        val forwarderCall = localTyper.typedPos(pos) {
          val capturedArgRefs = bridgeCapturedParams map gen.mkAttributedRef
          val functionArgRefs =
            map3(bridgeFunctionParams, functionParamTypes, targetParams.drop(numCaptures)) { (bridgeParam, functionParamTp, targetParam) =>
              val bridgeParamRef = gen.mkAttributedRef(bridgeParam)
              val targetParamTp  = targetParam.tpe

              // TODO: can we simplify this to something like `adaptToType(adaptToType(bridgeParamRef, functionParamTp), targetParamTp)`?
              val unboxed =
                functionParamTp match {
                  case ErasedValueType(clazz, underlying) =>
                    // when the original function expected an argument of value class type,
                    // the original target will expect the unboxed underlying value,
                    // whereas the bridge will receive the boxed value (since the sam's argument type did not match and we had to adapt)
                    localTyper.typed(unboxValueClass(bridgeParamRef, clazz, underlying), targetParamTp)
                  case _ => bridgeParamRef
                }

              adaptToType(unboxed, targetParamTp)
            }

          gen.mkMethodCall(Select(gen.mkAttributedThis(oldClass), target), capturedArgRefs ::: functionArgRefs)
        }

        val bridge = postErasure.newTransformer(unit).transform(DefDef(methSym, List(bridgeParams.map(ValDef(_))),
          adaptToType(forwarderCall setType functionResultType, bridgeResultType))).asInstanceOf[DefDef]

        boxingBridgeMethods += bridge
        bridge.symbol
      }
    }


    private def transformFunction(originalFunction: Function): Tree = {
      val target = targetMethod(originalFunction)
      assert(target.hasFlag(Flags.STATIC))
      target.setFlag(notPRIVATE)

      val funSym = originalFunction.tpe.typeSymbolDirect
      // The functional interface that can be used to adapt the lambda target method `target` to the given function type.
      val (functionalInterface, isSpecialized) =
        if (!isFunctionSymbol(funSym)) (funSym, false)
        else {
          val specializedName =
            specializeTypes.specializedFunctionName(funSym,
              exitingErasure(target.info.paramTypes).map(reboxValueClass) :+ reboxValueClass(exitingErasure(target.info.resultType))).toTypeName

          val isSpecialized = specializedName != funSym.name
          val functionalInterface =
            if (isSpecialized) {
              // Unfortunately we still need to use custom functional interfaces for specialized functions so that the
              // unboxed apply method is left abstract for us to implement.
              currentRun.runDefinitions.Scala_Java8_CompatPackage.info.decl(specializedName.prepend("J"))
            }
            else FunctionClass(originalFunction.vparams.length)

          (functionalInterface, isSpecialized)
        }

      val sam = originalFunction.attachments.get[SAMFunction].map(_.sam).getOrElse(NoSymbol)
      mkLambdaMetaFactoryCall(originalFunction, target, functionalInterface, sam, isSpecialized)
    }

    // here's the main entry point of the transform
    override def transform(tree: Tree): Tree = tree match {
      // the main thing we care about is lambdas
      case fun: Function =>
        super.transform(transformFunction(fun))
      case Template(_, _, _) =>
        def pretransform(tree: Tree): Tree = tree match {
          case dd: DefDef if dd.symbol.isDelambdafyTarget =>
            if (!dd.symbol.hasFlag(STATIC) && methodReferencesThis(dd.symbol)) {
              gen.mkStatic(dd, dd.symbol.name, sym => sym)
            } else {
              dd.symbol.setFlag(STATIC)
              dd
            }
          case t => t
        }
        try {
          // during this call boxingBridgeMethods will be populated from the Function case
          val Template(parents, self, body) = super.transform(deriveTemplate(tree)(_.mapConserve(pretransform)))
          Template(parents, self, body ++ boxingBridgeMethods)
        } finally boxingBridgeMethods.clear()
      case dd: DefDef if dd.symbol.isLiftedMethod && !dd.symbol.isDelambdafyTarget =>
        // SI-9390 emit lifted methods that don't require a `this` reference as STATIC
        // delambdafy targets are excluded as they are made static by `transformFunction`.
        if (!dd.symbol.hasFlag(STATIC) && !methodReferencesThis(dd.symbol)) {
          dd.symbol.setFlag(STATIC)
          dd.symbol.removeAttachment[mixer.NeedStaticImpl.type]
        }
        super.transform(tree)
      case Apply(fun, outer :: rest) if shouldElideOuterArg(fun.symbol, outer) =>
        val nullOuter = gen.mkZero(outer.tpe)
        treeCopy.Apply(tree, transform(fun), nullOuter :: transformTrees(rest))
      case _ => super.transform(tree)
    }
  } // DelambdafyTransformer

  private def shouldElideOuterArg(fun: Symbol, outerArg: Tree): Boolean =
    fun.isConstructor && treeInfo.isQualifierSafeToElide(outerArg) && fun.hasAttachment[OuterArgCanBeElided.type]

  // A traverser that finds symbols used but not defined in the given Tree
  // TODO freeVarTraverser in LambdaLift does a very similar task. With some
  // analysis this could probably be unified with it
  class FreeVarTraverser extends Traverser {
    val freeVars = mutable.LinkedHashSet[Symbol]()
    val declared = mutable.LinkedHashSet[Symbol]()

    override def traverse(tree: Tree) = {
      tree match {
        case Function(args, _) =>
          args foreach {arg => declared += arg.symbol}
        case ValDef(_, _, _, _) =>
          declared += tree.symbol
        case _: Bind =>
          declared += tree.symbol
        case Ident(_) =>
          val sym = tree.symbol
          if ((sym != NoSymbol) && sym.isLocalToBlock && sym.isTerm && !sym.isMethod && !declared.contains(sym)) freeVars += sym
        case _ =>
      }
      super.traverse(tree)
    }
  }

  object FreeVarTraverser {
    def freeVarsOf(function: Function) = {
      val freeVarsTraverser = new FreeVarTraverser
      freeVarsTraverser.traverse(function)
      freeVarsTraverser.freeVars
    }
  }

  // finds all methods that reference 'this'
  class ThisReferringMethodsTraverser extends Traverser {
    // the set of methods that refer to this
    private val thisReferringMethods = mutable.Set[Symbol]()

    // the set of lifted lambda body methods that each method refers to
    private val liftedMethodReferences = mutable.Map[Symbol, Set[Symbol]]().withDefault(_ => mutable.Set())

    def methodReferencesThisIn(tree: Tree) = {
      traverse(tree)
      liftedMethodReferences.keys foreach refersToThis

      thisReferringMethods
    }

    // recursively find methods that refer to 'this' directly or indirectly via references to other methods
    // for each method found add it to the referrers set
    private def refersToThis(symbol: Symbol): Boolean = {
      val seen = mutable.Set[Symbol]()
      def loop(symbol: Symbol): Boolean = {
        if (seen(symbol)) false
        else {
          seen += symbol
          (thisReferringMethods contains symbol) ||
            (liftedMethodReferences(symbol) exists loop) && {
              // add it early to memoize
              debuglog(s"$symbol indirectly refers to 'this'")
              thisReferringMethods += symbol
              true
            }
        }
      }
      loop(symbol)
    }

    private var currentMethod: Symbol = NoSymbol

    override def traverse(tree: Tree) = tree match {
      case DefDef(_, _, _, _, _, _) if tree.symbol.isDelambdafyTarget || tree.symbol.isLiftedMethod =>
        // we don't expect defs within defs. At this phase trees should be very flat
        if (currentMethod.exists) devWarning("Found a def within a def at a phase where defs are expected to be flattened out.")
        currentMethod = tree.symbol
        super.traverse(tree)
        currentMethod = NoSymbol
      case fun@Function(_, _) =>
        // we don't drill into functions because at the beginning of this phase they will always refer to 'this'.
        // They'll be of the form {(args...) => this.anonfun(args...)}
        // but we do need to make note of the lifted body method in case it refers to 'this'
        if (currentMethod.exists) liftedMethodReferences(currentMethod) += targetMethod(fun)
      case Apply(sel @ Select(This(_), _), args) if sel.symbol.isLiftedMethod =>
        if (currentMethod.exists) liftedMethodReferences(currentMethod) += sel.symbol
        super.traverseTrees(args)
      case Apply(fun, outer :: rest) if shouldElideOuterArg(fun.symbol, outer) =>
        super.traverse(fun)
        super.traverseTrees(rest)
      case This(_) =>
        if (currentMethod.exists && tree.symbol == currentMethod.enclClass) {
          debuglog(s"$currentMethod directly refers to 'this'")
          thisReferringMethods add currentMethod
        }
      case _: ClassDef if !tree.symbol.isTopLevel =>
      case _: DefDef =>
      case _ =>
        super.traverse(tree)
    }
  }
}