From 7f8e9876ca83db4ed7792f09f218b341fbc6c2b2 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Thu, 22 Nov 2012 00:28:21 +0100 Subject: Minimize lifting of vars. --- src/main/scala/scala/async/Async.scala | 10 +- src/main/scala/scala/async/ExprBuilder.scala | 127 ++++++++++----------- src/test/scala/scala/async/TreeInterrogation.scala | 3 +- 3 files changed, 68 insertions(+), 72 deletions(-) diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index 94f42c0..bad693d 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -83,6 +83,10 @@ abstract class AsyncBase { val traverser = new builder.LiftableVarTraverser traverser.traverse(btree) + val renameMap = traverser.liftable.map { + vd => + (vd.symbol, builder.name.fresh(vd.name)) + }.toMap AsyncUtils.vprintln(s"In file '${c.macroApplication.pos.source.path}':") AsyncUtils.vprintln(s"${c.macroApplication}") @@ -93,7 +97,7 @@ abstract class AsyncBase { case tree => (Nil, tree) } - val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000, Map()) + val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000, renameMap) asyncBlockBuilder.asyncStates foreach (s => AsyncUtils.vprintln(s)) @@ -101,10 +105,6 @@ abstract class AsyncBase { val initStates = asyncBlockBuilder.asyncStates.init val localVarTrees = asyncBlockBuilder.asyncStates.flatMap(_.allVarDefs).toList - val renameMap = traverser.liftable.map { - vd => - (vd.symbol, c.fresh(vd.name)) - }.toMap /* lazy val onCompleteHandler = (tr: Try[Any]) => state match { diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 1ca9e8f..f65e481 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -135,7 +135,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val /* * Builder for a single state of an async method. */ - class AsyncStateBuilder(state: Int, private var nameMap: Map[String, c.Name]) { + class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name]) { self => /* Statements preceding an await call. */ @@ -156,8 +156,8 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val private val renamer = new Transformer { override def transform(tree: Tree) = tree match { - case Ident(_) if nameMap.keySet contains tree.symbol.toString => - Ident(nameMap(tree.symbol.toString)) + case Ident(_) if nameMap.keySet contains tree.symbol => + Ident(nameMap(tree.symbol)) case _ => super.transform(tree) } @@ -169,9 +169,8 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val } //TODO do not ignore `mods` - def addVarDef(mods: Any, name: TermName, tpt: c.Tree, rhs: c.Tree, extNameMap: Map[String, c.Name]): this.type = { + def addVarDef(mods: Any, name: TermName, tpt: c.Tree, rhs: c.Tree): this.type = { varDefs += (name -> tpt.tpe) - nameMap ++= extNameMap // update name map this += Assign(Ident(name), rhs) this } @@ -197,8 +196,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val * @param awaitResultType the type of the result of await */ def complete(awaitArg: c.Tree, awaitResultName: TermName, awaitResultType: Tree, - extNameMap: Map[String, c.Name], nextState: Int = state + 1): this.type = { - nameMap ++= extNameMap.map { case (k, v) => (k.toString, v) } + nextState: Int = state + 1): this.type = { val renamed = renamer.transform(awaitArg) awaitable = resetDuplicate(renamed) resultName = awaitResultName @@ -264,7 +262,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val * @param toRename a `Map` for renaming the given key symbols to the mangled value names */ class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, - budget: Int, private var toRename: Map[String, c.Name]) { + budget: Int, private val toRename: Map[Symbol, c.Name]) { val asyncStates = ListBuffer[builder.AsyncState]() private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename) @@ -279,22 +277,19 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val case _ => false }) c.abort(tree.pos, "await unsupported in this position") //throw new FallbackToCpsException - def builderForBranch(tree: c.Tree, state: Int, nextState: Int, budget: Int, nameMap: Map[String, c.Name]): AsyncBlockBuilder = { + def builderForBranch(tree: c.Tree, state: Int, nextState: Int, budget: Int): AsyncBlockBuilder = { val (branchStats, branchExpr) = tree match { case Block(s, e) => (s, e) case _ => (List(tree), c.literalUnit.tree) } - new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, budget, nameMap) + new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, budget, toRename) } // populate asyncStates for (stat <- stats) stat match { // the val name = await(..) pattern case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == Async_await => - val newName = builder.name.fresh(name) - toRename += (stat.symbol.toString -> newName) - - asyncStates += stateBuilder.complete(args.head, newName, tpt, toRename).result // complete with await + asyncStates += stateBuilder.complete(args.head, toRename(stat.symbol).toTermName, tpt).result // complete with await if (remainingBudget > 0) remainingBudget -= 1 else @@ -302,13 +297,11 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val currState += 1 stateBuilder = new builder.AsyncStateBuilder(currState, toRename) - case ValDef(mods, name, tpt, rhs) => + case ValDef(mods, name, tpt, rhs) if toRename contains stat.symbol => checkForUnsupportedAwait(rhs) - val newName = builder.name.fresh(name) - toRename += (stat.symbol.toString -> newName) // when adding assignment need to take `toRename` into account - stateBuilder.addVarDef(mods, newName, tpt, rhs, toRename) + stateBuilder.addVarDef(mods, toRename(stat.symbol).toTermName, tpt, rhs) case If(cond, thenp, elsep) if stat exists isAwait => checkForUnsupportedAwait(cond) @@ -326,9 +319,8 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val List((thenp, currState + 1, thenBudget), (elsep, currState + thenBudget, elseBudget)) foreach { case (tree, state, branchBudget) => - val builder = builderForBranch(tree, state, currState + ifBudget, branchBudget, toRename) + val builder = builderForBranch(tree, state, currState + ifBudget, branchBudget) asyncStates ++= builder.asyncStates - toRename ++= builder.toRename } // create new state builder for state `currState + ifBudget` @@ -354,7 +346,6 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val } val builder = new AsyncBlockBuilder(casStats, casExpr, currState + (num * perCaseBudget) + 1, currState + matchBudget, perCaseBudget, toRename) asyncStates ++= builder.asyncStates - toRename ++= builder.toRename } // create new state builder for state `currState + matchBudget` @@ -433,50 +424,56 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val } } - override def traverse(tree: Tree) = tree match { - case cd: ClassDef => - val kind = if (cd.symbol.asClass.isTrait) "trait" else "class" - reportUnsupportedAwait(tree, s"nested ${kind}") - case md: ModuleDef => - reportUnsupportedAwait(tree, "nested object") - case _: Function => - reportUnsupportedAwait(tree, "nested anonymous function") - case If(cond, thenp, elsep) if tree exists isAwait => - traverse(cond) - blockId += 1 - traverse(thenp) - blockId += 1 - traverse(elsep) - blockId += 1 - case Match(selector, cases) if tree exists isAwait => - traverse(selector) - blockId += 1 - cases foreach {c => traverse(c); blockId += 1} - case Apply(fun, args) if isAwait(fun) => - traverseTrees(args) - traverse(fun) - blockId += 1 - case Apply(fun, args) => - val isInByName = isByName(fun) - for ((arg, index) <- args.zipWithIndex) { - if (!isInByName(index)) traverse(arg) - else reportUnsupportedAwait(arg, "by-name argument") - } - traverse(fun) - case vd: ValDef => - super.traverse(tree) - valDefBlockId += (vd.symbol -> (vd, blockId)) - if (vd.rhs.symbol == Async_await) liftable += vd - case as: Assign => - if (as.rhs.symbol == Async_await) liftable += valDefBlockId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block."))._1 - case rt: RefTree => - valDefBlockId.get(rt.symbol) match { - case Some((vd, defBlockId)) if defBlockId != blockId => - liftable += vd - case _ => - } - super.traverse(tree) - case _ => super.traverse(tree) + override def traverse(tree: Tree) = { + tree match { + case cd: ClassDef => + val kind = if (cd.symbol.asClass.isTrait) "trait" else "class" + reportUnsupportedAwait(tree, s"nested ${kind}") + case md: ModuleDef => + reportUnsupportedAwait(tree, "nested object") + case _: Function => + reportUnsupportedAwait(tree, "nested anonymous function") + case If(cond, thenp, elsep) if tree exists isAwait => + traverse(cond) + blockId += 1 + traverse(thenp) + blockId += 1 + traverse(elsep) + blockId += 1 + case Match(selector, cases) if tree exists isAwait => + traverse(selector) + blockId += 1 + cases foreach { + c => traverse(c); blockId += 1 + } + case Apply(fun, args) if isAwait(fun) => + traverseTrees(args) + traverse(fun) + blockId += 1 + case Apply(fun, args) => + val isInByName = isByName(fun) + for ((arg, index) <- args.zipWithIndex) { + if (!isInByName(index)) traverse(arg) + else reportUnsupportedAwait(arg, "by-name argument") + } + traverse(fun) + case vd: ValDef => + super.traverse(tree) + valDefBlockId += (vd.symbol ->(vd, blockId)) + if (vd.rhs.symbol == Async_await) liftable += vd + case as: Assign => + if (as.rhs.symbol == Async_await) liftable += valDefBlockId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block."))._1 + + super.traverse(tree) + case rt: RefTree => + valDefBlockId.get(rt.symbol) match { + case Some((vd, defBlockId)) if defBlockId != blockId => + liftable += vd + case _ => + } + super.traverse(tree) + case _ => super.traverse(tree) + } } } diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index 1293bdf..1ed9be2 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -32,7 +32,6 @@ class TreeInterrogation { val varDefs = tree1.collect { case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name } - // TODO no need to lift `y` as it is only accessed from a single state. - varDefs.map(_.decoded).toSet mustBe(Set("state$async", "onCompleteHandler$async", "x$1", "z$1", "y$1")) + varDefs.map(_.decoded).toSet mustBe(Set("state$async", "onCompleteHandler$async", "x$1", "z$1")) } } -- cgit v1.2.3