diff options
author | Jason Zaugg <jzaugg@gmail.com> | 2012-11-22 13:33:09 +0100 |
---|---|---|
committer | Jason Zaugg <jzaugg@gmail.com> | 2012-11-22 13:33:09 +0100 |
commit | 8e4a8ecdff955c4faa1dec344a2b93543ffe7d45 (patch) | |
tree | 8733f9b854baa83194b1688fa30ed5fc90fd249c /src/main | |
parent | a30ba69777a83d77b3924081f8b70d76c4a3ed59 (diff) | |
download | scala-async-8e4a8ecdff955c4faa1dec344a2b93543ffe7d45.tar.gz scala-async-8e4a8ecdff955c4faa1dec344a2b93543ffe7d45.tar.bz2 scala-async-8e4a8ecdff955c4faa1dec344a2b93543ffe7d45.zip |
Cleanups and docs.
- Move now-working duplicate definition tests from `neg` to `run`.
- Renames and small code beautification around the var lifting analysis
Diffstat (limited to 'src/main')
-rw-r--r-- | src/main/scala/scala/async/Async.scala | 62 | ||||
-rw-r--r-- | src/main/scala/scala/async/ExprBuilder.scala | 78 |
2 files changed, 76 insertions, 64 deletions
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index d088b45..bd766f2 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -71,45 +71,37 @@ abstract class AsyncBase { import builder.name import builder.futureSystemOps - val btree: Tree = { + // Transform to A-normal form: + // - no await calls in qualifiers or arguments, + // - if/match only used in statement position. + val anfTree: Block = { val transform = new AnfTransform[c.type](c) val stats1 :+ expr1 = transform.anf.transformToList(body.tree) - c.typeCheck(Block(stats1, expr1)) + c.typeCheck(Block(stats1, expr1)).asInstanceOf[Block] } - val traverser = new builder.LiftableVarTraverser - traverser.traverse(btree) - val renameMap = traverser.liftable.map { - vd => - (vd.symbol, builder.name.fresh(vd.name)) - }.toMap - - def location = try { - c.macroApplication.pos.source.path - } catch { - case _: UnsupportedOperationException => - c.macroApplication.pos.toString + // Analyze the block to find locals that will be accessed from multiple + // states of our generated state machine, e.g. a value assigned before + // an `await` and read afterwards. + val renameMap: Map[Symbol, TermName] = { + val analyzer = new builder.AsyncAnalyzer + analyzer.traverse(anfTree) + analyzer.valDefsToLift.map { + vd => + (vd.symbol, builder.name.fresh(vd.name)) + }.toMap } - AsyncUtils.vprintln(s"In file '$location':") - AsyncUtils.vprintln(s"${c.macroApplication}") - AsyncUtils.vprintln(s"ANF transform expands to:\n $btree") - - val (stats, expr) = btree match { - case Block(stats, expr) => (stats, expr) - case tree => (Nil, tree) - } val startState = builder.stateAssigner.nextState() val endState = Int.MaxValue - val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, startState, endState, renameMap) - - asyncBlockBuilder.asyncStates foreach (s => AsyncUtils.vprintln(s)) - + val asyncBlockBuilder = new builder.AsyncBlockBuilder(anfTree.stats, anfTree.expr, startState, endState, renameMap) val handlerCases: List[CaseDef] = asyncBlockBuilder.mkCombinedHandlerCases[T]() - val initStates = asyncBlockBuilder.asyncStates.init - val localVarTrees = asyncBlockBuilder.asyncStates.flatMap(_.allVarDefs).toList + import asyncBlockBuilder.asyncStates + logDiagnostics(c)(anfTree, asyncStates.map(_.toString)) + val initStates = asyncStates.init + val localVarTrees = asyncStates.flatMap(_.allVarDefs).toList /* lazy val onCompleteHandler = (tr: Try[Any]) => state match { @@ -186,4 +178,18 @@ abstract class AsyncBase { result } + + def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) { + def location = try { + c.macroApplication.pos.source.path + } catch { + case _: UnsupportedOperationException => + c.macroApplication.pos.toString + } + + AsyncUtils.vprintln(s"In file '$location':") + AsyncUtils.vprintln(s"${c.macroApplication}") + AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree") + states foreach (s => AsyncUtils.vprintln(s)) + } } diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 255349f..7a9c98d 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -5,6 +5,7 @@ package scala.async import scala.reflect.macros.Context import scala.collection.mutable.ListBuffer +import collection.mutable /* * @author Philipp Haller @@ -266,7 +267,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists { - case Apply(fun, _) if fun.symbol == Async_await => true + case Apply(fun, _) if isAwait(fun) => true case _ => false }) c.abort(tree.pos, "await unsupported in this position") //throw new FallbackToCpsException @@ -281,7 +282,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val // populate asyncStates for (stat <- stats) stat match { // the val name = await(..) pattern - case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == Async_await => + case ValDef(mods, name, tpt, Apply(fun, args)) if isAwait(fun) => val afterAwaitState = stateAssigner.nextState() asyncStates += stateBuilder.complete(args.head, toRename(stat.symbol).toTermName, tpt, afterAwaitState).result // complete with await currState = afterAwaitState @@ -390,21 +391,18 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val fun.symbol == defn.Async_await } - private[async] class LiftableVarTraverser extends Traverser { - var blockId = 0 - var valDefBlockId = Map[Symbol, (ValDef, Int)]() - val liftable = collection.mutable.Set[ValDef]() - + /** + * Analyze the contents of an `async` block in order to: + * - Report unsupported `await` calls under nested templates, functions, by-name arguments. + * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based + * on whether or not they are accessed only from a single state. + */ + private[async] class AsyncAnalyzer extends Traverser { + private var chunkId = 0 + private def nextChunk() = chunkId += 1 + private var valDefChunkId = Map[Symbol, (ValDef, Int)]() - def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { - val badAwaits = tree collect { - case rt: RefTree if rt.symbol == Async_await => rt - } - badAwaits foreach { - tree => - c.error(tree.pos, s"await must not be used under a $whyUnsupported.") - } - } + val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]() override def traverse(tree: Tree) = { tree match { @@ -416,22 +414,13 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val case _: Function => reportUnsupportedAwait(tree, "nested anonymous function") case If(cond, thenp, elsep) if tree exists isAwait => - traverse(cond) - blockId += 1 - traverse(thenp) - blockId += 1 - traverse(elsep) - blockId += 1 + traverseChunks(List(cond, thenp, elsep)) case Match(selector, cases) if tree exists isAwait => - traverse(selector) - blockId += 1 - cases foreach { - c => traverse(c); blockId += 1 - } + traverseChunks(selector :: cases) case Apply(fun, args) if isAwait(fun) => traverseTrees(args) traverse(fun) - blockId += 1 + nextChunk() case Apply(fun, args) => val isInByName = isByName(fun) for ((arg, index) <- args.zipWithIndex) { @@ -441,28 +430,45 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val traverse(fun) case vd: ValDef => super.traverse(tree) - valDefBlockId += (vd.symbol ->(vd, blockId)) - if (vd.rhs.symbol == Async_await) liftable += vd + valDefChunkId += (vd.symbol ->(vd, chunkId)) + if (isAwait(vd.rhs)) valDefsToLift += vd case as: Assign => - if (as.rhs.symbol == Async_await) liftable += valDefBlockId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block."))._1 - + if (isAwait(as.rhs)) { + // TODO test the orElse case, try to remove the restriction. + val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block.")) + valDefsToLift += vd + } super.traverse(tree) case rt: RefTree => - valDefBlockId.get(rt.symbol) match { - case Some((vd, defBlockId)) if defBlockId != blockId => - liftable += vd + valDefChunkId.get(rt.symbol) match { + case Some((vd, defChunkId)) if defChunkId != chunkId => + valDefsToLift += vd case _ => } super.traverse(tree) case _ => super.traverse(tree) } } + + private def traverseChunks(trees: List[Tree]) { + trees.foreach {t => traverse(t); nextChunk()} + } + + private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { + val badAwaits = tree collect { + case rt: RefTree if isAwait(rt) => rt + } + badAwaits foreach { + tree => + c.error(tree.pos, s"await must not be used under a $whyUnsupported.") + } + } } /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ private def methodSym(apply: c.Expr[Any]): Symbol = { - val tree2: Tree = c.typeCheck(apply.tree) // TODO why is this needed? + val tree2: Tree = c.typeCheck(apply.tree) tree2.collect { case s: SymTree if s.symbol.isMethod => s.symbol }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}")) |