From 4fc54635380eca4beb13678f24ba0a8e4d592bec Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Fri, 22 Nov 2013 11:17:22 +0100 Subject: Fix crashers in do/while and while(await(..)) The new tree shapes handled for do/while look like: // type checked async({ val b = false; doWhile$1(){ await(()); if (b) doWhile$1() else () }; () }) We had to change ExprBuilder to create states for the if/else that concludes the doWhile body, and also loosen the assertion that the label jump must be the last thing we see. We also have to look for more than just `containsAwait` when deciding whether an `If` needs to be transformed into states; it might also contain a jump to the enclosing label that is on the other side of an `await`, and hence needs to be a state transition instead. --- .../scala/scala/async/internal/AnfTransform.scala | 9 ++++- .../scala/scala/async/internal/ExprBuilder.scala | 8 +++-- .../scala/async/internal/TransformUtils.scala | 13 +++++++ src/test/scala/scala/async/TreeInterrogation.scala | 37 ++++---------------- .../scala/scala/async/run/ifelse0/WhileSpec.scala | 40 ++++++++++++++++++++++ 5 files changed, 73 insertions(+), 34 deletions(-) diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala index f19a87a..8518cf5 100644 --- a/src/main/scala/scala/async/internal/AnfTransform.scala +++ b/src/main/scala/scala/async/internal/AnfTransform.scala @@ -162,7 +162,14 @@ private[async] trait AnfTransform { def _transformToList(tree: Tree): List[Tree] = trace(tree) { val containsAwait = tree exists isAwait if (!containsAwait) { - List(tree) + tree match { + case Block(stats, expr) => + // avoids nested block in `while(await(false)) ...`. + // TODO I think `containsAwait` really should return true if the code contains a label jump to an enclosing + // while/doWhile and there is an await *anywhere* inside that construct. + stats :+ expr + case _ => List(tree) + } } else tree match { case Select(qual, sel) => val stats :+ expr = linearize.transformToList(qual) diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala index 85e0953..b0cd914 100644 --- a/src/main/scala/scala/async/internal/ExprBuilder.scala +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -127,7 +127,11 @@ trait ExprBuilder { private var nextJumpState: Option[Int] = None def +=(stat: Tree): this.type = { - assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat") + stat match { + case Literal(Constant(())) => // This case occurs in do/while + case _ => + assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat") + } def addStat() = stats += stat stat match { case Apply(fun, Nil) => @@ -228,7 +232,7 @@ trait ExprBuilder { currState = afterAwaitState stateBuilder = new AsyncStateBuilder(currState, symLookup) - case If(cond, thenp, elsep) if stat exists isAwait => + case If(cond, thenp, elsep) if (stat exists isAwait) || containsForiegnLabelJump(stat) => checkForUnsupportedAwait(cond) val thenStartState = nextState() diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala index 71fddaa..e382c62 100644 --- a/src/main/scala/scala/async/internal/TransformUtils.scala +++ b/src/main/scala/scala/async/internal/TransformUtils.scala @@ -96,6 +96,19 @@ private[async] trait TransformUtils { treeInfo.isExprSafeToInline(tree) } + // `while(await(x))` ... or `do { await(x); ... } while(...)` contain an `If` that loops; + // we must break that `If` into states so that it convert the label jump into a state machine + // transition + final def containsForiegnLabelJump(t: Tree): Boolean = { + val labelDefs = t.collect { + case ld: LabelDef => ld.symbol + }.toSet + t.exists { + case rt: RefTree => !(labelDefs contains rt.symbol) + case _ => false + } + } + /** Map a list of arguments to: * - A list of argument Trees * - A list of auxillary results. diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index c8fe2d6..b42726b 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -66,43 +66,18 @@ object TreeInterrogation extends App { withDebug { val cm = reflect.runtime.currentMirror val tb = mkToolbox("-cp ${toolboxClasspath} -Xprint:typer -uniqid") - import scala.async.internal.AsyncTestLV._ + import scala.async.internal.AsyncId._ val tree = tb.parse( """ - | import scala.async.internal.AsyncTestLV._ - | import scala.async.internal.AsyncTestLV - | - | case class MCell[T](var v: T) - | val f = async { MCell(1) } - | - | def m1(x: MCell[Int], y: Int): Int = - | async { x.v + y } - | case class Cell[T](v: T) - | + | import scala.async.internal.AsyncId._ | async { - | // state #1 - | val a: MCell[Int] = await(f) // await$13$1 - | // state #2 - | var y = MCell(0) - | - | while (a.v < 10) { - | // state #4 - | a.v = a.v + 1 - | y = MCell(await(a).v + 1) // await$14$1 - | // state #7 + | var b = true + | while(await(b)) { + | b = false | } - | - | // state #3 - | assert(AsyncTestLV.log.exists(entry => entry._1 == "await$14$1")) - | - | val b = await(m1(a, y.v)) // await$15$1 - | // state #8 - | assert(AsyncTestLV.log.exists(_ == ("a$1" -> MCell(10)))) - | assert(AsyncTestLV.log.exists(_ == ("y$1" -> MCell(11)))) - | b + | await(b) | } | - | | """.stripMargin) println(tree) val tree1 = tb.typeCheck(tree.duplicate) diff --git a/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala b/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala index 4b3c2aa..76e6b1e 100644 --- a/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala +++ b/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala @@ -76,4 +76,44 @@ class WhileSpec { } result mustBe () } + + @Test def doWhile() { + import AsyncId._ + val result = async { + var b = 0 + var x = "" + await(do { + x += "1" + x += await("2") + x += "3" + b += await(1) + } while (b < 2)) + await(x) + } + result mustBe "123123" + } + + @Test def whileAwaitCondition() { + import AsyncId._ + val result = async { + var b = true + while(await(b)) { + b = false + } + await(b) + } + result mustBe false + } + + @Test def doWhileAwaitCondition() { + import AsyncId._ + val result = async { + var b = true + do { + b = false + } while(await(b)) + b + } + result mustBe false + } } -- cgit v1.2.3