From 11fe7a0cb359d59752d7ce442a9ccaad31ad157d Mon Sep 17 00:00:00 2001 From: phaller Date: Mon, 5 Nov 2012 17:29:48 +0100 Subject: Support await inside match expressions --- src/main/scala/scala/async/ExprBuilder.scala | 102 ++++++++++++++++----- src/test/scala/scala/async/run/match0/Match0.scala | 73 +++++++++++++++ 2 files changed, 152 insertions(+), 23 deletions(-) create mode 100644 src/test/scala/scala/async/run/match0/Match0.scala (limited to 'src') diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 655c26f..deb8d28 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -131,7 +131,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { s"AsyncState #$state, next = $nextState" } - class AsyncStateWithIf(stats: List[c.Tree], state: Int) + class AsyncStateWithoutAwait(stats: List[c.Tree], state: Int) extends AsyncState(stats, state, 0) { // nextState unused, since encoded in then and else branches @@ -387,17 +387,50 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { this += If(cond, Block(mkStateTree(thenState), Apply(Ident("resume"), List())), Block(mkStateTree(elseState), Apply(Ident("resume"), List()))) - new AsyncStateWithIf(stats.toList, state) { + new AsyncStateWithoutAwait(stats.toList, state) { override val varDefs = self.varDefs.toList } } - + + /** + * Build `AsyncState` ending with a match expression. + * + * The cases of the match simply resume at the state of their corresponding right-hand side. + * + * @param scrutTree tree of the scrutinee + * @param cases list of case definitions + * @param stateFirstCase state of the right-hand side of the first case + * @param perCaseBudget maximum number of states per case + * @return an `AsyncState` representing the match expression + */ + def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], stateFirstCase: Int, perCasebudget: Int): AsyncState = { + // 1. build list of changed cases + val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { + case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(num * perCasebudget + stateFirstCase), Apply(Ident("resume"), List()))) + } + // 2. insert changed match tree at the end of the current state + this += Match(c.resetAllAttrs(scrutTree.duplicate), newCases) + new AsyncStateWithoutAwait(stats.toList, state) { + override val varDefs = self.varDefs.toList + } + } + override def toString: String = { val statsBeforeAwait = stats.mkString("\n") s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName" } } + /** + * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`). + * + * @param stats a list of expressions + * @param expr the last expression of the block + * @param startState the start state + * @param endState the state to continue with + * @param budget the maximum number of states in this block + * @param toRename a `Map` for renaming the given key symbols to the mangled value names + */ class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, budget: Int, private var toRename: Map[c.Symbol, c.Name]) { val asyncStates = ListBuffer[builder.AsyncState]() @@ -413,7 +446,15 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { case Apply(fun, _) if fun.symbol == awaitMethod => true case _ => false }) throw new FallbackToCpsException - + + def builderForBranch(tree: c.Tree, state: Int, nextState: Int, budget: Int, nameMap: Map[c.Symbol, c.Name]): AsyncBlockBuilder = { + val (branchStats, branchExpr) = tree match { + case Block(s, e) => (s, e) + case _ => (List(tree), Literal(Constant(()))) + } + new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, budget, nameMap) + } + // populate asyncStates for (stat <- stats) stat match { // the val name = await(..) pattern @@ -450,29 +491,44 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { asyncStates += // the two Int arguments are the start state of the then branch and the else branch, respectively stateBuilder.resultWithIf(cond, currState + 1, currState + thenBudget) - - val thenBuilder = thenp match { - case Block(thenStats, thenExpr) => - new AsyncBlockBuilder(thenStats, thenExpr, currState + 1, currState + ifBudget, thenBudget, toRename) - case _ => - new AsyncBlockBuilder(List(thenp), Literal(Constant(())), currState + 1, currState + ifBudget, thenBudget, toRename) + + List((thenp, currState + 1, thenBudget), (elsep, currState + thenBudget, elseBudget)) foreach { case (tree, state, branchBudget) => + val builder = builderForBranch(tree, state, currState + ifBudget, branchBudget, toRename) + asyncStates ++= builder.asyncStates + toRename ++= builder.toRename } - asyncStates ++= thenBuilder.asyncStates - toRename ++= thenBuilder.toRename - - val elseBuilder = elsep match { - case Block(elseStats, elseExpr) => - new AsyncBlockBuilder(elseStats, elseExpr, currState + thenBudget, currState + ifBudget, elseBudget, toRename) - case _ => - new AsyncBlockBuilder(List(elsep), Literal(Constant(())), currState + thenBudget, currState + ifBudget, elseBudget, toRename) - } - asyncStates ++= elseBuilder.asyncStates - toRename ++= elseBuilder.toRename - + // create new state builder for state `currState + ifBudget` currState = currState + ifBudget stateBuilder = new builder.AsyncStateBuilder(currState, toRename) - + + case Match(scrutinee, cases) => + vprintln("transforming match expr: " + stat) + checkForUnsupportedAwait(scrutinee) + + val matchBudget: Int = remainingBudget / 2 + remainingBudget -= matchBudget //TODO test if budget > 0 + // state that we continue with after match: currState + matchBudget + + val perCaseBudget: Int = matchBudget / cases.size + asyncStates += + // the two Int arguments are the start state of the first case and the per-case state budget, respectively + stateBuilder.resultWithMatch(scrutinee, cases, currState + 1, perCaseBudget) + + for ((cas, num) <- cases.zipWithIndex) { + val (casStats, casExpr) = cas match { + case CaseDef(_, _, Block(s, e)) => (s, e) + case CaseDef(_, _, rhs) => (List(rhs), Literal(Constant(()))) + } + val builder = new AsyncBlockBuilder(casStats, casExpr, currState + (num * perCaseBudget) + 1, currState + matchBudget, perCaseBudget, toRename) + asyncStates ++= builder.asyncStates + toRename ++= builder.toRename + } + + // create new state builder for state `currState + matchBudget` + currState = currState + matchBudget + stateBuilder = new builder.AsyncStateBuilder(currState, toRename) + case _ => checkForUnsupportedAwait(stat) stateBuilder += stat diff --git a/src/test/scala/scala/async/run/match0/Match0.scala b/src/test/scala/scala/async/run/match0/Match0.scala new file mode 100644 index 0000000..3c7e297 --- /dev/null +++ b/src/test/scala/scala/async/run/match0/Match0.scala @@ -0,0 +1,73 @@ +/** + * Copyright (C) 2012 Typesafe Inc. + */ + +package scala.async +package run +package match0 + +import language.{reflectiveCalls, postfixOps} +import scala.concurrent.{Future, ExecutionContext, future, Await} +import scala.concurrent.duration._ +import scala.async.Async.{async, await} +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.junit.Test + + +class TestMatchClass { + + import ExecutionContext.Implicits.global + + def m1(x: Int): Future[Int] = future { + Thread.sleep(1000) + x + 2 + } + + def m2(y: Int): Future[Int] = async { + val f = m1(y) + var z = 0 + y match { + case 10 => + val x1 = await(f) + z = x1 + 2 + case 20 => + val x2 = await(f) + z = x2 - 2 + } + z + } + + def m3(y: Int): Future[Int] = async { + val f = m1(y) + var z = 0 + y match { + case 0 => + val x2 = await(f) + z = x2 - 2 + case 1 => + val x1 = await(f) + z = x1 + 2 + } + z + } +} + + +@RunWith(classOf[JUnit4]) +class MatchSpec { + + @Test def `support await in a simple match expression`() { + val o = new TestMatchClass + val fut = o.m2(10) // matches first case + val res = Await.result(fut, 2 seconds) + res mustBe (14) + } + + @Test def `support await in a simple match expression 2`() { + val o = new TestMatchClass + val fut = o.m3(1) // matches second case + val res = Await.result(fut, 2 seconds) + res mustBe (5) + } +} -- cgit v1.2.3