diff options
author | Jason Zaugg <jzaugg@gmail.com> | 2015-07-23 23:15:37 +1000 |
---|---|---|
committer | Jason Zaugg <jzaugg@gmail.com> | 2015-09-22 16:53:33 +1000 |
commit | e3ff0382ae4e015fc69da8335450718951714982 (patch) | |
tree | 3f89dace31be3cd125531c0ba24270aa45100d7e /src/main/scala/scala/async/internal/ExprBuilder.scala | |
parent | 93f207fee780652d08f93e1ea40e018db59fee99 (diff) | |
download | scala-async-e3ff0382ae4e015fc69da8335450718951714982.tar.gz scala-async-e3ff0382ae4e015fc69da8335450718951714982.tar.bz2 scala-async-e3ff0382ae4e015fc69da8335450718951714982.zip |
Enable a compiler plugin to use the async transform after patmat
Currently, the async transformation is performed during the typer
phase, like all other macros.
We have to levy a few artificial restrictions on whern an async
boundary may be: for instance we don't support await within a
pattern guard. A more natural home for the transform would be
after patterns have been translated.
The test case in this commit shows how to use the async transform
from a custom compiler phase after patmat.
The remainder of the commit updates the implementation to handle
the new tree shapes.
For states that correspond to a label definition, we use `-symbol.id`
as the state ID. This made it easier to emit the forward jumps to when
processing the label application before we had seen the label
definition.
I've also made the transformation more efficient in the way it checks
whether a given tree encloses an `await` call: we traverse the input
tree at the start of the macro, and decorate it with tree attachments
containig the answer to this question. Even after the ANF and state
machine transforms introduce new layers of synthetic trees, the
`containsAwait` code need only traverse shallowly through those
trees to find a child that has the cached answer from the original
traversal.
I had to special case the ANF transform for expressions that always
lead to a label jump: we avoids trying to push an assignment to a result
variable into `if (cond) jump1() else jump2()`, in trees of the form:
```
% cat sandbox/jump.scala
class Test {
def test = {
(null: Any) match {
case _: String => ""
case _ => ""
}
}
}
% qscalac -Xprint:patmat -Xprint-types sandbox/jump.scala
def test: String = {
case <synthetic> val x1: Any = (null{Null(null)}: Any){Any};
case5(){
if (x1.isInstanceOf{[T0]=> Boolean}[String]{Boolean})
matchEnd4{(x: String)String}(""{String("")}){String}
else
case6{()String}(){String}{String}
}{String};
case6(){
matchEnd4{(x: String)String}(""{String("")}){String}
}{String};
matchEnd4(x: String){
x{String}
}{String}
}{String}
```
Diffstat (limited to 'src/main/scala/scala/async/internal/ExprBuilder.scala')
-rw-r--r-- | src/main/scala/scala/async/internal/ExprBuilder.scala | 80 |
1 files changed, 53 insertions, 27 deletions
diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala index 164e85b..16b9207 100644 --- a/src/main/scala/scala/async/internal/ExprBuilder.scala +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -146,6 +146,8 @@ trait ExprBuilder { private val stats = ListBuffer[Tree]() /** The state of the target of a LabelDef application (while loop jump) */ private var nextJumpState: Option[Int] = None + private var nextJumpSymbol: Symbol = NoSymbol + def effectiveNextState(nextState: Int) = nextJumpState.orElse(if (nextJumpSymbol == NoSymbol) None else Some(stateIdForLabel(nextJumpSymbol))).getOrElse(nextState) def +=(stat: Tree): this.type = { stat match { @@ -155,11 +157,16 @@ trait ExprBuilder { } def addStat() = stats += stat stat match { - case Apply(fun, Nil) => + case Apply(fun, args) if isLabel(fun.symbol) => // labelDefStates belongs to the current ExprBuilder labelDefStates get fun.symbol match { - case opt @ Some(nextState) => nextJumpState = opt // re-use object - case None => addStat() + case opt@Some(nextState) => + // A backward jump + nextJumpState = opt // re-use object + nextJumpSymbol = fun.symbol + case None => + // We haven't the corresponding LabelDef, this is a forward jump + nextJumpSymbol = fun.symbol } case _ => addStat() } @@ -169,13 +176,11 @@ trait ExprBuilder { def resultWithAwait(awaitable: Awaitable, onCompleteState: Int, nextState: Int): AsyncState = { - val effectiveNextState = nextJumpState.getOrElse(nextState) - new AsyncStateWithAwait(stats.toList, state, onCompleteState, effectiveNextState, awaitable, symLookup) + new AsyncStateWithAwait(stats.toList, state, onCompleteState, effectiveNextState(nextState), awaitable, symLookup) } def resultSimple(nextState: Int): AsyncState = { - val effectiveNextState = nextJumpState.getOrElse(nextState) - new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup) + new SimpleAsyncState(stats.toList, state, effectiveNextState(nextState), symLookup) } def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { @@ -243,9 +248,17 @@ trait ExprBuilder { } import stateAssigner.nextState + def directlyAdjacentLabelDefs(t: Tree): List[Tree] = { + def isPatternCaseLabelDef(t: Tree) = t match { + case LabelDef(name, _, _) => name.toString.startsWith("case") + case _ => false + } + val (before, _ :: after) = (stats :+ expr).span(_ ne t) + before.reverse.takeWhile(isPatternCaseLabelDef) ::: after.takeWhile(isPatternCaseLabelDef) + } // populate asyncStates - for (stat <- stats) stat match { + for (stat <- (stats :+ expr)) stat match { // the val name = await(..) pattern case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => val onCompleteState = nextState() @@ -255,7 +268,7 @@ trait ExprBuilder { currState = afterAwaitState stateBuilder = new AsyncStateBuilder(currState, symLookup) - case If(cond, thenp, elsep) if (stat exists isAwait) || containsForiegnLabelJump(stat) => + case If(cond, thenp, elsep) if containsAwait(stat) || containsForiegnLabelJump(stat) => checkForUnsupportedAwait(cond) val thenStartState = nextState() @@ -275,7 +288,7 @@ trait ExprBuilder { currState = afterIfState stateBuilder = new AsyncStateBuilder(currState, symLookup) - case Match(scrutinee, cases) if stat exists isAwait => + case Match(scrutinee, cases) if containsAwait(stat) => checkForUnsupportedAwait(scrutinee) val caseStates = cases.map(_ => nextState()) @@ -293,24 +306,21 @@ trait ExprBuilder { currState = afterMatchState stateBuilder = new AsyncStateBuilder(currState, symLookup) + case ld @ LabelDef(name, params, rhs) + if containsAwait(rhs) || directlyAdjacentLabelDefs(ld).exists(containsAwait) => - case ld @ LabelDef(name, params, rhs) if rhs exists isAwait => - val startLabelState = nextState() + val startLabelState = stateIdForLabel(ld.symbol) val afterLabelState = nextState() asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup) labelDefStates(ld.symbol) = startLabelState val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState) asyncStates ++= builder.asyncStates - currState = afterLabelState stateBuilder = new AsyncStateBuilder(currState, symLookup) - case _ => checkForUnsupportedAwait(stat) stateBuilder += stat } - // complete last state builder (representing the expressions after the last await) - stateBuilder += expr val lastState = stateBuilder.resultSimple(endState) asyncStates += lastState } @@ -383,18 +393,26 @@ trait ExprBuilder { * } * } */ - private def resumeFunTree[T: WeakTypeTag]: Tree = - Try( - Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ), - List( - CaseDef( - Bind(name.t, Ident(nme.WILDCARD)), - Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), { - val t = c.Expr[Throwable](Ident(name.t)) - val complete = futureSystemOps.completeProm[T]( + private def resumeFunTree[T: WeakTypeTag]: Tree = { + val body = Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T])) + Try( + body, + List( + CaseDef( + Bind(name.t, Typed(Ident(nme.WILDCARD), Ident(defn.ThrowableClass))), + EmptyTree, { + val then = { + val t = c.Expr[Throwable](Ident(name.t)) + val complete = futureSystemOps.completeProm[T]( c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryyFailure[T](t)).tree - Block(toList(complete), Return(literalUnit)) - })), EmptyTree) + Block(toList(complete), Return(literalUnit)) + } + If(Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), then, Throw(Ident(name.t))) + then + })), EmptyTree) + + //body + } def forever(t: Tree): Tree = { val labelName = name.fresh("while$") @@ -435,6 +453,14 @@ trait ExprBuilder { private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef = mkHandlerCase(num, adaptToUnit(rhs)) + // We use the convention that the state machine's ID for a state corresponding to + // a labeldef will a negative number be based on the symbol ID. This allows us + // to translate a forward jump to the label as a state transition to a known state + // ID, even though the state machine transform hasn't yet processed the target label + // def. Negative numbers are used so as as not to clash with regular state IDs, which + // are allocated in ascending order from 0. + private def stateIdForLabel(sym: Symbol): Int = -symId(sym) + private def tpeOf(t: Tree): Type = t match { case _ if t.tpe != null => t.tpe case Try(body, Nil, _) => tpeOf(body) |