diff options
author | Jason Zaugg <jzaugg@gmail.com> | 2013-07-02 15:55:34 +0200 |
---|---|---|
committer | Jason Zaugg <jzaugg@gmail.com> | 2013-07-03 10:04:55 +0200 |
commit | 82232ec47effb4a6b67b3a0792e1c7600e2d31b7 (patch) | |
tree | ed9925418aa0a631d1d25fd1be30f5d508e81b24 /src/main/scala/scala/async/ExprBuilder.scala | |
parent | d63b63f536aafa494c70835526174be1987050de (diff) | |
download | scala-async-82232ec47effb4a6b67b3a0792e1c7600e2d31b7.tar.gz scala-async-82232ec47effb4a6b67b3a0792e1c7600e2d31b7.tar.bz2 scala-async-82232ec47effb4a6b67b3a0792e1c7600e2d31b7.zip |
An overdue overhaul of macro internals.
- Avoid reset + retypecheck, instead hang onto the original types/symbols
- Eliminated duplication between AsyncDefinitionUseAnalyzer and ExprBuilder
- Instead, decide what do lift *after* running ExprBuilder
- Account for transitive references local classes/objects and lift them
as needed.
- Make the execution context an regular implicit parameter of the macro
- Fixes interaction with existential skolems and singleton types
Fixes #6, #13, #16, #17, #19, #21.
Diffstat (limited to 'src/main/scala/scala/async/ExprBuilder.scala')
-rw-r--r-- | src/main/scala/scala/async/ExprBuilder.scala | 205 |
1 files changed, 99 insertions, 106 deletions
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index ca46a83..a3837d3 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -7,17 +7,17 @@ import scala.reflect.macros.Context import scala.collection.mutable.ListBuffer import collection.mutable import language.existentials +import scala.reflect.api.Universe +import scala.reflect.api -private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: C, futureSystem: FS, origTree: C#Tree) { - builder => +trait ExprBuilder { + builder: AsyncMacro => - val utils = TransformUtils[c.type](c) - - import c.universe._ - import utils._ + import global._ import defn._ - lazy val futureSystemOps = futureSystem.mkOps(c) + val futureSystem: FutureSystem + val futureSystemOps: futureSystem.Ops { val universe: global.type } val stateAssigner = new StateAssigner val labelDefStates = collection.mutable.Map[Symbol, Int]() @@ -27,22 +27,27 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: def mkHandlerCaseForState: CaseDef - def mkOnCompleteHandler[T: c.WeakTypeTag]: Option[CaseDef] = None + def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None def stats: List[Tree] - final def body: c.Tree = stats match { + final def allStats: List[Tree] = this match { + case a: AsyncStateWithAwait => stats :+ a.awaitable.resultValDef + case _ => stats + } + + final def body: Tree = stats match { case stat :: Nil => stat case init :+ last => Block(init, last) } } /** A sequence of statements the concludes with a unconditional transition to `nextState` */ - final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int) + final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup) extends AsyncState { def mkHandlerCaseForState: CaseDef = - mkHandlerCase(state, stats :+ mkStateTree(nextState) :+ mkResumeApply) + mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup)) override val toString: String = s"AsyncState #$state, next = $nextState" @@ -51,7 +56,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: /** A sequence of statements with a conditional transition to the next state, which will represent * a branch of an `if` or a `match`. */ - final class AsyncStateWithoutAwait(val stats: List[c.Tree], val state: Int) extends AsyncState { + final class AsyncStateWithoutAwait(val stats: List[Tree], val state: Int) extends AsyncState { override def mkHandlerCaseForState: CaseDef = mkHandlerCase(state, stats) @@ -62,25 +67,25 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: /** A sequence of statements that concludes with an `await` call. The `onComplete` * handler will unconditionally transition to `nestState`.`` */ - final class AsyncStateWithAwait(val stats: List[c.Tree], val state: Int, nextState: Int, - awaitable: Awaitable) + final class AsyncStateWithAwait(val stats: List[Tree], val state: Int, nextState: Int, + val awaitable: Awaitable, symLookup: SymLookup) extends AsyncState { override def mkHandlerCaseForState: CaseDef = { - val callOnComplete = futureSystemOps.onComplete(c.Expr(awaitable.expr), - c.Expr(This(tpnme.EMPTY)), c.Expr(Ident(name.execContext))).tree + val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr), + Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree mkHandlerCase(state, stats :+ callOnComplete) } - override def mkOnCompleteHandler[T: c.WeakTypeTag]: Option[CaseDef] = { + override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = { val tryGetTree = Assign( Ident(awaitable.resultName), - TypeApply(Select(Select(Ident(name.tr), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) + TypeApply(Select(Select(Ident(symLookup.applyTrParam), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) ) /* if (tr.isFailure) - * result$async.complete(tr.asInstanceOf[Try[T]]) + * result.complete(tr.asInstanceOf[Try[T]]) * else { * <resultName> = tr.get.asInstanceOf[<resultType>] * <nextState> @@ -88,13 +93,13 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: * } */ val ifIsFailureTree = - If(Select(Ident(name.tr), Try_isFailure), + If(Select(Ident(symLookup.applyTrParam), Try_isFailure), futureSystemOps.completeProm[T]( - c.Expr[futureSystem.Prom[T]](Ident(name.result)), - c.Expr[scala.util.Try[T]]( - TypeApply(Select(Ident(name.tr), newTermName("asInstanceOf")), + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), + Expr[scala.util.Try[T]]( + TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")), List(TypeTree(weakTypeOf[scala.util.Try[T]]))))).tree, - Block(List(tryGetTree, mkStateTree(nextState)), mkResumeApply) + Block(List(tryGetTree, mkStateTree(nextState, symLookup)), mkResumeApply(symLookup)) ) Some(mkHandlerCase(state, List(ifIsFailureTree))) @@ -107,19 +112,16 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: /* * Builder for a single state of an async method. */ - final class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name]) { + final class AsyncStateBuilder(state: Int, private val symLookup: SymLookup) { /* Statements preceding an await call. */ - private val stats = ListBuffer[c.Tree]() + private val stats = ListBuffer[Tree]() /** The state of the target of a LabelDef application (while loop jump) */ private var nextJumpState: Option[Int] = None - private def renameReset(tree: Tree) = resetInternalAttrs(substituteNames(tree, nameMap)) - - def +=(stat: c.Tree): this.type = { + def +=(stat: Tree): this.type = { assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat") - def addStat() = stats += renameReset(stat) + def addStat() = stats += stat stat match { - case _: DefDef => // these have been lifted. case Apply(fun, Nil) => labelDefStates get fun.symbol match { case Some(nextState) => nextJumpState = Some(nextState) @@ -132,22 +134,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: def resultWithAwait(awaitable: Awaitable, nextState: Int): AsyncState = { - val sanitizedAwaitable = awaitable.copy(expr = renameReset(awaitable.expr)) val effectiveNextState = nextJumpState.getOrElse(nextState) - new AsyncStateWithAwait(stats.toList, state, effectiveNextState, sanitizedAwaitable) + new AsyncStateWithAwait(stats.toList, state, effectiveNextState, awaitable, symLookup) } def resultSimple(nextState: Int): AsyncState = { val effectiveNextState = nextJumpState.getOrElse(nextState) - new SimpleAsyncState(stats.toList, state, effectiveNextState) + new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup) } - def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = { - // 1. build changed if-else tree - // 2. insert that tree at the end of the current state - val cond = renameReset(condTree) - def mkBranch(state: Int) = Block(mkStateTree(state) :: Nil, mkResumeApply) - this += If(cond, mkBranch(thenState), mkBranch(elseState)) + def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { + def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup)) + this += If(condTree, mkBranch(thenState), mkBranch(elseState)) new AsyncStateWithoutAwait(stats.toList, state) } @@ -161,23 +159,20 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: * @param caseStates starting state of the right-hand side of the each case * @return an `AsyncState` representing the match expression */ - def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], caseStates: List[Int]): AsyncState = { + def resultWithMatch(scrutTree: Tree, cases: List[CaseDef], caseStates: List[Int], symLookup: SymLookup): AsyncState = { // 1. build list of changed cases val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { case CaseDef(pat, guard, rhs) => - val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal).map { - case ValDef(_, name, _, rhs) => Assign(Ident(name), rhs) - case t => sys.error(s"Unexpected tree. Expected ValDef, found: $t") - } - CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num)), mkResumeApply)) + val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal) + CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num), symLookup), mkResumeApply(symLookup))) } // 2. insert changed match tree at the end of the current state - this += Match(renameReset(scrutTree), newCases) + this += Match(scrutTree, newCases) new AsyncStateWithoutAwait(stats.toList, state) } - def resultWithLabel(startLabelState: Int): AsyncState = { - this += Block(mkStateTree(startLabelState) :: Nil, mkResumeApply) + def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = { + this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup)) new AsyncStateWithoutAwait(stats.toList, state) } @@ -194,24 +189,23 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: * @param expr the last expression of the block * @param startState the start state * @param endState the state to continue with - * @param toRename a `Map` for renaming the given key symbols to the mangled value names */ - final private class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, - private val toRename: Map[Symbol, c.Name]) { + final private class AsyncBlockBuilder(stats: List[Tree], expr: Tree, startState: Int, endState: Int, + private val symLookup: SymLookup) { val asyncStates = ListBuffer[AsyncState]() - var stateBuilder = new AsyncStateBuilder(startState, toRename) + var stateBuilder = new AsyncStateBuilder(startState, symLookup) var currState = startState /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ - def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists { + def checkForUnsupportedAwait(tree: Tree) = if (tree exists { case Apply(fun, _) if isAwait(fun) => true case _ => false - }) c.abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException + }) abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = { val (nestedStats, nestedExpr) = statsAndExpr(nestedTree) - new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, toRename) + new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, symLookup) } import stateAssigner.nextState @@ -219,16 +213,12 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: // populate asyncStates for (stat <- stats) stat match { // the val name = await(..) pattern - case ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => + case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => val afterAwaitState = nextState() - val awaitable = Awaitable(arg, toRename(stat.symbol).toTermName, tpt.tpe) + val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd) asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await currState = afterAwaitState - stateBuilder = new AsyncStateBuilder(currState, toRename) - - case ValDef(mods, name, tpt, rhs) if toRename contains stat.symbol => - checkForUnsupportedAwait(rhs) - stateBuilder += Assign(Ident(toRename(stat.symbol).toTermName), rhs) + stateBuilder = new AsyncStateBuilder(currState, symLookup) case If(cond, thenp, elsep) if stat exists isAwait => checkForUnsupportedAwait(cond) @@ -248,7 +238,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: } currState = afterIfState - stateBuilder = new AsyncStateBuilder(currState, toRename) + stateBuilder = new AsyncStateBuilder(currState, symLookup) case Match(scrutinee, cases) if stat exists isAwait => checkForUnsupportedAwait(scrutinee) @@ -257,7 +247,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: val afterMatchState = nextState() asyncStates += - stateBuilder.resultWithMatch(scrutinee, cases, caseStates) + stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup) for ((cas, num) <- cases.zipWithIndex) { val (stats, expr) = statsAndExpr(cas.body) @@ -267,18 +257,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: } currState = afterMatchState - stateBuilder = new AsyncStateBuilder(currState, toRename) + stateBuilder = new AsyncStateBuilder(currState, symLookup) case ld@LabelDef(name, params, rhs) if rhs exists isAwait => val startLabelState = nextState() val afterLabelState = nextState() - asyncStates += stateBuilder.resultWithLabel(startLabelState) + asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup) labelDefStates(ld.symbol) = startLabelState val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState) asyncStates ++= builder.asyncStates currState = afterLabelState - stateBuilder = new AsyncStateBuilder(currState, toRename) + stateBuilder = new AsyncStateBuilder(currState, symLookup) case _ => checkForUnsupportedAwait(stat) stateBuilder += stat @@ -292,17 +282,23 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: trait AsyncBlock { def asyncStates: List[AsyncState] - def onCompleteHandler[T: c.WeakTypeTag]: Tree + def onCompleteHandler[T: WeakTypeTag]: Tree + + def resumeFunTree[T]: DefDef + } - def resumeFunTree[T]: Tree + case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) { + def stateMachineMember(name: TermName): Symbol = + stateMachineClass.info.member(name) + def memberRef(name: TermName) = gen.mkAttributedRef(stateMachineMember(name)) } - def build(block: Block, toRename: Map[Symbol, c.Name]): AsyncBlock = { + def buildAsyncBlock(block: Block, symLookup: SymLookup): AsyncBlock = { val Block(stats, expr) = block val startState = stateAssigner.nextState() val endState = Int.MaxValue - val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, toRename) + val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, symLookup) new AsyncBlock { def asyncStates = blockBuilder.asyncStates.toList @@ -310,9 +306,9 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: def mkCombinedHandlerCases[T]: List[CaseDef] = { val caseForLastState: CaseDef = { val lastState = asyncStates.last - val lastStateBody = c.Expr[T](lastState.body) + val lastStateBody = Expr[T](lastState.body) val rhs = futureSystemOps.completeProm( - c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Success(lastStateBody.splice))) + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Success(lastStateBody.splice))) mkHandlerCase(lastState.state, rhs.tree) } asyncStates.toList match { @@ -327,18 +323,6 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: val initStates = asyncStates.init /** - * // assumes tr: Try[Any] is in scope. - * // - * state match { - * case 0 => { - * x11 = tr.get.asInstanceOf[Double]; - * state = 1; - * resume() - * } - */ - def onCompleteHandler[T: c.WeakTypeTag]: Tree = Match(Ident(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList) - - /** * def resume(): Unit = { * try { * state match { @@ -353,18 +337,31 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: * } * } */ - def resumeFunTree[T]: Tree = + def resumeFunTree[T]: DefDef = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), Try( - Match(Ident(name.state), mkCombinedHandlerCases[T]), + Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T]), List( CaseDef( - Apply(Ident(defn.NonFatalClass), List(Bind(name.tr, Ident(nme.WILDCARD)))), - EmptyTree, + Bind(name.t, Ident(nme.WILDCARD)), + Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), Block(List({ - val t = c.Expr[Throwable](Ident(name.tr)) - futureSystemOps.completeProm[T](c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Failure(t.splice))).tree - }), c.literalUnit.tree))), EmptyTree)) + val t = Expr[Throwable](Ident(name.t)) + futureSystemOps.completeProm[T]( + Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Failure(t.splice))).tree + }), literalUnit))), EmptyTree)) + + /** + * // assumes tr: Try[Any] is in scope. + * // + * state match { + * case 0 => { + * x11 = tr.get.asInstanceOf[Double]; + * state = 1; + * resume() + * } + */ + def onCompleteHandler[T: WeakTypeTag]: Tree = Match(symLookup.memberRef(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList) } } @@ -373,22 +370,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: case _ => false } - private final case class Awaitable(expr: Tree, resultName: TermName, resultType: Type) - - private val internalSyms = origTree.collect { - case dt: DefTree => dt.symbol - } + case class Awaitable(expr: Tree, resultName: Symbol, resultType: Type, resultValDef: ValDef) - private def resetInternalAttrs(tree: Tree) = utils.resetInternalAttrs(tree, internalSyms) + private def mkResumeApply(symLookup: SymLookup) = Apply(symLookup.memberRef(name.resume), Nil) - private def mkResumeApply = Apply(Ident(name.resume), Nil) + private def mkStateTree(nextState: Int, symLookup: SymLookup): Tree = + Assign(symLookup.memberRef(name.state), Literal(Constant(nextState))) - private def mkStateTree(nextState: Int): c.Tree = - Assign(Ident(name.state), c.literal(nextState).tree) + private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef = + mkHandlerCase(num, Block(rhs, literalUnit)) - private def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef = - mkHandlerCase(num, Block(rhs, c.literalUnit.tree)) + private def mkHandlerCase(num: Int, rhs: Tree): CaseDef = + CaseDef(Literal(Constant(num)), EmptyTree, rhs) - private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = - CaseDef(c.literal(num).tree, EmptyTree, rhs) + private def literalUnit = Literal(Constant(())) } |