From 8e4a8ecdff955c4faa1dec344a2b93543ffe7d45 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Thu, 22 Nov 2012 13:33:09 +0100 Subject: Cleanups and docs. - Move now-working duplicate definition tests from `neg` to `run`. - Renames and small code beautification around the var lifting analysis --- src/main/scala/scala/async/Async.scala | 62 +++++++++-------- src/main/scala/scala/async/ExprBuilder.scala | 78 ++++++++++++---------- .../scala/async/neg/AnfTransformNegSpec.scala | 59 ---------------- .../scala/async/run/anf/AnfTransformSpec.scala | 32 +++++++++ 4 files changed, 108 insertions(+), 123 deletions(-) delete mode 100644 src/test/scala/scala/async/neg/AnfTransformNegSpec.scala diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index d088b45..bd766f2 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -71,45 +71,37 @@ abstract class AsyncBase { import builder.name import builder.futureSystemOps - val btree: Tree = { + // Transform to A-normal form: + // - no await calls in qualifiers or arguments, + // - if/match only used in statement position. + val anfTree: Block = { val transform = new AnfTransform[c.type](c) val stats1 :+ expr1 = transform.anf.transformToList(body.tree) - c.typeCheck(Block(stats1, expr1)) + c.typeCheck(Block(stats1, expr1)).asInstanceOf[Block] } - val traverser = new builder.LiftableVarTraverser - traverser.traverse(btree) - val renameMap = traverser.liftable.map { - vd => - (vd.symbol, builder.name.fresh(vd.name)) - }.toMap - - def location = try { - c.macroApplication.pos.source.path - } catch { - case _: UnsupportedOperationException => - c.macroApplication.pos.toString + // Analyze the block to find locals that will be accessed from multiple + // states of our generated state machine, e.g. a value assigned before + // an `await` and read afterwards. + val renameMap: Map[Symbol, TermName] = { + val analyzer = new builder.AsyncAnalyzer + analyzer.traverse(anfTree) + analyzer.valDefsToLift.map { + vd => + (vd.symbol, builder.name.fresh(vd.name)) + }.toMap } - AsyncUtils.vprintln(s"In file '$location':") - AsyncUtils.vprintln(s"${c.macroApplication}") - AsyncUtils.vprintln(s"ANF transform expands to:\n $btree") - - val (stats, expr) = btree match { - case Block(stats, expr) => (stats, expr) - case tree => (Nil, tree) - } val startState = builder.stateAssigner.nextState() val endState = Int.MaxValue - val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, startState, endState, renameMap) - - asyncBlockBuilder.asyncStates foreach (s => AsyncUtils.vprintln(s)) - + val asyncBlockBuilder = new builder.AsyncBlockBuilder(anfTree.stats, anfTree.expr, startState, endState, renameMap) val handlerCases: List[CaseDef] = asyncBlockBuilder.mkCombinedHandlerCases[T]() - val initStates = asyncBlockBuilder.asyncStates.init - val localVarTrees = asyncBlockBuilder.asyncStates.flatMap(_.allVarDefs).toList + import asyncBlockBuilder.asyncStates + logDiagnostics(c)(anfTree, asyncStates.map(_.toString)) + val initStates = asyncStates.init + val localVarTrees = asyncStates.flatMap(_.allVarDefs).toList /* lazy val onCompleteHandler = (tr: Try[Any]) => state match { @@ -186,4 +178,18 @@ abstract class AsyncBase { result } + + def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) { + def location = try { + c.macroApplication.pos.source.path + } catch { + case _: UnsupportedOperationException => + c.macroApplication.pos.toString + } + + AsyncUtils.vprintln(s"In file '$location':") + AsyncUtils.vprintln(s"${c.macroApplication}") + AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree") + states foreach (s => AsyncUtils.vprintln(s)) + } } diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 255349f..7a9c98d 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -5,6 +5,7 @@ package scala.async import scala.reflect.macros.Context import scala.collection.mutable.ListBuffer +import collection.mutable /* * @author Philipp Haller @@ -266,7 +267,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists { - case Apply(fun, _) if fun.symbol == Async_await => true + case Apply(fun, _) if isAwait(fun) => true case _ => false }) c.abort(tree.pos, "await unsupported in this position") //throw new FallbackToCpsException @@ -281,7 +282,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val // populate asyncStates for (stat <- stats) stat match { // the val name = await(..) pattern - case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == Async_await => + case ValDef(mods, name, tpt, Apply(fun, args)) if isAwait(fun) => val afterAwaitState = stateAssigner.nextState() asyncStates += stateBuilder.complete(args.head, toRename(stat.symbol).toTermName, tpt, afterAwaitState).result // complete with await currState = afterAwaitState @@ -390,21 +391,18 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val fun.symbol == defn.Async_await } - private[async] class LiftableVarTraverser extends Traverser { - var blockId = 0 - var valDefBlockId = Map[Symbol, (ValDef, Int)]() - val liftable = collection.mutable.Set[ValDef]() - + /** + * Analyze the contents of an `async` block in order to: + * - Report unsupported `await` calls under nested templates, functions, by-name arguments. + * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based + * on whether or not they are accessed only from a single state. + */ + private[async] class AsyncAnalyzer extends Traverser { + private var chunkId = 0 + private def nextChunk() = chunkId += 1 + private var valDefChunkId = Map[Symbol, (ValDef, Int)]() - def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { - val badAwaits = tree collect { - case rt: RefTree if rt.symbol == Async_await => rt - } - badAwaits foreach { - tree => - c.error(tree.pos, s"await must not be used under a $whyUnsupported.") - } - } + val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]() override def traverse(tree: Tree) = { tree match { @@ -416,22 +414,13 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val case _: Function => reportUnsupportedAwait(tree, "nested anonymous function") case If(cond, thenp, elsep) if tree exists isAwait => - traverse(cond) - blockId += 1 - traverse(thenp) - blockId += 1 - traverse(elsep) - blockId += 1 + traverseChunks(List(cond, thenp, elsep)) case Match(selector, cases) if tree exists isAwait => - traverse(selector) - blockId += 1 - cases foreach { - c => traverse(c); blockId += 1 - } + traverseChunks(selector :: cases) case Apply(fun, args) if isAwait(fun) => traverseTrees(args) traverse(fun) - blockId += 1 + nextChunk() case Apply(fun, args) => val isInByName = isByName(fun) for ((arg, index) <- args.zipWithIndex) { @@ -441,28 +430,45 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val traverse(fun) case vd: ValDef => super.traverse(tree) - valDefBlockId += (vd.symbol ->(vd, blockId)) - if (vd.rhs.symbol == Async_await) liftable += vd + valDefChunkId += (vd.symbol ->(vd, chunkId)) + if (isAwait(vd.rhs)) valDefsToLift += vd case as: Assign => - if (as.rhs.symbol == Async_await) liftable += valDefBlockId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block."))._1 - + if (isAwait(as.rhs)) { + // TODO test the orElse case, try to remove the restriction. + val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block.")) + valDefsToLift += vd + } super.traverse(tree) case rt: RefTree => - valDefBlockId.get(rt.symbol) match { - case Some((vd, defBlockId)) if defBlockId != blockId => - liftable += vd + valDefChunkId.get(rt.symbol) match { + case Some((vd, defChunkId)) if defChunkId != chunkId => + valDefsToLift += vd case _ => } super.traverse(tree) case _ => super.traverse(tree) } } + + private def traverseChunks(trees: List[Tree]) { + trees.foreach {t => traverse(t); nextChunk()} + } + + private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { + val badAwaits = tree collect { + case rt: RefTree if isAwait(rt) => rt + } + badAwaits foreach { + tree => + c.error(tree.pos, s"await must not be used under a $whyUnsupported.") + } + } } /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ private def methodSym(apply: c.Expr[Any]): Symbol = { - val tree2: Tree = c.typeCheck(apply.tree) // TODO why is this needed? + val tree2: Tree = c.typeCheck(apply.tree) tree2.collect { case s: SymTree if s.symbol.isMethod => s.symbol }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}")) diff --git a/src/test/scala/scala/async/neg/AnfTransformNegSpec.scala b/src/test/scala/scala/async/neg/AnfTransformNegSpec.scala deleted file mode 100644 index 0678429..0000000 --- a/src/test/scala/scala/async/neg/AnfTransformNegSpec.scala +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright (C) 2012 Typesafe Inc. - */ -package scala.async -package neg - -import org.junit.runner.RunWith -import org.junit.runners.JUnit4 -import org.junit.{Ignore, Test} - -@RunWith(classOf[JUnit4]) -class AnfTransformNegSpec { - - @Test - @Ignore - def `inlining block produces duplicate definition`() { - expectError("x is already defined as value x", "-deprecation -Xfatal-warnings") { - """ - | import scala.concurrent.ExecutionContext.Implicits.global - | import scala.concurrent.Future - | import scala.async.Async._ - | - | async { - | val f = Future { 12 } - | val x = await(f) - | - | { - | val x = 42 - | println(x) - | } - | - | x - | } - """.stripMargin - } - } - - @Test - @Ignore - def `inlining block in tail position produces duplicate definition`() { - expectError("x is already defined as value x", "-deprecation -Xfatal-warnings") { - """ - | import scala.concurrent.ExecutionContext.Implicits.global - | import scala.concurrent.Future - | import scala.async.Async._ - | - | async { - | val f = Future { 12 } - | val x = await(f) - | - | { - | val x = 42 - | x - | } - | } - """.stripMargin - } - } -} diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala index f2fc2d7..0abb937 100644 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -111,4 +111,36 @@ class AnfTransformSpec { Await.result(fut, 2 seconds) State.result mustBe (14) } + + @Test + def `inlining block produces duplicate definition`() { + import scala.async.AsyncId + + AsyncId.async { + val f = 12 + val x = AsyncId.await(f) + + { + val x = 42 + println(x) + } + + x + } + } + @Test + def `inlining block in tail position produces duplicate definition`() { + import scala.async.AsyncId + + AsyncId.async { + val f = 12 + val x = AsyncId.await(f) + + { + val x = 42 // TODO should we rename the symbols when we collapse them into the same scope? + x + } + } mustBe (42) + + } } -- cgit v1.2.3