diff options
Diffstat (limited to 'src/main/scala/scala/async/internal/ExprBuilder.scala')
-rw-r--r-- | src/main/scala/scala/async/internal/ExprBuilder.scala | 80 |
1 files changed, 53 insertions, 27 deletions
diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala index 164e85b..16b9207 100644 --- a/src/main/scala/scala/async/internal/ExprBuilder.scala +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -146,6 +146,8 @@ trait ExprBuilder { private val stats = ListBuffer[Tree]() /** The state of the target of a LabelDef application (while loop jump) */ private var nextJumpState: Option[Int] = None + private var nextJumpSymbol: Symbol = NoSymbol + def effectiveNextState(nextState: Int) = nextJumpState.orElse(if (nextJumpSymbol == NoSymbol) None else Some(stateIdForLabel(nextJumpSymbol))).getOrElse(nextState) def +=(stat: Tree): this.type = { stat match { @@ -155,11 +157,16 @@ trait ExprBuilder { } def addStat() = stats += stat stat match { - case Apply(fun, Nil) => + case Apply(fun, args) if isLabel(fun.symbol) => // labelDefStates belongs to the current ExprBuilder labelDefStates get fun.symbol match { - case opt @ Some(nextState) => nextJumpState = opt // re-use object - case None => addStat() + case opt@Some(nextState) => + // A backward jump + nextJumpState = opt // re-use object + nextJumpSymbol = fun.symbol + case None => + // We haven't the corresponding LabelDef, this is a forward jump + nextJumpSymbol = fun.symbol } case _ => addStat() } @@ -169,13 +176,11 @@ trait ExprBuilder { def resultWithAwait(awaitable: Awaitable, onCompleteState: Int, nextState: Int): AsyncState = { - val effectiveNextState = nextJumpState.getOrElse(nextState) - new AsyncStateWithAwait(stats.toList, state, onCompleteState, effectiveNextState, awaitable, symLookup) + new AsyncStateWithAwait(stats.toList, state, onCompleteState, effectiveNextState(nextState), awaitable, symLookup) } def resultSimple(nextState: Int): AsyncState = { - val effectiveNextState = nextJumpState.getOrElse(nextState) - new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup) + new SimpleAsyncState(stats.toList, state, effectiveNextState(nextState), symLookup) } def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { @@ -243,9 +248,17 @@ trait ExprBuilder { } import stateAssigner.nextState + def directlyAdjacentLabelDefs(t: Tree): List[Tree] = { + def isPatternCaseLabelDef(t: Tree) = t match { + case LabelDef(name, _, _) => name.toString.startsWith("case") + case _ => false + } + val (before, _ :: after) = (stats :+ expr).span(_ ne t) + before.reverse.takeWhile(isPatternCaseLabelDef) ::: after.takeWhile(isPatternCaseLabelDef) + } // populate asyncStates - for (stat <- stats) stat match { + for (stat <- (stats :+ expr)) stat match { // the val name = await(..) pattern case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => val onCompleteState = nextState() @@ -255,7 +268,7 @@ trait ExprBuilder { currState = afterAwaitState stateBuilder = new AsyncStateBuilder(currState, symLookup) - case If(cond, thenp, elsep) if (stat exists isAwait) || containsForiegnLabelJump(stat) => + case If(cond, thenp, elsep) if containsAwait(stat) || containsForiegnLabelJump(stat) => checkForUnsupportedAwait(cond) val thenStartState = nextState() @@ -275,7 +288,7 @@ trait ExprBuilder { currState = afterIfState stateBuilder = new AsyncStateBuilder(currState, symLookup) - case Match(scrutinee, cases) if stat exists isAwait => + case Match(scrutinee, cases) if containsAwait(stat) => checkForUnsupportedAwait(scrutinee) val caseStates = cases.map(_ => nextState()) @@ -293,24 +306,21 @@ trait ExprBuilder { currState = afterMatchState stateBuilder = new AsyncStateBuilder(currState, symLookup) + case ld @ LabelDef(name, params, rhs) + if containsAwait(rhs) || directlyAdjacentLabelDefs(ld).exists(containsAwait) => - case ld @ LabelDef(name, params, rhs) if rhs exists isAwait => - val startLabelState = nextState() + val startLabelState = stateIdForLabel(ld.symbol) val afterLabelState = nextState() asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup) labelDefStates(ld.symbol) = startLabelState val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState) asyncStates ++= builder.asyncStates - currState = afterLabelState stateBuilder = new AsyncStateBuilder(currState, symLookup) - case _ => checkForUnsupportedAwait(stat) stateBuilder += stat } - // complete last state builder (representing the expressions after the last await) - stateBuilder += expr val lastState = stateBuilder.resultSimple(endState) asyncStates += lastState } @@ -383,18 +393,26 @@ trait ExprBuilder { * } * } */ - private def resumeFunTree[T: WeakTypeTag]: Tree = - Try( - Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ), - List( - CaseDef( - Bind(name.t, Ident(nme.WILDCARD)), - Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), { - val t = c.Expr[Throwable](Ident(name.t)) - val complete = futureSystemOps.completeProm[T]( + private def resumeFunTree[T: WeakTypeTag]: Tree = { + val body = Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T])) + Try( + body, + List( + CaseDef( + Bind(name.t, Typed(Ident(nme.WILDCARD), Ident(defn.ThrowableClass))), + EmptyTree, { + val then = { + val t = c.Expr[Throwable](Ident(name.t)) + val complete = futureSystemOps.completeProm[T]( c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryyFailure[T](t)).tree - Block(toList(complete), Return(literalUnit)) - })), EmptyTree) + Block(toList(complete), Return(literalUnit)) + } + If(Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), then, Throw(Ident(name.t))) + then + })), EmptyTree) + + //body + } def forever(t: Tree): Tree = { val labelName = name.fresh("while$") @@ -435,6 +453,14 @@ trait ExprBuilder { private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef = mkHandlerCase(num, adaptToUnit(rhs)) + // We use the convention that the state machine's ID for a state corresponding to + // a labeldef will a negative number be based on the symbol ID. This allows us + // to translate a forward jump to the label as a state transition to a known state + // ID, even though the state machine transform hasn't yet processed the target label + // def. Negative numbers are used so as as not to clash with regular state IDs, which + // are allocated in ascending order from 0. + private def stateIdForLabel(sym: Symbol): Int = -symId(sym) + private def tpeOf(t: Tree): Type = t match { case _ if t.tpe != null => t.tpe case Try(body, Nil, _) => tpeOf(body) |