aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala102
-rw-r--r--src/test/scala/scala/async/run/match0/Match0.scala73
2 files changed, 152 insertions, 23 deletions
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index 655c26f..deb8d28 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -131,7 +131,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
s"AsyncState #$state, next = $nextState"
}
- class AsyncStateWithIf(stats: List[c.Tree], state: Int)
+ class AsyncStateWithoutAwait(stats: List[c.Tree], state: Int)
extends AsyncState(stats, state, 0) {
// nextState unused, since encoded in then and else branches
@@ -387,17 +387,50 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
this += If(cond,
Block(mkStateTree(thenState), Apply(Ident("resume"), List())),
Block(mkStateTree(elseState), Apply(Ident("resume"), List())))
- new AsyncStateWithIf(stats.toList, state) {
+ new AsyncStateWithoutAwait(stats.toList, state) {
override val varDefs = self.varDefs.toList
}
}
-
+
+ /**
+ * Build `AsyncState` ending with a match expression.
+ *
+ * The cases of the match simply resume at the state of their corresponding right-hand side.
+ *
+ * @param scrutTree tree of the scrutinee
+ * @param cases list of case definitions
+ * @param stateFirstCase state of the right-hand side of the first case
+ * @param perCaseBudget maximum number of states per case
+ * @return an `AsyncState` representing the match expression
+ */
+ def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], stateFirstCase: Int, perCasebudget: Int): AsyncState = {
+ // 1. build list of changed cases
+ val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match {
+ case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(num * perCasebudget + stateFirstCase), Apply(Ident("resume"), List())))
+ }
+ // 2. insert changed match tree at the end of the current state
+ this += Match(c.resetAllAttrs(scrutTree.duplicate), newCases)
+ new AsyncStateWithoutAwait(stats.toList, state) {
+ override val varDefs = self.varDefs.toList
+ }
+ }
+
override def toString: String = {
val statsBeforeAwait = stats.mkString("\n")
s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName"
}
}
+ /**
+ * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`).
+ *
+ * @param stats a list of expressions
+ * @param expr the last expression of the block
+ * @param startState the start state
+ * @param endState the state to continue with
+ * @param budget the maximum number of states in this block
+ * @param toRename a `Map` for renaming the given key symbols to the mangled value names
+ */
class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int,
budget: Int, private var toRename: Map[c.Symbol, c.Name]) {
val asyncStates = ListBuffer[builder.AsyncState]()
@@ -413,7 +446,15 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
case Apply(fun, _) if fun.symbol == awaitMethod => true
case _ => false
}) throw new FallbackToCpsException
-
+
+ def builderForBranch(tree: c.Tree, state: Int, nextState: Int, budget: Int, nameMap: Map[c.Symbol, c.Name]): AsyncBlockBuilder = {
+ val (branchStats, branchExpr) = tree match {
+ case Block(s, e) => (s, e)
+ case _ => (List(tree), Literal(Constant(())))
+ }
+ new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, budget, nameMap)
+ }
+
// populate asyncStates
for (stat <- stats) stat match {
// the val name = await(..) pattern
@@ -450,29 +491,44 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
asyncStates +=
// the two Int arguments are the start state of the then branch and the else branch, respectively
stateBuilder.resultWithIf(cond, currState + 1, currState + thenBudget)
-
- val thenBuilder = thenp match {
- case Block(thenStats, thenExpr) =>
- new AsyncBlockBuilder(thenStats, thenExpr, currState + 1, currState + ifBudget, thenBudget, toRename)
- case _ =>
- new AsyncBlockBuilder(List(thenp), Literal(Constant(())), currState + 1, currState + ifBudget, thenBudget, toRename)
+
+ List((thenp, currState + 1, thenBudget), (elsep, currState + thenBudget, elseBudget)) foreach { case (tree, state, branchBudget) =>
+ val builder = builderForBranch(tree, state, currState + ifBudget, branchBudget, toRename)
+ asyncStates ++= builder.asyncStates
+ toRename ++= builder.toRename
}
- asyncStates ++= thenBuilder.asyncStates
- toRename ++= thenBuilder.toRename
-
- val elseBuilder = elsep match {
- case Block(elseStats, elseExpr) =>
- new AsyncBlockBuilder(elseStats, elseExpr, currState + thenBudget, currState + ifBudget, elseBudget, toRename)
- case _ =>
- new AsyncBlockBuilder(List(elsep), Literal(Constant(())), currState + thenBudget, currState + ifBudget, elseBudget, toRename)
- }
- asyncStates ++= elseBuilder.asyncStates
- toRename ++= elseBuilder.toRename
-
+
// create new state builder for state `currState + ifBudget`
currState = currState + ifBudget
stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
-
+
+ case Match(scrutinee, cases) =>
+ vprintln("transforming match expr: " + stat)
+ checkForUnsupportedAwait(scrutinee)
+
+ val matchBudget: Int = remainingBudget / 2
+ remainingBudget -= matchBudget //TODO test if budget > 0
+ // state that we continue with after match: currState + matchBudget
+
+ val perCaseBudget: Int = matchBudget / cases.size
+ asyncStates +=
+ // the two Int arguments are the start state of the first case and the per-case state budget, respectively
+ stateBuilder.resultWithMatch(scrutinee, cases, currState + 1, perCaseBudget)
+
+ for ((cas, num) <- cases.zipWithIndex) {
+ val (casStats, casExpr) = cas match {
+ case CaseDef(_, _, Block(s, e)) => (s, e)
+ case CaseDef(_, _, rhs) => (List(rhs), Literal(Constant(())))
+ }
+ val builder = new AsyncBlockBuilder(casStats, casExpr, currState + (num * perCaseBudget) + 1, currState + matchBudget, perCaseBudget, toRename)
+ asyncStates ++= builder.asyncStates
+ toRename ++= builder.toRename
+ }
+
+ // create new state builder for state `currState + matchBudget`
+ currState = currState + matchBudget
+ stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
+
case _ =>
checkForUnsupportedAwait(stat)
stateBuilder += stat
diff --git a/src/test/scala/scala/async/run/match0/Match0.scala b/src/test/scala/scala/async/run/match0/Match0.scala
new file mode 100644
index 0000000..3c7e297
--- /dev/null
+++ b/src/test/scala/scala/async/run/match0/Match0.scala
@@ -0,0 +1,73 @@
+/**
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+package scala.async
+package run
+package match0
+
+import language.{reflectiveCalls, postfixOps}
+import scala.concurrent.{Future, ExecutionContext, future, Await}
+import scala.concurrent.duration._
+import scala.async.Async.{async, await}
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import org.junit.Test
+
+
+class TestMatchClass {
+
+ import ExecutionContext.Implicits.global
+
+ def m1(x: Int): Future[Int] = future {
+ Thread.sleep(1000)
+ x + 2
+ }
+
+ def m2(y: Int): Future[Int] = async {
+ val f = m1(y)
+ var z = 0
+ y match {
+ case 10 =>
+ val x1 = await(f)
+ z = x1 + 2
+ case 20 =>
+ val x2 = await(f)
+ z = x2 - 2
+ }
+ z
+ }
+
+ def m3(y: Int): Future[Int] = async {
+ val f = m1(y)
+ var z = 0
+ y match {
+ case 0 =>
+ val x2 = await(f)
+ z = x2 - 2
+ case 1 =>
+ val x1 = await(f)
+ z = x1 + 2
+ }
+ z
+ }
+}
+
+
+@RunWith(classOf[JUnit4])
+class MatchSpec {
+
+ @Test def `support await in a simple match expression`() {
+ val o = new TestMatchClass
+ val fut = o.m2(10) // matches first case
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (14)
+ }
+
+ @Test def `support await in a simple match expression 2`() {
+ val o = new TestMatchClass
+ val fut = o.m3(1) // matches second case
+ val res = Await.result(fut, 2 seconds)
+ res mustBe (5)
+ }
+}