aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/ExprBuilder.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/scala/async/ExprBuilder.scala')
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala78
1 files changed, 42 insertions, 36 deletions
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index 255349f..7a9c98d 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -5,6 +5,7 @@ package scala.async
import scala.reflect.macros.Context
import scala.collection.mutable.ListBuffer
+import collection.mutable
/*
* @author Philipp Haller
@@ -266,7 +267,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
/* TODO Fall back to CPS plug-in if tree contains an `await` call. */
def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists {
- case Apply(fun, _) if fun.symbol == Async_await => true
+ case Apply(fun, _) if isAwait(fun) => true
case _ => false
}) c.abort(tree.pos, "await unsupported in this position") //throw new FallbackToCpsException
@@ -281,7 +282,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
// populate asyncStates
for (stat <- stats) stat match {
// the val name = await(..) pattern
- case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == Async_await =>
+ case ValDef(mods, name, tpt, Apply(fun, args)) if isAwait(fun) =>
val afterAwaitState = stateAssigner.nextState()
asyncStates += stateBuilder.complete(args.head, toRename(stat.symbol).toTermName, tpt, afterAwaitState).result // complete with await
currState = afterAwaitState
@@ -390,21 +391,18 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
fun.symbol == defn.Async_await
}
- private[async] class LiftableVarTraverser extends Traverser {
- var blockId = 0
- var valDefBlockId = Map[Symbol, (ValDef, Int)]()
- val liftable = collection.mutable.Set[ValDef]()
-
+ /**
+ * Analyze the contents of an `async` block in order to:
+ * - Report unsupported `await` calls under nested templates, functions, by-name arguments.
+ * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based
+ * on whether or not they are accessed only from a single state.
+ */
+ private[async] class AsyncAnalyzer extends Traverser {
+ private var chunkId = 0
+ private def nextChunk() = chunkId += 1
+ private var valDefChunkId = Map[Symbol, (ValDef, Int)]()
- def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) {
- val badAwaits = tree collect {
- case rt: RefTree if rt.symbol == Async_await => rt
- }
- badAwaits foreach {
- tree =>
- c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
- }
- }
+ val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]()
override def traverse(tree: Tree) = {
tree match {
@@ -416,22 +414,13 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
case _: Function =>
reportUnsupportedAwait(tree, "nested anonymous function")
case If(cond, thenp, elsep) if tree exists isAwait =>
- traverse(cond)
- blockId += 1
- traverse(thenp)
- blockId += 1
- traverse(elsep)
- blockId += 1
+ traverseChunks(List(cond, thenp, elsep))
case Match(selector, cases) if tree exists isAwait =>
- traverse(selector)
- blockId += 1
- cases foreach {
- c => traverse(c); blockId += 1
- }
+ traverseChunks(selector :: cases)
case Apply(fun, args) if isAwait(fun) =>
traverseTrees(args)
traverse(fun)
- blockId += 1
+ nextChunk()
case Apply(fun, args) =>
val isInByName = isByName(fun)
for ((arg, index) <- args.zipWithIndex) {
@@ -441,28 +430,45 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
traverse(fun)
case vd: ValDef =>
super.traverse(tree)
- valDefBlockId += (vd.symbol ->(vd, blockId))
- if (vd.rhs.symbol == Async_await) liftable += vd
+ valDefChunkId += (vd.symbol ->(vd, chunkId))
+ if (isAwait(vd.rhs)) valDefsToLift += vd
case as: Assign =>
- if (as.rhs.symbol == Async_await) liftable += valDefBlockId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block."))._1
-
+ if (isAwait(as.rhs)) {
+ // TODO test the orElse case, try to remove the restriction.
+ val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block."))
+ valDefsToLift += vd
+ }
super.traverse(tree)
case rt: RefTree =>
- valDefBlockId.get(rt.symbol) match {
- case Some((vd, defBlockId)) if defBlockId != blockId =>
- liftable += vd
+ valDefChunkId.get(rt.symbol) match {
+ case Some((vd, defChunkId)) if defChunkId != chunkId =>
+ valDefsToLift += vd
case _ =>
}
super.traverse(tree)
case _ => super.traverse(tree)
}
}
+
+ private def traverseChunks(trees: List[Tree]) {
+ trees.foreach {t => traverse(t); nextChunk()}
+ }
+
+ private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) {
+ val badAwaits = tree collect {
+ case rt: RefTree if isAwait(rt) => rt
+ }
+ badAwaits foreach {
+ tree =>
+ c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
+ }
+ }
}
/** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */
private def methodSym(apply: c.Expr[Any]): Symbol = {
- val tree2: Tree = c.typeCheck(apply.tree) // TODO why is this needed?
+ val tree2: Tree = c.typeCheck(apply.tree)
tree2.collect {
case s: SymTree if s.symbol.isMethod => s.symbol
}.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}"))