aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2012-11-22 00:28:21 +0100
committerJason Zaugg <jzaugg@gmail.com>2012-11-22 00:28:21 +0100
commit7f8e9876ca83db4ed7792f09f218b341fbc6c2b2 (patch)
tree46aa3de1e8fc9e2924089417ac3359ef0e780517
parentb089630c223d510899ecf74f0cd57b0ae3ad3842 (diff)
downloadscala-async-7f8e9876ca83db4ed7792f09f218b341fbc6c2b2.tar.gz
scala-async-7f8e9876ca83db4ed7792f09f218b341fbc6c2b2.tar.bz2
scala-async-7f8e9876ca83db4ed7792f09f218b341fbc6c2b2.zip
Minimize lifting of vars.
-rw-r--r--src/main/scala/scala/async/Async.scala10
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala127
-rw-r--r--src/test/scala/scala/async/TreeInterrogation.scala3
3 files changed, 68 insertions, 72 deletions
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala
index 94f42c0..bad693d 100644
--- a/src/main/scala/scala/async/Async.scala
+++ b/src/main/scala/scala/async/Async.scala
@@ -83,6 +83,10 @@ abstract class AsyncBase {
val traverser = new builder.LiftableVarTraverser
traverser.traverse(btree)
+ val renameMap = traverser.liftable.map {
+ vd =>
+ (vd.symbol, builder.name.fresh(vd.name))
+ }.toMap
AsyncUtils.vprintln(s"In file '${c.macroApplication.pos.source.path}':")
AsyncUtils.vprintln(s"${c.macroApplication}")
@@ -93,7 +97,7 @@ abstract class AsyncBase {
case tree => (Nil, tree)
}
- val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000, Map())
+ val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000, renameMap)
asyncBlockBuilder.asyncStates foreach (s => AsyncUtils.vprintln(s))
@@ -101,10 +105,6 @@ abstract class AsyncBase {
val initStates = asyncBlockBuilder.asyncStates.init
val localVarTrees = asyncBlockBuilder.asyncStates.flatMap(_.allVarDefs).toList
- val renameMap = traverser.liftable.map {
- vd =>
- (vd.symbol, c.fresh(vd.name))
- }.toMap
/*
lazy val onCompleteHandler = (tr: Try[Any]) => state match {
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index 1ca9e8f..f65e481 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -135,7 +135,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
/*
* Builder for a single state of an async method.
*/
- class AsyncStateBuilder(state: Int, private var nameMap: Map[String, c.Name]) {
+ class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name]) {
self =>
/* Statements preceding an await call. */
@@ -156,8 +156,8 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
private val renamer = new Transformer {
override def transform(tree: Tree) = tree match {
- case Ident(_) if nameMap.keySet contains tree.symbol.toString =>
- Ident(nameMap(tree.symbol.toString))
+ case Ident(_) if nameMap.keySet contains tree.symbol =>
+ Ident(nameMap(tree.symbol))
case _ =>
super.transform(tree)
}
@@ -169,9 +169,8 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
}
//TODO do not ignore `mods`
- def addVarDef(mods: Any, name: TermName, tpt: c.Tree, rhs: c.Tree, extNameMap: Map[String, c.Name]): this.type = {
+ def addVarDef(mods: Any, name: TermName, tpt: c.Tree, rhs: c.Tree): this.type = {
varDefs += (name -> tpt.tpe)
- nameMap ++= extNameMap // update name map
this += Assign(Ident(name), rhs)
this
}
@@ -197,8 +196,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
* @param awaitResultType the type of the result of await
*/
def complete(awaitArg: c.Tree, awaitResultName: TermName, awaitResultType: Tree,
- extNameMap: Map[String, c.Name], nextState: Int = state + 1): this.type = {
- nameMap ++= extNameMap.map { case (k, v) => (k.toString, v) }
+ nextState: Int = state + 1): this.type = {
val renamed = renamer.transform(awaitArg)
awaitable = resetDuplicate(renamed)
resultName = awaitResultName
@@ -264,7 +262,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
* @param toRename a `Map` for renaming the given key symbols to the mangled value names
*/
class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int,
- budget: Int, private var toRename: Map[String, c.Name]) {
+ budget: Int, private val toRename: Map[Symbol, c.Name]) {
val asyncStates = ListBuffer[builder.AsyncState]()
private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename)
@@ -279,22 +277,19 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
case _ => false
}) c.abort(tree.pos, "await unsupported in this position") //throw new FallbackToCpsException
- def builderForBranch(tree: c.Tree, state: Int, nextState: Int, budget: Int, nameMap: Map[String, c.Name]): AsyncBlockBuilder = {
+ def builderForBranch(tree: c.Tree, state: Int, nextState: Int, budget: Int): AsyncBlockBuilder = {
val (branchStats, branchExpr) = tree match {
case Block(s, e) => (s, e)
case _ => (List(tree), c.literalUnit.tree)
}
- new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, budget, nameMap)
+ new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, budget, toRename)
}
// populate asyncStates
for (stat <- stats) stat match {
// the val name = await(..) pattern
case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == Async_await =>
- val newName = builder.name.fresh(name)
- toRename += (stat.symbol.toString -> newName)
-
- asyncStates += stateBuilder.complete(args.head, newName, tpt, toRename).result // complete with await
+ asyncStates += stateBuilder.complete(args.head, toRename(stat.symbol).toTermName, tpt).result // complete with await
if (remainingBudget > 0)
remainingBudget -= 1
else
@@ -302,13 +297,11 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
currState += 1
stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
- case ValDef(mods, name, tpt, rhs) =>
+ case ValDef(mods, name, tpt, rhs) if toRename contains stat.symbol =>
checkForUnsupportedAwait(rhs)
- val newName = builder.name.fresh(name)
- toRename += (stat.symbol.toString -> newName)
// when adding assignment need to take `toRename` into account
- stateBuilder.addVarDef(mods, newName, tpt, rhs, toRename)
+ stateBuilder.addVarDef(mods, toRename(stat.symbol).toTermName, tpt, rhs)
case If(cond, thenp, elsep) if stat exists isAwait =>
checkForUnsupportedAwait(cond)
@@ -326,9 +319,8 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
List((thenp, currState + 1, thenBudget), (elsep, currState + thenBudget, elseBudget)) foreach {
case (tree, state, branchBudget) =>
- val builder = builderForBranch(tree, state, currState + ifBudget, branchBudget, toRename)
+ val builder = builderForBranch(tree, state, currState + ifBudget, branchBudget)
asyncStates ++= builder.asyncStates
- toRename ++= builder.toRename
}
// create new state builder for state `currState + ifBudget`
@@ -354,7 +346,6 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
}
val builder = new AsyncBlockBuilder(casStats, casExpr, currState + (num * perCaseBudget) + 1, currState + matchBudget, perCaseBudget, toRename)
asyncStates ++= builder.asyncStates
- toRename ++= builder.toRename
}
// create new state builder for state `currState + matchBudget`
@@ -433,50 +424,56 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
}
}
- override def traverse(tree: Tree) = tree match {
- case cd: ClassDef =>
- val kind = if (cd.symbol.asClass.isTrait) "trait" else "class"
- reportUnsupportedAwait(tree, s"nested ${kind}")
- case md: ModuleDef =>
- reportUnsupportedAwait(tree, "nested object")
- case _: Function =>
- reportUnsupportedAwait(tree, "nested anonymous function")
- case If(cond, thenp, elsep) if tree exists isAwait =>
- traverse(cond)
- blockId += 1
- traverse(thenp)
- blockId += 1
- traverse(elsep)
- blockId += 1
- case Match(selector, cases) if tree exists isAwait =>
- traverse(selector)
- blockId += 1
- cases foreach {c => traverse(c); blockId += 1}
- case Apply(fun, args) if isAwait(fun) =>
- traverseTrees(args)
- traverse(fun)
- blockId += 1
- case Apply(fun, args) =>
- val isInByName = isByName(fun)
- for ((arg, index) <- args.zipWithIndex) {
- if (!isInByName(index)) traverse(arg)
- else reportUnsupportedAwait(arg, "by-name argument")
- }
- traverse(fun)
- case vd: ValDef =>
- super.traverse(tree)
- valDefBlockId += (vd.symbol -> (vd, blockId))
- if (vd.rhs.symbol == Async_await) liftable += vd
- case as: Assign =>
- if (as.rhs.symbol == Async_await) liftable += valDefBlockId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block."))._1
- case rt: RefTree =>
- valDefBlockId.get(rt.symbol) match {
- case Some((vd, defBlockId)) if defBlockId != blockId =>
- liftable += vd
- case _ =>
- }
- super.traverse(tree)
- case _ => super.traverse(tree)
+ override def traverse(tree: Tree) = {
+ tree match {
+ case cd: ClassDef =>
+ val kind = if (cd.symbol.asClass.isTrait) "trait" else "class"
+ reportUnsupportedAwait(tree, s"nested ${kind}")
+ case md: ModuleDef =>
+ reportUnsupportedAwait(tree, "nested object")
+ case _: Function =>
+ reportUnsupportedAwait(tree, "nested anonymous function")
+ case If(cond, thenp, elsep) if tree exists isAwait =>
+ traverse(cond)
+ blockId += 1
+ traverse(thenp)
+ blockId += 1
+ traverse(elsep)
+ blockId += 1
+ case Match(selector, cases) if tree exists isAwait =>
+ traverse(selector)
+ blockId += 1
+ cases foreach {
+ c => traverse(c); blockId += 1
+ }
+ case Apply(fun, args) if isAwait(fun) =>
+ traverseTrees(args)
+ traverse(fun)
+ blockId += 1
+ case Apply(fun, args) =>
+ val isInByName = isByName(fun)
+ for ((arg, index) <- args.zipWithIndex) {
+ if (!isInByName(index)) traverse(arg)
+ else reportUnsupportedAwait(arg, "by-name argument")
+ }
+ traverse(fun)
+ case vd: ValDef =>
+ super.traverse(tree)
+ valDefBlockId += (vd.symbol ->(vd, blockId))
+ if (vd.rhs.symbol == Async_await) liftable += vd
+ case as: Assign =>
+ if (as.rhs.symbol == Async_await) liftable += valDefBlockId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block."))._1
+
+ super.traverse(tree)
+ case rt: RefTree =>
+ valDefBlockId.get(rt.symbol) match {
+ case Some((vd, defBlockId)) if defBlockId != blockId =>
+ liftable += vd
+ case _ =>
+ }
+ super.traverse(tree)
+ case _ => super.traverse(tree)
+ }
}
}
diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala
index 1293bdf..1ed9be2 100644
--- a/src/test/scala/scala/async/TreeInterrogation.scala
+++ b/src/test/scala/scala/async/TreeInterrogation.scala
@@ -32,7 +32,6 @@ class TreeInterrogation {
val varDefs = tree1.collect {
case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name
}
- // TODO no need to lift `y` as it is only accessed from a single state.
- varDefs.map(_.decoded).toSet mustBe(Set("state$async", "onCompleteHandler$async", "x$1", "z$1", "y$1"))
+ varDefs.map(_.decoded).toSet mustBe(Set("state$async", "onCompleteHandler$async", "x$1", "z$1"))
}
}