summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdriaan Moors <adriaan@lightbend.com>2016-07-27 11:18:08 -0700
committerAdriaan Moors <adriaan@lightbend.com>2016-08-12 16:39:31 -0700
commit228d1b1ceb7643df1313672cd620cddb1a429029 (patch)
tree4fb946f9fe8a03cf759f194b94e17c8544bbd4fb
parent618d42c747955a43557655bdc0c4281fec5a7923 (diff)
downloadscala-228d1b1ceb7643df1313672cd620cddb1a429029.tar.gz
scala-228d1b1ceb7643df1313672cd620cddb1a429029.tar.bz2
scala-228d1b1ceb7643df1313672cd620cddb1a429029.zip
Propagate overloaded function type to expected arg type
Infer missing parameter types for function literals passed to higher-order overloaded methods by deriving the expected argument type from the function types in the overloaded method type's argument types. This eases the pain caused by methods becoming overloaded because SAM types and function types are compatible, which used to disable parameter type inference because for overload resolution arguments are typed without expected type, while typedFunction needs the expected type to infer missing parameter types for function literals. It also aligns us with dotty. The special case for function literals seems reasonable, as it has precedent, and it just enables the special case in typing function literals (derive the param types from the expected type). Since this does change type inference, you can opt out using the Scala 2.11 source level. Fix scala/scala-dev#157
-rw-r--r--spec/06-expressions.md11
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/Infer.scala2
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/Typers.scala87
-rw-r--r--src/reflect/scala/reflect/internal/Definitions.scala11
-rw-r--r--src/reflect/scala/reflect/internal/TreeInfo.scala6
-rw-r--r--test/files/neg/sammy_overload.check14
-rw-r--r--test/files/neg/sammy_overload.scala12
-rw-r--r--test/files/neg/t6214.check7
-rw-r--r--test/files/pos/overloaded_ho_fun.scala51
9 files changed, 156 insertions, 45 deletions
diff --git a/spec/06-expressions.md b/spec/06-expressions.md
index 30ad73a3cd..1ecba5adf1 100644
--- a/spec/06-expressions.md
+++ b/spec/06-expressions.md
@@ -1426,9 +1426,14 @@ Let $\mathscr{B}$ be the set of alternatives in $\mathscr{A}$ that are [_applica
to expressions $(e_1 , \ldots , e_n)$ of types $(\mathit{shape}(e_1) , \ldots , \mathit{shape}(e_n))$.
If there is precisely one alternative in $\mathscr{B}$, that alternative is chosen.
-Otherwise, let $S_1 , \ldots , S_m$ be the vector of types obtained by
-typing each argument with an undefined expected type. For every
-member $m$ in $\mathscr{B}$ one determines whether it is applicable
+Otherwise, let $S_1 , \ldots , S_m$ be the list of types obtained by typing each argument as follows.
+An argument `$e_i$` of the shape `($p_1$: $T_1 , \ldots , p_n$: $T_n$) => $b$` where one of the `$T_i$` is missing,
+i.e., a function literal with a missing parameter type, is typed with an expected function type that
+propagates the least upper bound of the fully defined types of the corresponding parameters of
+the ([SAM-converted](#sam-conversion)) function types specified by the `$i$`th argument type found in each alternative.
+All other arguments are typed with an undefined expected type.
+
+For every member $m$ in $\mathscr{B}$ one determines whether it is applicable
to expressions ($e_1 , \ldots , e_m$) of types $S_1, \ldots , S_m$.
It is an error if none of the members in $\mathscr{B}$ is applicable. If there is one
diff --git a/src/compiler/scala/tools/nsc/typechecker/Infer.scala b/src/compiler/scala/tools/nsc/typechecker/Infer.scala
index 7112edd75d..0071d66eb9 100644
--- a/src/compiler/scala/tools/nsc/typechecker/Infer.scala
+++ b/src/compiler/scala/tools/nsc/typechecker/Infer.scala
@@ -731,7 +731,7 @@ trait Infer extends Checkable {
// If args eq the incoming arg types, fail; otherwise recurse with these args.
def tryWithArgs(args: List[Type]) = (
(args ne argtpes0)
- && isApplicable(undetparams, mt, args, pt)
+ && isApplicableToMethod(undetparams, mt, args, pt) // used to be isApplicable(undetparams, mt, args, pt), knowing mt: MethodType
)
def tryInstantiating(args: List[Type]) = falseIfNoInstance {
val restpe = mt resultType args
diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala
index d412b5ef33..8ce8dde9ab 100644
--- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala
+++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala
@@ -3271,40 +3271,74 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
}
val fun = preSelectOverloaded(fun0)
+ val argslen = args.length
fun.tpe match {
case OverloadedType(pre, alts) =>
def handleOverloaded = {
val undetparams = context.undetparams
+
+ def funArgTypes(tps: List[Type]) = tps.map { tp =>
+ val relTp = tp.asSeenFrom(pre, fun.symbol.owner)
+ val argTps = functionOrSamArgTypes(relTp)
+ //println(s"funArgTypes $argTps from $relTp")
+ argTps.map(approximateAbstracts)
+ }
+
+ def functionProto(argTps: List[Type]): Type =
+ try functionType(funArgTypes(argTps).transpose.map(lub), WildcardType)
+ catch { case _: IllegalArgumentException => WildcardType }
+
+ // To propagate as much information as possible to typedFunction, which uses the expected type to
+ // infer missing parameter types for Function trees that we're typing as arguments here,
+ // we expand the parameter types for all alternatives to the expected argument length,
+ // then transpose to get a list of alternative argument types (push down the overloading to the arguments).
+ // Thus, for each `arg` in `args`, the corresponding `argPts` in `altArgPts` is a list of expected types
+ // for `arg`. Depending on which overload is picked, only one of those expected types must be met, but
+ // we're in the process of figuring that out, so we'll approximate below by normalizing them to function types
+ // and lubbing the argument types (we treat SAM and FunctionN types equally, but non-function arguments
+ // do not receive special treatment: they are typed under WildcardType.)
+ val altArgPts =
+ if (settings.isScala212 && args.exists(treeInfo.isFunctionMissingParamType))
+ try alts.map(alt => formalTypes(alt.info.paramTypes, argslen)).transpose // do least amount of work up front
+ catch { case _: IllegalArgumentException => args.map(_ => Nil) } // fail safe in case formalTypes fails to align to argslen
+ else args.map(_ => Nil) // will type under argPt == WildcardType
+
val (args1, argTpes) = context.savingUndeterminedTypeParams() {
val amode = forArgMode(fun, mode)
- def typedArg0(tree: Tree) = typedArg(tree, amode, BYVALmode, WildcardType)
- args.map {
- case arg @ AssignOrNamedArg(Ident(name), rhs) =>
- // named args: only type the righthand sides ("unknown identifier" errors otherwise)
- // the assign is untyped; that's ok because we call doTypedApply
- val typedRhs = typedArg0(rhs)
- val argWithTypedRhs = treeCopy.AssignOrNamedArg(arg, arg.lhs, typedRhs)
-
- // TODO: SI-8197/SI-4592: check whether this named argument could be interpreted as an assign
+
+ map2(args, altArgPts) { (arg, argPts) =>
+ def typedArg0(tree: Tree) = {
+ // if we have an overloaded HOF such as `(f: Int => Int)Int <and> (f: Char => Char)Char`,
+ // and we're typing a function like `x => x` for the argument, try to collapse
+ // the overloaded type into a single function type from which `typedFunction`
+ // can derive the argument type for `x` in the function literal above
+ val argPt =
+ if (argPts.nonEmpty && treeInfo.isFunctionMissingParamType(tree)) functionProto(argPts)
+ else WildcardType
+
+ val argTyped = typedArg(tree, amode, BYVALmode, argPt)
+ (argTyped, argTyped.tpe.deconst)
+ }
+
+ arg match {
+ // SI-8197/SI-4592 call for checking whether this named argument could be interpreted as an assign
// infer.checkNames must not use UnitType: it may not be a valid assignment, or the setter may return another type from Unit
- //
- // var typedAsAssign = true
- // val argTyped = silent(_.typedArg(argWithTypedRhs, amode, BYVALmode, WildcardType)) orElse { errors =>
- // typedAsAssign = false
- // argWithTypedRhs
- // }
- //
- // TODO: add an assignmentType field to NamedType, equal to:
- // assignmentType = if (typedAsAssign) argTyped.tpe else NoType
-
- (argWithTypedRhs, NamedType(name, typedRhs.tpe.deconst))
- case arg @ treeInfo.WildcardStarArg(repeated) =>
- val arg1 = typedArg0(arg)
- (arg1, RepeatedType(arg1.tpe.deconst))
- case arg =>
- val arg1 = typedArg0(arg)
- (arg1, arg1.tpe.deconst)
+ // TODO: just make it an error to refer to a non-existent named arg, as it's far more likely to be
+ // a typo than an assignment passed as an argument
+ case AssignOrNamedArg(lhs@Ident(name), rhs) =>
+ // named args: only type the righthand sides ("unknown identifier" errors otherwise)
+ // the assign is untyped; that's ok because we call doTypedApply
+ typedArg0(rhs) match {
+ case (rhsTyped, tp) => (treeCopy.AssignOrNamedArg(arg, lhs, rhsTyped), NamedType(name, tp))
+ }
+ case treeInfo.WildcardStarArg(_) =>
+ typedArg0(arg) match {
+ case (argTyped, tp) => (argTyped, RepeatedType(tp))
+ }
+ case _ =>
+ typedArg0(arg)
+ }
}.unzip
}
if (context.reporter.hasErrors)
@@ -3335,7 +3369,6 @@ trait Typers extends Adaptations with Tags with TypersTracking with PatternTyper
case mt @ MethodType(params, _) =>
val paramTypes = mt.paramTypes
// repeat vararg as often as needed, remove by-name
- val argslen = args.length
val formals = formalTypes(paramTypes, argslen)
/* Try packing all arguments into a Tuple and apply `fun`
diff --git a/src/reflect/scala/reflect/internal/Definitions.scala b/src/reflect/scala/reflect/internal/Definitions.scala
index eca1bbea5a..0e63e81ae1 100644
--- a/src/reflect/scala/reflect/internal/Definitions.scala
+++ b/src/reflect/scala/reflect/internal/Definitions.scala
@@ -687,6 +687,17 @@ trait Definitions extends api.StandardDefinitions {
}
}
+ // the argument types expected by the function described by `tp` (a FunctionN or SAM type),
+ // or `Nil` if `tp` does not represent a function type or SAM (or if it happens to be Function0...)
+ def functionOrSamArgTypes(tp: Type): List[Type] = {
+ val dealiased = tp.dealiasWiden
+ if (isFunctionTypeDirect(dealiased)) dealiased.typeArgs.init
+ else samOf(tp) match {
+ case samSym if samSym.exists => tp.memberInfo(samSym).paramTypes
+ case _ => Nil
+ }
+ }
+
// the SAM's parameters and the Function's formals must have the same length
// (varargs etc don't come into play, as we're comparing signatures, not checking an application)
def samMatchesFunctionBasedOnArity(sam: Symbol, formals: List[Any]): Boolean =
diff --git a/src/reflect/scala/reflect/internal/TreeInfo.scala b/src/reflect/scala/reflect/internal/TreeInfo.scala
index 86b70643c9..b9f3e987ee 100644
--- a/src/reflect/scala/reflect/internal/TreeInfo.scala
+++ b/src/reflect/scala/reflect/internal/TreeInfo.scala
@@ -263,6 +263,12 @@ abstract class TreeInfo {
true
}
+ def isFunctionMissingParamType(tree: Tree): Boolean = tree match {
+ case Function(vparams, _) => vparams.exists(_.tpt.isEmpty)
+ case _ => false
+ }
+
+
/** Is symbol potentially a getter of a variable?
*/
def mayBeVarGetter(sym: Symbol): Boolean = sym.info match {
diff --git a/test/files/neg/sammy_overload.check b/test/files/neg/sammy_overload.check
index 903d7c88f4..87b198f4f0 100644
--- a/test/files/neg/sammy_overload.check
+++ b/test/files/neg/sammy_overload.check
@@ -1,7 +1,7 @@
-sammy_overload.scala:11: error: missing parameter type for expanded function ((x$1: <error>) => x$1.toString)
- O.m(_.toString) // error expected: eta-conversion breaks down due to overloading
- ^
-sammy_overload.scala:12: error: missing parameter type
- O.m(x => x) // error expected: needs param type
- ^
-two errors found
+sammy_overload.scala:14: error: overloaded method value m with alternatives:
+ (x: ToString)Int <and>
+ (x: Int => String)Int
+ cannot be applied to (Int => Int)
+ O.m(x => x) // error expected: m cannot be applied to Int => Int
+ ^
+one error found
diff --git a/test/files/neg/sammy_overload.scala b/test/files/neg/sammy_overload.scala
index 91c52cf96c..548e9d2d2e 100644
--- a/test/files/neg/sammy_overload.scala
+++ b/test/files/neg/sammy_overload.scala
@@ -2,12 +2,14 @@ trait ToString { def convert(x: Int): String }
class ExplicitSamType {
object O {
- def m(x: Int => String): Int = 0
- def m(x: ToString): Int = 1
+ def m(x: Int => String): Int = 0 // (1)
+ def m(x: ToString): Int = 1 // (2)
}
- O.m((x: Int) => x.toString) // ok, function type takes precedence
+ O.m((x: Int) => x.toString) // ok, function type takes precedence, because (1) is more specific than (2),
+ // because (1) is as specific as (2): (2) can be applied to a value of type Int => String (well, assuming it's a function literal)
+ // but (2) is not as specific as (1): (1) cannot be applied to a value of type ToString
- O.m(_.toString) // error expected: eta-conversion breaks down due to overloading
- O.m(x => x) // error expected: needs param type
+ O.m(_.toString) // ok: overloading resolution pushes through `Int` as the argument type, so this type checks
+ O.m(x => x) // error expected: m cannot be applied to Int => Int
}
diff --git a/test/files/neg/t6214.check b/test/files/neg/t6214.check
index 6349a3e71c..9d746351d1 100644
--- a/test/files/neg/t6214.check
+++ b/test/files/neg/t6214.check
@@ -1,4 +1,7 @@
-t6214.scala:5: error: missing parameter type
+t6214.scala:5: error: ambiguous reference to overloaded definition,
+both method m in object Test of type (f: Int => Unit)Int
+and method m in object Test of type (f: String => Unit)Int
+match argument types (Any => Unit)
m { s => case class Foo() }
- ^
+ ^
one error found
diff --git a/test/files/pos/overloaded_ho_fun.scala b/test/files/pos/overloaded_ho_fun.scala
new file mode 100644
index 0000000000..2699ad35f8
--- /dev/null
+++ b/test/files/pos/overloaded_ho_fun.scala
@@ -0,0 +1,51 @@
+import scala.math.Ordering
+import scala.reflect.ClassTag
+
+trait Sam { def apply(x: Int): String }
+trait SamP[U] { def apply(x: Int): U }
+
+class OverloadedFun[T](x: T) {
+ def foo(f: T => String): String = f(x)
+ def foo(f: Any => T): T = f("a")
+
+ def poly[U](f: Int => String): String = f(1)
+ def poly[U](f: Int => U): U = f(1)
+
+ def polySam[U](f: Sam): String = f(1)
+ def polySam[U](f: SamP[U]): U = f(1)
+
+ // check that we properly instantiate java.util.function.Function's type param to String
+ def polyJavaSam(f: String => String) = 1
+ def polyJavaSam(f: java.util.function.Function[String, String]) = 2
+}
+
+class StringLike(xs: String) {
+ def map[A](f: Char => A): Array[A] = ???
+ def map(f: Char => Char): String = ???
+}
+
+object Test {
+ val of = new OverloadedFun[Int](1)
+
+ of.foo(_.toString)
+
+ of.poly(x => x / 2 )
+ of.polySam(x => x / 2 )
+ of.polyJavaSam(x => x)
+
+ val sl = new StringLike("a")
+ sl.map(_ == 'a') // : Array[Boolean]
+ sl.map(x => 'a') // : String
+}
+
+object sorting {
+ def stableSort[K: ClassTag](a: Seq[K], f: (K, K) => Boolean): Array[K] = ???
+ def stableSort[L: ClassTag](a: Array[L], f: (L, L) => Boolean): Unit = ???
+
+ stableSort(??? : Seq[Boolean], (x: Boolean, y: Boolean) => x && !y)
+}
+
+// trait Bijection[A, B] extends (A => B) {
+// def andThen[C](g: Bijection[B, C]): Bijection[A, C] = ???
+// def compose[T](g: Bijection[T, A]) = g andThen this
+// }