diff options
-rw-r--r-- | src/dotty/tools/dotc/typer/Applications.scala | 29 | ||||
-rw-r--r-- | tests/neg/overloaded.scala | 17 | ||||
-rw-r--r-- | tests/pos/overloaded.scala | 24 |
3 files changed, 65 insertions, 5 deletions
diff --git a/src/dotty/tools/dotc/typer/Applications.scala b/src/dotty/tools/dotc/typer/Applications.scala index b3a71408b..d655b25f6 100644 --- a/src/dotty/tools/dotc/typer/Applications.scala +++ b/src/dotty/tools/dotc/typer/Applications.scala @@ -1151,9 +1151,30 @@ trait Applications extends Compatibility { self: Typer => } arg match { case arg: untpd.Function if arg.args.exists(isUnknownParamType) => - val commonFormal = altFormals.map(_.head).reduceLeft(_ | _) - overload.println(i"pretype arg $arg with expected type $commonFormal") - pt.typedArg(arg, commonFormal) + def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head)) + val formalsForArg: List[Type] = altFormals.map(_.head) + // For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form + // (p_1_1, ..., p_m_1) => r_1 + // ... + // (p_1_n, ..., p_m_n) => r_n + val decomposedFormalsForArg: List[Option[(List[Type], Type)]] = + formalsForArg.map(defn.FunctionOf.unapply) + if (decomposedFormalsForArg.forall(_.isDefined)) { + val formalParamTypessForArg: List[List[Type]] = + decomposedFormalsForArg.map(_.get._1) + if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) { + val commonParamTypes = formalParamTypessForArg.transpose.map(ps => + // Given definitions above, for i = 1,...,m, + // ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column + // If all p_i_k's are the same, assume the type as formal parameter + // type of the i'th parameter of the closure. + if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head + else WildcardType) + val commonFormal = defn.FunctionOf(commonParamTypes, WildcardType) + overload.println(i"pretype arg $arg with expected type $commonFormal") + pt.typedArg(arg, commonFormal) + } + } case _ => } recur(altFormals.map(_.tail), args1) @@ -1161,7 +1182,7 @@ trait Applications extends Compatibility { self: Typer => } def paramTypes(alt: Type): List[Type] = alt match { case mt: MethodType => mt.paramTypes - case mt: PolyType => paramTypes(mt.resultType).map(wildApprox(_)) + case mt: PolyType => paramTypes(mt.resultType) case _ => Nil } recur(alts.map(alt => paramTypes(alt.widen)), pt.args) diff --git a/tests/neg/overloaded.scala b/tests/neg/overloaded.scala new file mode 100644 index 000000000..ce971ebcf --- /dev/null +++ b/tests/neg/overloaded.scala @@ -0,0 +1,17 @@ +// testing the limits of parameter type inference + +object Test { + def mapX(f: Char => Char): String = ??? + def mapX[U](f: U => U): U = ??? + mapX(x => x) // error: missing parameter type + + def foo(f: Char => Char): Unit = ??? + def foo(f: Int => Int): String = ??? + foo(x => x) // error: missing parameter type + + def bar(f: (Char, Char) => Unit): Unit = ??? + def bar(f: Char => Unit) = ??? + bar((x, y) => ()) + bar (x => ()) + +} diff --git a/tests/pos/overloaded.scala b/tests/pos/overloaded.scala index 9e2260c1c..6a8e72714 100644 --- a/tests/pos/overloaded.scala +++ b/tests/pos/overloaded.scala @@ -24,5 +24,27 @@ object overloaded { def map(f: Char => Char): String = ??? def map[U](f: Char => U): Seq[U] = ??? - map(x => x.toUpper) + val r1 = map(x => x.toUpper) + val t1: String = r1 + val r2 = map(x => x.toInt) + val t2: Seq[Int] = r2 + + def flatMap(f: Char => String): String = ??? + def flatMap[U](f: Char => Seq[U]): Seq[U] = ??? + val r3 = flatMap(x => x.toString) + val t3: String = r3 + val r4 = flatMap(x => List(x)) + val t4: Seq[Char] = r4 + + def bar(f: (Char, Char) => Unit): Unit = ??? + def bar(f: Char => Unit) = ??? + bar((x, y) => ()) + bar (x => ()) + + def combine(f: (Char, Int) => Int): Int = ??? + def combine(f: (String, Int) => String): String = ??? + val r5 = combine((x: Char, y) => x + y) + val t5: Int = r5 + val r6 = combine((x: String, y) => x ++ y.toString) + val t6: String = r6 } |