diff options
Diffstat (limited to 'src/main/scala/scala/async/internal/AnfTransform.scala')
-rw-r--r-- | src/main/scala/scala/async/internal/AnfTransform.scala | 81 |
1 files changed, 70 insertions, 11 deletions
diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala index f81f5af..4545ca6 100644 --- a/src/main/scala/scala/async/internal/AnfTransform.scala +++ b/src/main/scala/scala/async/internal/AnfTransform.scala @@ -16,16 +16,18 @@ private[async] trait AnfTransform { import c.internal._ import decorators._ - def anfTransform(tree: Tree): Block = { + def anfTransform(tree: Tree, owner: Symbol): Block = { // Must prepend the () for issue #31. - val block = c.typecheck(atPos(tree.pos)(Block(List(Literal(Constant(()))), tree))).setType(tree.tpe) + val block = c.typecheck(atPos(tree.pos)(newBlock(List(Literal(Constant(()))), tree))).setType(tree.tpe) sealed abstract class AnfMode case object Anf extends AnfMode case object Linearizing extends AnfMode + val tree1 = adjustTypeOfTranslatedPatternMatches(block, owner) + var mode: AnfMode = Anf - typingTransform(block)((tree, api) => { + typingTransform(tree1, owner)((tree, api) => { def blockToList(tree: Tree): List[Tree] = tree match { case Block(stats, expr) => stats :+ expr case t => t :: Nil @@ -34,7 +36,7 @@ private[async] trait AnfTransform { def listToBlock(trees: List[Tree]): Block = trees match { case trees @ (init :+ last) => val pos = trees.map(_.pos).reduceLeft(_ union _) - Block(init, last).setType(last.tpe).setPos(pos) + newBlock(init, last).setType(last.tpe).setPos(pos) } object linearize { @@ -66,6 +68,17 @@ private[async] trait AnfTransform { stats :+ valDef :+ atPos(tree.pos)(ref1) case If(cond, thenp, elsep) => + // If we run the ANF transform post patmat, deal with trees like `(if (cond) jump1(){String} else jump2(){String}){String}` + // as though it was typed with `Unit`. + def isPatMatGeneratedJump(t: Tree): Boolean = t match { + case Block(_, expr) => isPatMatGeneratedJump(expr) + case If(_, thenp, elsep) => isPatMatGeneratedJump(thenp) && isPatMatGeneratedJump(elsep) + case _: Apply if isLabel(t.symbol) => true + case _ => false + } + if (isPatMatGeneratedJump(expr)) { + internal.setType(expr, definitions.UnitTpe) + } // if type of if-else is Unit don't introduce assignment, // but add Unit value to bring it into form expected by async transform if (expr.tpe =:= definitions.UnitTpe) { @@ -77,7 +90,7 @@ private[async] trait AnfTransform { def branchWithAssign(orig: Tree) = api.typecheck(atPos(orig.pos) { def cast(t: Tree) = mkAttributedCastPreservingAnnotations(t, tpe(varDef.symbol)) orig match { - case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr))) + case Block(thenStats, thenExpr) => newBlock(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr))) case _ => Assign(Ident(varDef.symbol), cast(orig)) } }) @@ -115,7 +128,7 @@ private[async] trait AnfTransform { } } - private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = { + def defineVar(prefix: String, tp: Type, pos: Position): ValDef = { val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(uncheckedBounds(tp)) valDef(sym, mkZero(uncheckedBounds(tp))).setType(NoType).setPos(pos) } @@ -152,8 +165,7 @@ private[async] trait AnfTransform { } def _transformToList(tree: Tree): List[Tree] = trace(tree) { - val containsAwait = tree exists isAwait - if (!containsAwait) { + if (!containsAwait(tree)) { tree match { case Block(stats, expr) => // avoids nested block in `while(await(false)) ...`. @@ -207,10 +219,11 @@ private[async] trait AnfTransform { funStats ++ argStatss.flatten.flatten :+ typedNewApply case Block(stats, expr) => - (stats :+ expr).flatMap(linearize.transformToList) + val trees = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit) ::: linearize.transformToList(expr) + eliminateMatchEndLabelParameter(trees) case ValDef(mods, name, tpt, rhs) => - if (rhs exists isAwait) { + if (containsAwait(rhs)) { val stats :+ expr = api.atOwner(api.currentOwner.owner)(linearize.transformToList(rhs)) stats.foreach(_.changeOwner(api.currentOwner, api.currentOwner.owner)) stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr) @@ -247,7 +260,7 @@ private[async] trait AnfTransform { scrutStats :+ treeCopy.Match(tree, scrutExpr, caseDefs) case LabelDef(name, params, rhs) => - List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) + List(LabelDef(name, params, newBlock(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) case TypeApply(fun, targs) => val funStats :+ simpleFun = linearize.transformToList(fun) @@ -259,6 +272,52 @@ private[async] trait AnfTransform { } } + // Replace the label parameters on `matchEnd` with use of a `matchRes` temporary variable + // + // CaseDefs are translated to labels without parmeters. A terminal label, `matchEnd`, accepts + // a parameter which is the result of the match (this is regular, so even Unit-typed matches have this). + // + // For our purposes, it is easier to: + // - extract a `matchRes` variable + // - rewrite the terminal label def to take no parameters, and instead read this temp variable + // - change jumps to the terminal label to an assignment and a no-arg label application + def eliminateMatchEndLabelParameter(statsExpr: List[Tree]): List[Tree] = { + import internal.{methodType, setInfo} + val caseDefToMatchResult = collection.mutable.Map[Symbol, Symbol]() + + val matchResults = collection.mutable.Buffer[Tree]() + val statsExpr0 = statsExpr.reverseMap { + case ld @ LabelDef(_, param :: Nil, body) => + val matchResult = linearize.defineVar(name.matchRes, param.tpe, ld.pos) + matchResults += matchResult + caseDefToMatchResult(ld.symbol) = matchResult.symbol + val ld2 = treeCopy.LabelDef(ld, ld.name, Nil, body.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil)) + setInfo(ld.symbol, methodType(Nil, ld.symbol.info.resultType)) + ld2 + case t => + if (caseDefToMatchResult.isEmpty) t + else typingTransform(t)((tree, api) => + tree match { + case Apply(fun, arg :: Nil) if isLabel(fun.symbol) && caseDefToMatchResult.contains(fun.symbol) => + api.typecheck(atPos(tree.pos)(newBlock(Assign(Ident(caseDefToMatchResult(fun.symbol)), api.recur(arg)) :: Nil, treeCopy.Apply(tree, fun, Nil)))) + case Block(stats, expr) => + api.default(tree) match { + case Block(stats, Block(stats1, expr)) => + treeCopy.Block(tree, stats ::: stats1, expr) + case t => t + } + case _ => + api.default(tree) + } + ) + } + matchResults.toList match { + case Nil => statsExpr + case r1 :: Nil => (r1 +: statsExpr0.reverse) :+ atPos(tree.pos)(gen.mkAttributedIdent(r1.symbol)) + case _ => c.error(macroPos, "Internal error: unexpected tree encountered during ANF transform " + statsExpr); statsExpr + } + } + def anfLinearize(tree: Tree): Block = { val trees: List[Tree] = mode match { case Anf => anf._transformToList(tree) |