diff options
author | odersky <odersky@gmail.com> | 2017-02-23 13:33:48 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-02-23 13:33:48 +0100 |
commit | 2699715684e9ea474001e9fc85bb3c7d25ed1ca5 (patch) | |
tree | c0218ed915ac88dade3370ad0212c58d13c656b2 | |
parent | a0f47a00131935d85f957a80d0c4472eaa7b5baa (diff) | |
parent | 47b6c6b204ca6b577e18fd82ca8af0e5710771b1 (diff) | |
download | dotty-2699715684e9ea474001e9fc85bb3c7d25ed1ca5.tar.gz dotty-2699715684e9ea474001e9fc85bb3c7d25ed1ca5.tar.bz2 dotty-2699715684e9ea474001e9fc85bb3c7d25ed1ca5.zip |
Merge pull request #2015 from dotty-staging/add-pf-overloading
Add overloading support for case-closures
-rw-r--r-- | compiler/src/dotty/tools/dotc/ast/TreeInfo.scala | 2 | ||||
-rw-r--r-- | compiler/src/dotty/tools/dotc/typer/Applications.scala | 54 | ||||
-rw-r--r-- | compiler/test/dotty/tools/dotc/EntryPointsTest.scala.disabled (renamed from compiler/test/dotty/tools/dotc/EntryPointsTest.scala) | 0 | ||||
-rw-r--r-- | tests/pos/inferOverloaded.scala | 41 |
4 files changed, 70 insertions, 27 deletions
diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index bcda4b92f..e48b1039b 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -287,6 +287,8 @@ trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped] case ValDef(_, tpt, _) => tpt.isEmpty case _ => false } + case Match(EmptyTree, _) => + true case _ => false } diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 5e092871d..de017961a 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -1227,6 +1227,8 @@ trait Applications extends Compatibility { self: Typer with Dynamic => def typeShape(tree: untpd.Tree): Type = tree match { case untpd.Function(args, body) => defn.FunctionOf(args map Function.const(defn.AnyType), typeShape(body)) + case Match(EmptyTree, _) => + defn.PartialFunctionType.appliedTo(defn.AnyType :: defn.NothingType :: Nil) case _ => defn.NothingType } @@ -1271,7 +1273,7 @@ trait Applications extends Compatibility { self: Typer with Dynamic => alts filter (alt => sizeFits(alt, alt.widen)) def narrowByShapes(alts: List[TermRef]): List[TermRef] = { - if (normArgs exists (_.isInstanceOf[untpd.Function])) + if (normArgs exists untpd.isFunctionWithUnknownParamType) if (hasNamedArg(args)) narrowByTrees(alts, args map treeShape, resultType) else narrowByTypes(alts, normArgs map typeShape, resultType) else @@ -1351,33 +1353,31 @@ trait Applications extends Compatibility { self: Typer with Dynamic => case ValDef(_, tpt, _) => tpt.isEmpty case _ => false } - arg match { - case arg: untpd.Function if arg.args.exists(isUnknownParamType) => - def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head)) - val formalsForArg: List[Type] = altFormals.map(_.head) - // For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form - // (p_1_1, ..., p_m_1) => r_1 - // ... - // (p_1_n, ..., p_m_n) => r_n - val decomposedFormalsForArg: List[Option[(List[Type], Type, Boolean)]] = - formalsForArg.map(defn.FunctionOf.unapply) - if (decomposedFormalsForArg.forall(_.isDefined)) { - val formalParamTypessForArg: List[List[Type]] = - decomposedFormalsForArg.map(_.get._1) - if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) { - val commonParamTypes = formalParamTypessForArg.transpose.map(ps => - // Given definitions above, for i = 1,...,m, - // ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column - // If all p_i_k's are the same, assume the type as formal parameter - // type of the i'th parameter of the closure. - if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head - else WildcardType) - val commonFormal = defn.FunctionOf(commonParamTypes, WildcardType) - overload.println(i"pretype arg $arg with expected type $commonFormal") - pt.typedArg(arg, commonFormal) - } + if (untpd.isFunctionWithUnknownParamType(arg)) { + def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head)) + val formalsForArg: List[Type] = altFormals.map(_.head) + // For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form + // (p_1_1, ..., p_m_1) => r_1 + // ... + // (p_1_n, ..., p_m_n) => r_n + val decomposedFormalsForArg: List[Option[(List[Type], Type, Boolean)]] = + formalsForArg.map(defn.FunctionOf.unapply) + if (decomposedFormalsForArg.forall(_.isDefined)) { + val formalParamTypessForArg: List[List[Type]] = + decomposedFormalsForArg.map(_.get._1) + if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) { + val commonParamTypes = formalParamTypessForArg.transpose.map(ps => + // Given definitions above, for i = 1,...,m, + // ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column + // If all p_i_k's are the same, assume the type as formal parameter + // type of the i'th parameter of the closure. + if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head + else WildcardType) + val commonFormal = defn.FunctionOf(commonParamTypes, WildcardType) + overload.println(i"pretype arg $arg with expected type $commonFormal") + pt.typedArg(arg, commonFormal) } - case _ => + } } recur(altFormals.map(_.tail), args1) case _ => diff --git a/compiler/test/dotty/tools/dotc/EntryPointsTest.scala b/compiler/test/dotty/tools/dotc/EntryPointsTest.scala.disabled index 00918a282..00918a282 100644 --- a/compiler/test/dotty/tools/dotc/EntryPointsTest.scala +++ b/compiler/test/dotty/tools/dotc/EntryPointsTest.scala.disabled diff --git a/tests/pos/inferOverloaded.scala b/tests/pos/inferOverloaded.scala new file mode 100644 index 000000000..e7179a04a --- /dev/null +++ b/tests/pos/inferOverloaded.scala @@ -0,0 +1,41 @@ +class MySeq[T] { + def map1[U](f: T => U): MySeq[U] = new MySeq[U] + def map2[U](f: T => U): MySeq[U] = new MySeq[U] +} + +class MyMap[A, B] extends MySeq[(A, B)] { + def map1[C](f: (A, B) => C): MySeq[C] = new MySeq[C] + def map1[C, D](f: (A, B) => (C, D)): MyMap[C, D] = new MyMap[C, D] + def map1[C, D](f: ((A, B)) => (C, D)): MyMap[C, D] = new MyMap[C, D] + + def foo(f: Function2[Int, Int, Int]): Unit = () + def foo[R](pf: PartialFunction[(A, B), R]): MySeq[R] = new MySeq[R] +} + +object Test { + val m = new MyMap[Int, String] + + // This one already worked because it is not overloaded: + m.map2 { case (k, v) => k - 1 } + + // These already worked because preSelectOverloaded eliminated the non-applicable overload: + m.map1(t => t._1) + m.map1((kInFunction, vInFunction) => kInFunction - 1) + val r1 = m.map1(t => (t._1, 42.0)) + val r1t: MyMap[Int, Double] = r1 + + // These worked because the argument types are known for overload resolution: + m.map1({ case (k, v) => k - 1 }: PartialFunction[(Int, String), Int]) + m.map2({ case (k, v) => k - 1 }: PartialFunction[(Int, String), Int]) + + // These ones did not work before: + m.map1 { case (k, v) => k } + val r = m.map1 { case (k, v) => (k, k*10) } + val rt: MyMap[Int, Int] = r + m.foo { case (k, v) => k - 1 } + + // Used to be ambiguous but overload resolution now favors PartialFunction + def h[R](pf: Function2[Int, String, R]): Unit = () + def h[R](pf: PartialFunction[(Double, Double), R]): Unit = () + h { case (a: Double, b: Double) => 42: Int } +} |