diff options
author | Jason Zaugg <jzaugg@gmail.com> | 2012-11-05 13:58:48 +0100 |
---|---|---|
committer | Jason Zaugg <jzaugg@gmail.com> | 2012-11-05 16:32:01 +0100 |
commit | 3f36c1ea4b95ba046fa378ade19ca368e6e5c21b (patch) | |
tree | 7c292dd5483b87169f1c9c66ea449cb8d598ac63 /src/main | |
parent | 610e649174ba7fa699ba076aa5996af3f6de9519 (diff) | |
download | scala-async-3f36c1ea4b95ba046fa378ade19ca368e6e5c21b.tar.gz scala-async-3f36c1ea4b95ba046fa378ade19ca368e6e5c21b.tar.bz2 scala-async-3f36c1ea4b95ba046fa378ade19ca368e6e5c21b.zip |
A minimal SBT build.
Doesn't execute tests yet.
Diffstat (limited to 'src/main')
-rw-r--r-- | src/main/scala/scala/async/Async.scala | 135 | ||||
-rw-r--r-- | src/main/scala/scala/async/AsyncUtils.scala | 36 | ||||
-rw-r--r-- | src/main/scala/scala/async/ExprBuilder.scala | 519 |
3 files changed, 690 insertions, 0 deletions
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala new file mode 100644 index 0000000..2c81bc3 --- /dev/null +++ b/src/main/scala/scala/async/Async.scala @@ -0,0 +1,135 @@ +/** + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ +package scala.async + +import language.experimental.macros + +import scala.reflect.macros.Context +import scala.collection.mutable.ListBuffer +import scala.concurrent.{ Future, Promise, ExecutionContext, future } +import ExecutionContext.Implicits.global +import scala.util.control.NonFatal +import scala.util.continuations.{ shift, reset, cpsParam } + +/* Extending `ControlThrowable`, by default, also avoids filling in the stack trace. */ +class FallbackToCpsException extends scala.util.control.ControlThrowable + +/* + * @author Philipp Haller + */ +object Async extends AsyncUtils { + + def async[T](body: T): Future[T] = macro asyncImpl[T] + + def await[T](awaitable: Future[T]): T = ??? + + /* Fall back for `await` when it is called at an unsupported position. + */ + def awaitCps[T, U](awaitable: Future[T], p: Promise[U]): T @cpsParam[U, Unit] = + shift { + (k: (T => U)) => + awaitable onComplete { + case tr => p.success(k(tr.get)) + } + } + + def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = { + import c.universe._ + import Flag._ + + val builder = new ExprBuilder[c.type](c) + val awaitMethod = awaitSym(c) + + try { + body.tree match { + case Block(stats, expr) => + val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000, Map()) + + vprintln(s"states of current method:") + asyncBlockBuilder.asyncStates foreach vprintln + + val handlerExpr = asyncBlockBuilder.mkCombinedHandlerExpr() + + vprintln(s"GENERATED handler expr:") + vprintln(handlerExpr) + + val handlerForLastState: c.Expr[PartialFunction[Int, Unit]] = { + val tree = Apply(Select(Ident("result"), newTermName("success")), + List(asyncBlockBuilder.asyncStates.last.body)) + builder.mkHandler(asyncBlockBuilder.asyncStates.last.state, c.Expr[Unit](tree)) + } + + vprintln("GENERATED handler for last state:") + vprintln(handlerForLastState) + + val localVarTrees = asyncBlockBuilder.asyncStates.init.flatMap(_.allVarDefs).toList + + /* + def resume(): Unit = { + try { + (handlerExpr.splice orElse handlerForLastState.splice)(state) + } catch { + case NonFatal(t) => result.failure(t) + } + } + */ + val resumeFunTree: c.Tree = DefDef(Modifiers(), newTermName("resume"), List(), List(List()), Ident(definitions.UnitClass), + Try(Apply(Select( + Apply(Select(handlerExpr.tree, newTermName("orElse")), List(handlerForLastState.tree)), + newTermName("apply")), List(Ident(newTermName("state")))), + List( + CaseDef( + Apply(Select(Select(Select(Ident(newTermName("scala")), newTermName("util")), newTermName("control")), newTermName("NonFatal")), List(Bind(newTermName("t"), Ident(nme.WILDCARD)))), + EmptyTree, + Block(List( + Apply(Select(Ident(newTermName("result")), newTermName("failure")), List(Ident(newTermName("t"))))), + Literal(Constant(()))))), EmptyTree)) + + reify { + val result = Promise[T]() + var state = 0 + future { + c.Expr(Block( + localVarTrees :+ resumeFunTree, + Apply(Ident(newTermName("resume")), List()))).splice + } + result.future + } + + case _ => + // issue error message + reify { + sys.error("expression not supported by async") + } + } + } catch { + case _: FallbackToCpsException => + // replace `await` invocations with `awaitCps` invocations + val awaitReplacer = new Transformer { + val awaitCpsMethod = awaitCpsSym(c) + override def transform(tree: Tree): Tree = tree match { + case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == awaitMethod => + val typeApp = treeCopy.TypeApply(fun, Ident(awaitCpsMethod), List(TypeTree(futArgTpt.tpe), TypeTree(body.tree.tpe))) + treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate)) :+ Ident(newTermName("p"))) + + case _ => + super.transform(tree) + } + } + + val newBody = awaitReplacer.transform(body.tree) + + reify { + val p = Promise[T]() + future { + reset { + c.Expr(c.resetAllAttrs(newBody.duplicate)).asInstanceOf[c.Expr[T]].splice + } + } + p.future + } + } + } + +} diff --git a/src/main/scala/scala/async/AsyncUtils.scala b/src/main/scala/scala/async/AsyncUtils.scala new file mode 100644 index 0000000..19e9d92 --- /dev/null +++ b/src/main/scala/scala/async/AsyncUtils.scala @@ -0,0 +1,36 @@ +/** + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ +package scala.async + +import scala.reflect.macros.Context + +/* + * @author Philipp Haller + */ +trait AsyncUtils { + + val verbose = false + + protected def vprintln(s: Any): Unit = if (verbose) + println("[async] "+s) + + /* Symbol of the `Async.await` method in context `c`. + */ + protected def awaitSym(c: Context): c.universe.Symbol = { + val asyncMod = c.mirror.staticModule("scala.async.Async") + val tpe = asyncMod.moduleClass.asType.toType + tpe.member(c.universe.newTermName("await")) + } + + protected def awaitCpsSym(c: Context): c.universe.Symbol = { + val asyncMod = c.mirror.staticModule("scala.async.Async") + val tpe = asyncMod.moduleClass.asType.toType + tpe.member(c.universe.newTermName("awaitCps")) + } + + private var cnt = 0 + protected[async] def freshString(prefix: String): String = + prefix + "$async$" + { cnt += 1; cnt } + +} diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala new file mode 100644 index 0000000..b7d6446 --- /dev/null +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -0,0 +1,519 @@ +/** + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ +package scala.async + +import scala.reflect.macros.Context +import scala.collection.mutable.{ ListBuffer, Builder } + +/* + * @author Philipp Haller + */ +class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { + builder => + + import c.universe._ + import Flag._ + + private val awaitMethod = awaitSym(c) + + /* Make a partial function literal handling case #num: + * + * { + * case any if any == num => rhs + * } + */ + def mkHandler(num: Int, rhs: c.Expr[Unit]): c.Expr[PartialFunction[Int, Unit]] = { +/* + val numLiteral = c.Expr[Int](Literal(Constant(num))) + + reify(new PartialFunction[Int, Unit] { + def isDefinedAt(`x$1`: Int) = + `x$1` == numLiteral.splice + def apply(`x$1`: Int) = `x$1` match { + case any: Int if any == numLiteral.splice => + rhs.splice + } + }) +*/ + val rhsTree = c.resetAllAttrs(rhs.tree.duplicate) + val handlerTree = mkHandlerTree(num, rhsTree) + c.Expr(handlerTree).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] + } + + def mkIncrStateTree(): c.Tree = + Assign( + Ident(newTermName("state")), + Apply(Select(Ident(newTermName("state")), newTermName("$plus")), List(Literal(Constant(1))))) + + def mkStateTree(nextState: Int): c.Tree = + Assign( + Ident(newTermName("state")), + Literal(Constant(nextState))) + + def mkVarDefTree(resultType: c.universe.Type, resultName: c.universe.TermName): c.Tree = { + val rhs = + if (resultType <:< definitions.IntTpe) Literal(Constant(0)) + else if (resultType <:< definitions.LongTpe) Literal(Constant(0L)) + else if (resultType <:< definitions.BooleanTpe) Literal(Constant(false)) + else Literal(Constant(null)) + ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs) + } + + def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = + CaseDef( + // pattern + Bind(newTermName("any"), Typed(Ident(nme.WILDCARD), Ident(definitions.IntClass))), + // guard + Apply(Select(Ident(newTermName("any")), newTermName("$eq$eq")), List(Literal(Constant(num)))), + rhs + ) + + def mkHandlerTreeFor(cases: List[(CaseDef, Int)]): c.Tree = { + val partFunIdent = Ident(c.mirror.staticClass("scala.PartialFunction")) + val intIdent = Ident(definitions.IntClass) + val unitIdent = Ident(definitions.UnitClass) + + val caseCheck = + Apply(Select(Apply(Select(Ident(newTermName("List")), newTermName("apply")), + cases.map(p => Literal(Constant(p._2)))), newTermName("contains")), List(Ident(newTermName("x$1")))) + + Block(List( + // anonymous subclass of PartialFunction[Int, Unit] + ClassDef(Modifiers(FINAL), newTypeName("$anon"), List(), Template(List(AppliedTypeTree(partFunIdent, List(intIdent, unitIdent))), + emptyValDef, List( + DefDef(Modifiers(), nme.CONSTRUCTOR, List(), List(List()), TypeTree(), + Block(List(Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), List())), Literal(Constant(())))), + + DefDef(Modifiers(), newTermName("isDefinedAt"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(), + caseCheck), + + DefDef(Modifiers(), newTermName("apply"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(), + Match(Ident(newTermName("x$1")), cases.map(_._1)) // combine all cases into a single match + ) + )) + )), + Apply(Select(New(Ident(newTypeName("$anon"))), nme.CONSTRUCTOR), List()) + ) + } + + def mkHandlerTree(num: Int, rhs: c.Tree): c.Tree = + mkHandlerTreeFor(List(mkHandlerCase(num, rhs) -> num)) + + class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) { + val body: c.Tree = + if (stats.size == 1) stats.head + else Block(stats: _*) + + val varDefs: List[(c.universe.TermName, c.universe.Type)] = List() + + def mkHandlerCaseForState(): CaseDef = + mkHandlerCase(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*)) + + def mkHandlerTreeForState(): c.Tree = + mkHandlerTree(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*)) + + def mkHandlerTreeForState(nextState: Int): c.Tree = + mkHandlerTree(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*)) + + def varDefForResult: Option[c.Tree] = + None + + def allVarDefs: List[c.Tree] = + varDefForResult.toList ++ varDefs.map(p => mkVarDefTree(p._2, p._1)) + + override val toString: String = + s"AsyncState #$state, next = $nextState" + } + + class AsyncStateWithIf(stats: List[c.Tree], state: Int) + extends AsyncState(stats, state, 0) { // nextState unused, since encoded in then and else branches + + override def mkHandlerTreeForState(): c.Tree = + mkHandlerTree(state, Block(stats: _*)) + + //TODO mkHandlerTreeForState(nextState: Int) + + override def mkHandlerCaseForState(): CaseDef = + mkHandlerCase(state, Block(stats: _*)) + + override val toString: String = + s"AsyncStateWithIf #$state, next = $nextState" + } + + abstract class AsyncStateWithAwait(stats: List[c.Tree], state: Int, nextState: Int) + extends AsyncState(stats, state, nextState) { + val awaitable: c.Tree + val resultName: c.universe.TermName + val resultType: c.universe.Type + + override val toString: String = + s"AsyncStateWithAwait #$state, next = $nextState" + + /* Make an `onComplete` invocation: + * + * awaitable.onComplete { + * case tr => + * resultName = tr.get + * resume() + * } + */ + def mkOnCompleteTree: c.Tree = { + val assignTree = + Assign( + Ident(resultName.toString), + Select(Ident("tr"), c.universe.newTermName("get")) + ) + val handlerTree = + Match( + EmptyTree, + List( + CaseDef(Bind(c.universe.newTermName("tr"), Ident("_")), EmptyTree, + Block(assignTree, Apply(Ident("resume"), List())) // rhs of case + ) + ) + ) + Apply( + Select(awaitable, c.universe.newTermName("onComplete")), + List(handlerTree) + ) + } + + /* Make an `onComplete` invocation which increments the state upon resuming: + * + * awaitable.onComplete { + * case tr => + * resultName = tr.get + * state += 1 + * resume() + * } + */ + def mkOnCompleteIncrStateTree: c.Tree = { + val tryGetTree = + Assign( + Ident(resultName.toString), + Select(Ident("tr"), c.universe.newTermName("get")) + ) + val handlerTree = + Match( + EmptyTree, + List( + CaseDef(Bind(c.universe.newTermName("tr"), Ident("_")), EmptyTree, + Block(tryGetTree, mkIncrStateTree(), Apply(Ident("resume"), List())) // rhs of case + ) + ) + ) + Apply( + Select(awaitable, c.universe.newTermName("onComplete")), + List(handlerTree) + ) + } + + /* Make an `onComplete` invocation which sets the state to `nextState` upon resuming: + * + * awaitable.onComplete { + * case tr => + * resultName = tr.get + * state = `nextState` + * resume() + * } + */ + def mkOnCompleteStateTree(nextState: Int): c.Tree = { + val tryGetTree = + Assign( + Ident(resultName.toString), + Select(Ident("tr"), c.universe.newTermName("get")) + ) + val handlerTree = + Match( + EmptyTree, + List( + CaseDef(Bind(c.universe.newTermName("tr"), Ident("_")), EmptyTree, + Block(tryGetTree, mkStateTree(nextState), Apply(Ident("resume"), List())) // rhs of case + ) + ) + ) + Apply( + Select(awaitable, c.universe.newTermName("onComplete")), + List(handlerTree) + ) + } + + /* Make a partial function literal handling case #num: + * + * { + * case any if any == num => + * stats + * awaitable.onComplete { + * case tr => + * resultName = tr.get + * resume() + * } + * } + */ + def mkHandlerForState(num: Int): c.Expr[PartialFunction[Int, Unit]] = { + assert(awaitable != null) + builder.mkHandler(num, c.Expr[Unit](Block((stats :+ mkOnCompleteTree): _*))) + } + + /* Make a partial function literal handling case #num: + * + * { + * case any if any == num => + * stats + * awaitable.onComplete { + * case tr => + * resultName = tr.get + * state += 1 + * resume() + * } + * } + */ + override def mkHandlerTreeForState(): c.Tree = { + assert(awaitable != null) + mkHandlerTree(state, Block((stats :+ mkOnCompleteIncrStateTree): _*)) + } + + override def mkHandlerTreeForState(nextState: Int): c.Tree = { + assert(awaitable != null) + mkHandlerTree(state, Block((stats :+ mkOnCompleteStateTree(nextState)): _*)) + } + + override def mkHandlerCaseForState(): CaseDef = { + assert(awaitable != null) + mkHandlerCase(state, Block((stats :+ mkOnCompleteIncrStateTree): _*)) + } + + override def varDefForResult: Option[c.Tree] = { + val rhs = + if (resultType <:< definitions.IntTpe) Literal(Constant(0)) + else if (resultType <:< definitions.LongTpe) Literal(Constant(0L)) + else if (resultType <:< definitions.BooleanTpe) Literal(Constant(false)) + else if (resultType <:< definitions.FloatTpe) Literal(Constant(0.0f)) + else if (resultType <:< definitions.DoubleTpe) Literal(Constant(0.0d)) + else if (resultType <:< definitions.CharTpe) Literal(Constant(0.toChar)) + else if (resultType <:< definitions.ShortTpe) Literal(Constant(0.toShort)) + else if (resultType <:< definitions.ByteTpe) Literal(Constant(0.toByte)) + else Literal(Constant(null)) + Some( + ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs) + ) + } + } + + /* + * Builder for a single state of an async method. + */ + class AsyncStateBuilder(state: Int, private var nameMap: Map[c.Symbol, c.Name]) extends Builder[c.Tree, AsyncState] { + self => + + /* Statements preceding an await call. */ + private val stats = ListBuffer[c.Tree]() + + /* Argument of an await call. */ + var awaitable: c.Tree = null + + /* Result name of an await call. */ + var resultName: c.universe.TermName = null + + /* Result type of an await call. */ + var resultType: c.universe.Type = null + + var nextState: Int = state + 1 + + private val varDefs = ListBuffer[(c.universe.TermName, c.universe.Type)]() + + private val renamer = new Transformer { + override def transform(tree: Tree) = tree match { + case Ident(_) if nameMap.keySet contains tree.symbol => + Ident(nameMap(tree.symbol)) + case _ => + super.transform(tree) + } + } + + def += (stat: c.Tree): this.type = { + stats += c.resetAllAttrs(renamer.transform(stat).duplicate) + this + } + + //TODO do not ignore `mods` + def addVarDef(mods: Any, name: c.universe.TermName, tpt: c.Tree, rhs: c.Tree, extNameMap: Map[c.Symbol, c.Name]): this.type = { + varDefs += (name -> tpt.tpe) + nameMap ++= extNameMap // update name map + this += Assign(Ident(name), c.resetAllAttrs(renamer.transform(rhs).duplicate)) + this + } + + def result(): AsyncState = + if (awaitable == null) + new AsyncState(stats.toList, state, nextState) { + override val varDefs = self.varDefs.toList + } + else + new AsyncStateWithAwait(stats.toList, state, nextState) { + val awaitable = self.awaitable + val resultName = self.resultName + val resultType = self.resultType + override val varDefs = self.varDefs.toList + } + + def clear(): Unit = { + stats.clear() + awaitable = null + resultName = null + resultType = null + } + + /* Result needs to be created as a var at the beginning of the transformed method body, so that + * it is visible in subsequent states of the state machine. + * + * @param awaitArg the argument of await + * @param awaitResultName the name of the variable that the result of await is assigned to + * @param awaitResultType the type of the result of await + */ + def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: Tree, extNameMap: Map[c.Symbol, c.Name], nextState: Int = state + 1): this.type = { + nameMap ++= extNameMap + awaitable = c.resetAllAttrs(renamer.transform(awaitArg).duplicate) + resultName = awaitResultName + resultType = awaitResultType.tpe + this.nextState = nextState + this + } + + def complete(nextState: Int): this.type = { + this.nextState = nextState + this + } + + def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = { + // 1. build changed if-else tree + // 2. insert that tree at the end of the current state + val cond = c.resetAllAttrs(condTree.duplicate) + this += If(cond, + Block(mkStateTree(thenState), Apply(Ident("resume"), List())), + Block(mkStateTree(elseState), Apply(Ident("resume"), List()))) + new AsyncStateWithIf(stats.toList, state) { + override val varDefs = self.varDefs.toList + } + } + + override def toString: String = { + val statsBeforeAwait = stats.mkString("\n") + s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName" + } + } + + class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, budget: Int, private var toRename: Map[c.Symbol, c.Name]) { + val asyncStates = ListBuffer[builder.AsyncState]() + + private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename) // current state builder + private var currState = startState + + private var remainingBudget = budget + + /* Fall back to CPS plug-in if tree contains an `await` call. */ + def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists { + case Apply(fun, _) if fun.symbol == awaitMethod => true + case _ => false + }) throw new FallbackToCpsException + + // populate asyncStates + for (stat <- stats) stat match { + // the val name = await(..) pattern + case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == awaitMethod => + val newName = newTermName(Async.freshString(name.toString())) + toRename += (stat.symbol -> newName) + + asyncStates += stateBuilder.complete(args(0), newName, tpt, toRename).result // complete with await + if (remainingBudget > 0) + remainingBudget -= 1 + else + assert(false, "too many invocations of `await` in current method") + currState += 1 + stateBuilder = new builder.AsyncStateBuilder(currState, toRename) + + case ValDef(mods, name, tpt, rhs) => + checkForUnsupportedAwait(rhs) + + val newName = newTermName(Async.freshString(name.toString())) + toRename += (stat.symbol -> newName) + // when adding assignment need to take `toRename` into account + stateBuilder.addVarDef(mods, newName, tpt, rhs, toRename) + + case If(cond, thenp, elsep) => + checkForUnsupportedAwait(cond) + + val ifBudget: Int = remainingBudget / 2 + remainingBudget -= ifBudget //TODO test if budget > 0 + // state that we continue with after if-else: currState + ifBudget + + val thenBudget: Int = ifBudget / 2 + val elseBudget = ifBudget - thenBudget + + asyncStates += + // the two Int arguments are the start state of the then branch and the else branch, respectively + stateBuilder.resultWithIf(cond, currState + 1, currState + thenBudget) + + val thenBuilder = thenp match { + case Block(thenStats, thenExpr) => + new AsyncBlockBuilder(thenStats, thenExpr, currState + 1, currState + ifBudget, thenBudget, toRename) + case _ => + new AsyncBlockBuilder(List(thenp), Literal(Constant(())), currState + 1, currState + ifBudget, thenBudget, toRename) + } + asyncStates ++= thenBuilder.asyncStates + toRename ++= thenBuilder.toRename + + val elseBuilder = elsep match { + case Block(elseStats, elseExpr) => + new AsyncBlockBuilder(elseStats, elseExpr, currState + thenBudget, currState + ifBudget, elseBudget, toRename) + case _ => + new AsyncBlockBuilder(List(elsep), Literal(Constant(())), currState + thenBudget, currState + ifBudget, elseBudget, toRename) + } + asyncStates ++= elseBuilder.asyncStates + toRename ++= elseBuilder.toRename + + // create new state builder for state `currState + ifBudget` + currState = currState + ifBudget + stateBuilder = new builder.AsyncStateBuilder(currState, toRename) + + case _ => + checkForUnsupportedAwait(stat) + stateBuilder += stat + } + // complete last state builder (representing the expressions after the last await) + stateBuilder += expr + val lastState = stateBuilder.complete(endState).result + asyncStates += lastState + + def mkCombinedHandlerExpr(): c.Expr[PartialFunction[Int, Unit]] = { + assert(asyncStates.size > 1) + + val cases = for (state <- asyncStates.toList) yield state.mkHandlerCaseForState() + c.Expr(mkHandlerTreeFor(cases zip asyncStates.init.map(_.state))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] + } + + /* Builds the handler expression for a sequence of async states. + */ + def mkHandlerExpr(): c.Expr[PartialFunction[Int, Unit]] = { + assert(asyncStates.size > 1) + + var handlerExpr = + c.Expr(asyncStates(0).mkHandlerTreeForState()).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] + + if (asyncStates.size == 2) + handlerExpr + else { + for (asyncState <- asyncStates.tail.init) { // do not traverse first or last state + val handlerTreeForNextState = asyncState.mkHandlerTreeForState() + val currentHandlerTreeNaked = c.resetAllAttrs(handlerExpr.tree.duplicate) + handlerExpr = c.Expr( + Apply(Select(currentHandlerTreeNaked, newTermName("orElse")), + List(handlerTreeForNextState))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] + } + handlerExpr + } + } + + } +} |