package scala.tools.nsc package transform import symtab._ import Flags._ import scala.collection._ /** * This transformer is responsible for preparing Function nodes for runtime, * by translating to a tree that will be converted to an invokedynamic by the backend. * * The main assumption it makes is that a Function {args => body} has been turned into * {args => liftedBody()} where lifted body is a top level method that implements the body of the function. * Currently Uncurry is responsible for that transformation. * * From this shape of Function, Delambdafy will create: * * An application of the captured arguments to a fictional symbol representing the lambda factory. * This will be translated by the backed into an invokedynamic using a bootstrap method in JDK8's `LambdaMetaFactory`. * The captured arguments include `this` if `liftedBody` is unable to be made STATIC. */ abstract class Delambdafy extends Transform with TypingTransformers with ast.TreeDSL with TypeAdaptingTransformer { import global._ import definitions._ val analyzer: global.analyzer.type = global.analyzer /** the following two members override abstract members in Transform */ val phaseName: String = "delambdafy" final case class LambdaMetaFactoryCapable(target: Symbol, arity: Int, functionalInterface: Symbol, sam: Symbol, isSerializable: Boolean, addScalaSerializableMarker: Boolean) /** * Get the symbol of the target lifted lambda body method from a function. I.e. if * the function is {args => anonfun(args)} then this method returns anonfun's symbol */ private def targetMethod(fun: Function): Symbol = fun match { case Function(_, Apply(target, _)) => target.symbol case _ => // any other shape of Function is unexpected at this point abort(s"could not understand function with tree $fun") } override def newPhase(prev: scala.tools.nsc.Phase): StdPhase = { if (settings.Ydelambdafy.value == "method") new Phase(prev) else new SkipPhase(prev) } class SkipPhase(prev: scala.tools.nsc.Phase) extends StdPhase(prev) { def apply(unit: global.CompilationUnit): Unit = () } protected def newTransformer(unit: CompilationUnit): Transformer = new DelambdafyTransformer(unit) class DelambdafyTransformer(unit: CompilationUnit) extends TypingTransformer(unit) { // we need to know which methods refer to the 'this' reference so that we can determine which lambdas need access to it // TODO: this looks expensive, so I made it a lazy val. Can we make it more pay-as-you-go / optimize for common shapes? private[this] lazy val methodReferencesThis: Set[Symbol] = (new ThisReferringMethodsTraverser).methodReferencesThisIn(unit.body) private def mkLambdaMetaFactoryCall(fun: Function, target: Symbol, functionalInterface: Symbol, samUserDefined: Symbol, isSpecialized: Boolean): Tree = { val pos = fun.pos def isSelfParam(p: Symbol) = p.isSynthetic && p.name == nme.SELF val hasSelfParam = isSelfParam(target.firstParam) val allCapturedArgRefs = { // find which variables are free in the lambda because those are captures that need to be // passed into the constructor of the anonymous function class val captureArgs = FreeVarTraverser.freeVarsOf(fun).iterator.map(capture => gen.mkAttributedRef(capture) setPos pos ).toList if (!hasSelfParam) captureArgs.filterNot(arg => isSelfParam(arg.symbol)) else if (currentMethod.hasFlag(Flags.STATIC)) captureArgs else (gen.mkAttributedThis(fun.symbol.enclClass) setPos pos) :: captureArgs } // Create a symbol representing a fictional lambda factory method that accepts the captured // arguments and returns the SAM type. val msym = { val meth = currentOwner.newMethod(nme.ANON_FUN_NAME, pos, ARTIFACT) val capturedParams = meth.newSyntheticValueParams(allCapturedArgRefs.map(_.tpe)) meth.setInfo(MethodType(capturedParams, fun.tpe)) } // We then apply this symbol to the captures. val apply = localTyper.typedPos(pos)(Apply(Ident(msym), allCapturedArgRefs)) // TODO: this is a bit gross val sam = samUserDefined orElse { if (isSpecialized) functionalInterface.info.decls.find(_.isDeferred).get else functionalInterface.info.member(nme.apply) } // no need for adaptation when the implemented sam is of a specialized built-in function type val lambdaTarget = if (isSpecialized) target else createBoxingBridgeMethodIfNeeded(fun, target, functionalInterface, sam) val isSerializable = samUserDefined == NoSymbol || samUserDefined.owner.isNonBottomSubClass(definitions.JavaSerializableClass) val addScalaSerializableMarker = samUserDefined == NoSymbol // The backend needs to know the target of the lambda and the functional interface in order // to emit the invokedynamic instruction. We pass this information as tree attachment. // // see https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/LambdaMetafactory.html // instantiatedMethodType is derived from lambdaTarget's signature // samMethodType is derived from samOf(functionalInterface)'s signature apply.updateAttachment(LambdaMetaFactoryCapable(lambdaTarget, fun.vparams.length, functionalInterface, sam, isSerializable, addScalaSerializableMarker)) apply } private val boxingBridgeMethods = mutable.ArrayBuffer[Tree]() private def reboxValueClass(tp: Type) = tp match { case ErasedValueType(valueClazz, _) => TypeRef(NoPrefix, valueClazz, Nil) case _ => tp } // exclude primitives and value classes, which need special boxing private def isReferenceType(tp: Type) = !tp.isInstanceOf[ErasedValueType] && { val sym = tp.typeSymbol !(isPrimitiveValueClass(sym) || sym.isDerivedValueClass) } // determine which lambda target to use with java's LMF -- create a new one if scala-specific boxing is required def createBoxingBridgeMethodIfNeeded(fun: Function, target: Symbol, functionalInterface: Symbol, sam: Symbol): Symbol = { val oldClass = fun.symbol.enclClass val pos = fun.pos // At erasure, there won't be any captured arguments (they are added in constructors) val functionParamTypes = exitingErasure(target.info.paramTypes) val functionResultType = exitingErasure(target.info.resultType) val samParamTypes = exitingErasure(sam.info.paramTypes) val samResultType = exitingErasure(sam.info.resultType) /** How to satisfy the linking invariants of https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/LambdaMetafactory.html * * Given samMethodType: (U1..Un)Ru and function type T1,..., Tn => Rt (the target method created by uncurry) * * Do we need a bridge, or can we use the original lambda target for implMethod: ( A1..An)Ra * (We can ignore capture here.) * * If, for i=1..N: * Ai =:= Ui || (Ai <:< Ui <:< AnyRef) * Ru =:= void || (Ra =:= Ru || (Ra <:< AnyRef, Ru <:< AnyRef)) * * We can use the target method as-is -- if not, we create a bridging one that uses the types closest * to the target method that still meet the above requirements. */ val resTpOk = ( samResultType =:= UnitTpe || functionResultType =:= samResultType || (isReferenceType(samResultType) && isReferenceType(functionResultType))) // yes, this is what the spec says -- no further correspondence required if (resTpOk && (samParamTypes corresponds functionParamTypes){ (samParamTp, funParamTp) => funParamTp =:= samParamTp || (isReferenceType(funParamTp) && isReferenceType(samParamTp) && funParamTp <:< samParamTp) }) target else { // We have to construct a new lambda target that bridges to the one created by uncurry. // The bridge must satisfy the above invariants, while also minimizing adaptation on our end. // LMF will insert runtime casts according to the spec at the above link. // we use the more precise type between samParamTp and funParamTp to minimize boxing in the bridge method // we are constructing a method whose signature matches the sam's signature (because the original target did not) // whenever a type in the sam's signature is (erases to) a primitive type, we must pick the sam's version, // as we don't implement the logic regarding widening that's performed by LMF -- we require =:= for primitives // // We use the sam's type for the check whether we're dealing with a reference type, as it could be a generic type, // which means the function's parameter -- even if it expects a value class -- will need to be // boxed on the generic call to the sam method. val bridgeParamTypes = map2(samParamTypes, functionParamTypes){ (samParamTp, funParamTp) => if (isReferenceType(samParamTp) && funParamTp <:< samParamTp) funParamTp else samParamTp } val bridgeResultType = if (resTpOk && isReferenceType(samResultType) && functionResultType <:< samResultType) functionResultType else samResultType val typeAdapter = new TypeAdapter { def typedPos(pos: Position)(tree: Tree): Tree = localTyper.typedPos(pos)(tree) } import typeAdapter.{adaptToType, unboxValueClass} val targetParams = target.paramss.head val numCaptures = targetParams.length - functionParamTypes.length val (targetCapturedParams, targetFunctionParams) = targetParams.splitAt(numCaptures) val methSym = oldClass.newMethod(target.name.append("$adapted").toTermName, target.pos, target.flags | FINAL | ARTIFACT | STATIC) val bridgeCapturedParams = targetCapturedParams.map(param => methSym.newSyntheticValueParam(param.tpe, param.name.toTermName)) val bridgeFunctionParams = map2(targetFunctionParams, bridgeParamTypes)((param, tp) => methSym.newSyntheticValueParam(tp, param.name.toTermName)) val bridgeParams = bridgeCapturedParams ::: bridgeFunctionParams methSym setInfo MethodType(bridgeParams, bridgeResultType) oldClass.info.decls enter methSym val forwarderCall = localTyper.typedPos(pos) { val capturedArgRefs = bridgeCapturedParams map gen.mkAttributedRef val functionArgRefs = map3(bridgeFunctionParams, functionParamTypes, targetParams.drop(numCaptures)) { (bridgeParam, functionParamTp, targetParam) => val bridgeParamRef = gen.mkAttributedRef(bridgeParam) val targetParamTp = targetParam.tpe // TODO: can we simplify this to something like `adaptToType(adaptToType(bridgeParamRef, functionParamTp), targetParamTp)`? val unboxed = functionParamTp match { case ErasedValueType(clazz, underlying) => // when the original function expected an argument of value class type, // the original target will expect the unboxed underlying value, // whereas the bridge will receive the boxed value (since the sam's argument type did not match and we had to adapt) localTyper.typed(unboxValueClass(bridgeParamRef, clazz, underlying), targetParamTp) case _ => bridgeParamRef } adaptToType(unboxed, targetParamTp) } gen.mkMethodCall(Select(gen.mkAttributedThis(oldClass), target), capturedArgRefs ::: functionArgRefs) } val bridge = postErasure.newTransformer(unit).transform(DefDef(methSym, List(bridgeParams.map(ValDef(_))), adaptToType(forwarderCall setType functionResultType, bridgeResultType))).asInstanceOf[DefDef] boxingBridgeMethods += bridge bridge.symbol } } private def transformFunction(originalFunction: Function): Tree = { val target = targetMethod(originalFunction) assert(target.hasFlag(Flags.STATIC)) target.setFlag(notPRIVATE) val funSym = originalFunction.tpe.typeSymbolDirect // The functional interface that can be used to adapt the lambda target method `target` to the given function type. val (functionalInterface, isSpecialized) = if (!isFunctionSymbol(funSym)) (funSym, false) else { val specializedName = specializeTypes.specializedFunctionName(funSym, exitingErasure(target.info.paramTypes).map(reboxValueClass) :+ reboxValueClass(exitingErasure(target.info.resultType))).toTypeName val isSpecialized = specializedName != funSym.name val functionalInterface = if (isSpecialized) { // Unfortunately we still need to use custom functional interfaces for specialized functions so that the // unboxed apply method is left abstract for us to implement. currentRun.runDefinitions.Scala_Java8_CompatPackage.info.decl(specializedName.prepend("J")) } else FunctionClass(originalFunction.vparams.length) (functionalInterface, isSpecialized) } val sam = originalFunction.attachments.get[SAMFunction].map(_.sam).getOrElse(NoSymbol) mkLambdaMetaFactoryCall(originalFunction, target, functionalInterface, sam, isSpecialized) } // here's the main entry point of the transform override def transform(tree: Tree): Tree = tree match { // the main thing we care about is lambdas case fun: Function => super.transform(transformFunction(fun)) case Template(_, _, _) => def pretransform(tree: Tree): Tree = tree match { case dd: DefDef if dd.symbol.isDelambdafyTarget => if (!dd.symbol.hasFlag(STATIC) && methodReferencesThis(dd.symbol)) { gen.mkStatic(dd, dd.symbol.name, sym => sym) } else { dd.symbol.setFlag(STATIC) dd } case t => t } try { // during this call boxingBridgeMethods will be populated from the Function case val Template(parents, self, body) = super.transform(deriveTemplate(tree)(_.mapConserve(pretransform))) Template(parents, self, body ++ boxingBridgeMethods) } finally boxingBridgeMethods.clear() case dd: DefDef if dd.symbol.isLiftedMethod && !dd.symbol.isDelambdafyTarget => // SI-9390 emit lifted methods that don't require a `this` reference as STATIC // delambdafy targets are excluded as they are made static by `transformFunction`. if (!dd.symbol.hasFlag(STATIC) && !methodReferencesThis(dd.symbol)) { dd.symbol.setFlag(STATIC) dd.symbol.removeAttachment[mixer.NeedStaticImpl.type] } super.transform(tree) case Apply(fun, outer :: rest) if shouldElideOuterArg(fun.symbol, outer) => val nullOuter = gen.mkZero(outer.tpe) treeCopy.Apply(tree, transform(fun), nullOuter :: transformTrees(rest)) case _ => super.transform(tree) } } // DelambdafyTransformer private def shouldElideOuterArg(fun: Symbol, outerArg: Tree): Boolean = fun.isConstructor && treeInfo.isQualifierSafeToElide(outerArg) && fun.hasAttachment[OuterArgCanBeElided.type] // A traverser that finds symbols used but not defined in the given Tree // TODO freeVarTraverser in LambdaLift does a very similar task. With some // analysis this could probably be unified with it class FreeVarTraverser extends Traverser { val freeVars = mutable.LinkedHashSet[Symbol]() val declared = mutable.LinkedHashSet[Symbol]() override def traverse(tree: Tree) = { tree match { case Function(args, _) => args foreach {arg => declared += arg.symbol} case ValDef(_, _, _, _) => declared += tree.symbol case _: Bind => declared += tree.symbol case Ident(_) => val sym = tree.symbol if ((sym != NoSymbol) && sym.isLocalToBlock && sym.isTerm && !sym.isMethod && !declared.contains(sym)) freeVars += sym case _ => } super.traverse(tree) } } object FreeVarTraverser { def freeVarsOf(function: Function) = { val freeVarsTraverser = new FreeVarTraverser freeVarsTraverser.traverse(function) freeVarsTraverser.freeVars } } // finds all methods that reference 'this' class ThisReferringMethodsTraverser extends Traverser { // the set of methods that refer to this private val thisReferringMethods = mutable.Set[Symbol]() // the set of lifted lambda body methods that each method refers to private val liftedMethodReferences = mutable.Map[Symbol, Set[Symbol]]().withDefault(_ => mutable.Set()) def methodReferencesThisIn(tree: Tree) = { traverse(tree) liftedMethodReferences.keys foreach refersToThis thisReferringMethods } // recursively find methods that refer to 'this' directly or indirectly via references to other methods // for each method found add it to the referrers set private def refersToThis(symbol: Symbol): Boolean = { var seen = mutable.Set[Symbol]() def loop(symbol: Symbol): Boolean = { if (seen(symbol)) false else { seen += symbol (thisReferringMethods contains symbol) || (liftedMethodReferences(symbol) exists loop) && { // add it early to memoize debuglog(s"$symbol indirectly refers to 'this'") thisReferringMethods += symbol true } } } loop(symbol) } private var currentMethod: Symbol = NoSymbol override def traverse(tree: Tree) = tree match { case DefDef(_, _, _, _, _, _) if tree.symbol.isDelambdafyTarget || tree.symbol.isLiftedMethod => // we don't expect defs within defs. At this phase trees should be very flat if (currentMethod.exists) devWarning("Found a def within a def at a phase where defs are expected to be flattened out.") currentMethod = tree.symbol super.traverse(tree) currentMethod = NoSymbol case fun@Function(_, _) => // we don't drill into functions because at the beginning of this phase they will always refer to 'this'. // They'll be of the form {(args...) => this.anonfun(args...)} // but we do need to make note of the lifted body method in case it refers to 'this' if (currentMethod.exists) liftedMethodReferences(currentMethod) += targetMethod(fun) case Apply(sel @ Select(This(_), _), args) if sel.symbol.isLiftedMethod => if (currentMethod.exists) liftedMethodReferences(currentMethod) += sel.symbol super.traverseTrees(args) case Apply(fun, outer :: rest) if shouldElideOuterArg(fun.symbol, outer) => super.traverse(fun) super.traverseTrees(rest) case This(_) => if (currentMethod.exists && tree.symbol == currentMethod.enclClass) { debuglog(s"$currentMethod directly refers to 'this'") thisReferringMethods add currentMethod } case _: ClassDef if !tree.symbol.isTopLevel => case _: DefDef => case _ => super.traverse(tree) } } }