From 6c164a5d906c657baa045c1d564c63273eb65f31 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Tue, 21 Feb 2017 17:50:05 +0100 Subject: Extend argument pretyping to case-closures --- .../src/dotty/tools/dotc/typer/Applications.scala | 57 ++++++++++------------ 1 file changed, 25 insertions(+), 32 deletions(-) (limited to 'compiler/src/dotty/tools/dotc/typer/Applications.scala') diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 2cfe01616..b2c4e2f45 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -1272,13 +1272,8 @@ trait Applications extends Compatibility { self: Typer with Dynamic => def narrowBySize(alts: List[TermRef]): List[TermRef] = alts filter (alt => sizeFits(alt, alt.widen)) - def isFunArg(arg: untpd.Tree) = arg match { - case untpd.Function(_, _) | Match(EmptyTree, _) => true - case _ => false - } - def narrowByShapes(alts: List[TermRef]): List[TermRef] = { - if (normArgs exists isFunArg) + if (normArgs exists untpd.isFunctionWithUnknownParamType) if (hasNamedArg(args)) narrowByTrees(alts, args map treeShape, resultType) else narrowByTypes(alts, normArgs map typeShape, resultType) else @@ -1358,33 +1353,31 @@ trait Applications extends Compatibility { self: Typer with Dynamic => case ValDef(_, tpt, _) => tpt.isEmpty case _ => false } - arg match { - case arg: untpd.Function if arg.args.exists(isUnknownParamType) => - def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head)) - val formalsForArg: List[Type] = altFormals.map(_.head) - // For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form - // (p_1_1, ..., p_m_1) => r_1 - // ... - // (p_1_n, ..., p_m_n) => r_n - val decomposedFormalsForArg: List[Option[(List[Type], Type, Boolean)]] = - formalsForArg.map(defn.FunctionOf.unapply) - if (decomposedFormalsForArg.forall(_.isDefined)) { - val formalParamTypessForArg: List[List[Type]] = - decomposedFormalsForArg.map(_.get._1) - if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) { - val commonParamTypes = formalParamTypessForArg.transpose.map(ps => - // Given definitions above, for i = 1,...,m, - // ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column - // If all p_i_k's are the same, assume the type as formal parameter - // type of the i'th parameter of the closure. - if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head - else WildcardType) - val commonFormal = defn.FunctionOf(commonParamTypes, WildcardType) - overload.println(i"pretype arg $arg with expected type $commonFormal") - pt.typedArg(arg, commonFormal) - } + if (untpd.isFunctionWithUnknownParamType(arg)) { + def isUniform[T](xs: List[T])(p: (T, T) => Boolean) = xs.forall(p(_, xs.head)) + val formalsForArg: List[Type] = altFormals.map(_.head) + // For alternatives alt_1, ..., alt_n, test whether formal types for current argument are of the form + // (p_1_1, ..., p_m_1) => r_1 + // ... + // (p_1_n, ..., p_m_n) => r_n + val decomposedFormalsForArg: List[Option[(List[Type], Type, Boolean)]] = + formalsForArg.map(defn.FunctionOf.unapply) + if (decomposedFormalsForArg.forall(_.isDefined)) { + val formalParamTypessForArg: List[List[Type]] = + decomposedFormalsForArg.map(_.get._1) + if (isUniform(formalParamTypessForArg)((x, y) => x.length == y.length)) { + val commonParamTypes = formalParamTypessForArg.transpose.map(ps => + // Given definitions above, for i = 1,...,m, + // ps(i) = List(p_i_1, ..., p_i_n) -- i.e. a column + // If all p_i_k's are the same, assume the type as formal parameter + // type of the i'th parameter of the closure. + if (isUniform(ps)(ctx.typeComparer.isSameTypeWhenFrozen(_, _))) ps.head + else WildcardType) + val commonFormal = defn.FunctionOf(commonParamTypes, WildcardType) + println(i"pretype arg $arg with expected type $commonFormal") + pt.typedArg(arg, commonFormal) } - case _ => + } } recur(altFormals.map(_.tail), args1) case _ => -- cgit v1.2.3