aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/ExprBuilder.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/scala/async/ExprBuilder.scala')
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala56
1 files changed, 41 insertions, 15 deletions
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index 8ea7ecf..60430c4 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -54,8 +54,21 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C
private def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef =
mkHandlerCase(num, Block(rhs: _*))
- private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef =
- CaseDef(c.literal(num).tree, EmptyTree, rhs)
+ private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = {
+ val rhs1 = new Transformer {
+ override def transform(tree: Tree): Tree = tree match {
+ case Apply(Ident(name), args) =>
+ val jumpTarget = labelDefStates get name // TODO attempt to be symful
+ jumpTarget match {
+ case Some(state) => Return(Block(mkStateTree(state), mkResumeApply))
+ case None => super.transform(tree)
+ }
+ case _ => super.transform(tree)
+ }
+ }.transform(rhs)
+
+ CaseDef(c.literal(num).tree, EmptyTree, rhs1)
+ }
class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) {
val body: c.Tree = stats match {
@@ -209,7 +222,7 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C
def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = {
// 1. build changed if-else tree
// 2. insert that tree at the end of the current state
- val cond = resetDuplicate(condTree)
+ val cond = resetDuplicate(renamer.transform(condTree))
this += If(cond,
Block(mkStateTree(thenState), mkResumeApply),
Block(mkStateTree(elseState), mkResumeApply))
@@ -240,6 +253,13 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C
}
}
+ def resultWithLabel(startLabelState: Int): AsyncState = {
+ this += Block(mkStateTree(startLabelState), mkResumeApply)
+ new AsyncStateWithoutAwait(stats.toList, state) {
+ override val varDefs = self.varDefs.toList
+ }
+ }
+
override def toString: String = {
val statsBeforeAwait = stats.mkString("\n")
s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName"
@@ -248,6 +268,8 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C
val stateAssigner = new StateAssigner
+ val labelDefStates = collection.mutable.Map[Name, Int]()
+
/**
* An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`).
*
@@ -262,7 +284,6 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C
val asyncStates = ListBuffer[builder.AsyncState]()
private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename)
- // current state builder
private var currState = startState
/* TODO Fall back to CPS plug-in if tree contains an `await` call. */
@@ -272,10 +293,7 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C
}) c.abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException
def builderForBranch(tree: c.Tree, state: Int, nextState: Int): AsyncBlockBuilder = {
- val (branchStats, branchExpr) = tree match {
- case Block(s, e) => (s, e)
- case _ => (List(tree), c.literalUnit.tree)
- }
+ val (branchStats, branchExpr) = statsAndExpr(tree)
new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, toRename)
}
@@ -324,18 +342,26 @@ final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C
stateBuilder.resultWithMatch(scrutinee, cases, caseStates)
for ((cas, num) <- cases.zipWithIndex) {
- val (casStats, casExpr) = cas match {
- case CaseDef(_, _, Block(s, e)) => (s, e)
- case CaseDef(_, _, rhs) => (List(rhs), c.literalUnit.tree)
- }
+ val (casStats, casExpr) = statsAndExpr(cas.body)
val builder = new AsyncBlockBuilder(casStats, casExpr, caseStates(num), afterMatchState, toRename)
asyncStates ++= builder.asyncStates
}
currState = afterMatchState
- stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
-
- case _ =>
+ stateBuilder = new AsyncStateBuilder(currState, toRename)
+
+ case ld@LabelDef(name, params, rhs) if rhs exists isAwait =>
+ val startLabelState = stateAssigner.nextState()
+ val afterLabelState = stateAssigner.nextState()
+ asyncStates += stateBuilder.resultWithLabel(startLabelState)
+ val (stats, expr) = statsAndExpr(rhs)
+ labelDefStates(ld.symbol.name) = startLabelState
+ val builder = new AsyncBlockBuilder(stats, expr, startLabelState, afterLabelState, toRename)
+ asyncStates ++= builder.asyncStates
+
+ currState = afterLabelState
+ stateBuilder = new AsyncStateBuilder(currState, toRename)
+ case _ =>
checkForUnsupportedAwait(stat)
stateBuilder += stat
}