diff options
Diffstat (limited to 'src/main/scala/scala/async/AnfTransform.scala')
-rw-r--r-- | src/main/scala/scala/async/AnfTransform.scala | 150 |
1 files changed, 88 insertions, 62 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index e1d7cd5..24f37e7 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -1,41 +1,49 @@ + package scala.async import scala.reflect.macros.Context class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { + import c.universe._ - import AsyncUtils._ object inline { def transformToList(tree: Tree): List[Tree] = { val stats :+ expr = anf.transformToList(tree) expr match { - - case Apply(fun, args) if fun.toString.startsWith("scala.async.Async.await") => - val liftedName = c.fresh("await$") - stats :+ ValDef(NoMods, liftedName, TypeTree(), expr) :+ Ident(liftedName) + case Apply(fun, args) if isAwait(fun) => + val valDef = defineVal("await", expr) + stats :+ valDef :+ Ident(valDef.name) case If(cond, thenp, elsep) => // if type of if-else is Unit don't introduce assignment, // but add Unit value to bring it into form expected by async transform if (expr.tpe =:= definitions.UnitTpe) { stats :+ expr :+ Literal(Constant(())) + } else { + val varDef = defineVar("ifres", expr.tpe) + def branchWithAssign(orig: Tree) = orig match { + case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.name), thenExpr)) + case _ => Assign(Ident(varDef.name), orig) + } + val ifWithAssign = If(cond, branchWithAssign(thenp), branchWithAssign(elsep)) + stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name) + } + + case Match(scrut, cases) => + // if type of match is Unit don't introduce assignment, + // but add Unit value to bring it into form expected by async transform + if (expr.tpe =:= definitions.UnitTpe) { + stats :+ expr :+ Literal(Constant(())) } else { - val liftedName = c.fresh("ifres$") - val varDef = - ValDef(Modifiers(Flag.MUTABLE), liftedName, TypeTree(expr.tpe), defaultValue(expr.tpe)) - val thenWithAssign = thenp match { - case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(liftedName), thenExpr)) - case _ => Assign(Ident(liftedName), thenp) - } - val elseWithAssign = elsep match { - case Block(elseStats, elseExpr) => Block(elseStats, Assign(Ident(liftedName), elseExpr)) - case _ => Assign(Ident(liftedName), elsep) + val varDef = defineVar("matchres", expr.tpe) + val casesWithAssign = cases map { + case CaseDef(pat, guard, Block(caseStats, caseExpr)) => CaseDef(pat, guard, Block(caseStats, Assign(Ident(varDef.name), caseExpr))) + case CaseDef(pat, guard, body) => CaseDef(pat, guard, Assign(Ident(varDef.name), body)) } - val ifWithAssign = - If(cond, thenWithAssign, elseWithAssign) - stats :+ varDef :+ ifWithAssign :+ Ident(liftedName) + val matchWithAssign = Match(scrut, casesWithAssign) + stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name) } case _ => stats :+ expr @@ -44,58 +52,76 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { def transformToList(trees: List[Tree]): List[Tree] = trees match { case fst :: rest => transformToList(fst) ++ transformToList(rest) - case Nil => Nil + case Nil => Nil } - } - - object anf { - def transformToList(tree: Tree): List[Tree] = tree match { - case Select(qual, sel) => - val stats :+ expr = inline.transformToList(qual) - stats :+ Select(expr, sel) - case Apply(fun, args) => - val funStats :+ simpleFun = inline.transformToList(fun) - val argLists = args map inline.transformToList - val allArgStats = argLists flatMap (_.init) - val simpleArgs = argLists map (_.last) - funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs) - - case Block(stats, expr) => - inline.transformToList(stats) ++ inline.transformToList(expr) - - case ValDef(mods, name, tpt, rhs) => - val stats :+ expr = inline.transformToList(rhs) - stats :+ ValDef(mods, name, tpt, expr).setSymbol(tree.symbol) - - case Assign(name, rhs) => - val stats :+ expr = inline.transformToList(rhs) - stats :+ Assign(name, expr) - - case If(cond, thenp, elsep) => - val stats :+ expr = inline.transformToList(cond) - val thenStats :+ thenExpr = inline.transformToList(thenp) - val elseStats :+ elseExpr = inline.transformToList(elsep) - stats :+ - c.typeCheck(If(expr, Block(thenStats, thenExpr), Block(elseStats, elseExpr))) + def transformToBlock(tree: Tree): Block = transformToList(tree) match { + case stats :+ expr => Block(stats, expr) + } - //TODO - case Literal(_) | Ident(_) | This(_) | Match(_, _) | New(_) | Function(_, _) => List(tree) + def liftedName(prefix: String) = c.fresh(prefix + "$") - case TypeApply(fun, targs) => - val funStats :+ simpleFun = inline.transformToList(fun) - funStats :+ TypeApply(simpleFun, targs) + private def defineVar(prefix: String, tp: Type): ValDef = + ValDef(Modifiers(Flag.MUTABLE), liftedName(prefix), TypeTree(tp), defaultValue(tp)) - //TODO - case DefDef(mods, name, tparams, vparamss, tpt, rhs) => List(tree) + private def defineVal(prefix: String, lhs: Tree): ValDef = + ValDef(NoMods, liftedName(prefix), TypeTree(), lhs) + } - case ClassDef(mods, name, tparams, impl) => List(tree) + object anf { + def transformToList(tree: Tree): List[Tree] = { + def containsAwait = tree exists isAwait + tree match { + case Select(qual, sel) if containsAwait => + val stats :+ expr = inline.transformToList(qual) + stats :+ Select(expr, sel).setSymbol(tree.symbol) + + case Apply(fun, args) if containsAwait => + // we an assume that no await call appears in a by-name argument position, + // this has already been checked. + + val funStats :+ simpleFun = inline.transformToList(fun) + val argLists = args map inline.transformToList + val allArgStats = argLists flatMap (_.init) + val simpleArgs = argLists map (_.last) + funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs).setSymbol(tree.symbol) + + case Block(stats, expr) => // TODO figure out why adding a guard `if containsAwait` breaks LocalClasses0Spec. + inline.transformToList(stats :+ expr) + + case ValDef(mods, name, tpt, rhs) if containsAwait => + if (rhs exists isAwait) { + val stats :+ expr = inline.transformToList(rhs) + stats :+ ValDef(mods, name, tpt, expr).setSymbol(tree.symbol) + } else List(tree) + case Assign(lhs, rhs) if containsAwait => + val stats :+ expr = inline.transformToList(rhs) + stats :+ Assign(lhs, expr) + + case If(cond, thenp, elsep) if containsAwait => + val stats :+ expr = inline.transformToList(cond) + val thenBlock = inline.transformToBlock(thenp) + val elseBlock = inline.transformToBlock(elsep) + stats :+ + c.typeCheck(If(expr, thenBlock, elseBlock)) + + case Match(scrut, cases) if containsAwait => + val scrutStats :+ scrutExpr = inline.transformToList(scrut) + val caseDefs = cases map { + case CaseDef(pat, guard, body) => + val block = inline.transformToBlock(body) + CaseDef(pat, guard, block) + } + scrutStats :+ c.typeCheck(Match(scrutExpr, caseDefs)) - case ModuleDef(mods, name, impl) => List(tree) + case TypeApply(fun, targs) if containsAwait => + val funStats :+ simpleFun = inline.transformToList(fun) + funStats :+ TypeApply(simpleFun, targs).setSymbol(tree.symbol) - case _ => - c.error(tree.pos, "Internal error while compiling `async` block") - ??? + case _ => + List(tree) + } } } + } |