aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/dotty/tools/dotc/typer/Applications.scala29
-rw-r--r--tests/neg/overloaded.scala17
-rw-r--r--tests/pos/overloaded.scala24
3 files changed, 65 insertions, 5 deletions
diff --git a/src/dotty/tools/dotc/typer/Applications.scala b/src/dotty/tools/dotc/typer/Applications.scala
index b3a71408b..d655b25f6 100644
--- a/src/dotty/tools/dotc/typer/Applications.scala
+++ b/src/dotty/tools/dotc/typer/Applications.scala
@@ -1151,9 +1151,30 @@ trait Applications extends Compatibility { self: Typer =>
}
arg match {
case arg: untpd.Function if arg.args.exists(isUnknownParamType) =>
- val commonFormal = altFormals.map(_.head).reduceLeft(_ | _)
- overload.println(i"pretype arg $arg with expected type $commonFormal")
- pt.typedArg(arg, commonFormal)
+ def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head))
+ val formalsForArg: List[Type] = altFormals.map(_.head)
+ // For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form
+ // (p_1_1, ..., p_m_1) => r_1
+ // ...
+ // (p_1_n, ..., p_m_n) => r_n
+ val decomposedFormalsForArg: List[Option[(List[Type], Type)]] =
+ formalsForArg.map(defn.FunctionOf.unapply)
+ if (decomposedFormalsForArg.forall(_.isDefined)) {
+ val formalParamTypessForArg: List[List[Type]] =
+ decomposedFormalsForArg.map(_.get._1)
+ if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) {
+ val commonParamTypes = formalParamTypessForArg.transpose.map(ps =>
+ // Given definitions above, for i = 1,...,m,
+ // ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column
+ // If all p_i_k's are the same, assume the type as formal parameter
+ // type of the i'th parameter of the closure.
+ if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head
+ else WildcardType)
+ val commonFormal = defn.FunctionOf(commonParamTypes, WildcardType)
+ overload.println(i"pretype arg $arg with expected type $commonFormal")
+ pt.typedArg(arg, commonFormal)
+ }
+ }
case _ =>
}
recur(altFormals.map(_.tail), args1)
@@ -1161,7 +1182,7 @@ trait Applications extends Compatibility { self: Typer =>
}
def paramTypes(alt: Type): List[Type] = alt match {
case mt: MethodType => mt.paramTypes
- case mt: PolyType => paramTypes(mt.resultType).map(wildApprox(_))
+ case mt: PolyType => paramTypes(mt.resultType)
case _ => Nil
}
recur(alts.map(alt => paramTypes(alt.widen)), pt.args)
diff --git a/tests/neg/overloaded.scala b/tests/neg/overloaded.scala
new file mode 100644
index 000000000..ce971ebcf
--- /dev/null
+++ b/tests/neg/overloaded.scala
@@ -0,0 +1,17 @@
+// testing the limits of parameter type inference
+
+object Test {
+ def mapX(f: Char => Char): String = ???
+ def mapX[U](f: U => U): U = ???
+ mapX(x => x) // error: missing parameter type
+
+ def foo(f: Char => Char): Unit = ???
+ def foo(f: Int => Int): String = ???
+ foo(x => x) // error: missing parameter type
+
+ def bar(f: (Char, Char) => Unit): Unit = ???
+ def bar(f: Char => Unit) = ???
+ bar((x, y) => ())
+ bar (x => ())
+
+}
diff --git a/tests/pos/overloaded.scala b/tests/pos/overloaded.scala
index 9e2260c1c..6a8e72714 100644
--- a/tests/pos/overloaded.scala
+++ b/tests/pos/overloaded.scala
@@ -24,5 +24,27 @@ object overloaded {
def map(f: Char => Char): String = ???
def map[U](f: Char => U): Seq[U] = ???
- map(x => x.toUpper)
+ val r1 = map(x => x.toUpper)
+ val t1: String = r1
+ val r2 = map(x => x.toInt)
+ val t2: Seq[Int] = r2
+
+ def flatMap(f: Char => String): String = ???
+ def flatMap[U](f: Char => Seq[U]): Seq[U] = ???
+ val r3 = flatMap(x => x.toString)
+ val t3: String = r3
+ val r4 = flatMap(x => List(x))
+ val t4: Seq[Char] = r4
+
+ def bar(f: (Char, Char) => Unit): Unit = ???
+ def bar(f: Char => Unit) = ???
+ bar((x, y) => ())
+ bar (x => ())
+
+ def combine(f: (Char, Int) => Int): Int = ???
+ def combine(f: (String, Int) => String): String = ???
+ val r5 = combine((x: Char, y) => x + y)
+ val t5: Int = r5
+ val r6 = combine((x: String, y) => x ++ y.toString)
+ val t6: String = r6
}