diff options
8 files changed, 60 insertions, 13 deletions
diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala index 82aa3c65aa..cff623e2b2 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala @@ -659,7 +659,7 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { case Apply(fun, args) if app.hasAttachment[delambdafy.LambdaMetaFactoryCapable] => val attachment = app.attachments.get[delambdafy.LambdaMetaFactoryCapable].get genLoadArguments(args, paramTKs(app)) - genInvokeDynamicLambda(attachment.target, attachment.arity, attachment.functionalInterface) + genInvokeDynamicLambda(attachment.target, attachment.arity, attachment.functionalInterface, attachment.sam) generatedType = methodBTypeFromSymbol(fun.symbol).returnType case Apply(fun @ _, List(expr)) if currentRun.runDefinitions.isBox(fun.symbol) => @@ -1360,7 +1360,7 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { def genSynchronized(tree: Apply, expectedType: BType): BType def genLoadTry(tree: Try): BType - def genInvokeDynamicLambda(lambdaTarget: Symbol, arity: Int, functionalInterface: Symbol) { + def genInvokeDynamicLambda(lambdaTarget: Symbol, arity: Int, functionalInterface: Symbol, sam: Symbol) { val isStaticMethod = lambdaTarget.hasFlag(Flags.STATIC) def asmType(sym: Symbol) = classBTypeFromSymbol(sym).toASMType @@ -1375,7 +1375,6 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { val invokedType = asm.Type.getMethodDescriptor(asmType(functionalInterface), (receiver ::: capturedParams).map(sym => typeToBType(sym.info).toASMType): _*) val constrainedType = new MethodBType(lambdaParams.map(p => typeToBType(p.tpe)), typeToBType(lambdaTarget.tpe.resultType)).toASMType - val sam = functionalInterface.info.decls.find(_.isDeferred).getOrElse(functionalInterface.info.member(nme.apply)) val samName = sam.name.toString val samMethodType = methodBTypeFromSymbol(sam).toASMType diff --git a/src/compiler/scala/tools/nsc/transform/Delambdafy.scala b/src/compiler/scala/tools/nsc/transform/Delambdafy.scala index 0614b138a7..32ab52203e 100644 --- a/src/compiler/scala/tools/nsc/transform/Delambdafy.scala +++ b/src/compiler/scala/tools/nsc/transform/Delambdafy.scala @@ -29,7 +29,7 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre /** the following two members override abstract members in Transform */ val phaseName: String = "delambdafy" - final case class LambdaMetaFactoryCapable(target: Symbol, arity: Int, functionalInterface: Symbol) + final case class LambdaMetaFactoryCapable(target: Symbol, arity: Int, functionalInterface: Symbol, sam: Symbol) /** * Get the symbol of the target lifted lambda body method from a function. I.e. if @@ -60,7 +60,7 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre private[this] lazy val methodReferencesThis: Set[Symbol] = (new ThisReferringMethodsTraverser).methodReferencesThisIn(unit.body) - private def mkLambdaMetaFactoryCall(fun: Function, target: Symbol, functionalInterface: Symbol, isSpecialized: Boolean): Tree = { + private def mkLambdaMetaFactoryCall(fun: Function, target: Symbol, functionalInterface: Symbol, samUserDefined: Symbol, isSpecialized: Boolean): Tree = { val pos = fun.pos val allCapturedArgRefs = { // find which variables are free in the lambda because those are captures that need to be @@ -84,8 +84,14 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre // We then apply this symbol to the captures. val apply = localTyper.typedPos(pos)(Apply(Ident(msym), allCapturedArgRefs)) + // TODO: this is a bit gross + val sam = samUserDefined orElse { + if (isSpecialized) functionalInterface.info.decls.find(_.isDeferred).get + else functionalInterface.info.member(nme.apply) + } + // no need for adaptation when the implemented sam is of a specialized built-in function type - val lambdaTarget = if (isSpecialized) target else createBoxingBridgeMethodIfNeeded(fun, target, functionalInterface) + val lambdaTarget = if (isSpecialized) target else createBoxingBridgeMethodIfNeeded(fun, target, functionalInterface, sam) // The backend needs to know the target of the lambda and the functional interface in order // to emit the invokedynamic instruction. We pass this information as tree attachment. @@ -93,7 +99,7 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre // see https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/LambdaMetafactory.html // instantiatedMethodType is derived from lambdaTarget's signature // samMethodType is derived from samOf(functionalInterface)'s signature - apply.updateAttachment(LambdaMetaFactoryCapable(lambdaTarget, fun.vparams.length, functionalInterface)) + apply.updateAttachment(LambdaMetaFactoryCapable(lambdaTarget, fun.vparams.length, functionalInterface, sam)) apply } @@ -113,7 +119,7 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre } // determine which lambda target to use with java's LMF -- create a new one if scala-specific boxing is required - def createBoxingBridgeMethodIfNeeded(fun: Function, target: Symbol, functionalInterface: Symbol): Symbol = { + def createBoxingBridgeMethodIfNeeded(fun: Function, target: Symbol, functionalInterface: Symbol, sam: Symbol): Symbol = { val oldClass = fun.symbol.enclClass val pos = fun.pos @@ -121,7 +127,6 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre val functionParamTypes = exitingErasure(target.info.paramTypes) val functionResultType = exitingErasure(target.info.resultType) - val sam = samOf(functionalInterface.tpe) orElse functionalInterface.info.member(nme.apply) val samParamTypes = exitingErasure(sam.info.paramTypes) val samResultType = exitingErasure(sam.info.resultType) @@ -241,7 +246,8 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre (functionalInterface, isSpecialized) } - mkLambdaMetaFactoryCall(originalFunction, target, functionalInterface, isSpecialized) + val sam = originalFunction.attachments.get[SAMFunction].map(_.sam).getOrElse(NoSymbol) + mkLambdaMetaFactoryCall(originalFunction, target, functionalInterface, sam, isSpecialized) } // here's the main entry point of the transform diff --git a/src/compiler/scala/tools/nsc/transform/Erasure.scala b/src/compiler/scala/tools/nsc/transform/Erasure.scala index be6316cc57..2df4265573 100644 --- a/src/compiler/scala/tools/nsc/transform/Erasure.scala +++ b/src/compiler/scala/tools/nsc/transform/Erasure.scala @@ -721,6 +721,12 @@ abstract class Erasure extends AddInterfaces if (branch == EmptyTree) branch else adaptToType(branch, tree1.tpe) tree1 match { + case fun: Function => + fun.attachments.get[SAMFunction] match { + case Some(SAMFunction(samTp, _)) => fun setType specialScalaErasure(samTp) + case _ => fun + } + case If(cond, thenp, elsep) => treeCopy.If(tree1, cond, adaptBranch(thenp), adaptBranch(elsep)) case Match(selector, cases) => diff --git a/src/compiler/scala/tools/nsc/transform/TypeAdaptingTransformer.scala b/src/compiler/scala/tools/nsc/transform/TypeAdaptingTransformer.scala index 596091f75d..afafdedce7 100644 --- a/src/compiler/scala/tools/nsc/transform/TypeAdaptingTransformer.scala +++ b/src/compiler/scala/tools/nsc/transform/TypeAdaptingTransformer.scala @@ -118,8 +118,7 @@ trait TypeAdaptingTransformer { self: TreeDSL => val needsExtraCast = isPrimitiveValueType(tree.tpe.typeArgs.head) && !isPrimitiveValueType(pt.typeArgs.head) val tree1 = if (needsExtraCast) gen.mkRuntimeCall(nme.toObjectArray, List(tree)) else tree gen.mkAttributedCast(tree1, pt) - } else if (samMatchingFunction(tree, pt).exists) tree setType pt // SAM <:< FunctionN if sam is convertible to said function - else gen.mkAttributedCast(tree, pt) + } else gen.mkAttributedCast(tree, pt) } /** Adapt `tree` to expected type `pt`. diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala index cd5759f40f..a9d5b69e2e 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala @@ -1076,7 +1076,8 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper val sam = samMatchingFunction(tree, pt) // this implies tree.isInstanceOf[Function] if (sam.exists && !tree.tpe.isErroneous) { val samTree = adaptToSAM(sam, tree.asInstanceOf[Function], pt, mode) - if (samTree ne EmptyTree) return samTree + if (samTree ne EmptyTree) + return samTree.updateAttachment(SAMFunction(pt, sam)) } } diff --git a/src/reflect/scala/reflect/internal/StdAttachments.scala b/src/reflect/scala/reflect/internal/StdAttachments.scala index 8358c1295c..0243dd48d2 100644 --- a/src/reflect/scala/reflect/internal/StdAttachments.scala +++ b/src/reflect/scala/reflect/internal/StdAttachments.scala @@ -38,6 +38,19 @@ trait StdAttachments { */ case class CompoundTypeTreeOriginalAttachment(parents: List[Tree], stats: List[Tree]) + /** Attached to a Function node during type checking when the expected type is a SAM type (and not a built-in FunctionN). + * + * Ideally, we'd move to Dotty's Closure AST, which tracks the environment, + * the lifted method that has the implementation, and the target type. + * For backwards compatibility, an attachment is the best we can do right now. + * + * @param samTp the expected type that triggered sam conversion (may be a subtype of the type corresponding to sam's owner) + * @param sam the single abstract method implemented by the Function we're attaching this to + * + * @since 2.12.0-M4 + */ + case class SAMFunction(samTp: Type, sam: Symbol) extends PlainAttachment + /** When present, indicates that the host `Ident` has been created from a backquoted identifier. */ case object BackquotedIdentifierAttachment extends PlainAttachment diff --git a/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala b/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala index 13874916cc..4630597668 100644 --- a/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala +++ b/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala @@ -37,6 +37,7 @@ trait JavaUniverseForce { self: runtime.JavaUniverse => this.FixedMirrorTreeCreator this.FixedMirrorTypeCreator this.CompoundTypeTreeOriginalAttachment + this.SAMFunction this.BackquotedIdentifierAttachment this.ForAttachment this.SyntheticUnitAttachment diff --git a/test/files/run/sammy_erasure_cce.scala b/test/files/run/sammy_erasure_cce.scala new file mode 100644 index 0000000000..fb973befe4 --- /dev/null +++ b/test/files/run/sammy_erasure_cce.scala @@ -0,0 +1,22 @@ +trait F1 { + def apply(a: List[String]): String + def f1 = "f1" +} + +object Test extends App { + // Wrap the sam-targeting function in a context where the expected type is erased (identity's argument type erases to Object), + // so that Erasure can't tell that the types actually conform by looking only + // at an un-adorned Function tree and the expected type + // (because a function type needs no cast it the expected type is a SAM type), + // + // A correct implementation of Typers/Erasure tracks a Function's SAM target type directly + // (currently using an attachment for backwards compat), + // and not in the expected type (which was the case in my first attempt), + // as the expected type may lose its SAM status due to erasure. + // (In a sense, this need not be so, but erasure drops type parameters, + // so that identity's F1 type argument cannot be propagated to its argument type.) + def foo = identity[F1]((as: List[String]) => as.head) + + // check that this doesn't CCE's + foo.f1 +} |