From 4855a5ff60a0b4992da141054c074438b456c3fc Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Fri, 23 Nov 2012 23:03:59 +0100 Subject: Support await in a while loop. --- src/main/scala/scala/async/AnfTransform.scala | 3 ++ src/main/scala/scala/async/AsyncAnalysis.scala | 2 + src/main/scala/scala/async/ExprBuilder.scala | 56 ++++++++++++++++------ src/main/scala/scala/async/TransformUtils.scala | 6 +++ src/test/scala/scala/async/TreeInterrogation.scala | 29 ++++++----- src/test/scala/scala/async/neg/NakedAwait.scala | 16 ------- .../scala/scala/async/run/ifelse0/WhileSpec.scala | 43 +++++++++++++++++ 7 files changed, 112 insertions(+), 43 deletions(-) create mode 100644 src/test/scala/scala/async/run/ifelse0/WhileSpec.scala diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 4c78b5a..0756baf 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -204,6 +204,9 @@ private[async] final case class AnfTransform[C <: Context](override val c: C) ex } scrutStats :+ c.typeCheck(attachCopy.Match(tree)(scrutExpr, caseDefs)) + case LabelDef(name, params, rhs) if containsAwait => + List(LabelDef(name, params, Block(inline.transformToList(rhs), Literal(Constant(()))))) + case TypeApply(fun, targs) if containsAwait => val funStats :+ simpleFun = inline.transformToList(fun) funStats :+ attachCopy.TypeApply(tree)(simpleFun, targs).setSymbol(tree.symbol) diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index 88a1bb0..9e24130 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -93,6 +93,8 @@ private[async] final case class AsyncAnalysis[C <: Context](override val c: C) e traverseChunks(List(cond, thenp, elsep)) case Match(selector, cases) if tree exists isAwait => traverseChunks(selector :: cases) + case LabelDef(name, params, rhs) if rhs exists isAwait => + traverseChunks(rhs :: Nil) case Apply(fun, args) if isAwait(fun) => super.traverse(tree) nextChunk() diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 8ea7ecf..60430c4 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -54,8 +54,21 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C private def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef = mkHandlerCase(num, Block(rhs: _*)) - private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = - CaseDef(c.literal(num).tree, EmptyTree, rhs) + private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = { + val rhs1 = new Transformer { + override def transform(tree: Tree): Tree = tree match { + case Apply(Ident(name), args) => + val jumpTarget = labelDefStates get name // TODO attempt to be symful + jumpTarget match { + case Some(state) => Return(Block(mkStateTree(state), mkResumeApply)) + case None => super.transform(tree) + } + case _ => super.transform(tree) + } + }.transform(rhs) + + CaseDef(c.literal(num).tree, EmptyTree, rhs1) + } class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) { val body: c.Tree = stats match { @@ -209,7 +222,7 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = { // 1. build changed if-else tree // 2. insert that tree at the end of the current state - val cond = resetDuplicate(condTree) + val cond = resetDuplicate(renamer.transform(condTree)) this += If(cond, Block(mkStateTree(thenState), mkResumeApply), Block(mkStateTree(elseState), mkResumeApply)) @@ -240,6 +253,13 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C } } + def resultWithLabel(startLabelState: Int): AsyncState = { + this += Block(mkStateTree(startLabelState), mkResumeApply) + new AsyncStateWithoutAwait(stats.toList, state) { + override val varDefs = self.varDefs.toList + } + } + override def toString: String = { val statsBeforeAwait = stats.mkString("\n") s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName" @@ -248,6 +268,8 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C val stateAssigner = new StateAssigner + val labelDefStates = collection.mutable.Map[Name, Int]() + /** * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`). * @@ -262,7 +284,6 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C val asyncStates = ListBuffer[builder.AsyncState]() private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename) - // current state builder private var currState = startState /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ @@ -272,10 +293,7 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C }) c.abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException def builderForBranch(tree: c.Tree, state: Int, nextState: Int): AsyncBlockBuilder = { - val (branchStats, branchExpr) = tree match { - case Block(s, e) => (s, e) - case _ => (List(tree), c.literalUnit.tree) - } + val (branchStats, branchExpr) = statsAndExpr(tree) new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, toRename) } @@ -324,18 +342,26 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C stateBuilder.resultWithMatch(scrutinee, cases, caseStates) for ((cas, num) <- cases.zipWithIndex) { - val (casStats, casExpr) = cas match { - case CaseDef(_, _, Block(s, e)) => (s, e) - case CaseDef(_, _, rhs) => (List(rhs), c.literalUnit.tree) - } + val (casStats, casExpr) = statsAndExpr(cas.body) val builder = new AsyncBlockBuilder(casStats, casExpr, caseStates(num), afterMatchState, toRename) asyncStates ++= builder.asyncStates } currState = afterMatchState - stateBuilder = new builder.AsyncStateBuilder(currState, toRename) - - case _ => + stateBuilder = new AsyncStateBuilder(currState, toRename) + + case ld@LabelDef(name, params, rhs) if rhs exists isAwait => + val startLabelState = stateAssigner.nextState() + val afterLabelState = stateAssigner.nextState() + asyncStates += stateBuilder.resultWithLabel(startLabelState) + val (stats, expr) = statsAndExpr(rhs) + labelDefStates(ld.symbol.name) = startLabelState + val builder = new AsyncBlockBuilder(stats, expr, startLabelState, afterLabelState, toRename) + asyncStates ++= builder.asyncStates + + currState = afterLabelState + stateBuilder = new AsyncStateBuilder(currState, toRename) + case _ => checkForUnsupportedAwait(stat) stateBuilder += stat } diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index b79be87..e37f66d 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -78,6 +78,11 @@ class TransformUtils[C <: Context](val c: C) { } } + protected def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match { + case Block(stats, expr) => (stats, expr) + case _ => (List(tree), Literal(Constant(()))) + } + private[async] object defn { def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) @@ -157,4 +162,5 @@ class TransformUtils[C <: Context](val c: C) { def ValDef(tree: Tree)(mods: Modifiers, name: TermName, tpt: Tree, rhs: Tree): ValDef = copyAttach(tree, c.universe.ValDef(mods, name, tpt, rhs)) } + } diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index cf5948c..9ac0dce 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -36,23 +36,28 @@ class TreeInterrogation { } - // @Test + @Test def sandbox() { val cm = reflect.runtime.currentMirror val tb = mkToolbox("-cp target/scala-2.10/classes") val tree = tb.parse( - """| import _root_.scala.async.AsyncId._ + """ import _root_.scala.async.AsyncId._ | async { - | var x = 0 - | var y = 0 - | while (x <= 2) { - | y = await(x) - | x += 1 - | } - | y + | var xxx: Int = 0 + | var y = 0 + | println("before while") + | while (xxx < 3) { + | println("in while before await") + | y = await(xxx) + | println("in while after await") + | xxx = xxx + 1 + | } + | println("after while") + | y | }""".stripMargin) - val tree1 = tb.typeCheck(tree) - - println(cm.universe.show(tree1)) + //println(tree) + val tree1 = tb.typeCheck(tree.duplicate) + //println(cm.universe.show(tree1)) + //println(tb.eval(tree)) } } diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala index d400729..a0c4e4d 100644 --- a/src/test/scala/scala/async/neg/NakedAwait.scala +++ b/src/test/scala/scala/async/neg/NakedAwait.scala @@ -117,20 +117,4 @@ class NakedAwait { """.stripMargin } } - - @Test - def whileBody() { - expectError("await must not be used in this position") { - """ import _root_.scala.async.AsyncId._ - | async { - | var x = 0 - | var y = 0 - | while (x <= 2) { - | y = await(x) - | x += 1 - | } - | y - | }""".stripMargin - } - } } diff --git a/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala b/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala new file mode 100644 index 0000000..d08e2c5 --- /dev/null +++ b/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala @@ -0,0 +1,43 @@ +package scala.async +package run +package ifelse0 + +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.junit.Test + +@RunWith(classOf[JUnit4]) +class WhileSpec { + + @Test + def whiling1() { + import AsyncId._ + + val result = async { + var xxx: Int = 0 + var y = 0 + while (xxx < 3) { + y = await(xxx) + xxx = xxx + 1 + } + y + } + result mustBe (2) + } + + @Test + def whiling2() { + import AsyncId._ + + val result = async { + var xxx: Int = 0 + var y = 0 + while (false) { + y = await(xxx) + xxx = xxx + 1 + } + y + } + result mustBe (0) + } +} \ No newline at end of file -- cgit v1.2.3