From 9c36a76171d497fe277030d2f682f927ae2657f4 Mon Sep 17 00:00:00 2001 From: Adriaan Moors Date: Thu, 6 Nov 2014 11:50:13 +0100 Subject: [sammy] eta-expansion, overloading (SI-8310) Playing with Java 8 Streams from the repl showed we weren't eta-expanding, nor resolving overloading for SAMs. Also, the way Java uses wildcards to represent use-site variance stresses type inference past its bendiness point (due to excessive existentials). I introduce `wildcardExtrapolation` to simplify the resulting types (without losing precision): `wildcardExtrapolation(tp) =:= tp`. For example, the `MethodType` given by `def bla(x: (_ >: String)): (_ <: Int)` is both a subtype and a supertype of `def bla(x: String): Int`. Translating http://winterbe.com/posts/2014/07/31/java8-stream-tutorial-examples/ into Scala shows most of this works, though we have some more work to do (see near the end). ``` scala> import java.util.Arrays scala> import java.util.stream.Stream scala> import java.util.stream.IntStream scala> val myList = Arrays.asList("a1", "a2", "b1", "c2", "c1") myList: java.util.List[String] = [a1, a2, b1, c2, c1] scala> myList.stream.filter(_.startsWith("c")).map(_.toUpperCase).sorted.forEach(println) C1 C2 scala> myList.stream.filter(_.startsWith("c")).map(_.toUpperCase).sorted res8: java.util.stream.Stream[?0] = java.util.stream.SortedOps$OfRef@133e7789 scala> Arrays.asList("a1", "a2", "a3").stream.findFirst.ifPresent(println) a1 scala> Stream.of("a1", "a2", "a3").findFirst.ifPresent(println) a1 scala> IntStream.range(1, 4).forEach(println) :37: error: object creation impossible, since method accept in trait IntConsumer of type (x$1: Int)Unit is not defined (Note that Int does not match Any: class Int in package scala is a subclass of class Any in package scala, but method parameter types must match exactly.) IntStream.range(1, 4).forEach(println) ^ scala> IntStream.range(1, 4).forEach(println(_: Int)) // TODO: can we avoid this annotation? 1 2 3 scala> Arrays.stream(Array(1, 2, 3)).map(n => 2 * n + 1).average.ifPresent(println(_: Double)) 5.0 scala> Stream.of("a1", "a2", "a3").map(_.substring(1)).mapToInt(_.parseInt).max.ifPresent(println(_: Int)) // whoops! ReplGlobal.abort: Unknown type: , [class scala.reflect.internal.Types$ErrorType$, class scala.reflect.internal.Types$ErrorType$] TypeRef? false error: Unknown type: , [class scala.reflect.internal.Types$ErrorType$, class scala.reflect.internal.Types$ErrorType$] TypeRef? false scala.reflect.internal.FatalError: Unknown type: , [class scala.reflect.internal.Types$ErrorType$, class scala.reflect.internal.Types$ErrorType$] TypeRef? false at scala.reflect.internal.Reporting$class.abort(Reporting.scala:59) scala> IntStream.range(1, 4).mapToObj(i => "a" + i).forEach(println) a1 a2 a3 ``` --- .../scala/tools/nsc/typechecker/Infer.scala | 6 ++++ .../scala/tools/nsc/typechecker/Typers.scala | 36 ++++++++++++++++------ .../scala/reflect/internal/Definitions.scala | 2 +- .../scala/reflect/internal/tpe/TypeMaps.scala | 16 ++++++++++ .../scala/reflect/runtime/JavaUniverseForce.scala | 1 + 5 files changed, 50 insertions(+), 11 deletions(-) (limited to 'src') diff --git a/src/compiler/scala/tools/nsc/typechecker/Infer.scala b/src/compiler/scala/tools/nsc/typechecker/Infer.scala index ee2775ee26..8979b26719 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Infer.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Infer.scala @@ -295,11 +295,17 @@ trait Infer extends Checkable { && !isByNameParamType(tp) && isCompatible(tp, dropByName(pt)) ) + def isCompatibleSam(tp: Type, pt: Type): Boolean = { + val samFun = typer.samToFunctionType(pt) + (samFun ne NoType) && isCompatible(tp, samFun) + } + val tp1 = normalize(tp) ( (tp1 weak_<:< pt) || isCoercible(tp1, pt) || isCompatibleByName(tp, pt) + || isCompatibleSam(tp, pt) ) } def isCompatibleArgs(tps: List[Type], pts: List[Type]) = (tps corresponds pts)(isCompatible) diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala index 70acb03584..c41b81aaf5 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala @@ -741,6 +741,26 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper case _ => } + /** + * Convert a SAM type to the corresponding FunctionType, + * extrapolating BoundedWildcardTypes in the process + * (no type precision is lost by the extrapolation, + * but this facilitates dealing with the types arising from Java's use-site variance). + */ + def samToFunctionType(tp: Type, sam: Symbol = NoSymbol): Type = { + val samSym = sam orElse samOf(tp) + + def correspondingFunctionSymbol = { + val numVparams = samSym.info.params.length + if (numVparams > definitions.MaxFunctionArity) NoSymbol + else FunctionClass(numVparams) + } + + if (samSym.exists && samSym.owner != correspondingFunctionSymbol) // don't treat Functions as SAMs + wildcardExtrapolation(normalize(tp memberInfo samSym)) + else NoType + } + /** Perform the following adaptations of expression, pattern or type `tree` wrt to * given mode `mode` and given prototype `pt`: * (-1) For expressions with annotated types, let AnnotationCheckers decide what to do @@ -824,7 +844,7 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper case Block(_, tree1) => tree1.symbol case _ => tree.symbol } - if (!meth.isConstructor && isFunctionType(pt)) { // (4.2) + if (!meth.isConstructor && (isFunctionType(pt) || samOf(pt).exists)) { // (4.2) debuglog(s"eta-expanding $tree: ${tree.tpe} to $pt") checkParamsConvertible(tree, tree.tpe) val tree0 = etaExpand(context.unit, tree, this) @@ -2838,7 +2858,7 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper * as `(a => a): Int => Int` should not (yet) get the sam treatment. */ val sam = - if (!settings.Xexperimental || pt.typeSymbol == FunctionSymbol) NoSymbol + if (pt.typeSymbol == FunctionSymbol) NoSymbol else samOf(pt) /* The SAM case comes first so that this works: @@ -2848,15 +2868,11 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper * Note that the arity of the sam must correspond to the arity of the function. */ val samViable = sam.exists && sameLength(sam.info.params, fun.vparams) + val ptNorm = if (samViable) samToFunctionType(pt, sam) else pt val (argpts, respt) = - if (samViable) { - val samInfo = pt memberInfo sam - (samInfo.paramTypes, samInfo.resultType) - } else { - pt baseType FunctionSymbol match { - case TypeRef(_, FunctionSymbol, args :+ res) => (args, res) - case _ => (fun.vparams map (_ => if (pt == ErrorType) ErrorType else NoType), WildcardType) - } + ptNorm baseType FunctionSymbol match { + case TypeRef(_, FunctionSymbol, args :+ res) => (args, res) + case _ => (fun.vparams map (_ => if (pt == ErrorType) ErrorType else NoType), WildcardType) } if (!FunctionSymbol.exists) diff --git a/src/reflect/scala/reflect/internal/Definitions.scala b/src/reflect/scala/reflect/internal/Definitions.scala index e2ee6a9076..85b080f29b 100644 --- a/src/reflect/scala/reflect/internal/Definitions.scala +++ b/src/reflect/scala/reflect/internal/Definitions.scala @@ -790,7 +790,7 @@ trait Definitions extends api.StandardDefinitions { * The class defining the method is a supertype of `tp` that * has a public no-arg primary constructor. */ - def samOf(tp: Type): Symbol = { + def samOf(tp: Type): Symbol = if (!settings.Xexperimental) NoSymbol else { // if tp has a constructor, it must be public and must not take any arguments // (not even an implicit argument list -- to keep it simple for now) val tpSym = tp.typeSymbol diff --git a/src/reflect/scala/reflect/internal/tpe/TypeMaps.scala b/src/reflect/scala/reflect/internal/tpe/TypeMaps.scala index f06420de96..662306d5ab 100644 --- a/src/reflect/scala/reflect/internal/tpe/TypeMaps.scala +++ b/src/reflect/scala/reflect/internal/tpe/TypeMaps.scala @@ -422,6 +422,22 @@ private[internal] trait TypeMaps { } } + /** + * Get rid of BoundedWildcardType where variance allows us to do so. + * Invariant: `wildcardExtrapolation(tp) =:= tp` + * + * For example, the MethodType given by `def bla(x: (_ >: String)): (_ <: Int)` + * is both a subtype and a supertype of `def bla(x: String): Int`. + */ + object wildcardExtrapolation extends TypeMap(trackVariance = true) { + def apply(tp: Type): Type = + tp match { + case BoundedWildcardType(TypeBounds(lo, AnyTpe)) if variance.isContravariant =>lo + case BoundedWildcardType(TypeBounds(NothingTpe, hi)) if variance.isCovariant => hi + case tp => mapOver(tp) + } + } + /** Might the given symbol be important when calculating the prefix * of a type? When tp.asSeenFrom(pre, clazz) is called on `tp`, * the result will be `tp` unchanged if `pre` is trivial and `clazz` diff --git a/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala b/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala index 18a3c5d63f..c87b810bdd 100644 --- a/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala +++ b/src/reflect/scala/reflect/runtime/JavaUniverseForce.scala @@ -170,6 +170,7 @@ trait JavaUniverseForce { self: runtime.JavaUniverse => this.dropSingletonType this.abstractTypesToBounds this.dropIllegalStarTypes + this.wildcardExtrapolation this.IsDependentCollector this.ApproximateDependentMap this.wildcardToTypeVarMap -- cgit v1.2.3