From f2cf8e76fd184c95f2ad2f81659b60ee2ce75ec3 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Fri, 29 Sep 2017 10:34:38 +1000 Subject: Fix ANF transform for corner case in late transforms Unfortunately I wasn't able to extract a test case, but the patch has been tested to fix a problem on a real world code base. --- .../scala/scala/async/internal/AnfTransform.scala | 36 ++++++++++++---------- .../scala/async/internal/TransformUtils.scala | 21 +++++++++++++ src/test/scala/scala/async/run/WarningsSpec.scala | 2 +- 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala index 32c993a..93297f7 100644 --- a/src/main/scala/scala/async/internal/AnfTransform.scala +++ b/src/main/scala/scala/async/internal/AnfTransform.scala @@ -110,13 +110,19 @@ private[async] trait AnfTransform { statsExprThrow } else { val varDef = defineVar(name.ifRes, expr.tpe, tree.pos) - def branchWithAssign(orig: Tree) = api.typecheck(atPos(orig.pos) { - def cast(t: Tree) = mkAttributedCastPreservingAnnotations(t, tpe(varDef.symbol)) - orig match { - case Block(thenStats, thenExpr) => newBlock(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr))) - case _ => Assign(Ident(varDef.symbol), cast(orig)) + def typedAssign(lhs: Tree) = + api.typecheck(atPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, tpe(varDef.symbol))))) + + def branchWithAssign(t: Tree): Tree = { + t match { + case MatchEnd(ld) => + deriveLabelDef(ld, branchWithAssign) + case blk @ Block(thenStats, thenExpr) => + treeCopy.Block(blk, thenStats, typedAssign(thenExpr)).setType(definitions.UnitTpe) + case _ => + typedAssign(t) } - }) + } val ifWithAssign = treeCopy.If(tree, cond, branchWithAssign(thenp), branchWithAssign(elsep)).setType(definitions.UnitTpe) stats :+ varDef :+ ifWithAssign :+ atPos(tree.pos)(gen.mkAttributedStableRef(varDef.symbol)).setType(tree.tpe) } @@ -139,11 +145,14 @@ private[async] trait AnfTransform { api.typecheck(atPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, tpe(varDef.symbol))))) val casesWithAssign = cases map { case cd@CaseDef(pat, guard, body) => - val newBody = body match { - case b@Block(caseStats, caseExpr) => treeCopy.Block(b, caseStats, typedAssign(caseExpr)).setType(definitions.UnitTpe) - case _ => typedAssign(body) + def bodyWithAssign(t: Tree): Tree = { + t match { + case MatchEnd(ld) => deriveLabelDef(ld, bodyWithAssign) + case b@Block(caseStats, caseExpr) => treeCopy.Block(b, caseStats, bodyWithAssign(caseExpr)).setType(definitions.UnitTpe) + case _ => typedAssign(t) + } } - treeCopy.CaseDef(cd, pat, guard, newBody).setType(definitions.UnitTpe) + treeCopy.CaseDef(cd, pat, guard, bodyWithAssign(body)).setType(definitions.UnitTpe) } val matchWithAssign = treeCopy.Match(tree, scrut, casesWithAssign).setType(definitions.UnitTpe) require(matchWithAssign.tpe != null, matchWithAssign) @@ -228,11 +237,6 @@ private[async] trait AnfTransform { val stats1 = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit) val exprs1 = linearize.transformToList(expr) val trees = stats1 ::: exprs1 - def isMatchEndLabel(t: Tree): Boolean = t match { - case ValDef(_, _, _, t) if isMatchEndLabel(t) => true - case ld: LabelDef if ld.name.toString.startsWith("matchEnd") => true - case _ => false - } def groupsEndingWith[T](ts: List[T])(f: T => Boolean): List[List[T]] = if (ts.isEmpty) Nil else { ts.indexWhere(f) match { case -1 => List(ts) @@ -241,7 +245,7 @@ private[async] trait AnfTransform { ts1 :: groupsEndingWith(ts2)(f) } } - val matchGroups = groupsEndingWith(trees)(isMatchEndLabel) + val matchGroups = groupsEndingWith(trees){ case MatchEnd(_) => true; case _ => false } val trees1 = matchGroups.flatMap(eliminateMatchEndLabelParameter) val result = trees1 flatMap { case Block(stats, expr) => stats :+ expr diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala index 1720815..848861c 100644 --- a/src/main/scala/scala/async/internal/TransformUtils.scala +++ b/src/main/scala/scala/async/internal/TransformUtils.scala @@ -544,6 +544,27 @@ private[async] trait TransformUtils { result :: Nil } } + + def deriveLabelDef(ld: LabelDef, applyToRhs: Tree => Tree): LabelDef = { + val rhs2 = applyToRhs(ld.rhs) + val ld2 = treeCopy.LabelDef(ld, ld.name, ld.params, rhs2) + if (ld eq ld2) ld + else { + val info2 = ld2.symbol.info match { + case MethodType(params, p) => internal.methodType(params, rhs2.tpe) + case t => t + } + internal.setInfo(ld2.symbol, info2) + ld2 + } + } + object MatchEnd { + def unapply(t: Tree): Option[LabelDef] = t match { + case ValDef(_, _, _, t) => unapply(t) + case ld: LabelDef if ld.name.toString.startsWith("matchEnd") => Some(ld) + case _ => None + } + } } case object ContainsAwait diff --git a/src/test/scala/scala/async/run/WarningsSpec.scala b/src/test/scala/scala/async/run/WarningsSpec.scala index c80bf9e..c76168e 100644 --- a/src/test/scala/scala/async/run/WarningsSpec.scala +++ b/src/test/scala/scala/async/run/WarningsSpec.scala @@ -15,7 +15,7 @@ class WarningsSpec { @Test // https://github.com/scala/async/issues/74 - def noPureExpressionInStatementPositionWarning_t74() { + def noPureExpressionInStatementPositionWarning_t74(): Unit = { val tb = mkToolbox(s"-cp ${toolboxClasspath} -Xfatal-warnings") // was: "a pure expression does nothing in statement position; you may be omitting necessary parentheses" tb.eval(tb.parse { -- cgit v1.2.3