aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/dotty/tools/dotc/typer/Applications.scala47
-rw-r--r--src/dotty/tools/dotc/typer/Typer.scala10
-rw-r--r--tests/pos/harmonize.scala28
3 files changed, 70 insertions, 15 deletions
diff --git a/src/dotty/tools/dotc/typer/Applications.scala b/src/dotty/tools/dotc/typer/Applications.scala
index e1d3d243d..dadd2afdc 100644
--- a/src/dotty/tools/dotc/typer/Applications.scala
+++ b/src/dotty/tools/dotc/typer/Applications.scala
@@ -127,6 +127,9 @@ trait Applications extends Compatibility { self: Typer =>
*/
protected def makeVarArg(n: Int, elemFormal: Type): Unit
+ /** If all `args` have primitive numeric types, make sure it's the same one */
+ protected def harmonizeArgs(args: List[TypedArg]): List[TypedArg]
+
/** Signal failure with given message at position of given argument */
protected def fail(msg: => String, arg: Arg): Unit
@@ -334,7 +337,14 @@ trait Applications extends Compatibility { self: Typer =>
addTyped(arg, formal)
case _ =>
val elemFormal = formal.widenExpr.argTypesLo.head
- args foreach (addTyped(_, elemFormal))
+ val origConstraint = ctx.typerState.constraint
+ var typedArgs = args.map(typedArg(_, elemFormal))
+ val harmonizedArgs = harmonizeArgs(typedArgs)
+ if (harmonizedArgs ne typedArgs) {
+ ctx.typerState.constraint = origConstraint
+ typedArgs = harmonizedArgs
+ }
+ typedArgs.foreach(addArg(_, elemFormal))
makeVarArg(args.length, elemFormal)
}
else args match {
@@ -389,6 +399,7 @@ trait Applications extends Compatibility { self: Typer =>
def argType(arg: Tree, formal: Type): Type = normalize(arg.tpe, formal)
def treeToArg(arg: Tree): Tree = arg
def isVarArg(arg: Tree): Boolean = tpd.isWildcardStarArg(arg)
+ def harmonizeArgs(args: List[Tree]) = harmonize(args)
}
/** Subclass of Application for applicability tests with type arguments and value
@@ -405,6 +416,7 @@ trait Applications extends Compatibility { self: Typer =>
def argType(arg: Type, formal: Type): Type = arg
def treeToArg(arg: Tree): Type = arg.tpe
def isVarArg(arg: Type): Boolean = arg.isRepeatedParam
+ def harmonizeArgs(args: List[Type]) = harmonizeTypes(args)
}
/** Subclass of Application for type checking an Apply node, where
@@ -430,6 +442,8 @@ trait Applications extends Compatibility { self: Typer =>
typedArgBuf += seqToRepeated(seqLit)
}
+ def harmonizeArgs(args: List[TypedArg]) = harmonize(args)
+
override def appPos = app.pos
def fail(msg: => String, arg: Trees.Tree[T]) = {
@@ -1025,25 +1039,34 @@ trait Applications extends Compatibility { self: Typer =>
}
}
- def harmonize(trees: List[Tree])(implicit ctx: Context): List[Tree] = {
- def numericClasses(trees: List[Tree], acc: Set[Symbol]): Set[Symbol] = trees match {
- case tree :: trees1 =>
- val sym = tree.tpe.typeSymbol
- if (sym.isNumericValueClass && tree.tpe.isRef(sym))
- numericClasses(trees1, acc + sym)
- else
- Set()
+ private def harmonizeWith[T <: AnyRef](ts: List[T])(tpe: T => Type, adapt: (T, Type) => T)(implicit ctx: Context): List[T] = {
+ def numericClasses(ts: List[T], acc: Set[Symbol]): Set[Symbol] = ts match {
+ case t :: ts1 =>
+ val sym = tpe(t).widen.classSymbol
+ if (sym.isNumericValueClass) numericClasses(ts1, acc + sym)
+ else Set()
case Nil =>
acc
}
- val clss = numericClasses(trees, Set())
+ val clss = numericClasses(ts, Set())
if (clss.size > 1) {
val lub = defn.ScalaNumericValueClassList.find(lubCls =>
clss.forall(defn.isValueSubClass(_, lubCls))).get.typeRef
- trees.mapConserve(tree => adaptInterpolated(tree, lub, tree))
+ ts.mapConserve(adapt(_, lub))
}
- else trees
+ else ts
}
+
+ def harmonize(trees: List[Tree])(implicit ctx: Context): List[Tree] = {
+ def adapt(tree: Tree, pt: Type): Tree = tree match {
+ case cdef: CaseDef => tpd.cpy.CaseDef(cdef)(body = adapt(cdef.body, pt))
+ case _ => adaptInterpolated(tree, pt, tree)
+ }
+ harmonizeWith(trees)(_.tpe, adapt)
+ }
+
+ def harmonizeTypes(tpes: List[Type])(implicit ctx: Context): List[Type] =
+ harmonizeWith(tpes)(identity, (tp, pt) => pt)
}
/*
diff --git a/src/dotty/tools/dotc/typer/Typer.scala b/src/dotty/tools/dotc/typer/Typer.scala
index 9fecc9742..ea19dd1c9 100644
--- a/src/dotty/tools/dotc/typer/Typer.scala
+++ b/src/dotty/tools/dotc/typer/Typer.scala
@@ -493,7 +493,8 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
val cond1 = typed(tree.cond, defn.BooleanType)
val thenp1 = typed(tree.thenp, pt)
val elsep1 = typed(tree.elsep orElse untpd.unitLiteral withPos tree.pos, pt)
- assignType(cpy.If(tree)(cond1, thenp1, elsep1), thenp1, elsep1)
+ val thenp2 :: elsep2 :: Nil = harmonize(thenp1 :: elsep1 :: Nil)
+ assignType(cpy.If(tree)(cond1, thenp2, elsep2), thenp2, elsep2)
}
def typedFunction(tree: untpd.Function, pt: Type)(implicit ctx: Context) = track("typedFunction") {
@@ -629,7 +630,8 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
fullyDefinedType(sel1.tpe, "pattern selector", tree.pos))
val cases1 = typedCases(tree.cases, selType, pt)
- assignType(cpy.Match(tree)(sel1, cases1), cases1)
+ val cases2 = harmonize(cases1).asInstanceOf[List[CaseDef]]
+ assignType(cpy.Match(tree)(sel1, cases2), cases2)
}
}
@@ -737,7 +739,9 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
val expr1 = typed(tree.expr, pt)
val cases1 = typedCases(tree.cases, defn.ThrowableType, pt)
val finalizer1 = typed(tree.finalizer, defn.UnitType)
- assignType(cpy.Try(tree)(expr1, cases1, finalizer1), expr1, cases1)
+ val expr2 :: cases2x = harmonize(expr1 :: cases1)
+ val cases2 = cases2x.asInstanceOf[List[CaseDef]]
+ assignType(cpy.Try(tree)(expr2, cases2, finalizer1), expr2, cases2)
}
def typedThrow(tree: untpd.Throw)(implicit ctx: Context): Tree = track("typedThrow") {
diff --git a/tests/pos/harmonize.scala b/tests/pos/harmonize.scala
new file mode 100644
index 000000000..267db6134
--- /dev/null
+++ b/tests/pos/harmonize.scala
@@ -0,0 +1,28 @@
+object Test {
+
+ def main(args: Array[String]) = {
+ val x = true
+ val n = 1
+/* val y = if (x) 'A' else n
+ val z: Int = y
+
+ val yy = n match {
+ case 1 => 'A'
+ case 2 => n
+ case 3 => 1.0
+ }
+ val zz: Double = yy
+
+ val a = try {
+ 'A'
+ } catch {
+ case ex: Exception => n
+ case ex: Error => 3L
+ }
+ val b: Long = a
+*/
+ val xs = List(1.0, n, 'c')
+ val ys: List[Double] = xs
+ }
+
+}