diff options
author | Jason Zaugg <jzaugg@gmail.com> | 2013-07-07 10:48:11 +1000 |
---|---|---|
committer | Jason Zaugg <jzaugg@gmail.com> | 2013-07-07 10:48:11 +1000 |
commit | 2d8506a64392cd7192b6831c38798cc9a7c8bfed (patch) | |
tree | 84eafcf1a9a179eeaa97dd1e3595c18351b2b814 /src/main/scala/scala/async/internal/ExprBuilder.scala | |
parent | c60c38ca6098402f7a9cc6d6746b664bb2b1306c (diff) | |
download | scala-async-2d8506a64392cd7192b6831c38798cc9a7c8bfed.tar.gz scala-async-2d8506a64392cd7192b6831c38798cc9a7c8bfed.tar.bz2 scala-async-2d8506a64392cd7192b6831c38798cc9a7c8bfed.zip |
Move implementation details to scala.async.internal._.
If we intend to keep CPS fallback around for any length of time
it should probably move there too.
Diffstat (limited to 'src/main/scala/scala/async/internal/ExprBuilder.scala')
-rw-r--r-- | src/main/scala/scala/async/internal/ExprBuilder.scala | 388 |
1 files changed, 388 insertions, 0 deletions
diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala new file mode 100644 index 0000000..1ce30e6 --- /dev/null +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -0,0 +1,388 @@ +/* + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ +package scala.async.internal + +import scala.reflect.macros.Context +import scala.collection.mutable.ListBuffer +import collection.mutable +import language.existentials +import scala.reflect.api.Universe +import scala.reflect.api +import scala.Some + +trait ExprBuilder { + builder: AsyncMacro => + + import global._ + import defn._ + + val futureSystem: FutureSystem + val futureSystemOps: futureSystem.Ops { val universe: global.type } + + val stateAssigner = new StateAssigner + val labelDefStates = collection.mutable.Map[Symbol, Int]() + + trait AsyncState { + def state: Int + + def mkHandlerCaseForState: CaseDef + + def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None + + def stats: List[Tree] + + final def allStats: List[Tree] = this match { + case a: AsyncStateWithAwait => stats :+ a.awaitable.resultValDef + case _ => stats + } + + final def body: Tree = stats match { + case stat :: Nil => stat + case init :+ last => Block(init, last) + } + } + + /** A sequence of statements the concludes with a unconditional transition to `nextState` */ + final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup) + extends AsyncState { + + def mkHandlerCaseForState: CaseDef = + mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup)) + + override val toString: String = + s"AsyncState #$state, next = $nextState" + } + + /** A sequence of statements with a conditional transition to the next state, which will represent + * a branch of an `if` or a `match`. + */ + final class AsyncStateWithoutAwait(val stats: List[Tree], val state: Int) extends AsyncState { + override def mkHandlerCaseForState: CaseDef = + mkHandlerCase(state, stats) + + override val toString: String = + s"AsyncStateWithoutAwait #$state" + } + + /** A sequence of statements that concludes with an `await` call. The `onComplete` + * handler will unconditionally transition to `nestState`.`` + */ + final class AsyncStateWithAwait(val stats: List[Tree], val state: Int, nextState: Int, + val awaitable: Awaitable, symLookup: SymLookup) + extends AsyncState { + + override def mkHandlerCaseForState: CaseDef = { + val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr), + Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree + mkHandlerCase(state, stats :+ callOnComplete) + } + + override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = { + val tryGetTree = + Assign( + Ident(awaitable.resultName), + TypeApply(Select(Select(Ident(symLookup.applyTrParam), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) + ) + + /* if (tr.isFailure) + * result.complete(tr.asInstanceOf[Try[T]]) + * else { + * <resultName> = tr.get.asInstanceOf[<resultType>] + * <nextState> + * <mkResumeApply> + * } + */ + val ifIsFailureTree = + If(Select(Ident(symLookup.applyTrParam), Try_isFailure), + futureSystemOps.completeProm[T]( + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), + Expr[scala.util.Try[T]]( + TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")), + List(TypeTree(weakTypeOf[scala.util.Try[T]]))))).tree, + Block(List(tryGetTree, mkStateTree(nextState, symLookup)), mkResumeApply(symLookup)) + ) + + Some(mkHandlerCase(state, List(ifIsFailureTree))) + } + + override val toString: String = + s"AsyncStateWithAwait #$state, next = $nextState" + } + + /* + * Builder for a single state of an async method. + */ + final class AsyncStateBuilder(state: Int, private val symLookup: SymLookup) { + /* Statements preceding an await call. */ + private val stats = ListBuffer[Tree]() + /** The state of the target of a LabelDef application (while loop jump) */ + private var nextJumpState: Option[Int] = None + + def +=(stat: Tree): this.type = { + assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat") + def addStat() = stats += stat + stat match { + case Apply(fun, Nil) => + labelDefStates get fun.symbol match { + case Some(nextState) => nextJumpState = Some(nextState) + case None => addStat() + } + case _ => addStat() + } + this + } + + def resultWithAwait(awaitable: Awaitable, + nextState: Int): AsyncState = { + val effectiveNextState = nextJumpState.getOrElse(nextState) + new AsyncStateWithAwait(stats.toList, state, effectiveNextState, awaitable, symLookup) + } + + def resultSimple(nextState: Int): AsyncState = { + val effectiveNextState = nextJumpState.getOrElse(nextState) + new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup) + } + + def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { + def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup)) + this += If(condTree, mkBranch(thenState), mkBranch(elseState)) + new AsyncStateWithoutAwait(stats.toList, state) + } + + /** + * Build `AsyncState` ending with a match expression. + * + * The cases of the match simply resume at the state of their corresponding right-hand side. + * + * @param scrutTree tree of the scrutinee + * @param cases list of case definitions + * @param caseStates starting state of the right-hand side of the each case + * @return an `AsyncState` representing the match expression + */ + def resultWithMatch(scrutTree: Tree, cases: List[CaseDef], caseStates: List[Int], symLookup: SymLookup): AsyncState = { + // 1. build list of changed cases + val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { + case CaseDef(pat, guard, rhs) => + val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal) + CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num), symLookup), mkResumeApply(symLookup))) + } + // 2. insert changed match tree at the end of the current state + this += Match(scrutTree, newCases) + new AsyncStateWithoutAwait(stats.toList, state) + } + + def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = { + this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup)) + new AsyncStateWithoutAwait(stats.toList, state) + } + + override def toString: String = { + val statsBeforeAwait = stats.mkString("\n") + s"ASYNC STATE:\n$statsBeforeAwait" + } + } + + /** + * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`). + * + * @param stats a list of expressions + * @param expr the last expression of the block + * @param startState the start state + * @param endState the state to continue with + */ + final private class AsyncBlockBuilder(stats: List[Tree], expr: Tree, startState: Int, endState: Int, + private val symLookup: SymLookup) { + val asyncStates = ListBuffer[AsyncState]() + + var stateBuilder = new AsyncStateBuilder(startState, symLookup) + var currState = startState + + /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ + def checkForUnsupportedAwait(tree: Tree) = if (tree exists { + case Apply(fun, _) if isAwait(fun) => true + case _ => false + }) abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException + + def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = { + val (nestedStats, nestedExpr) = statsAndExpr(nestedTree) + new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, symLookup) + } + + import stateAssigner.nextState + + // populate asyncStates + for (stat <- stats) stat match { + // the val name = await(..) pattern + case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => + val afterAwaitState = nextState() + val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd) + asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await + currState = afterAwaitState + stateBuilder = new AsyncStateBuilder(currState, symLookup) + + case If(cond, thenp, elsep) if stat exists isAwait => + checkForUnsupportedAwait(cond) + + val thenStartState = nextState() + val elseStartState = nextState() + val afterIfState = nextState() + + asyncStates += + // the two Int arguments are the start state of the then branch and the else branch, respectively + stateBuilder.resultWithIf(cond, thenStartState, elseStartState) + + List((thenp, thenStartState), (elsep, elseStartState)) foreach { + case (branchTree, state) => + val builder = nestedBlockBuilder(branchTree, state, afterIfState) + asyncStates ++= builder.asyncStates + } + + currState = afterIfState + stateBuilder = new AsyncStateBuilder(currState, symLookup) + + case Match(scrutinee, cases) if stat exists isAwait => + checkForUnsupportedAwait(scrutinee) + + val caseStates = cases.map(_ => nextState()) + val afterMatchState = nextState() + + asyncStates += + stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup) + + for ((cas, num) <- cases.zipWithIndex) { + val (stats, expr) = statsAndExpr(cas.body) + val stats1 = stats.dropWhile(isSyntheticBindVal) + val builder = nestedBlockBuilder(Block(stats1, expr), caseStates(num), afterMatchState) + asyncStates ++= builder.asyncStates + } + + currState = afterMatchState + stateBuilder = new AsyncStateBuilder(currState, symLookup) + + case ld@LabelDef(name, params, rhs) if rhs exists isAwait => + val startLabelState = nextState() + val afterLabelState = nextState() + asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup) + labelDefStates(ld.symbol) = startLabelState + val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState) + asyncStates ++= builder.asyncStates + + currState = afterLabelState + stateBuilder = new AsyncStateBuilder(currState, symLookup) + case _ => + checkForUnsupportedAwait(stat) + stateBuilder += stat + } + // complete last state builder (representing the expressions after the last await) + stateBuilder += expr + val lastState = stateBuilder.resultSimple(endState) + asyncStates += lastState + } + + trait AsyncBlock { + def asyncStates: List[AsyncState] + + def onCompleteHandler[T: WeakTypeTag]: Tree + + def resumeFunTree[T]: DefDef + } + + case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) { + def stateMachineMember(name: TermName): Symbol = + stateMachineClass.info.member(name) + def memberRef(name: TermName) = gen.mkAttributedRef(stateMachineMember(name)) + } + + def buildAsyncBlock(block: Block, symLookup: SymLookup): AsyncBlock = { + val Block(stats, expr) = block + val startState = stateAssigner.nextState() + val endState = Int.MaxValue + + val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, symLookup) + + new AsyncBlock { + def asyncStates = blockBuilder.asyncStates.toList + + def mkCombinedHandlerCases[T]: List[CaseDef] = { + val caseForLastState: CaseDef = { + val lastState = asyncStates.last + val lastStateBody = Expr[T](lastState.body) + val rhs = futureSystemOps.completeProm( + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Success(lastStateBody.splice))) + mkHandlerCase(lastState.state, rhs.tree) + } + asyncStates.toList match { + case s :: Nil => + List(caseForLastState) + case _ => + val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState + initCases :+ caseForLastState + } + } + + val initStates = asyncStates.init + + /** + * def resume(): Unit = { + * try { + * state match { + * case 0 => { + * f11 = exprReturningFuture + * f11.onComplete(onCompleteHandler)(context) + * } + * ... + * } + * } catch { + * case NonFatal(t) => result.failure(t) + * } + * } + */ + def resumeFunTree[T]: DefDef = + DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), + Try( + Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T]), + List( + CaseDef( + Bind(name.t, Ident(nme.WILDCARD)), + Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), + Block(List({ + val t = Expr[Throwable](Ident(name.t)) + futureSystemOps.completeProm[T]( + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Failure(t.splice))).tree + }), literalUnit))), EmptyTree)) + + /** + * // assumes tr: Try[Any] is in scope. + * // + * state match { + * case 0 => { + * x11 = tr.get.asInstanceOf[Double]; + * state = 1; + * resume() + * } + */ + def onCompleteHandler[T: WeakTypeTag]: Tree = Match(symLookup.memberRef(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList) + } + } + + private def isSyntheticBindVal(tree: Tree) = tree match { + case vd@ValDef(_, lname, _, Ident(rname)) => lname.toString.contains(name.bindSuffix) + case _ => false + } + + case class Awaitable(expr: Tree, resultName: Symbol, resultType: Type, resultValDef: ValDef) + + private def mkResumeApply(symLookup: SymLookup) = Apply(symLookup.memberRef(name.resume), Nil) + + private def mkStateTree(nextState: Int, symLookup: SymLookup): Tree = + Assign(symLookup.memberRef(name.state), Literal(Constant(nextState))) + + private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef = + mkHandlerCase(num, Block(rhs, literalUnit)) + + private def mkHandlerCase(num: Int, rhs: Tree): CaseDef = + CaseDef(Literal(Constant(num)), EmptyTree, rhs) + + private def literalUnit = Literal(Constant(())) +} |