aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2015-07-23 23:15:37 +1000
committerJason Zaugg <jzaugg@gmail.com>2015-09-22 16:53:33 +1000
commite3ff0382ae4e015fc69da8335450718951714982 (patch)
tree3f89dace31be3cd125531c0ba24270aa45100d7e /src
parent93f207fee780652d08f93e1ea40e018db59fee99 (diff)
downloadscala-async-e3ff0382ae4e015fc69da8335450718951714982.tar.gz
scala-async-e3ff0382ae4e015fc69da8335450718951714982.tar.bz2
scala-async-e3ff0382ae4e015fc69da8335450718951714982.zip
Enable a compiler plugin to use the async transform after patmat
Currently, the async transformation is performed during the typer phase, like all other macros. We have to levy a few artificial restrictions on whern an async boundary may be: for instance we don't support await within a pattern guard. A more natural home for the transform would be after patterns have been translated. The test case in this commit shows how to use the async transform from a custom compiler phase after patmat. The remainder of the commit updates the implementation to handle the new tree shapes. For states that correspond to a label definition, we use `-symbol.id` as the state ID. This made it easier to emit the forward jumps to when processing the label application before we had seen the label definition. I've also made the transformation more efficient in the way it checks whether a given tree encloses an `await` call: we traverse the input tree at the start of the macro, and decorate it with tree attachments containig the answer to this question. Even after the ANF and state machine transforms introduce new layers of synthetic trees, the `containsAwait` code need only traverse shallowly through those trees to find a child that has the cached answer from the original traversal. I had to special case the ANF transform for expressions that always lead to a label jump: we avoids trying to push an assignment to a result variable into `if (cond) jump1() else jump2()`, in trees of the form: ``` % cat sandbox/jump.scala class Test { def test = { (null: Any) match { case _: String => "" case _ => "" } } } % qscalac -Xprint:patmat -Xprint-types sandbox/jump.scala def test: String = { case <synthetic> val x1: Any = (null{Null(null)}: Any){Any}; case5(){ if (x1.isInstanceOf{[T0]=> Boolean}[String]{Boolean}) matchEnd4{(x: String)String}(""{String("")}){String} else case6{()String}(){String}{String} }{String}; case6(){ matchEnd4{(x: String)String}(""{String("")}){String} }{String}; matchEnd4(x: String){ x{String} }{String} }{String} ```
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/scala/async/internal/AnfTransform.scala73
-rw-r--r--src/main/scala/scala/async/internal/AsyncBase.scala4
-rw-r--r--src/main/scala/scala/async/internal/AsyncId.scala10
-rw-r--r--src/main/scala/scala/async/internal/AsyncMacro.scala7
-rw-r--r--src/main/scala/scala/async/internal/AsyncTransform.scala8
-rw-r--r--src/main/scala/scala/async/internal/ExprBuilder.scala80
-rw-r--r--src/main/scala/scala/async/internal/Lifter.scala1
-rw-r--r--src/main/scala/scala/async/internal/StateAssigner.scala3
-rw-r--r--src/main/scala/scala/async/internal/TransformUtils.scala126
-rw-r--r--src/test/scala/scala/async/TreeInterrogation.scala2
-rw-r--r--src/test/scala/scala/async/run/late/LateExpansion.scala170
11 files changed, 430 insertions, 54 deletions
diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala
index f81f5af..cc77ec7 100644
--- a/src/main/scala/scala/async/internal/AnfTransform.scala
+++ b/src/main/scala/scala/async/internal/AnfTransform.scala
@@ -16,16 +16,18 @@ private[async] trait AnfTransform {
import c.internal._
import decorators._
- def anfTransform(tree: Tree): Block = {
+ def anfTransform(tree: Tree, owner: Symbol): Block = {
// Must prepend the () for issue #31.
- val block = c.typecheck(atPos(tree.pos)(Block(List(Literal(Constant(()))), tree))).setType(tree.tpe)
+ val block = c.typecheck(atPos(tree.pos)(newBlock(List(Literal(Constant(()))), tree))).setType(tree.tpe)
sealed abstract class AnfMode
case object Anf extends AnfMode
case object Linearizing extends AnfMode
+ val tree1 = adjustTypeOfTranslatedPatternMatches(block, owner)
+
var mode: AnfMode = Anf
- typingTransform(block)((tree, api) => {
+ typingTransform(tree1, owner)((tree, api) => {
def blockToList(tree: Tree): List[Tree] = tree match {
case Block(stats, expr) => stats :+ expr
case t => t :: Nil
@@ -34,7 +36,7 @@ private[async] trait AnfTransform {
def listToBlock(trees: List[Tree]): Block = trees match {
case trees @ (init :+ last) =>
val pos = trees.map(_.pos).reduceLeft(_ union _)
- Block(init, last).setType(last.tpe).setPos(pos)
+ newBlock(init, last).setType(last.tpe).setPos(pos)
}
object linearize {
@@ -66,6 +68,17 @@ private[async] trait AnfTransform {
stats :+ valDef :+ atPos(tree.pos)(ref1)
case If(cond, thenp, elsep) =>
+ // If we run the ANF transform post patmat, deal with trees like `(if (cond) jump1(){String} else jump2(){String}){String}`
+ // as though it was typed with `Unit`.
+ def isPatMatGeneratedJump(t: Tree): Boolean = t match {
+ case Block(_, expr) => isPatMatGeneratedJump(expr)
+ case If(_, thenp, elsep) => isPatMatGeneratedJump(thenp) && isPatMatGeneratedJump(elsep)
+ case _: Apply if isLabel(t.symbol) => true
+ case _ => false
+ }
+ if (isPatMatGeneratedJump(expr)) {
+ internal.setType(expr, definitions.UnitTpe)
+ }
// if type of if-else is Unit don't introduce assignment,
// but add Unit value to bring it into form expected by async transform
if (expr.tpe =:= definitions.UnitTpe) {
@@ -77,7 +90,7 @@ private[async] trait AnfTransform {
def branchWithAssign(orig: Tree) = api.typecheck(atPos(orig.pos) {
def cast(t: Tree) = mkAttributedCastPreservingAnnotations(t, tpe(varDef.symbol))
orig match {
- case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr)))
+ case Block(thenStats, thenExpr) => newBlock(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr)))
case _ => Assign(Ident(varDef.symbol), cast(orig))
}
})
@@ -115,7 +128,7 @@ private[async] trait AnfTransform {
}
}
- private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
+ def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(uncheckedBounds(tp))
valDef(sym, mkZero(uncheckedBounds(tp))).setType(NoType).setPos(pos)
}
@@ -152,8 +165,7 @@ private[async] trait AnfTransform {
}
def _transformToList(tree: Tree): List[Tree] = trace(tree) {
- val containsAwait = tree exists isAwait
- if (!containsAwait) {
+ if (!containsAwait(tree)) {
tree match {
case Block(stats, expr) =>
// avoids nested block in `while(await(false)) ...`.
@@ -207,10 +219,11 @@ private[async] trait AnfTransform {
funStats ++ argStatss.flatten.flatten :+ typedNewApply
case Block(stats, expr) =>
- (stats :+ expr).flatMap(linearize.transformToList)
+ val trees = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit) ::: linearize.transformToList(expr)
+ eliminateLabelParameters(trees)
case ValDef(mods, name, tpt, rhs) =>
- if (rhs exists isAwait) {
+ if (containsAwait(rhs)) {
val stats :+ expr = api.atOwner(api.currentOwner.owner)(linearize.transformToList(rhs))
stats.foreach(_.changeOwner(api.currentOwner, api.currentOwner.owner))
stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr)
@@ -247,7 +260,7 @@ private[async] trait AnfTransform {
scrutStats :+ treeCopy.Match(tree, scrutExpr, caseDefs)
case LabelDef(name, params, rhs) =>
- List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
+ List(LabelDef(name, params, newBlock(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
case TypeApply(fun, targs) =>
val funStats :+ simpleFun = linearize.transformToList(fun)
@@ -259,6 +272,44 @@ private[async] trait AnfTransform {
}
}
+ // Replace the label parameters on `matchEnd` with use of a `matchRes` temporary variable
+ def eliminateLabelParameters(statsExpr: List[Tree]): List[Tree] = {
+ import internal.{methodType, setInfo}
+ val caseDefToMatchResult = collection.mutable.Map[Symbol, Symbol]()
+
+ val matchResults = collection.mutable.Buffer[Tree]()
+ val statsExpr0 = statsExpr.reverseMap {
+ case ld @ LabelDef(_, param :: Nil, body) =>
+ val matchResult = linearize.defineVar(name.matchRes, param.tpe, ld.pos)
+ matchResults += matchResult
+ caseDefToMatchResult(ld.symbol) = matchResult.symbol
+ val ld2 = treeCopy.LabelDef(ld, ld.name, Nil, body.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil))
+ setInfo(ld.symbol, methodType(Nil, ld.symbol.info.resultType))
+ ld2
+ case t =>
+ if (caseDefToMatchResult.isEmpty) t
+ else typingTransform(t)((tree, api) =>
+ tree match {
+ case Apply(fun, arg :: Nil) if isLabel(fun.symbol) && caseDefToMatchResult.contains(fun.symbol) =>
+ api.typecheck(atPos(tree.pos)(newBlock(Assign(Ident(caseDefToMatchResult(fun.symbol)), api.recur(arg)) :: Nil, treeCopy.Apply(tree, fun, Nil))))
+ case Block(stats, expr) =>
+ api.default(tree) match {
+ case Block(stats, Block(stats1, expr)) =>
+ treeCopy.Block(tree, stats ::: stats1, expr)
+ case t => t
+ }
+ case _ =>
+ api.default(tree)
+ }
+ )
+ }
+ matchResults.toList match {
+ case Nil => statsExpr0.reverse
+ case r1 :: Nil => (r1 +: statsExpr0.reverse) :+ atPos(tree.pos)(gen.mkAttributedIdent(r1.symbol))
+ case _ => c.error(macroPos, "Internal error: unexpected tree encountered during ANF transform " + statsExpr); statsExpr
+ }
+ }
+
def anfLinearize(tree: Tree): Block = {
val trees: List[Tree] = mode match {
case Anf => anf._transformToList(tree)
diff --git a/src/main/scala/scala/async/internal/AsyncBase.scala b/src/main/scala/scala/async/internal/AsyncBase.scala
index 7464c42..7a1e274 100644
--- a/src/main/scala/scala/async/internal/AsyncBase.scala
+++ b/src/main/scala/scala/async/internal/AsyncBase.scala
@@ -43,9 +43,9 @@ abstract class AsyncBase {
(body: c.Expr[T])
(execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = {
import c.universe._, c.internal._, decorators._
- val asyncMacro = AsyncMacro(c, self)
+ val asyncMacro = AsyncMacro(c, self)(body.tree)
- val code = asyncMacro.asyncTransform[T](body.tree, execContext.tree)(c.weakTypeTag[T])
+ val code = asyncMacro.asyncTransform[T](execContext.tree)(c.weakTypeTag[T])
AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}")
// Mark range positions for synthetic code as transparent to allow some wiggle room for overlapping ranges
diff --git a/src/main/scala/scala/async/internal/AsyncId.scala b/src/main/scala/scala/async/internal/AsyncId.scala
index 3afa55b..8654474 100644
--- a/src/main/scala/scala/async/internal/AsyncId.scala
+++ b/src/main/scala/scala/async/internal/AsyncId.scala
@@ -41,11 +41,11 @@ object AsyncTestLV extends AsyncBase {
* A trivial implementation of [[FutureSystem]] that performs computations
* on the current thread. Useful for testing.
*/
+class Box[A] {
+ var a: A = _
+}
object IdentityFutureSystem extends FutureSystem {
-
- class Prom[A] {
- var a: A = _
- }
+ type Prom[A] = Box[A]
type Fut[A] = A
type ExecContext = Unit
@@ -57,7 +57,7 @@ object IdentityFutureSystem extends FutureSystem {
def execContext: Expr[ExecContext] = c.Expr[Unit](Literal(Constant(())))
- def promType[A: WeakTypeTag]: Type = weakTypeOf[Prom[A]]
+ def promType[A: WeakTypeTag]: Type = weakTypeOf[Box[A]]
def tryType[A: WeakTypeTag]: Type = weakTypeOf[scala.util.Try[A]]
def execContextType: Type = weakTypeOf[Unit]
diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala
index e969f9b..e22407d 100644
--- a/src/main/scala/scala/async/internal/AsyncMacro.scala
+++ b/src/main/scala/scala/async/internal/AsyncMacro.scala
@@ -1,15 +1,17 @@
package scala.async.internal
object AsyncMacro {
- def apply(c0: reflect.macros.Context, base: AsyncBase): AsyncMacro { val c: c0.type } = {
+ def apply(c0: reflect.macros.Context, base: AsyncBase)(body0: c0.Tree): AsyncMacro { val c: c0.type } = {
import language.reflectiveCalls
new AsyncMacro { self =>
val c: c0.type = c0
+ val body: c.Tree = body0
// This member is required by `AsyncTransform`:
val asyncBase: AsyncBase = base
// These members are required by `ExprBuilder`:
val futureSystem: FutureSystem = base.futureSystem
val futureSystemOps: futureSystem.Ops {val c: self.c.type} = futureSystem.mkOps(c)
+ val containsAwait: c.Tree => Boolean = containsAwaitCached(body0)
}
}
}
@@ -19,7 +21,10 @@ private[async] trait AsyncMacro
with ExprBuilder with AsyncTransform with AsyncAnalysis with LiveVariables {
val c: scala.reflect.macros.Context
+ val body: c.Tree
+ val containsAwait: c.Tree => Boolean
lazy val macroPos = c.macroApplication.pos.makeTransparent
def atMacroPos(t: c.Tree) = c.universe.atPos(macroPos)(t)
+
}
diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala
index baa3fc2..f491403 100644
--- a/src/main/scala/scala/async/internal/AsyncTransform.scala
+++ b/src/main/scala/scala/async/internal/AsyncTransform.scala
@@ -9,7 +9,7 @@ trait AsyncTransform {
val asyncBase: AsyncBase
- def asyncTransform[T](body: Tree, execContext: Tree)
+ def asyncTransform[T](execContext: Tree)
(resultType: WeakTypeTag[T]): Tree = {
// We annotate the type of the whole expression as `T @uncheckedBounds` so as not to introduce
@@ -22,7 +22,7 @@ trait AsyncTransform {
// Transform to A-normal form:
// - no await calls in qualifiers or arguments,
// - if/match only used in statement position.
- val anfTree0: Block = anfTransform(body)
+ val anfTree0: Block = anfTransform(body, c.internal.enclosingOwner)
val anfTree = futureSystemOps.postAnfTransform(anfTree0)
@@ -35,7 +35,7 @@ trait AsyncTransform {
val stateMachine: ClassDef = {
val body: List[Tree] = {
val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(StateAssigner.Initial)))
- val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T](uncheckedBoundsResultTag)), futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree)
+ val resultAndAccessors = mkMutableField(futureSystemOps.promType[T](uncheckedBoundsResultTag), name.result, futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree)
val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext)
val apply0DefDef: DefDef = {
@@ -43,7 +43,7 @@ trait AsyncTransform {
// See SI-1247 for the the optimization that avoids creation.
DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.apply), literalNull :: Nil))
}
- List(emptyConstructor, stateVar, result, execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef)
+ List(emptyConstructor, stateVar) ++ resultAndAccessors ++ List(execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef)
}
val tryToUnit = appliedType(definitions.FunctionClass(1), futureSystemOps.tryType[Any], typeOf[Unit])
diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala
index 164e85b..16b9207 100644
--- a/src/main/scala/scala/async/internal/ExprBuilder.scala
+++ b/src/main/scala/scala/async/internal/ExprBuilder.scala
@@ -146,6 +146,8 @@ trait ExprBuilder {
private val stats = ListBuffer[Tree]()
/** The state of the target of a LabelDef application (while loop jump) */
private var nextJumpState: Option[Int] = None
+ private var nextJumpSymbol: Symbol = NoSymbol
+ def effectiveNextState(nextState: Int) = nextJumpState.orElse(if (nextJumpSymbol == NoSymbol) None else Some(stateIdForLabel(nextJumpSymbol))).getOrElse(nextState)
def +=(stat: Tree): this.type = {
stat match {
@@ -155,11 +157,16 @@ trait ExprBuilder {
}
def addStat() = stats += stat
stat match {
- case Apply(fun, Nil) =>
+ case Apply(fun, args) if isLabel(fun.symbol) =>
// labelDefStates belongs to the current ExprBuilder
labelDefStates get fun.symbol match {
- case opt @ Some(nextState) => nextJumpState = opt // re-use object
- case None => addStat()
+ case opt@Some(nextState) =>
+ // A backward jump
+ nextJumpState = opt // re-use object
+ nextJumpSymbol = fun.symbol
+ case None =>
+ // We haven't the corresponding LabelDef, this is a forward jump
+ nextJumpSymbol = fun.symbol
}
case _ => addStat()
}
@@ -169,13 +176,11 @@ trait ExprBuilder {
def resultWithAwait(awaitable: Awaitable,
onCompleteState: Int,
nextState: Int): AsyncState = {
- val effectiveNextState = nextJumpState.getOrElse(nextState)
- new AsyncStateWithAwait(stats.toList, state, onCompleteState, effectiveNextState, awaitable, symLookup)
+ new AsyncStateWithAwait(stats.toList, state, onCompleteState, effectiveNextState(nextState), awaitable, symLookup)
}
def resultSimple(nextState: Int): AsyncState = {
- val effectiveNextState = nextJumpState.getOrElse(nextState)
- new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup)
+ new SimpleAsyncState(stats.toList, state, effectiveNextState(nextState), symLookup)
}
def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = {
@@ -243,9 +248,17 @@ trait ExprBuilder {
}
import stateAssigner.nextState
+ def directlyAdjacentLabelDefs(t: Tree): List[Tree] = {
+ def isPatternCaseLabelDef(t: Tree) = t match {
+ case LabelDef(name, _, _) => name.toString.startsWith("case")
+ case _ => false
+ }
+ val (before, _ :: after) = (stats :+ expr).span(_ ne t)
+ before.reverse.takeWhile(isPatternCaseLabelDef) ::: after.takeWhile(isPatternCaseLabelDef)
+ }
// populate asyncStates
- for (stat <- stats) stat match {
+ for (stat <- (stats :+ expr)) stat match {
// the val name = await(..) pattern
case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) =>
val onCompleteState = nextState()
@@ -255,7 +268,7 @@ trait ExprBuilder {
currState = afterAwaitState
stateBuilder = new AsyncStateBuilder(currState, symLookup)
- case If(cond, thenp, elsep) if (stat exists isAwait) || containsForiegnLabelJump(stat) =>
+ case If(cond, thenp, elsep) if containsAwait(stat) || containsForiegnLabelJump(stat) =>
checkForUnsupportedAwait(cond)
val thenStartState = nextState()
@@ -275,7 +288,7 @@ trait ExprBuilder {
currState = afterIfState
stateBuilder = new AsyncStateBuilder(currState, symLookup)
- case Match(scrutinee, cases) if stat exists isAwait =>
+ case Match(scrutinee, cases) if containsAwait(stat) =>
checkForUnsupportedAwait(scrutinee)
val caseStates = cases.map(_ => nextState())
@@ -293,24 +306,21 @@ trait ExprBuilder {
currState = afterMatchState
stateBuilder = new AsyncStateBuilder(currState, symLookup)
+ case ld @ LabelDef(name, params, rhs)
+ if containsAwait(rhs) || directlyAdjacentLabelDefs(ld).exists(containsAwait) =>
- case ld @ LabelDef(name, params, rhs) if rhs exists isAwait =>
- val startLabelState = nextState()
+ val startLabelState = stateIdForLabel(ld.symbol)
val afterLabelState = nextState()
asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup)
labelDefStates(ld.symbol) = startLabelState
val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState)
asyncStates ++= builder.asyncStates
-
currState = afterLabelState
stateBuilder = new AsyncStateBuilder(currState, symLookup)
-
case _ =>
checkForUnsupportedAwait(stat)
stateBuilder += stat
}
- // complete last state builder (representing the expressions after the last await)
- stateBuilder += expr
val lastState = stateBuilder.resultSimple(endState)
asyncStates += lastState
}
@@ -383,18 +393,26 @@ trait ExprBuilder {
* }
* }
*/
- private def resumeFunTree[T: WeakTypeTag]: Tree =
- Try(
- Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ),
- List(
- CaseDef(
- Bind(name.t, Ident(nme.WILDCARD)),
- Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), {
- val t = c.Expr[Throwable](Ident(name.t))
- val complete = futureSystemOps.completeProm[T](
+ private def resumeFunTree[T: WeakTypeTag]: Tree = {
+ val body = Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]))
+ Try(
+ body,
+ List(
+ CaseDef(
+ Bind(name.t, Typed(Ident(nme.WILDCARD), Ident(defn.ThrowableClass))),
+ EmptyTree, {
+ val then = {
+ val t = c.Expr[Throwable](Ident(name.t))
+ val complete = futureSystemOps.completeProm[T](
c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryyFailure[T](t)).tree
- Block(toList(complete), Return(literalUnit))
- })), EmptyTree)
+ Block(toList(complete), Return(literalUnit))
+ }
+ If(Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), then, Throw(Ident(name.t)))
+ then
+ })), EmptyTree)
+
+ //body
+ }
def forever(t: Tree): Tree = {
val labelName = name.fresh("while$")
@@ -435,6 +453,14 @@ trait ExprBuilder {
private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef =
mkHandlerCase(num, adaptToUnit(rhs))
+ // We use the convention that the state machine's ID for a state corresponding to
+ // a labeldef will a negative number be based on the symbol ID. This allows us
+ // to translate a forward jump to the label as a state transition to a known state
+ // ID, even though the state machine transform hasn't yet processed the target label
+ // def. Negative numbers are used so as as not to clash with regular state IDs, which
+ // are allocated in ascending order from 0.
+ private def stateIdForLabel(sym: Symbol): Int = -symId(sym)
+
private def tpeOf(t: Tree): Type = t match {
case _ if t.tpe != null => t.tpe
case Try(body, Nil, _) => tpeOf(body)
diff --git a/src/main/scala/scala/async/internal/Lifter.scala b/src/main/scala/scala/async/internal/Lifter.scala
index 4242a8e..2998baf 100644
--- a/src/main/scala/scala/async/internal/Lifter.scala
+++ b/src/main/scala/scala/async/internal/Lifter.scala
@@ -40,6 +40,7 @@ trait Lifter {
val defs: Map[Tree, Int] = {
/** Collect the DefTrees directly enclosed within `t` that have the same owner */
def collectDirectlyEnclosedDefs(t: Tree): List[DefTree] = t match {
+ case ld: LabelDef => Nil
case dt: DefTree => dt :: Nil
case _: Function => Nil
case t =>
diff --git a/src/main/scala/scala/async/internal/StateAssigner.scala b/src/main/scala/scala/async/internal/StateAssigner.scala
index 55e7a51..2b74e8d 100644
--- a/src/main/scala/scala/async/internal/StateAssigner.scala
+++ b/src/main/scala/scala/async/internal/StateAssigner.scala
@@ -7,8 +7,7 @@ package scala.async.internal
private[async] final class StateAssigner {
private var current = StateAssigner.Initial
- def nextState(): Int =
- try current finally current += 1
+ def nextState(): Int = try current finally current += 1
}
object StateAssigner {
diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala
index 547f980..ed8b103 100644
--- a/src/main/scala/scala/async/internal/TransformUtils.scala
+++ b/src/main/scala/scala/async/internal/TransformUtils.scala
@@ -41,6 +41,19 @@ private[async] trait TransformUtils {
def isAwait(fun: Tree) =
fun.symbol == defn.Async_await
+ def newBlock(stats: List[Tree], expr: Tree): Block = {
+ Block(stats, expr)
+ }
+
+ def isLiteralUnit(t: Tree) = t match {
+ case Literal(Constant(())) =>
+ true
+ case _ => false
+ }
+
+ def isPastTyper =
+ c.universe.asInstanceOf[scala.reflect.internal.SymbolTable].isPastTyper
+
// Copy pasted from TreeInfo in the compiler.
// Using a quasiquote pattern like `case q"$fun[..$targs](...$args)" => is not
// sufficient since https://github.com/scala/scala/pull/3656 as it doesn't match
@@ -150,6 +163,7 @@ private[async] trait TransformUtils {
}
val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal")
+ val ThrowableClass = rootMirror.staticClass("java.lang.Throwable")
val Async_await = asyncBase.awaitMethod(c.universe)(c.macroApplication.symbol).ensuring(_ != NoSymbol)
val IllegalStateExceptionClass = rootMirror.staticClass("java.lang.IllegalStateException")
}
@@ -161,16 +175,26 @@ private[async] trait TransformUtils {
val labelDefs = t.collect {
case ld: LabelDef => ld.symbol
}.toSet
- t.exists {
+ val result = t.exists {
case rt: RefTree => rt.symbol != null && isLabel(rt.symbol) && !(labelDefs contains rt.symbol)
case _ => false
}
+ result
}
- private def isLabel(sym: Symbol): Boolean = {
+ def isLabel(sym: Symbol): Boolean = {
val LABEL = 1L << 17 // not in the public reflection API.
(internal.flags(sym).asInstanceOf[Long] & LABEL) != 0L
}
+ def symId(sym: Symbol): Int = {
+ val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable]
+ sym.asInstanceOf[symtab.Symbol].id
+ }
+ def substituteTrees(t: Tree, from: List[Symbol], to: List[Tree]): Tree = {
+ val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable]
+ val subst = new symtab.TreeSubstituter(from.asInstanceOf[List[symtab.Symbol]], to.asInstanceOf[List[symtab.Tree]])
+ subst.transform(t.asInstanceOf[symtab.Tree]).asInstanceOf[Tree]
+ }
/** Map a list of arguments to:
@@ -362,4 +386,102 @@ private[async] trait TransformUtils {
else withAnnotation(tp, Annotation(UncheckedBoundsClass.asType.toType, Nil, ListMap()))
}
// =====================================
+
+ /**
+ * Efficiently decorate each subtree within `t` with the result of `t exists isAwait`,
+ * and return a function that can be used on derived trees to efficiently test the
+ * same condition.
+ *
+ * If the derived tree contains synthetic wrapper trees, these will be recursed into
+ * in search of a sub tree that was decorated with the cached answer.
+ */
+ final def containsAwaitCached(t: Tree): Tree => Boolean = {
+ val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
+ def attachContainsAwait(t: Tree): Unit = {
+ val t1 = t.asInstanceOf[symtab.Tree]
+ t1.updateAttachment(ContainsAwait)
+ t1.removeAttachment[NoAwait.type]
+ }
+ def attachNoAwait(t: Tree): Unit = {
+ val t1 = t.asInstanceOf[symtab.Tree]
+ t1.updateAttachment(NoAwait)
+ }
+ object markContainsAwaitTraverser extends Traverser {
+ var stack: List[Tree] = Nil
+
+ override def traverse(tree: Tree): Unit = {
+ stack ::= tree
+ try {
+ if (isAwait(tree))
+ stack.foreach(attachContainsAwait)
+ else
+ attachNoAwait(tree)
+ super.traverse(tree)
+ } finally stack = stack.tail
+ }
+ }
+ markContainsAwaitTraverser.traverse(t)
+
+ (t: Tree) => {
+ val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
+ object traverser extends Traverser {
+ var containsAwait = false
+ override def traverse(tree: Tree): Unit = {
+ if (tree.asInstanceOf[symtab.Tree].hasAttachment[NoAwait.type])
+ ()
+ else if (tree.asInstanceOf[symtab.Tree].hasAttachment[ContainsAwait.type])
+ containsAwait = true
+ else super.traverse(tree)
+ }
+ }
+ traverser.traverse(t)
+ traverser.containsAwait
+ }
+ }
+
+ final def adjustTypeOfTranslatedPatternMatches(t: Tree, owner: Symbol): Tree = {
+ import definitions.UnitTpe
+ typingTransform(t, owner) {
+ (tree, api) =>
+ tree match {
+ case Block(stats, expr) =>
+ val stats1 = stats map api.recur
+ val expr1 = api.recur(expr)
+ if (expr1.tpe =:= UnitTpe)
+ internal.setType(treeCopy.Block(tree, stats1, expr1), UnitTpe)
+ else
+ treeCopy.Block(tree, stats1, expr1)
+ case If(cond, thenp, elsep) =>
+ val cond1 = api.recur(cond)
+ val thenp1 = api.recur(thenp)
+ val elsep1 = api.recur(elsep)
+ if (thenp1.tpe =:= definitions.UnitTpe && elsep.tpe =:= UnitTpe)
+ internal.setType(treeCopy.If(tree, cond1, thenp1, elsep1), UnitTpe)
+ else
+ treeCopy.If(tree, cond1, thenp1, elsep1)
+ case Apply(fun, args) if isLabel(fun.symbol) =>
+ internal.setType(treeCopy.Apply(tree, api.recur(fun), args map api.recur), UnitTpe)
+ case t => api.default(t)
+ }
+ }
+ }
+
+ final def mkMutableField(tpt: Type, name: TermName, init: Tree): List[Tree] = {
+ if (isPastTyper) {
+ // If we are running after the typer phase (ie being called from a compiler plugin)
+ // we have to create the trio of members manually.
+ val ACCESSOR = (1L << 27).asInstanceOf[FlagSet]
+ val STABLE = (1L << 22).asInstanceOf[FlagSet]
+ val field = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name + " ", TypeTree(tpt), init)
+ val getter = DefDef(Modifiers(ACCESSOR | STABLE), name, Nil, Nil, TypeTree(tpt), Select(This(tpnme.EMPTY), field.name))
+ val setter = DefDef(Modifiers(ACCESSOR), name + "_=", Nil, List(List(ValDef(NoMods, TermName("x"), TypeTree(tpt), EmptyTree))), TypeTree(definitions.UnitTpe), Assign(Select(This(tpnme.EMPTY), field.name), Ident(TermName("x"))))
+ field :: getter :: setter :: Nil
+ } else {
+ val result = ValDef(NoMods, name, TypeTree(tpt), init)
+ result :: Nil
+ }
+ }
}
+
+case object ContainsAwait
+case object NoAwait \ No newline at end of file
diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala
index d6c619f..09fa69e 100644
--- a/src/test/scala/scala/async/TreeInterrogation.scala
+++ b/src/test/scala/scala/async/TreeInterrogation.scala
@@ -82,6 +82,8 @@ object TreeInterrogation extends App {
println(tree)
val tree1 = tb.typeCheck(tree.duplicate)
println(cm.universe.show(tree1))
+
println(tb.eval(tree))
}
+
}
diff --git a/src/test/scala/scala/async/run/late/LateExpansion.scala b/src/test/scala/scala/async/run/late/LateExpansion.scala
new file mode 100644
index 0000000..b866527
--- /dev/null
+++ b/src/test/scala/scala/async/run/late/LateExpansion.scala
@@ -0,0 +1,170 @@
+package scala.async.run.late
+
+import java.io.File
+
+import junit.framework.Assert.assertEquals
+import org.junit.Test
+
+import scala.annotation.StaticAnnotation
+import scala.async.internal.{AsyncId, AsyncMacro}
+import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader
+import scala.tools.nsc._
+import scala.tools.nsc.plugins.{Plugin, PluginComponent}
+import scala.tools.nsc.reporters.StoreReporter
+import scala.tools.nsc.transform.TypingTransformers
+
+// Tests for customized use of the async transform from a compiler plugin, which
+// calls it from a new phase that runs after patmat.
+class LateExpansion {
+ @Test def test0(): Unit = {
+ val result = wrapAndRun(
+ """
+ | @autoawait def id(a: String) = a
+ | id("foo") + id("bar")
+ | """.stripMargin)
+ assertEquals("foobar", result)
+ }
+ @Test def testGuard(): Unit = {
+ val result = wrapAndRun(
+ """
+ | @autoawait def id[A](a: A) = a
+ | "" match { case _ if id(false) => ???; case _ => "okay" }
+ | """.stripMargin)
+ assertEquals("okay", result)
+ }
+
+ @Test def testExtractor(): Unit = {
+ val result = wrapAndRun(
+ """
+ | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) }
+ | "" match { case Extractor(a, b) if "".isEmpty => a == b }
+ | """.stripMargin)
+ assertEquals(true, result)
+ }
+
+ @Test def testNestedMatchExtractor(): Unit = {
+ val result = wrapAndRun(
+ """
+ | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) }
+ | "" match {
+ | case _ if "".isEmpty =>
+ | "" match { case Extractor(a, b) => a == b }
+ | }
+ | """.stripMargin)
+ assertEquals(true, result)
+ }
+
+ @Test def testCombo(): Unit = {
+ val result = wrapAndRun(
+ """
+ | object Extractor1 { @autoawait def unapply(a: String) = Some((a + 1, a + 2)) }
+ | object Extractor2 { @autoawait def unapply(a: String) = Some(a + 3) }
+ | @autoawait def id(a: String) = a
+ | println("Test.test")
+ | val r1 = Predef.identity("blerg") match {
+ | case x if " ".isEmpty => "case 2: " + x
+ | case Extractor1(Extractor2(x), y: String) if x == "xxx" => "case 1: " + x + ":" + y
+ | x match {
+ | case Extractor1(Extractor2(x), y: String) =>
+ | case _ =>
+ | }
+ | case Extractor2(x) => "case 3: " + x
+ | }
+ | r1
+ | """.stripMargin)
+ assertEquals("case 3: blerg3", result)
+ }
+
+ def wrapAndRun(code: String): Any = {
+ run(
+ s"""
+ |import scala.async.run.late.{autoawait,lateasync}
+ |object Test {
+ | @lateasync
+ | def test: Any = {
+ | $code
+ | }
+ |}
+ | """.stripMargin)
+ }
+
+ def run(code: String): Any = {
+ val reporter = new StoreReporter
+ val settings = new Settings(println(_))
+ settings.outdir.value = sys.props("java.io.tmpdir")
+ settings.embeddedDefaults(getClass.getClassLoader)
+ val isInSBT = !settings.classpath.isSetByUser
+ if (isInSBT) settings.usejavacp.value = true
+ val global = new Global(settings, reporter) {
+ self =>
+
+ object late extends {
+ val global: self.type = self
+ } with LatePlugin
+
+ override protected def loadPlugins(): List[Plugin] = late :: Nil
+ }
+ import global._
+
+ val run = new Run
+ val source = newSourceFile(code)
+ run.compileSources(source :: Nil)
+ assert(!reporter.hasErrors, reporter.infos.mkString("\n"))
+ val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader)
+ val cls = loader.loadClass("Test")
+ cls.getMethod("test").invoke(null)
+ }
+}
+
+abstract class LatePlugin extends Plugin {
+ import global._
+
+ override val components: List[PluginComponent] = List(new PluginComponent with TypingTransformers {
+ val global: LatePlugin.this.global.type = LatePlugin.this.global
+
+ lazy val asyncIdSym = symbolOf[AsyncId.type]
+ lazy val asyncSym = asyncIdSym.info.member(TermName("async"))
+ lazy val awaitSym = asyncIdSym.info.member(TermName("await"))
+ lazy val autoAwaitSym = symbolOf[autoawait]
+ lazy val lateAsyncSym = symbolOf[lateasync]
+
+ def newTransformer(unit: CompilationUnit) = new TypingTransformer(unit) {
+ override def transform(tree: Tree): Tree = {
+ super.transform(tree) match {
+ case ap@Apply(fun, args) if fun.symbol.hasAnnotation(autoAwaitSym) =>
+ localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(ap.tpe) :: Nil), ap :: Nil))
+ case dd: DefDef if dd.symbol.hasAnnotation(lateAsyncSym) => atOwner(dd.symbol) {
+ val expandee = localTyper.context.withMacrosDisabled(
+ localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(dd.rhs.tpe) :: Nil), List(dd.rhs)))
+ )
+ val c = analyzer.macroContext(localTyper, gen.mkAttributedRef(asyncIdSym), expandee)
+ val asyncMacro = AsyncMacro(c, AsyncId)(dd.rhs)
+ val code = asyncMacro.asyncTransform[Any](localTyper.typed(Literal(Constant(()))))(c.weakTypeTag[Any])
+ deriveDefDef(dd)(_ => localTyper.typed(code))
+ }
+ case x => x
+ }
+ }
+ }
+
+ override def newPhase(prev: Phase): Phase = new StdPhase(prev) {
+ override def apply(unit: CompilationUnit): Unit = {
+ val translated = newTransformer(unit).transformUnit(unit)
+ //println(show(unit.body))
+ translated
+ }
+ }
+
+ override val runsAfter: List[String] = "patmat" :: Nil
+ override val phaseName: String = "postpatmat"
+
+ })
+ override val description: String = "postpatmat"
+ override val name: String = "postpatmat"
+}
+
+// Methods with this annotation are translated to having the RHS wrapped in `AsyncId.async { <original RHS> }`
+final class lateasync extends StaticAnnotation
+
+// Calls to methods with this annotation are translated to `AsyncId.await(<call>)`
+final class autoawait extends StaticAnnotation