aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/AnfTransform.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/scala/async/AnfTransform.scala')
-rw-r--r--src/main/scala/scala/async/AnfTransform.scala150
1 files changed, 88 insertions, 62 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala
index e1d7cd5..24f37e7 100644
--- a/src/main/scala/scala/async/AnfTransform.scala
+++ b/src/main/scala/scala/async/AnfTransform.scala
@@ -1,41 +1,49 @@
+
package scala.async
import scala.reflect.macros.Context
class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) {
+
import c.universe._
- import AsyncUtils._
object inline {
def transformToList(tree: Tree): List[Tree] = {
val stats :+ expr = anf.transformToList(tree)
expr match {
-
- case Apply(fun, args) if fun.toString.startsWith("scala.async.Async.await") =>
- val liftedName = c.fresh("await$")
- stats :+ ValDef(NoMods, liftedName, TypeTree(), expr) :+ Ident(liftedName)
+ case Apply(fun, args) if isAwait(fun) =>
+ val valDef = defineVal("await", expr)
+ stats :+ valDef :+ Ident(valDef.name)
case If(cond, thenp, elsep) =>
// if type of if-else is Unit don't introduce assignment,
// but add Unit value to bring it into form expected by async transform
if (expr.tpe =:= definitions.UnitTpe) {
stats :+ expr :+ Literal(Constant(()))
+ } else {
+ val varDef = defineVar("ifres", expr.tpe)
+ def branchWithAssign(orig: Tree) = orig match {
+ case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.name), thenExpr))
+ case _ => Assign(Ident(varDef.name), orig)
+ }
+ val ifWithAssign = If(cond, branchWithAssign(thenp), branchWithAssign(elsep))
+ stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name)
+ }
+
+ case Match(scrut, cases) =>
+ // if type of match is Unit don't introduce assignment,
+ // but add Unit value to bring it into form expected by async transform
+ if (expr.tpe =:= definitions.UnitTpe) {
+ stats :+ expr :+ Literal(Constant(()))
}
else {
- val liftedName = c.fresh("ifres$")
- val varDef =
- ValDef(Modifiers(Flag.MUTABLE), liftedName, TypeTree(expr.tpe), defaultValue(expr.tpe))
- val thenWithAssign = thenp match {
- case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(liftedName), thenExpr))
- case _ => Assign(Ident(liftedName), thenp)
- }
- val elseWithAssign = elsep match {
- case Block(elseStats, elseExpr) => Block(elseStats, Assign(Ident(liftedName), elseExpr))
- case _ => Assign(Ident(liftedName), elsep)
+ val varDef = defineVar("matchres", expr.tpe)
+ val casesWithAssign = cases map {
+ case CaseDef(pat, guard, Block(caseStats, caseExpr)) => CaseDef(pat, guard, Block(caseStats, Assign(Ident(varDef.name), caseExpr)))
+ case CaseDef(pat, guard, body) => CaseDef(pat, guard, Assign(Ident(varDef.name), body))
}
- val ifWithAssign =
- If(cond, thenWithAssign, elseWithAssign)
- stats :+ varDef :+ ifWithAssign :+ Ident(liftedName)
+ val matchWithAssign = Match(scrut, casesWithAssign)
+ stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name)
}
case _ =>
stats :+ expr
@@ -44,58 +52,76 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) {
def transformToList(trees: List[Tree]): List[Tree] = trees match {
case fst :: rest => transformToList(fst) ++ transformToList(rest)
- case Nil => Nil
+ case Nil => Nil
}
- }
-
- object anf {
- def transformToList(tree: Tree): List[Tree] = tree match {
- case Select(qual, sel) =>
- val stats :+ expr = inline.transformToList(qual)
- stats :+ Select(expr, sel)
- case Apply(fun, args) =>
- val funStats :+ simpleFun = inline.transformToList(fun)
- val argLists = args map inline.transformToList
- val allArgStats = argLists flatMap (_.init)
- val simpleArgs = argLists map (_.last)
- funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs)
-
- case Block(stats, expr) =>
- inline.transformToList(stats) ++ inline.transformToList(expr)
-
- case ValDef(mods, name, tpt, rhs) =>
- val stats :+ expr = inline.transformToList(rhs)
- stats :+ ValDef(mods, name, tpt, expr).setSymbol(tree.symbol)
-
- case Assign(name, rhs) =>
- val stats :+ expr = inline.transformToList(rhs)
- stats :+ Assign(name, expr)
-
- case If(cond, thenp, elsep) =>
- val stats :+ expr = inline.transformToList(cond)
- val thenStats :+ thenExpr = inline.transformToList(thenp)
- val elseStats :+ elseExpr = inline.transformToList(elsep)
- stats :+
- c.typeCheck(If(expr, Block(thenStats, thenExpr), Block(elseStats, elseExpr)))
+ def transformToBlock(tree: Tree): Block = transformToList(tree) match {
+ case stats :+ expr => Block(stats, expr)
+ }
- //TODO
- case Literal(_) | Ident(_) | This(_) | Match(_, _) | New(_) | Function(_, _) => List(tree)
+ def liftedName(prefix: String) = c.fresh(prefix + "$")
- case TypeApply(fun, targs) =>
- val funStats :+ simpleFun = inline.transformToList(fun)
- funStats :+ TypeApply(simpleFun, targs)
+ private def defineVar(prefix: String, tp: Type): ValDef =
+ ValDef(Modifiers(Flag.MUTABLE), liftedName(prefix), TypeTree(tp), defaultValue(tp))
- //TODO
- case DefDef(mods, name, tparams, vparamss, tpt, rhs) => List(tree)
+ private def defineVal(prefix: String, lhs: Tree): ValDef =
+ ValDef(NoMods, liftedName(prefix), TypeTree(), lhs)
+ }
- case ClassDef(mods, name, tparams, impl) => List(tree)
+ object anf {
+ def transformToList(tree: Tree): List[Tree] = {
+ def containsAwait = tree exists isAwait
+ tree match {
+ case Select(qual, sel) if containsAwait =>
+ val stats :+ expr = inline.transformToList(qual)
+ stats :+ Select(expr, sel).setSymbol(tree.symbol)
+
+ case Apply(fun, args) if containsAwait =>
+ // we an assume that no await call appears in a by-name argument position,
+ // this has already been checked.
+
+ val funStats :+ simpleFun = inline.transformToList(fun)
+ val argLists = args map inline.transformToList
+ val allArgStats = argLists flatMap (_.init)
+ val simpleArgs = argLists map (_.last)
+ funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs).setSymbol(tree.symbol)
+
+ case Block(stats, expr) => // TODO figure out why adding a guard `if containsAwait` breaks LocalClasses0Spec.
+ inline.transformToList(stats :+ expr)
+
+ case ValDef(mods, name, tpt, rhs) if containsAwait =>
+ if (rhs exists isAwait) {
+ val stats :+ expr = inline.transformToList(rhs)
+ stats :+ ValDef(mods, name, tpt, expr).setSymbol(tree.symbol)
+ } else List(tree)
+ case Assign(lhs, rhs) if containsAwait =>
+ val stats :+ expr = inline.transformToList(rhs)
+ stats :+ Assign(lhs, expr)
+
+ case If(cond, thenp, elsep) if containsAwait =>
+ val stats :+ expr = inline.transformToList(cond)
+ val thenBlock = inline.transformToBlock(thenp)
+ val elseBlock = inline.transformToBlock(elsep)
+ stats :+
+ c.typeCheck(If(expr, thenBlock, elseBlock))
+
+ case Match(scrut, cases) if containsAwait =>
+ val scrutStats :+ scrutExpr = inline.transformToList(scrut)
+ val caseDefs = cases map {
+ case CaseDef(pat, guard, body) =>
+ val block = inline.transformToBlock(body)
+ CaseDef(pat, guard, block)
+ }
+ scrutStats :+ c.typeCheck(Match(scrutExpr, caseDefs))
- case ModuleDef(mods, name, impl) => List(tree)
+ case TypeApply(fun, targs) if containsAwait =>
+ val funStats :+ simpleFun = inline.transformToList(fun)
+ funStats :+ TypeApply(simpleFun, targs).setSymbol(tree.symbol)
- case _ =>
- c.error(tree.pos, "Internal error while compiling `async` block")
- ???
+ case _ =>
+ List(tree)
+ }
}
}
+
}