aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2012-11-22 13:33:09 +0100
committerJason Zaugg <jzaugg@gmail.com>2012-11-22 13:33:09 +0100
commit8e4a8ecdff955c4faa1dec344a2b93543ffe7d45 (patch)
tree8733f9b854baa83194b1688fa30ed5fc90fd249c
parenta30ba69777a83d77b3924081f8b70d76c4a3ed59 (diff)
downloadscala-async-8e4a8ecdff955c4faa1dec344a2b93543ffe7d45.tar.gz
scala-async-8e4a8ecdff955c4faa1dec344a2b93543ffe7d45.tar.bz2
scala-async-8e4a8ecdff955c4faa1dec344a2b93543ffe7d45.zip
Cleanups and docs.
- Move now-working duplicate definition tests from `neg` to `run`. - Renames and small code beautification around the var lifting analysis
-rw-r--r--src/main/scala/scala/async/Async.scala62
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala78
-rw-r--r--src/test/scala/scala/async/neg/AnfTransformNegSpec.scala59
-rw-r--r--src/test/scala/scala/async/run/anf/AnfTransformSpec.scala32
4 files changed, 108 insertions, 123 deletions
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala
index d088b45..bd766f2 100644
--- a/src/main/scala/scala/async/Async.scala
+++ b/src/main/scala/scala/async/Async.scala
@@ -71,45 +71,37 @@ abstract class AsyncBase {
import builder.name
import builder.futureSystemOps
- val btree: Tree = {
+ // Transform to A-normal form:
+ // - no await calls in qualifiers or arguments,
+ // - if/match only used in statement position.
+ val anfTree: Block = {
val transform = new AnfTransform[c.type](c)
val stats1 :+ expr1 = transform.anf.transformToList(body.tree)
- c.typeCheck(Block(stats1, expr1))
+ c.typeCheck(Block(stats1, expr1)).asInstanceOf[Block]
}
- val traverser = new builder.LiftableVarTraverser
- traverser.traverse(btree)
- val renameMap = traverser.liftable.map {
- vd =>
- (vd.symbol, builder.name.fresh(vd.name))
- }.toMap
-
- def location = try {
- c.macroApplication.pos.source.path
- } catch {
- case _: UnsupportedOperationException =>
- c.macroApplication.pos.toString
+ // Analyze the block to find locals that will be accessed from multiple
+ // states of our generated state machine, e.g. a value assigned before
+ // an `await` and read afterwards.
+ val renameMap: Map[Symbol, TermName] = {
+ val analyzer = new builder.AsyncAnalyzer
+ analyzer.traverse(anfTree)
+ analyzer.valDefsToLift.map {
+ vd =>
+ (vd.symbol, builder.name.fresh(vd.name))
+ }.toMap
}
- AsyncUtils.vprintln(s"In file '$location':")
- AsyncUtils.vprintln(s"${c.macroApplication}")
- AsyncUtils.vprintln(s"ANF transform expands to:\n $btree")
-
- val (stats, expr) = btree match {
- case Block(stats, expr) => (stats, expr)
- case tree => (Nil, tree)
- }
val startState = builder.stateAssigner.nextState()
val endState = Int.MaxValue
- val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, startState, endState, renameMap)
-
- asyncBlockBuilder.asyncStates foreach (s => AsyncUtils.vprintln(s))
-
+ val asyncBlockBuilder = new builder.AsyncBlockBuilder(anfTree.stats, anfTree.expr, startState, endState, renameMap)
val handlerCases: List[CaseDef] = asyncBlockBuilder.mkCombinedHandlerCases[T]()
- val initStates = asyncBlockBuilder.asyncStates.init
- val localVarTrees = asyncBlockBuilder.asyncStates.flatMap(_.allVarDefs).toList
+ import asyncBlockBuilder.asyncStates
+ logDiagnostics(c)(anfTree, asyncStates.map(_.toString))
+ val initStates = asyncStates.init
+ val localVarTrees = asyncStates.flatMap(_.allVarDefs).toList
/*
lazy val onCompleteHandler = (tr: Try[Any]) => state match {
@@ -186,4 +178,18 @@ abstract class AsyncBase {
result
}
+
+ def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) {
+ def location = try {
+ c.macroApplication.pos.source.path
+ } catch {
+ case _: UnsupportedOperationException =>
+ c.macroApplication.pos.toString
+ }
+
+ AsyncUtils.vprintln(s"In file '$location':")
+ AsyncUtils.vprintln(s"${c.macroApplication}")
+ AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree")
+ states foreach (s => AsyncUtils.vprintln(s))
+ }
}
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index 255349f..7a9c98d 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -5,6 +5,7 @@ package scala.async
import scala.reflect.macros.Context
import scala.collection.mutable.ListBuffer
+import collection.mutable
/*
* @author Philipp Haller
@@ -266,7 +267,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
/* TODO Fall back to CPS plug-in if tree contains an `await` call. */
def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists {
- case Apply(fun, _) if fun.symbol == Async_await => true
+ case Apply(fun, _) if isAwait(fun) => true
case _ => false
}) c.abort(tree.pos, "await unsupported in this position") //throw new FallbackToCpsException
@@ -281,7 +282,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
// populate asyncStates
for (stat <- stats) stat match {
// the val name = await(..) pattern
- case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == Async_await =>
+ case ValDef(mods, name, tpt, Apply(fun, args)) if isAwait(fun) =>
val afterAwaitState = stateAssigner.nextState()
asyncStates += stateBuilder.complete(args.head, toRename(stat.symbol).toTermName, tpt, afterAwaitState).result // complete with await
currState = afterAwaitState
@@ -390,21 +391,18 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
fun.symbol == defn.Async_await
}
- private[async] class LiftableVarTraverser extends Traverser {
- var blockId = 0
- var valDefBlockId = Map[Symbol, (ValDef, Int)]()
- val liftable = collection.mutable.Set[ValDef]()
-
+ /**
+ * Analyze the contents of an `async` block in order to:
+ * - Report unsupported `await` calls under nested templates, functions, by-name arguments.
+ * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based
+ * on whether or not they are accessed only from a single state.
+ */
+ private[async] class AsyncAnalyzer extends Traverser {
+ private var chunkId = 0
+ private def nextChunk() = chunkId += 1
+ private var valDefChunkId = Map[Symbol, (ValDef, Int)]()
- def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) {
- val badAwaits = tree collect {
- case rt: RefTree if rt.symbol == Async_await => rt
- }
- badAwaits foreach {
- tree =>
- c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
- }
- }
+ val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]()
override def traverse(tree: Tree) = {
tree match {
@@ -416,22 +414,13 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
case _: Function =>
reportUnsupportedAwait(tree, "nested anonymous function")
case If(cond, thenp, elsep) if tree exists isAwait =>
- traverse(cond)
- blockId += 1
- traverse(thenp)
- blockId += 1
- traverse(elsep)
- blockId += 1
+ traverseChunks(List(cond, thenp, elsep))
case Match(selector, cases) if tree exists isAwait =>
- traverse(selector)
- blockId += 1
- cases foreach {
- c => traverse(c); blockId += 1
- }
+ traverseChunks(selector :: cases)
case Apply(fun, args) if isAwait(fun) =>
traverseTrees(args)
traverse(fun)
- blockId += 1
+ nextChunk()
case Apply(fun, args) =>
val isInByName = isByName(fun)
for ((arg, index) <- args.zipWithIndex) {
@@ -441,28 +430,45 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
traverse(fun)
case vd: ValDef =>
super.traverse(tree)
- valDefBlockId += (vd.symbol ->(vd, blockId))
- if (vd.rhs.symbol == Async_await) liftable += vd
+ valDefChunkId += (vd.symbol ->(vd, chunkId))
+ if (isAwait(vd.rhs)) valDefsToLift += vd
case as: Assign =>
- if (as.rhs.symbol == Async_await) liftable += valDefBlockId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block."))._1
-
+ if (isAwait(as.rhs)) {
+ // TODO test the orElse case, try to remove the restriction.
+ val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block."))
+ valDefsToLift += vd
+ }
super.traverse(tree)
case rt: RefTree =>
- valDefBlockId.get(rt.symbol) match {
- case Some((vd, defBlockId)) if defBlockId != blockId =>
- liftable += vd
+ valDefChunkId.get(rt.symbol) match {
+ case Some((vd, defChunkId)) if defChunkId != chunkId =>
+ valDefsToLift += vd
case _ =>
}
super.traverse(tree)
case _ => super.traverse(tree)
}
}
+
+ private def traverseChunks(trees: List[Tree]) {
+ trees.foreach {t => traverse(t); nextChunk()}
+ }
+
+ private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) {
+ val badAwaits = tree collect {
+ case rt: RefTree if isAwait(rt) => rt
+ }
+ badAwaits foreach {
+ tree =>
+ c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
+ }
+ }
}
/** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */
private def methodSym(apply: c.Expr[Any]): Symbol = {
- val tree2: Tree = c.typeCheck(apply.tree) // TODO why is this needed?
+ val tree2: Tree = c.typeCheck(apply.tree)
tree2.collect {
case s: SymTree if s.symbol.isMethod => s.symbol
}.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}"))
diff --git a/src/test/scala/scala/async/neg/AnfTransformNegSpec.scala b/src/test/scala/scala/async/neg/AnfTransformNegSpec.scala
deleted file mode 100644
index 0678429..0000000
--- a/src/test/scala/scala/async/neg/AnfTransformNegSpec.scala
+++ /dev/null
@@ -1,59 +0,0 @@
-/**
- * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
- */
-package scala.async
-package neg
-
-import org.junit.runner.RunWith
-import org.junit.runners.JUnit4
-import org.junit.{Ignore, Test}
-
-@RunWith(classOf[JUnit4])
-class AnfTransformNegSpec {
-
- @Test
- @Ignore
- def `inlining block produces duplicate definition`() {
- expectError("x is already defined as value x", "-deprecation -Xfatal-warnings") {
- """
- | import scala.concurrent.ExecutionContext.Implicits.global
- | import scala.concurrent.Future
- | import scala.async.Async._
- |
- | async {
- | val f = Future { 12 }
- | val x = await(f)
- |
- | {
- | val x = 42
- | println(x)
- | }
- |
- | x
- | }
- """.stripMargin
- }
- }
-
- @Test
- @Ignore
- def `inlining block in tail position produces duplicate definition`() {
- expectError("x is already defined as value x", "-deprecation -Xfatal-warnings") {
- """
- | import scala.concurrent.ExecutionContext.Implicits.global
- | import scala.concurrent.Future
- | import scala.async.Async._
- |
- | async {
- | val f = Future { 12 }
- | val x = await(f)
- |
- | {
- | val x = 42
- | x
- | }
- | }
- """.stripMargin
- }
- }
-}
diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
index f2fc2d7..0abb937 100644
--- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
+++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
@@ -111,4 +111,36 @@ class AnfTransformSpec {
Await.result(fut, 2 seconds)
State.result mustBe (14)
}
+
+ @Test
+ def `inlining block produces duplicate definition`() {
+ import scala.async.AsyncId
+
+ AsyncId.async {
+ val f = 12
+ val x = AsyncId.await(f)
+
+ {
+ val x = 42
+ println(x)
+ }
+
+ x
+ }
+ }
+ @Test
+ def `inlining block in tail position produces duplicate definition`() {
+ import scala.async.AsyncId
+
+ AsyncId.async {
+ val f = 12
+ val x = AsyncId.await(f)
+
+ {
+ val x = 42 // TODO should we rename the symbols when we collapse them into the same scope?
+ x
+ }
+ } mustBe (42)
+
+ }
}