aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/internal/AnfTransform.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/scala/async/internal/AnfTransform.scala')
-rw-r--r--src/main/scala/scala/async/internal/AnfTransform.scala81
1 files changed, 70 insertions, 11 deletions
diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala
index f81f5af..4545ca6 100644
--- a/src/main/scala/scala/async/internal/AnfTransform.scala
+++ b/src/main/scala/scala/async/internal/AnfTransform.scala
@@ -16,16 +16,18 @@ private[async] trait AnfTransform {
import c.internal._
import decorators._
- def anfTransform(tree: Tree): Block = {
+ def anfTransform(tree: Tree, owner: Symbol): Block = {
// Must prepend the () for issue #31.
- val block = c.typecheck(atPos(tree.pos)(Block(List(Literal(Constant(()))), tree))).setType(tree.tpe)
+ val block = c.typecheck(atPos(tree.pos)(newBlock(List(Literal(Constant(()))), tree))).setType(tree.tpe)
sealed abstract class AnfMode
case object Anf extends AnfMode
case object Linearizing extends AnfMode
+ val tree1 = adjustTypeOfTranslatedPatternMatches(block, owner)
+
var mode: AnfMode = Anf
- typingTransform(block)((tree, api) => {
+ typingTransform(tree1, owner)((tree, api) => {
def blockToList(tree: Tree): List[Tree] = tree match {
case Block(stats, expr) => stats :+ expr
case t => t :: Nil
@@ -34,7 +36,7 @@ private[async] trait AnfTransform {
def listToBlock(trees: List[Tree]): Block = trees match {
case trees @ (init :+ last) =>
val pos = trees.map(_.pos).reduceLeft(_ union _)
- Block(init, last).setType(last.tpe).setPos(pos)
+ newBlock(init, last).setType(last.tpe).setPos(pos)
}
object linearize {
@@ -66,6 +68,17 @@ private[async] trait AnfTransform {
stats :+ valDef :+ atPos(tree.pos)(ref1)
case If(cond, thenp, elsep) =>
+ // If we run the ANF transform post patmat, deal with trees like `(if (cond) jump1(){String} else jump2(){String}){String}`
+ // as though it was typed with `Unit`.
+ def isPatMatGeneratedJump(t: Tree): Boolean = t match {
+ case Block(_, expr) => isPatMatGeneratedJump(expr)
+ case If(_, thenp, elsep) => isPatMatGeneratedJump(thenp) && isPatMatGeneratedJump(elsep)
+ case _: Apply if isLabel(t.symbol) => true
+ case _ => false
+ }
+ if (isPatMatGeneratedJump(expr)) {
+ internal.setType(expr, definitions.UnitTpe)
+ }
// if type of if-else is Unit don't introduce assignment,
// but add Unit value to bring it into form expected by async transform
if (expr.tpe =:= definitions.UnitTpe) {
@@ -77,7 +90,7 @@ private[async] trait AnfTransform {
def branchWithAssign(orig: Tree) = api.typecheck(atPos(orig.pos) {
def cast(t: Tree) = mkAttributedCastPreservingAnnotations(t, tpe(varDef.symbol))
orig match {
- case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr)))
+ case Block(thenStats, thenExpr) => newBlock(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr)))
case _ => Assign(Ident(varDef.symbol), cast(orig))
}
})
@@ -115,7 +128,7 @@ private[async] trait AnfTransform {
}
}
- private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
+ def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(uncheckedBounds(tp))
valDef(sym, mkZero(uncheckedBounds(tp))).setType(NoType).setPos(pos)
}
@@ -152,8 +165,7 @@ private[async] trait AnfTransform {
}
def _transformToList(tree: Tree): List[Tree] = trace(tree) {
- val containsAwait = tree exists isAwait
- if (!containsAwait) {
+ if (!containsAwait(tree)) {
tree match {
case Block(stats, expr) =>
// avoids nested block in `while(await(false)) ...`.
@@ -207,10 +219,11 @@ private[async] trait AnfTransform {
funStats ++ argStatss.flatten.flatten :+ typedNewApply
case Block(stats, expr) =>
- (stats :+ expr).flatMap(linearize.transformToList)
+ val trees = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit) ::: linearize.transformToList(expr)
+ eliminateMatchEndLabelParameter(trees)
case ValDef(mods, name, tpt, rhs) =>
- if (rhs exists isAwait) {
+ if (containsAwait(rhs)) {
val stats :+ expr = api.atOwner(api.currentOwner.owner)(linearize.transformToList(rhs))
stats.foreach(_.changeOwner(api.currentOwner, api.currentOwner.owner))
stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr)
@@ -247,7 +260,7 @@ private[async] trait AnfTransform {
scrutStats :+ treeCopy.Match(tree, scrutExpr, caseDefs)
case LabelDef(name, params, rhs) =>
- List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
+ List(LabelDef(name, params, newBlock(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
case TypeApply(fun, targs) =>
val funStats :+ simpleFun = linearize.transformToList(fun)
@@ -259,6 +272,52 @@ private[async] trait AnfTransform {
}
}
+ // Replace the label parameters on `matchEnd` with use of a `matchRes` temporary variable
+ //
+ // CaseDefs are translated to labels without parmeters. A terminal label, `matchEnd`, accepts
+ // a parameter which is the result of the match (this is regular, so even Unit-typed matches have this).
+ //
+ // For our purposes, it is easier to:
+ // - extract a `matchRes` variable
+ // - rewrite the terminal label def to take no parameters, and instead read this temp variable
+ // - change jumps to the terminal label to an assignment and a no-arg label application
+ def eliminateMatchEndLabelParameter(statsExpr: List[Tree]): List[Tree] = {
+ import internal.{methodType, setInfo}
+ val caseDefToMatchResult = collection.mutable.Map[Symbol, Symbol]()
+
+ val matchResults = collection.mutable.Buffer[Tree]()
+ val statsExpr0 = statsExpr.reverseMap {
+ case ld @ LabelDef(_, param :: Nil, body) =>
+ val matchResult = linearize.defineVar(name.matchRes, param.tpe, ld.pos)
+ matchResults += matchResult
+ caseDefToMatchResult(ld.symbol) = matchResult.symbol
+ val ld2 = treeCopy.LabelDef(ld, ld.name, Nil, body.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil))
+ setInfo(ld.symbol, methodType(Nil, ld.symbol.info.resultType))
+ ld2
+ case t =>
+ if (caseDefToMatchResult.isEmpty) t
+ else typingTransform(t)((tree, api) =>
+ tree match {
+ case Apply(fun, arg :: Nil) if isLabel(fun.symbol) && caseDefToMatchResult.contains(fun.symbol) =>
+ api.typecheck(atPos(tree.pos)(newBlock(Assign(Ident(caseDefToMatchResult(fun.symbol)), api.recur(arg)) :: Nil, treeCopy.Apply(tree, fun, Nil))))
+ case Block(stats, expr) =>
+ api.default(tree) match {
+ case Block(stats, Block(stats1, expr)) =>
+ treeCopy.Block(tree, stats ::: stats1, expr)
+ case t => t
+ }
+ case _ =>
+ api.default(tree)
+ }
+ )
+ }
+ matchResults.toList match {
+ case Nil => statsExpr
+ case r1 :: Nil => (r1 +: statsExpr0.reverse) :+ atPos(tree.pos)(gen.mkAttributedIdent(r1.symbol))
+ case _ => c.error(macroPos, "Internal error: unexpected tree encountered during ANF transform " + statsExpr); statsExpr
+ }
+ }
+
def anfLinearize(tree: Tree): Block = {
val trees: List[Tree] = mode match {
case Anf => anf._transformToList(tree)