diff options
author | Jason Zaugg <jzaugg@gmail.com> | 2015-09-24 10:28:07 +1000 |
---|---|---|
committer | Jason Zaugg <jzaugg@gmail.com> | 2015-09-24 10:28:07 +1000 |
commit | 7263aaad02a75978a0a48f90bf171c66cda4328c (patch) | |
tree | f3e876db8c7b7b4d5d7311dc7e9b3742057cf233 /src/main | |
parent | 93f207fee780652d08f93e1ea40e018db59fee99 (diff) | |
parent | 168e10cd8b60789aa3c9c96aeb5d5522c3ec6922 (diff) | |
download | scala-async-7263aaad02a75978a0a48f90bf171c66cda4328c.tar.gz scala-async-7263aaad02a75978a0a48f90bf171c66cda4328c.tar.bz2 scala-async-7263aaad02a75978a0a48f90bf171c66cda4328c.zip |
Merge pull request #141 from retronym/ticket/await-extractorv0.9.6-RC1_2.11v0.9.5-RC1_2.11
Enable a compiler plugin to use the async transform after patmat
Diffstat (limited to 'src/main')
9 files changed, 287 insertions, 55 deletions
diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala index f81f5af..4545ca6 100644 --- a/src/main/scala/scala/async/internal/AnfTransform.scala +++ b/src/main/scala/scala/async/internal/AnfTransform.scala @@ -16,16 +16,18 @@ private[async] trait AnfTransform { import c.internal._ import decorators._ - def anfTransform(tree: Tree): Block = { + def anfTransform(tree: Tree, owner: Symbol): Block = { // Must prepend the () for issue #31. - val block = c.typecheck(atPos(tree.pos)(Block(List(Literal(Constant(()))), tree))).setType(tree.tpe) + val block = c.typecheck(atPos(tree.pos)(newBlock(List(Literal(Constant(()))), tree))).setType(tree.tpe) sealed abstract class AnfMode case object Anf extends AnfMode case object Linearizing extends AnfMode + val tree1 = adjustTypeOfTranslatedPatternMatches(block, owner) + var mode: AnfMode = Anf - typingTransform(block)((tree, api) => { + typingTransform(tree1, owner)((tree, api) => { def blockToList(tree: Tree): List[Tree] = tree match { case Block(stats, expr) => stats :+ expr case t => t :: Nil @@ -34,7 +36,7 @@ private[async] trait AnfTransform { def listToBlock(trees: List[Tree]): Block = trees match { case trees @ (init :+ last) => val pos = trees.map(_.pos).reduceLeft(_ union _) - Block(init, last).setType(last.tpe).setPos(pos) + newBlock(init, last).setType(last.tpe).setPos(pos) } object linearize { @@ -66,6 +68,17 @@ private[async] trait AnfTransform { stats :+ valDef :+ atPos(tree.pos)(ref1) case If(cond, thenp, elsep) => + // If we run the ANF transform post patmat, deal with trees like `(if (cond) jump1(){String} else jump2(){String}){String}` + // as though it was typed with `Unit`. + def isPatMatGeneratedJump(t: Tree): Boolean = t match { + case Block(_, expr) => isPatMatGeneratedJump(expr) + case If(_, thenp, elsep) => isPatMatGeneratedJump(thenp) && isPatMatGeneratedJump(elsep) + case _: Apply if isLabel(t.symbol) => true + case _ => false + } + if (isPatMatGeneratedJump(expr)) { + internal.setType(expr, definitions.UnitTpe) + } // if type of if-else is Unit don't introduce assignment, // but add Unit value to bring it into form expected by async transform if (expr.tpe =:= definitions.UnitTpe) { @@ -77,7 +90,7 @@ private[async] trait AnfTransform { def branchWithAssign(orig: Tree) = api.typecheck(atPos(orig.pos) { def cast(t: Tree) = mkAttributedCastPreservingAnnotations(t, tpe(varDef.symbol)) orig match { - case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr))) + case Block(thenStats, thenExpr) => newBlock(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr))) case _ => Assign(Ident(varDef.symbol), cast(orig)) } }) @@ -115,7 +128,7 @@ private[async] trait AnfTransform { } } - private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = { + def defineVar(prefix: String, tp: Type, pos: Position): ValDef = { val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(uncheckedBounds(tp)) valDef(sym, mkZero(uncheckedBounds(tp))).setType(NoType).setPos(pos) } @@ -152,8 +165,7 @@ private[async] trait AnfTransform { } def _transformToList(tree: Tree): List[Tree] = trace(tree) { - val containsAwait = tree exists isAwait - if (!containsAwait) { + if (!containsAwait(tree)) { tree match { case Block(stats, expr) => // avoids nested block in `while(await(false)) ...`. @@ -207,10 +219,11 @@ private[async] trait AnfTransform { funStats ++ argStatss.flatten.flatten :+ typedNewApply case Block(stats, expr) => - (stats :+ expr).flatMap(linearize.transformToList) + val trees = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit) ::: linearize.transformToList(expr) + eliminateMatchEndLabelParameter(trees) case ValDef(mods, name, tpt, rhs) => - if (rhs exists isAwait) { + if (containsAwait(rhs)) { val stats :+ expr = api.atOwner(api.currentOwner.owner)(linearize.transformToList(rhs)) stats.foreach(_.changeOwner(api.currentOwner, api.currentOwner.owner)) stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr) @@ -247,7 +260,7 @@ private[async] trait AnfTransform { scrutStats :+ treeCopy.Match(tree, scrutExpr, caseDefs) case LabelDef(name, params, rhs) => - List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) + List(LabelDef(name, params, newBlock(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) case TypeApply(fun, targs) => val funStats :+ simpleFun = linearize.transformToList(fun) @@ -259,6 +272,52 @@ private[async] trait AnfTransform { } } + // Replace the label parameters on `matchEnd` with use of a `matchRes` temporary variable + // + // CaseDefs are translated to labels without parmeters. A terminal label, `matchEnd`, accepts + // a parameter which is the result of the match (this is regular, so even Unit-typed matches have this). + // + // For our purposes, it is easier to: + // - extract a `matchRes` variable + // - rewrite the terminal label def to take no parameters, and instead read this temp variable + // - change jumps to the terminal label to an assignment and a no-arg label application + def eliminateMatchEndLabelParameter(statsExpr: List[Tree]): List[Tree] = { + import internal.{methodType, setInfo} + val caseDefToMatchResult = collection.mutable.Map[Symbol, Symbol]() + + val matchResults = collection.mutable.Buffer[Tree]() + val statsExpr0 = statsExpr.reverseMap { + case ld @ LabelDef(_, param :: Nil, body) => + val matchResult = linearize.defineVar(name.matchRes, param.tpe, ld.pos) + matchResults += matchResult + caseDefToMatchResult(ld.symbol) = matchResult.symbol + val ld2 = treeCopy.LabelDef(ld, ld.name, Nil, body.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil)) + setInfo(ld.symbol, methodType(Nil, ld.symbol.info.resultType)) + ld2 + case t => + if (caseDefToMatchResult.isEmpty) t + else typingTransform(t)((tree, api) => + tree match { + case Apply(fun, arg :: Nil) if isLabel(fun.symbol) && caseDefToMatchResult.contains(fun.symbol) => + api.typecheck(atPos(tree.pos)(newBlock(Assign(Ident(caseDefToMatchResult(fun.symbol)), api.recur(arg)) :: Nil, treeCopy.Apply(tree, fun, Nil)))) + case Block(stats, expr) => + api.default(tree) match { + case Block(stats, Block(stats1, expr)) => + treeCopy.Block(tree, stats ::: stats1, expr) + case t => t + } + case _ => + api.default(tree) + } + ) + } + matchResults.toList match { + case Nil => statsExpr + case r1 :: Nil => (r1 +: statsExpr0.reverse) :+ atPos(tree.pos)(gen.mkAttributedIdent(r1.symbol)) + case _ => c.error(macroPos, "Internal error: unexpected tree encountered during ANF transform " + statsExpr); statsExpr + } + } + def anfLinearize(tree: Tree): Block = { val trees: List[Tree] = mode match { case Anf => anf._transformToList(tree) diff --git a/src/main/scala/scala/async/internal/AsyncBase.scala b/src/main/scala/scala/async/internal/AsyncBase.scala index 7464c42..7a1e274 100644 --- a/src/main/scala/scala/async/internal/AsyncBase.scala +++ b/src/main/scala/scala/async/internal/AsyncBase.scala @@ -43,9 +43,9 @@ abstract class AsyncBase { (body: c.Expr[T]) (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { import c.universe._, c.internal._, decorators._ - val asyncMacro = AsyncMacro(c, self) + val asyncMacro = AsyncMacro(c, self)(body.tree) - val code = asyncMacro.asyncTransform[T](body.tree, execContext.tree)(c.weakTypeTag[T]) + val code = asyncMacro.asyncTransform[T](execContext.tree)(c.weakTypeTag[T]) AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}") // Mark range positions for synthetic code as transparent to allow some wiggle room for overlapping ranges diff --git a/src/main/scala/scala/async/internal/AsyncId.scala b/src/main/scala/scala/async/internal/AsyncId.scala index 3afa55b..8654474 100644 --- a/src/main/scala/scala/async/internal/AsyncId.scala +++ b/src/main/scala/scala/async/internal/AsyncId.scala @@ -41,11 +41,11 @@ object AsyncTestLV extends AsyncBase { * A trivial implementation of [[FutureSystem]] that performs computations * on the current thread. Useful for testing. */ +class Box[A] { + var a: A = _ +} object IdentityFutureSystem extends FutureSystem { - - class Prom[A] { - var a: A = _ - } + type Prom[A] = Box[A] type Fut[A] = A type ExecContext = Unit @@ -57,7 +57,7 @@ object IdentityFutureSystem extends FutureSystem { def execContext: Expr[ExecContext] = c.Expr[Unit](Literal(Constant(()))) - def promType[A: WeakTypeTag]: Type = weakTypeOf[Prom[A]] + def promType[A: WeakTypeTag]: Type = weakTypeOf[Box[A]] def tryType[A: WeakTypeTag]: Type = weakTypeOf[scala.util.Try[A]] def execContextType: Type = weakTypeOf[Unit] diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala index e969f9b..e22407d 100644 --- a/src/main/scala/scala/async/internal/AsyncMacro.scala +++ b/src/main/scala/scala/async/internal/AsyncMacro.scala @@ -1,15 +1,17 @@ package scala.async.internal object AsyncMacro { - def apply(c0: reflect.macros.Context, base: AsyncBase): AsyncMacro { val c: c0.type } = { + def apply(c0: reflect.macros.Context, base: AsyncBase)(body0: c0.Tree): AsyncMacro { val c: c0.type } = { import language.reflectiveCalls new AsyncMacro { self => val c: c0.type = c0 + val body: c.Tree = body0 // This member is required by `AsyncTransform`: val asyncBase: AsyncBase = base // These members are required by `ExprBuilder`: val futureSystem: FutureSystem = base.futureSystem val futureSystemOps: futureSystem.Ops {val c: self.c.type} = futureSystem.mkOps(c) + val containsAwait: c.Tree => Boolean = containsAwaitCached(body0) } } } @@ -19,7 +21,10 @@ private[async] trait AsyncMacro with ExprBuilder with AsyncTransform with AsyncAnalysis with LiveVariables { val c: scala.reflect.macros.Context + val body: c.Tree + val containsAwait: c.Tree => Boolean lazy val macroPos = c.macroApplication.pos.makeTransparent def atMacroPos(t: c.Tree) = c.universe.atPos(macroPos)(t) + } diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala index baa3fc2..af290e4 100644 --- a/src/main/scala/scala/async/internal/AsyncTransform.scala +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -9,7 +9,7 @@ trait AsyncTransform { val asyncBase: AsyncBase - def asyncTransform[T](body: Tree, execContext: Tree) + def asyncTransform[T](execContext: Tree) (resultType: WeakTypeTag[T]): Tree = { // We annotate the type of the whole expression as `T @uncheckedBounds` so as not to introduce @@ -22,7 +22,7 @@ trait AsyncTransform { // Transform to A-normal form: // - no await calls in qualifiers or arguments, // - if/match only used in statement position. - val anfTree0: Block = anfTransform(body) + val anfTree0: Block = anfTransform(body, c.internal.enclosingOwner) val anfTree = futureSystemOps.postAnfTransform(anfTree0) @@ -35,7 +35,7 @@ trait AsyncTransform { val stateMachine: ClassDef = { val body: List[Tree] = { val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(StateAssigner.Initial))) - val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T](uncheckedBoundsResultTag)), futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree) + val resultAndAccessors = mkMutableField(futureSystemOps.promType[T](uncheckedBoundsResultTag), name.result, futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree) val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext) val apply0DefDef: DefDef = { @@ -43,7 +43,7 @@ trait AsyncTransform { // See SI-1247 for the the optimization that avoids creation. DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.apply), literalNull :: Nil)) } - List(emptyConstructor, stateVar, result, execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef) + List(emptyConstructor, stateVar) ++ resultAndAccessors ++ List(execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef) } val tryToUnit = appliedType(definitions.FunctionClass(1), futureSystemOps.tryType[Any], typeOf[Unit]) @@ -98,10 +98,11 @@ trait AsyncTransform { } val isSimple = asyncBlock.asyncStates.size == 1 - if (isSimple) + val result = if (isSimple) futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }` else startStateMachine + cleanupContainsAwaitAttachments(result) } def logDiagnostics(anfTree: Tree, states: Seq[String]) { diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala index 164e85b..16b9207 100644 --- a/src/main/scala/scala/async/internal/ExprBuilder.scala +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -146,6 +146,8 @@ trait ExprBuilder { private val stats = ListBuffer[Tree]() /** The state of the target of a LabelDef application (while loop jump) */ private var nextJumpState: Option[Int] = None + private var nextJumpSymbol: Symbol = NoSymbol + def effectiveNextState(nextState: Int) = nextJumpState.orElse(if (nextJumpSymbol == NoSymbol) None else Some(stateIdForLabel(nextJumpSymbol))).getOrElse(nextState) def +=(stat: Tree): this.type = { stat match { @@ -155,11 +157,16 @@ trait ExprBuilder { } def addStat() = stats += stat stat match { - case Apply(fun, Nil) => + case Apply(fun, args) if isLabel(fun.symbol) => // labelDefStates belongs to the current ExprBuilder labelDefStates get fun.symbol match { - case opt @ Some(nextState) => nextJumpState = opt // re-use object - case None => addStat() + case opt@Some(nextState) => + // A backward jump + nextJumpState = opt // re-use object + nextJumpSymbol = fun.symbol + case None => + // We haven't the corresponding LabelDef, this is a forward jump + nextJumpSymbol = fun.symbol } case _ => addStat() } @@ -169,13 +176,11 @@ trait ExprBuilder { def resultWithAwait(awaitable: Awaitable, onCompleteState: Int, nextState: Int): AsyncState = { - val effectiveNextState = nextJumpState.getOrElse(nextState) - new AsyncStateWithAwait(stats.toList, state, onCompleteState, effectiveNextState, awaitable, symLookup) + new AsyncStateWithAwait(stats.toList, state, onCompleteState, effectiveNextState(nextState), awaitable, symLookup) } def resultSimple(nextState: Int): AsyncState = { - val effectiveNextState = nextJumpState.getOrElse(nextState) - new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup) + new SimpleAsyncState(stats.toList, state, effectiveNextState(nextState), symLookup) } def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { @@ -243,9 +248,17 @@ trait ExprBuilder { } import stateAssigner.nextState + def directlyAdjacentLabelDefs(t: Tree): List[Tree] = { + def isPatternCaseLabelDef(t: Tree) = t match { + case LabelDef(name, _, _) => name.toString.startsWith("case") + case _ => false + } + val (before, _ :: after) = (stats :+ expr).span(_ ne t) + before.reverse.takeWhile(isPatternCaseLabelDef) ::: after.takeWhile(isPatternCaseLabelDef) + } // populate asyncStates - for (stat <- stats) stat match { + for (stat <- (stats :+ expr)) stat match { // the val name = await(..) pattern case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => val onCompleteState = nextState() @@ -255,7 +268,7 @@ trait ExprBuilder { currState = afterAwaitState stateBuilder = new AsyncStateBuilder(currState, symLookup) - case If(cond, thenp, elsep) if (stat exists isAwait) || containsForiegnLabelJump(stat) => + case If(cond, thenp, elsep) if containsAwait(stat) || containsForiegnLabelJump(stat) => checkForUnsupportedAwait(cond) val thenStartState = nextState() @@ -275,7 +288,7 @@ trait ExprBuilder { currState = afterIfState stateBuilder = new AsyncStateBuilder(currState, symLookup) - case Match(scrutinee, cases) if stat exists isAwait => + case Match(scrutinee, cases) if containsAwait(stat) => checkForUnsupportedAwait(scrutinee) val caseStates = cases.map(_ => nextState()) @@ -293,24 +306,21 @@ trait ExprBuilder { currState = afterMatchState stateBuilder = new AsyncStateBuilder(currState, symLookup) + case ld @ LabelDef(name, params, rhs) + if containsAwait(rhs) || directlyAdjacentLabelDefs(ld).exists(containsAwait) => - case ld @ LabelDef(name, params, rhs) if rhs exists isAwait => - val startLabelState = nextState() + val startLabelState = stateIdForLabel(ld.symbol) val afterLabelState = nextState() asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup) labelDefStates(ld.symbol) = startLabelState val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState) asyncStates ++= builder.asyncStates - currState = afterLabelState stateBuilder = new AsyncStateBuilder(currState, symLookup) - case _ => checkForUnsupportedAwait(stat) stateBuilder += stat } - // complete last state builder (representing the expressions after the last await) - stateBuilder += expr val lastState = stateBuilder.resultSimple(endState) asyncStates += lastState } @@ -383,18 +393,26 @@ trait ExprBuilder { * } * } */ - private def resumeFunTree[T: WeakTypeTag]: Tree = - Try( - Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ), - List( - CaseDef( - Bind(name.t, Ident(nme.WILDCARD)), - Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), { - val t = c.Expr[Throwable](Ident(name.t)) - val complete = futureSystemOps.completeProm[T]( + private def resumeFunTree[T: WeakTypeTag]: Tree = { + val body = Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T])) + Try( + body, + List( + CaseDef( + Bind(name.t, Typed(Ident(nme.WILDCARD), Ident(defn.ThrowableClass))), + EmptyTree, { + val then = { + val t = c.Expr[Throwable](Ident(name.t)) + val complete = futureSystemOps.completeProm[T]( c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryyFailure[T](t)).tree - Block(toList(complete), Return(literalUnit)) - })), EmptyTree) + Block(toList(complete), Return(literalUnit)) + } + If(Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), then, Throw(Ident(name.t))) + then + })), EmptyTree) + + //body + } def forever(t: Tree): Tree = { val labelName = name.fresh("while$") @@ -435,6 +453,14 @@ trait ExprBuilder { private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef = mkHandlerCase(num, adaptToUnit(rhs)) + // We use the convention that the state machine's ID for a state corresponding to + // a labeldef will a negative number be based on the symbol ID. This allows us + // to translate a forward jump to the label as a state transition to a known state + // ID, even though the state machine transform hasn't yet processed the target label + // def. Negative numbers are used so as as not to clash with regular state IDs, which + // are allocated in ascending order from 0. + private def stateIdForLabel(sym: Symbol): Int = -symId(sym) + private def tpeOf(t: Tree): Type = t match { case _ if t.tpe != null => t.tpe case Try(body, Nil, _) => tpeOf(body) diff --git a/src/main/scala/scala/async/internal/Lifter.scala b/src/main/scala/scala/async/internal/Lifter.scala index 4242a8e..2998baf 100644 --- a/src/main/scala/scala/async/internal/Lifter.scala +++ b/src/main/scala/scala/async/internal/Lifter.scala @@ -40,6 +40,7 @@ trait Lifter { val defs: Map[Tree, Int] = { /** Collect the DefTrees directly enclosed within `t` that have the same owner */ def collectDirectlyEnclosedDefs(t: Tree): List[DefTree] = t match { + case ld: LabelDef => Nil case dt: DefTree => dt :: Nil case _: Function => Nil case t => diff --git a/src/main/scala/scala/async/internal/StateAssigner.scala b/src/main/scala/scala/async/internal/StateAssigner.scala index 55e7a51..2b74e8d 100644 --- a/src/main/scala/scala/async/internal/StateAssigner.scala +++ b/src/main/scala/scala/async/internal/StateAssigner.scala @@ -7,8 +7,7 @@ package scala.async.internal private[async] final class StateAssigner { private var current = StateAssigner.Initial - def nextState(): Int = - try current finally current += 1 + def nextState(): Int = try current finally current += 1 } object StateAssigner { diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala index 547f980..90419d3 100644 --- a/src/main/scala/scala/async/internal/TransformUtils.scala +++ b/src/main/scala/scala/async/internal/TransformUtils.scala @@ -41,6 +41,19 @@ private[async] trait TransformUtils { def isAwait(fun: Tree) = fun.symbol == defn.Async_await + def newBlock(stats: List[Tree], expr: Tree): Block = { + Block(stats, expr) + } + + def isLiteralUnit(t: Tree) = t match { + case Literal(Constant(())) => + true + case _ => false + } + + def isPastTyper = + c.universe.asInstanceOf[scala.reflect.internal.SymbolTable].isPastTyper + // Copy pasted from TreeInfo in the compiler. // Using a quasiquote pattern like `case q"$fun[..$targs](...$args)" => is not // sufficient since https://github.com/scala/scala/pull/3656 as it doesn't match @@ -150,6 +163,7 @@ private[async] trait TransformUtils { } val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal") + val ThrowableClass = rootMirror.staticClass("java.lang.Throwable") val Async_await = asyncBase.awaitMethod(c.universe)(c.macroApplication.symbol).ensuring(_ != NoSymbol) val IllegalStateExceptionClass = rootMirror.staticClass("java.lang.IllegalStateException") } @@ -161,16 +175,26 @@ private[async] trait TransformUtils { val labelDefs = t.collect { case ld: LabelDef => ld.symbol }.toSet - t.exists { + val result = t.exists { case rt: RefTree => rt.symbol != null && isLabel(rt.symbol) && !(labelDefs contains rt.symbol) case _ => false } + result } - private def isLabel(sym: Symbol): Boolean = { + def isLabel(sym: Symbol): Boolean = { val LABEL = 1L << 17 // not in the public reflection API. (internal.flags(sym).asInstanceOf[Long] & LABEL) != 0L } + def symId(sym: Symbol): Int = { + val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable] + sym.asInstanceOf[symtab.Symbol].id + } + def substituteTrees(t: Tree, from: List[Symbol], to: List[Tree]): Tree = { + val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable] + val subst = new symtab.TreeSubstituter(from.asInstanceOf[List[symtab.Symbol]], to.asInstanceOf[List[symtab.Tree]]) + subst.transform(t.asInstanceOf[symtab.Tree]).asInstanceOf[Tree] + } /** Map a list of arguments to: @@ -362,4 +386,121 @@ private[async] trait TransformUtils { else withAnnotation(tp, Annotation(UncheckedBoundsClass.asType.toType, Nil, ListMap())) } // ===================================== + + /** + * Efficiently decorate each subtree within `t` with the result of `t exists isAwait`, + * and return a function that can be used on derived trees to efficiently test the + * same condition. + * + * If the derived tree contains synthetic wrapper trees, these will be recursed into + * in search of a sub tree that was decorated with the cached answer. + */ + final def containsAwaitCached(t: Tree): Tree => Boolean = { + def treeCannotContainAwait(t: Tree) = t match { + case _: Ident | _: TypeTree | _: Literal => true + case _ => false + } + def shouldAttach(t: Tree) = !treeCannotContainAwait(t) + val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] + def attachContainsAwait(t: Tree): Unit = if (shouldAttach(t)) { + val t1 = t.asInstanceOf[symtab.Tree] + t1.updateAttachment(ContainsAwait) + t1.removeAttachment[NoAwait.type] + } + def attachNoAwait(t: Tree): Unit = if (shouldAttach(t)) { + val t1 = t.asInstanceOf[symtab.Tree] + t1.updateAttachment(NoAwait) + } + object markContainsAwaitTraverser extends Traverser { + var stack: List[Tree] = Nil + + override def traverse(tree: Tree): Unit = { + stack ::= tree + try { + if (isAwait(tree)) + stack.foreach(attachContainsAwait) + else + attachNoAwait(tree) + super.traverse(tree) + } finally stack = stack.tail + } + } + markContainsAwaitTraverser.traverse(t) + + (t: Tree) => { + object traverser extends Traverser { + var containsAwait = false + override def traverse(tree: Tree): Unit = { + def castTree = tree.asInstanceOf[symtab.Tree] + if (!castTree.hasAttachment[NoAwait.type]) { + if (castTree.hasAttachment[ContainsAwait.type]) + containsAwait = true + else if (!treeCannotContainAwait(t)) + super.traverse(tree) + } + } + } + traverser.traverse(t) + traverser.containsAwait + } + } + + final def cleanupContainsAwaitAttachments(t: Tree): t.type = { + val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] + t.foreach {t => + t.asInstanceOf[symtab.Tree].removeAttachment[ContainsAwait.type] + t.asInstanceOf[symtab.Tree].removeAttachment[NoAwait.type] + } + t + } + + // First modification to translated patterns: + // - Set the type of label jumps to `Unit` + // - Propagate this change to trees known to directly enclose them: + // ``If` / `Block`) adjust types of enclosing + final def adjustTypeOfTranslatedPatternMatches(t: Tree, owner: Symbol): Tree = { + import definitions.UnitTpe + typingTransform(t, owner) { + (tree, api) => + tree match { + case Block(stats, expr) => + val stats1 = stats map api.recur + val expr1 = api.recur(expr) + if (expr1.tpe =:= UnitTpe) + internal.setType(treeCopy.Block(tree, stats1, expr1), UnitTpe) + else + treeCopy.Block(tree, stats1, expr1) + case If(cond, thenp, elsep) => + val cond1 = api.recur(cond) + val thenp1 = api.recur(thenp) + val elsep1 = api.recur(elsep) + if (thenp1.tpe =:= definitions.UnitTpe && elsep.tpe =:= UnitTpe) + internal.setType(treeCopy.If(tree, cond1, thenp1, elsep1), UnitTpe) + else + treeCopy.If(tree, cond1, thenp1, elsep1) + case Apply(fun, args) if isLabel(fun.symbol) => + internal.setType(treeCopy.Apply(tree, api.recur(fun), args map api.recur), UnitTpe) + case t => api.default(t) + } + } + } + + final def mkMutableField(tpt: Type, name: TermName, init: Tree): List[Tree] = { + if (isPastTyper) { + // If we are running after the typer phase (ie being called from a compiler plugin) + // we have to create the trio of members manually. + val ACCESSOR = (1L << 27).asInstanceOf[FlagSet] + val STABLE = (1L << 22).asInstanceOf[FlagSet] + val field = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name + " ", TypeTree(tpt), init) + val getter = DefDef(Modifiers(ACCESSOR | STABLE), name, Nil, Nil, TypeTree(tpt), Select(This(tpnme.EMPTY), field.name)) + val setter = DefDef(Modifiers(ACCESSOR), name + "_=", Nil, List(List(ValDef(NoMods, TermName("x"), TypeTree(tpt), EmptyTree))), TypeTree(definitions.UnitTpe), Assign(Select(This(tpnme.EMPTY), field.name), Ident(TermName("x")))) + field :: getter :: setter :: Nil + } else { + val result = ValDef(NoMods, name, TypeTree(tpt), init) + result :: Nil + } + } } + +case object ContainsAwait +case object NoAwait
\ No newline at end of file |