aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/ExprBuilder.scala
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2013-07-02 15:55:34 +0200
committerJason Zaugg <jzaugg@gmail.com>2013-07-03 10:04:55 +0200
commit82232ec47effb4a6b67b3a0792e1c7600e2d31b7 (patch)
treeed9925418aa0a631d1d25fd1be30f5d508e81b24 /src/main/scala/scala/async/ExprBuilder.scala
parentd63b63f536aafa494c70835526174be1987050de (diff)
downloadscala-async-82232ec47effb4a6b67b3a0792e1c7600e2d31b7.tar.gz
scala-async-82232ec47effb4a6b67b3a0792e1c7600e2d31b7.tar.bz2
scala-async-82232ec47effb4a6b67b3a0792e1c7600e2d31b7.zip
An overdue overhaul of macro internals.
- Avoid reset + retypecheck, instead hang onto the original types/symbols - Eliminated duplication between AsyncDefinitionUseAnalyzer and ExprBuilder - Instead, decide what do lift *after* running ExprBuilder - Account for transitive references local classes/objects and lift them as needed. - Make the execution context an regular implicit parameter of the macro - Fixes interaction with existential skolems and singleton types Fixes #6, #13, #16, #17, #19, #21.
Diffstat (limited to 'src/main/scala/scala/async/ExprBuilder.scala')
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala205
1 files changed, 99 insertions, 106 deletions
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index ca46a83..a3837d3 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -7,17 +7,17 @@ import scala.reflect.macros.Context
import scala.collection.mutable.ListBuffer
import collection.mutable
import language.existentials
+import scala.reflect.api.Universe
+import scala.reflect.api
-private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: C, futureSystem: FS, origTree: C#Tree) {
- builder =>
+trait ExprBuilder {
+ builder: AsyncMacro =>
- val utils = TransformUtils[c.type](c)
-
- import c.universe._
- import utils._
+ import global._
import defn._
- lazy val futureSystemOps = futureSystem.mkOps(c)
+ val futureSystem: FutureSystem
+ val futureSystemOps: futureSystem.Ops { val universe: global.type }
val stateAssigner = new StateAssigner
val labelDefStates = collection.mutable.Map[Symbol, Int]()
@@ -27,22 +27,27 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
def mkHandlerCaseForState: CaseDef
- def mkOnCompleteHandler[T: c.WeakTypeTag]: Option[CaseDef] = None
+ def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None
def stats: List[Tree]
- final def body: c.Tree = stats match {
+ final def allStats: List[Tree] = this match {
+ case a: AsyncStateWithAwait => stats :+ a.awaitable.resultValDef
+ case _ => stats
+ }
+
+ final def body: Tree = stats match {
case stat :: Nil => stat
case init :+ last => Block(init, last)
}
}
/** A sequence of statements the concludes with a unconditional transition to `nextState` */
- final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int)
+ final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup)
extends AsyncState {
def mkHandlerCaseForState: CaseDef =
- mkHandlerCase(state, stats :+ mkStateTree(nextState) :+ mkResumeApply)
+ mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup))
override val toString: String =
s"AsyncState #$state, next = $nextState"
@@ -51,7 +56,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
/** A sequence of statements with a conditional transition to the next state, which will represent
* a branch of an `if` or a `match`.
*/
- final class AsyncStateWithoutAwait(val stats: List[c.Tree], val state: Int) extends AsyncState {
+ final class AsyncStateWithoutAwait(val stats: List[Tree], val state: Int) extends AsyncState {
override def mkHandlerCaseForState: CaseDef =
mkHandlerCase(state, stats)
@@ -62,25 +67,25 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
/** A sequence of statements that concludes with an `await` call. The `onComplete`
* handler will unconditionally transition to `nestState`.``
*/
- final class AsyncStateWithAwait(val stats: List[c.Tree], val state: Int, nextState: Int,
- awaitable: Awaitable)
+ final class AsyncStateWithAwait(val stats: List[Tree], val state: Int, nextState: Int,
+ val awaitable: Awaitable, symLookup: SymLookup)
extends AsyncState {
override def mkHandlerCaseForState: CaseDef = {
- val callOnComplete = futureSystemOps.onComplete(c.Expr(awaitable.expr),
- c.Expr(This(tpnme.EMPTY)), c.Expr(Ident(name.execContext))).tree
+ val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr),
+ Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree
mkHandlerCase(state, stats :+ callOnComplete)
}
- override def mkOnCompleteHandler[T: c.WeakTypeTag]: Option[CaseDef] = {
+ override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = {
val tryGetTree =
Assign(
Ident(awaitable.resultName),
- TypeApply(Select(Select(Ident(name.tr), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType)))
+ TypeApply(Select(Select(Ident(symLookup.applyTrParam), Try_get), newTermName("asInstanceOf")), List(TypeTree(awaitable.resultType)))
)
/* if (tr.isFailure)
- * result$async.complete(tr.asInstanceOf[Try[T]])
+ * result.complete(tr.asInstanceOf[Try[T]])
* else {
* <resultName> = tr.get.asInstanceOf[<resultType>]
* <nextState>
@@ -88,13 +93,13 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
* }
*/
val ifIsFailureTree =
- If(Select(Ident(name.tr), Try_isFailure),
+ If(Select(Ident(symLookup.applyTrParam), Try_isFailure),
futureSystemOps.completeProm[T](
- c.Expr[futureSystem.Prom[T]](Ident(name.result)),
- c.Expr[scala.util.Try[T]](
- TypeApply(Select(Ident(name.tr), newTermName("asInstanceOf")),
+ Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)),
+ Expr[scala.util.Try[T]](
+ TypeApply(Select(Ident(symLookup.applyTrParam), newTermName("asInstanceOf")),
List(TypeTree(weakTypeOf[scala.util.Try[T]]))))).tree,
- Block(List(tryGetTree, mkStateTree(nextState)), mkResumeApply)
+ Block(List(tryGetTree, mkStateTree(nextState, symLookup)), mkResumeApply(symLookup))
)
Some(mkHandlerCase(state, List(ifIsFailureTree)))
@@ -107,19 +112,16 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
/*
* Builder for a single state of an async method.
*/
- final class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name]) {
+ final class AsyncStateBuilder(state: Int, private val symLookup: SymLookup) {
/* Statements preceding an await call. */
- private val stats = ListBuffer[c.Tree]()
+ private val stats = ListBuffer[Tree]()
/** The state of the target of a LabelDef application (while loop jump) */
private var nextJumpState: Option[Int] = None
- private def renameReset(tree: Tree) = resetInternalAttrs(substituteNames(tree, nameMap))
-
- def +=(stat: c.Tree): this.type = {
+ def +=(stat: Tree): this.type = {
assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat")
- def addStat() = stats += renameReset(stat)
+ def addStat() = stats += stat
stat match {
- case _: DefDef => // these have been lifted.
case Apply(fun, Nil) =>
labelDefStates get fun.symbol match {
case Some(nextState) => nextJumpState = Some(nextState)
@@ -132,22 +134,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
def resultWithAwait(awaitable: Awaitable,
nextState: Int): AsyncState = {
- val sanitizedAwaitable = awaitable.copy(expr = renameReset(awaitable.expr))
val effectiveNextState = nextJumpState.getOrElse(nextState)
- new AsyncStateWithAwait(stats.toList, state, effectiveNextState, sanitizedAwaitable)
+ new AsyncStateWithAwait(stats.toList, state, effectiveNextState, awaitable, symLookup)
}
def resultSimple(nextState: Int): AsyncState = {
val effectiveNextState = nextJumpState.getOrElse(nextState)
- new SimpleAsyncState(stats.toList, state, effectiveNextState)
+ new SimpleAsyncState(stats.toList, state, effectiveNextState, symLookup)
}
- def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = {
- // 1. build changed if-else tree
- // 2. insert that tree at the end of the current state
- val cond = renameReset(condTree)
- def mkBranch(state: Int) = Block(mkStateTree(state) :: Nil, mkResumeApply)
- this += If(cond, mkBranch(thenState), mkBranch(elseState))
+ def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = {
+ def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup))
+ this += If(condTree, mkBranch(thenState), mkBranch(elseState))
new AsyncStateWithoutAwait(stats.toList, state)
}
@@ -161,23 +159,20 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
* @param caseStates starting state of the right-hand side of the each case
* @return an `AsyncState` representing the match expression
*/
- def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], caseStates: List[Int]): AsyncState = {
+ def resultWithMatch(scrutTree: Tree, cases: List[CaseDef], caseStates: List[Int], symLookup: SymLookup): AsyncState = {
// 1. build list of changed cases
val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match {
case CaseDef(pat, guard, rhs) =>
- val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal).map {
- case ValDef(_, name, _, rhs) => Assign(Ident(name), rhs)
- case t => sys.error(s"Unexpected tree. Expected ValDef, found: $t")
- }
- CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num)), mkResumeApply))
+ val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal)
+ CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num), symLookup), mkResumeApply(symLookup)))
}
// 2. insert changed match tree at the end of the current state
- this += Match(renameReset(scrutTree), newCases)
+ this += Match(scrutTree, newCases)
new AsyncStateWithoutAwait(stats.toList, state)
}
- def resultWithLabel(startLabelState: Int): AsyncState = {
- this += Block(mkStateTree(startLabelState) :: Nil, mkResumeApply)
+ def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = {
+ this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup))
new AsyncStateWithoutAwait(stats.toList, state)
}
@@ -194,24 +189,23 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
* @param expr the last expression of the block
* @param startState the start state
* @param endState the state to continue with
- * @param toRename a `Map` for renaming the given key symbols to the mangled value names
*/
- final private class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int,
- private val toRename: Map[Symbol, c.Name]) {
+ final private class AsyncBlockBuilder(stats: List[Tree], expr: Tree, startState: Int, endState: Int,
+ private val symLookup: SymLookup) {
val asyncStates = ListBuffer[AsyncState]()
- var stateBuilder = new AsyncStateBuilder(startState, toRename)
+ var stateBuilder = new AsyncStateBuilder(startState, symLookup)
var currState = startState
/* TODO Fall back to CPS plug-in if tree contains an `await` call. */
- def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists {
+ def checkForUnsupportedAwait(tree: Tree) = if (tree exists {
case Apply(fun, _) if isAwait(fun) => true
case _ => false
- }) c.abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException
+ }) abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException
def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = {
val (nestedStats, nestedExpr) = statsAndExpr(nestedTree)
- new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, toRename)
+ new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, symLookup)
}
import stateAssigner.nextState
@@ -219,16 +213,12 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
// populate asyncStates
for (stat <- stats) stat match {
// the val name = await(..) pattern
- case ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) =>
+ case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) =>
val afterAwaitState = nextState()
- val awaitable = Awaitable(arg, toRename(stat.symbol).toTermName, tpt.tpe)
+ val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd)
asyncStates += stateBuilder.resultWithAwait(awaitable, afterAwaitState) // complete with await
currState = afterAwaitState
- stateBuilder = new AsyncStateBuilder(currState, toRename)
-
- case ValDef(mods, name, tpt, rhs) if toRename contains stat.symbol =>
- checkForUnsupportedAwait(rhs)
- stateBuilder += Assign(Ident(toRename(stat.symbol).toTermName), rhs)
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
case If(cond, thenp, elsep) if stat exists isAwait =>
checkForUnsupportedAwait(cond)
@@ -248,7 +238,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
}
currState = afterIfState
- stateBuilder = new AsyncStateBuilder(currState, toRename)
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
case Match(scrutinee, cases) if stat exists isAwait =>
checkForUnsupportedAwait(scrutinee)
@@ -257,7 +247,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
val afterMatchState = nextState()
asyncStates +=
- stateBuilder.resultWithMatch(scrutinee, cases, caseStates)
+ stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup)
for ((cas, num) <- cases.zipWithIndex) {
val (stats, expr) = statsAndExpr(cas.body)
@@ -267,18 +257,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
}
currState = afterMatchState
- stateBuilder = new AsyncStateBuilder(currState, toRename)
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
case ld@LabelDef(name, params, rhs) if rhs exists isAwait =>
val startLabelState = nextState()
val afterLabelState = nextState()
- asyncStates += stateBuilder.resultWithLabel(startLabelState)
+ asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup)
labelDefStates(ld.symbol) = startLabelState
val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState)
asyncStates ++= builder.asyncStates
currState = afterLabelState
- stateBuilder = new AsyncStateBuilder(currState, toRename)
+ stateBuilder = new AsyncStateBuilder(currState, symLookup)
case _ =>
checkForUnsupportedAwait(stat)
stateBuilder += stat
@@ -292,17 +282,23 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
trait AsyncBlock {
def asyncStates: List[AsyncState]
- def onCompleteHandler[T: c.WeakTypeTag]: Tree
+ def onCompleteHandler[T: WeakTypeTag]: Tree
+
+ def resumeFunTree[T]: DefDef
+ }
- def resumeFunTree[T]: Tree
+ case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) {
+ def stateMachineMember(name: TermName): Symbol =
+ stateMachineClass.info.member(name)
+ def memberRef(name: TermName) = gen.mkAttributedRef(stateMachineMember(name))
}
- def build(block: Block, toRename: Map[Symbol, c.Name]): AsyncBlock = {
+ def buildAsyncBlock(block: Block, symLookup: SymLookup): AsyncBlock = {
val Block(stats, expr) = block
val startState = stateAssigner.nextState()
val endState = Int.MaxValue
- val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, toRename)
+ val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, symLookup)
new AsyncBlock {
def asyncStates = blockBuilder.asyncStates.toList
@@ -310,9 +306,9 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
def mkCombinedHandlerCases[T]: List[CaseDef] = {
val caseForLastState: CaseDef = {
val lastState = asyncStates.last
- val lastStateBody = c.Expr[T](lastState.body)
+ val lastStateBody = Expr[T](lastState.body)
val rhs = futureSystemOps.completeProm(
- c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Success(lastStateBody.splice)))
+ Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Success(lastStateBody.splice)))
mkHandlerCase(lastState.state, rhs.tree)
}
asyncStates.toList match {
@@ -327,18 +323,6 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
val initStates = asyncStates.init
/**
- * // assumes tr: Try[Any] is in scope.
- * //
- * state match {
- * case 0 => {
- * x11 = tr.get.asInstanceOf[Double];
- * state = 1;
- * resume()
- * }
- */
- def onCompleteHandler[T: c.WeakTypeTag]: Tree = Match(Ident(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList)
-
- /**
* def resume(): Unit = {
* try {
* state match {
@@ -353,18 +337,31 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
* }
* }
*/
- def resumeFunTree[T]: Tree =
+ def resumeFunTree[T]: DefDef =
DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass),
Try(
- Match(Ident(name.state), mkCombinedHandlerCases[T]),
+ Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T]),
List(
CaseDef(
- Apply(Ident(defn.NonFatalClass), List(Bind(name.tr, Ident(nme.WILDCARD)))),
- EmptyTree,
+ Bind(name.t, Ident(nme.WILDCARD)),
+ Apply(Ident(defn.NonFatalClass), List(Ident(name.t))),
Block(List({
- val t = c.Expr[Throwable](Ident(name.tr))
- futureSystemOps.completeProm[T](c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Failure(t.splice))).tree
- }), c.literalUnit.tree))), EmptyTree))
+ val t = Expr[Throwable](Ident(name.t))
+ futureSystemOps.completeProm[T](
+ Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), reify(scala.util.Failure(t.splice))).tree
+ }), literalUnit))), EmptyTree))
+
+ /**
+ * // assumes tr: Try[Any] is in scope.
+ * //
+ * state match {
+ * case 0 => {
+ * x11 = tr.get.asInstanceOf[Double];
+ * state = 1;
+ * resume()
+ * }
+ */
+ def onCompleteHandler[T: WeakTypeTag]: Tree = Match(symLookup.memberRef(name.state), initStates.flatMap(_.mkOnCompleteHandler[T]).toList)
}
}
@@ -373,22 +370,18 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
case _ => false
}
- private final case class Awaitable(expr: Tree, resultName: TermName, resultType: Type)
-
- private val internalSyms = origTree.collect {
- case dt: DefTree => dt.symbol
- }
+ case class Awaitable(expr: Tree, resultName: Symbol, resultType: Type, resultValDef: ValDef)
- private def resetInternalAttrs(tree: Tree) = utils.resetInternalAttrs(tree, internalSyms)
+ private def mkResumeApply(symLookup: SymLookup) = Apply(symLookup.memberRef(name.resume), Nil)
- private def mkResumeApply = Apply(Ident(name.resume), Nil)
+ private def mkStateTree(nextState: Int, symLookup: SymLookup): Tree =
+ Assign(symLookup.memberRef(name.state), Literal(Constant(nextState)))
- private def mkStateTree(nextState: Int): c.Tree =
- Assign(Ident(name.state), c.literal(nextState).tree)
+ private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef =
+ mkHandlerCase(num, Block(rhs, literalUnit))
- private def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef =
- mkHandlerCase(num, Block(rhs, c.literalUnit.tree))
+ private def mkHandlerCase(num: Int, rhs: Tree): CaseDef =
+ CaseDef(Literal(Constant(num)), EmptyTree, rhs)
- private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef =
- CaseDef(c.literal(num).tree, EmptyTree, rhs)
+ private def literalUnit = Literal(Constant(()))
}