aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2012-11-23 23:03:59 +0100
committerJason Zaugg <jzaugg@gmail.com>2012-11-23 23:03:59 +0100
commit4855a5ff60a0b4992da141054c074438b456c3fc (patch)
treea90ce01f030413c252335836237d4a8b2b8856ce
parentdb5fd4638c0aac51d66244404dad4dd779f184fa (diff)
downloadscala-async-4855a5ff60a0b4992da141054c074438b456c3fc.tar.gz
scala-async-4855a5ff60a0b4992da141054c074438b456c3fc.tar.bz2
scala-async-4855a5ff60a0b4992da141054c074438b456c3fc.zip
Support await in a while loop.
-rw-r--r--src/main/scala/scala/async/AnfTransform.scala3
-rw-r--r--src/main/scala/scala/async/AsyncAnalysis.scala2
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala56
-rw-r--r--src/main/scala/scala/async/TransformUtils.scala6
-rw-r--r--src/test/scala/scala/async/TreeInterrogation.scala29
-rw-r--r--src/test/scala/scala/async/neg/NakedAwait.scala16
-rw-r--r--src/test/scala/scala/async/run/ifelse0/WhileSpec.scala43
7 files changed, 112 insertions, 43 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala
index 4c78b5a..0756baf 100644
--- a/src/main/scala/scala/async/AnfTransform.scala
+++ b/src/main/scala/scala/async/AnfTransform.scala
@@ -204,6 +204,9 @@ private[async] final case class AnfTransform[C <: Context](override val c: C) ex
}
scrutStats :+ c.typeCheck(attachCopy.Match(tree)(scrutExpr, caseDefs))
+ case LabelDef(name, params, rhs) if containsAwait =>
+ List(LabelDef(name, params, Block(inline.transformToList(rhs), Literal(Constant(())))))
+
case TypeApply(fun, targs) if containsAwait =>
val funStats :+ simpleFun = inline.transformToList(fun)
funStats :+ attachCopy.TypeApply(tree)(simpleFun, targs).setSymbol(tree.symbol)
diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala
index 88a1bb0..9e24130 100644
--- a/src/main/scala/scala/async/AsyncAnalysis.scala
+++ b/src/main/scala/scala/async/AsyncAnalysis.scala
@@ -93,6 +93,8 @@ private[async] final case class AsyncAnalysis[C <: Context](override val c: C) e
traverseChunks(List(cond, thenp, elsep))
case Match(selector, cases) if tree exists isAwait =>
traverseChunks(selector :: cases)
+ case LabelDef(name, params, rhs) if rhs exists isAwait =>
+ traverseChunks(rhs :: Nil)
case Apply(fun, args) if isAwait(fun) =>
super.traverse(tree)
nextChunk()
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index 8ea7ecf..60430c4 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -54,8 +54,21 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C
private def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef =
mkHandlerCase(num, Block(rhs: _*))
- private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef =
- CaseDef(c.literal(num).tree, EmptyTree, rhs)
+ private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = {
+ val rhs1 = new Transformer {
+ override def transform(tree: Tree): Tree = tree match {
+ case Apply(Ident(name), args) =>
+ val jumpTarget = labelDefStates get name // TODO attempt to be symful
+ jumpTarget match {
+ case Some(state) => Return(Block(mkStateTree(state), mkResumeApply))
+ case None => super.transform(tree)
+ }
+ case _ => super.transform(tree)
+ }
+ }.transform(rhs)
+
+ CaseDef(c.literal(num).tree, EmptyTree, rhs1)
+ }
class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) {
val body: c.Tree = stats match {
@@ -209,7 +222,7 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C
def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = {
// 1. build changed if-else tree
// 2. insert that tree at the end of the current state
- val cond = resetDuplicate(condTree)
+ val cond = resetDuplicate(renamer.transform(condTree))
this += If(cond,
Block(mkStateTree(thenState), mkResumeApply),
Block(mkStateTree(elseState), mkResumeApply))
@@ -240,6 +253,13 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C
}
}
+ def resultWithLabel(startLabelState: Int): AsyncState = {
+ this += Block(mkStateTree(startLabelState), mkResumeApply)
+ new AsyncStateWithoutAwait(stats.toList, state) {
+ override val varDefs = self.varDefs.toList
+ }
+ }
+
override def toString: String = {
val statsBeforeAwait = stats.mkString("\n")
s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName"
@@ -248,6 +268,8 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C
val stateAssigner = new StateAssigner
+ val labelDefStates = collection.mutable.Map[Name, Int]()
+
/**
* An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`).
*
@@ -262,7 +284,6 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C
val asyncStates = ListBuffer[builder.AsyncState]()
private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename)
- // current state builder
private var currState = startState
/* TODO Fall back to CPS plug-in if tree contains an `await` call. */
@@ -272,10 +293,7 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C
}) c.abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException
def builderForBranch(tree: c.Tree, state: Int, nextState: Int): AsyncBlockBuilder = {
- val (branchStats, branchExpr) = tree match {
- case Block(s, e) => (s, e)
- case _ => (List(tree), c.literalUnit.tree)
- }
+ val (branchStats, branchExpr) = statsAndExpr(tree)
new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, toRename)
}
@@ -324,18 +342,26 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C
stateBuilder.resultWithMatch(scrutinee, cases, caseStates)
for ((cas, num) <- cases.zipWithIndex) {
- val (casStats, casExpr) = cas match {
- case CaseDef(_, _, Block(s, e)) => (s, e)
- case CaseDef(_, _, rhs) => (List(rhs), c.literalUnit.tree)
- }
+ val (casStats, casExpr) = statsAndExpr(cas.body)
val builder = new AsyncBlockBuilder(casStats, casExpr, caseStates(num), afterMatchState, toRename)
asyncStates ++= builder.asyncStates
}
currState = afterMatchState
- stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
-
- case _ =>
+ stateBuilder = new AsyncStateBuilder(currState, toRename)
+
+ case ld@LabelDef(name, params, rhs) if rhs exists isAwait =>
+ val startLabelState = stateAssigner.nextState()
+ val afterLabelState = stateAssigner.nextState()
+ asyncStates += stateBuilder.resultWithLabel(startLabelState)
+ val (stats, expr) = statsAndExpr(rhs)
+ labelDefStates(ld.symbol.name) = startLabelState
+ val builder = new AsyncBlockBuilder(stats, expr, startLabelState, afterLabelState, toRename)
+ asyncStates ++= builder.asyncStates
+
+ currState = afterLabelState
+ stateBuilder = new AsyncStateBuilder(currState, toRename)
+ case _ =>
checkForUnsupportedAwait(stat)
stateBuilder += stat
}
diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala
index b79be87..e37f66d 100644
--- a/src/main/scala/scala/async/TransformUtils.scala
+++ b/src/main/scala/scala/async/TransformUtils.scala
@@ -78,6 +78,11 @@ class TransformUtils[C <: Context](val c: C) {
}
}
+ protected def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match {
+ case Block(stats, expr) => (stats, expr)
+ case _ => (List(tree), Literal(Constant(())))
+ }
+
private[async] object defn {
def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = {
c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree)))
@@ -157,4 +162,5 @@ class TransformUtils[C <: Context](val c: C) {
def ValDef(tree: Tree)(mods: Modifiers, name: TermName, tpt: Tree, rhs: Tree): ValDef =
copyAttach(tree, c.universe.ValDef(mods, name, tpt, rhs))
}
+
}
diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala
index cf5948c..9ac0dce 100644
--- a/src/test/scala/scala/async/TreeInterrogation.scala
+++ b/src/test/scala/scala/async/TreeInterrogation.scala
@@ -36,23 +36,28 @@ class TreeInterrogation {
}
- // @Test
+ @Test
def sandbox() {
val cm = reflect.runtime.currentMirror
val tb = mkToolbox("-cp target/scala-2.10/classes")
val tree = tb.parse(
- """| import _root_.scala.async.AsyncId._
+ """ import _root_.scala.async.AsyncId._
| async {
- | var x = 0
- | var y = 0
- | while (x <= 2) {
- | y = await(x)
- | x += 1
- | }
- | y
+ | var xxx: Int = 0
+ | var y = 0
+ | println("before while")
+ | while (xxx < 3) {
+ | println("in while before await")
+ | y = await(xxx)
+ | println("in while after await")
+ | xxx = xxx + 1
+ | }
+ | println("after while")
+ | y
| }""".stripMargin)
- val tree1 = tb.typeCheck(tree)
-
- println(cm.universe.show(tree1))
+ //println(tree)
+ val tree1 = tb.typeCheck(tree.duplicate)
+ //println(cm.universe.show(tree1))
+ //println(tb.eval(tree))
}
}
diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala
index d400729..a0c4e4d 100644
--- a/src/test/scala/scala/async/neg/NakedAwait.scala
+++ b/src/test/scala/scala/async/neg/NakedAwait.scala
@@ -117,20 +117,4 @@ class NakedAwait {
""".stripMargin
}
}
-
- @Test
- def whileBody() {
- expectError("await must not be used in this position") {
- """ import _root_.scala.async.AsyncId._
- | async {
- | var x = 0
- | var y = 0
- | while (x <= 2) {
- | y = await(x)
- | x += 1
- | }
- | y
- | }""".stripMargin
- }
- }
}
diff --git a/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala b/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala
new file mode 100644
index 0000000..d08e2c5
--- /dev/null
+++ b/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala
@@ -0,0 +1,43 @@
+package scala.async
+package run
+package ifelse0
+
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import org.junit.Test
+
+@RunWith(classOf[JUnit4])
+class WhileSpec {
+
+ @Test
+ def whiling1() {
+ import AsyncId._
+
+ val result = async {
+ var xxx: Int = 0
+ var y = 0
+ while (xxx < 3) {
+ y = await(xxx)
+ xxx = xxx + 1
+ }
+ y
+ }
+ result mustBe (2)
+ }
+
+ @Test
+ def whiling2() {
+ import AsyncId._
+
+ val result = async {
+ var xxx: Int = 0
+ var y = 0
+ while (false) {
+ y = await(xxx)
+ xxx = xxx + 1
+ }
+ y
+ }
+ result mustBe (0)
+ }
+} \ No newline at end of file