diff options
Diffstat (limited to 'src/main/scala/scala/async/ExprBuilder.scala')
-rw-r--r-- | src/main/scala/scala/async/ExprBuilder.scala | 56 |
1 files changed, 41 insertions, 15 deletions
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 8ea7ecf..60430c4 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -54,8 +54,21 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C private def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef = mkHandlerCase(num, Block(rhs: _*)) - private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = - CaseDef(c.literal(num).tree, EmptyTree, rhs) + private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = { + val rhs1 = new Transformer { + override def transform(tree: Tree): Tree = tree match { + case Apply(Ident(name), args) => + val jumpTarget = labelDefStates get name // TODO attempt to be symful + jumpTarget match { + case Some(state) => Return(Block(mkStateTree(state), mkResumeApply)) + case None => super.transform(tree) + } + case _ => super.transform(tree) + } + }.transform(rhs) + + CaseDef(c.literal(num).tree, EmptyTree, rhs1) + } class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) { val body: c.Tree = stats match { @@ -209,7 +222,7 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = { // 1. build changed if-else tree // 2. insert that tree at the end of the current state - val cond = resetDuplicate(condTree) + val cond = resetDuplicate(renamer.transform(condTree)) this += If(cond, Block(mkStateTree(thenState), mkResumeApply), Block(mkStateTree(elseState), mkResumeApply)) @@ -240,6 +253,13 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C } } + def resultWithLabel(startLabelState: Int): AsyncState = { + this += Block(mkStateTree(startLabelState), mkResumeApply) + new AsyncStateWithoutAwait(stats.toList, state) { + override val varDefs = self.varDefs.toList + } + } + override def toString: String = { val statsBeforeAwait = stats.mkString("\n") s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName" @@ -248,6 +268,8 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C val stateAssigner = new StateAssigner + val labelDefStates = collection.mutable.Map[Name, Int]() + /** * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`). * @@ -262,7 +284,6 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C val asyncStates = ListBuffer[builder.AsyncState]() private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename) - // current state builder private var currState = startState /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ @@ -272,10 +293,7 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C }) c.abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException def builderForBranch(tree: c.Tree, state: Int, nextState: Int): AsyncBlockBuilder = { - val (branchStats, branchExpr) = tree match { - case Block(s, e) => (s, e) - case _ => (List(tree), c.literalUnit.tree) - } + val (branchStats, branchExpr) = statsAndExpr(tree) new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, toRename) } @@ -324,18 +342,26 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C stateBuilder.resultWithMatch(scrutinee, cases, caseStates) for ((cas, num) <- cases.zipWithIndex) { - val (casStats, casExpr) = cas match { - case CaseDef(_, _, Block(s, e)) => (s, e) - case CaseDef(_, _, rhs) => (List(rhs), c.literalUnit.tree) - } + val (casStats, casExpr) = statsAndExpr(cas.body) val builder = new AsyncBlockBuilder(casStats, casExpr, caseStates(num), afterMatchState, toRename) asyncStates ++= builder.asyncStates } currState = afterMatchState - stateBuilder = new builder.AsyncStateBuilder(currState, toRename) - - case _ => + stateBuilder = new AsyncStateBuilder(currState, toRename) + + case ld@LabelDef(name, params, rhs) if rhs exists isAwait => + val startLabelState = stateAssigner.nextState() + val afterLabelState = stateAssigner.nextState() + asyncStates += stateBuilder.resultWithLabel(startLabelState) + val (stats, expr) = statsAndExpr(rhs) + labelDefStates(ld.symbol.name) = startLabelState + val builder = new AsyncBlockBuilder(stats, expr, startLabelState, afterLabelState, toRename) + asyncStates ++= builder.asyncStates + + currState = afterLabelState + stateBuilder = new AsyncStateBuilder(currState, toRename) + case _ => checkForUnsupportedAwait(stat) stateBuilder += stat } |