diff options
Diffstat (limited to 'src/main/scala/scala/async/ExprBuilder.scala')
-rw-r--r-- | src/main/scala/scala/async/ExprBuilder.scala | 78 |
1 files changed, 42 insertions, 36 deletions
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 255349f..7a9c98d 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -5,6 +5,7 @@ package scala.async import scala.reflect.macros.Context import scala.collection.mutable.ListBuffer +import collection.mutable /* * @author Philipp Haller @@ -266,7 +267,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists { - case Apply(fun, _) if fun.symbol == Async_await => true + case Apply(fun, _) if isAwait(fun) => true case _ => false }) c.abort(tree.pos, "await unsupported in this position") //throw new FallbackToCpsException @@ -281,7 +282,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val // populate asyncStates for (stat <- stats) stat match { // the val name = await(..) pattern - case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == Async_await => + case ValDef(mods, name, tpt, Apply(fun, args)) if isAwait(fun) => val afterAwaitState = stateAssigner.nextState() asyncStates += stateBuilder.complete(args.head, toRename(stat.symbol).toTermName, tpt, afterAwaitState).result // complete with await currState = afterAwaitState @@ -390,21 +391,18 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val fun.symbol == defn.Async_await } - private[async] class LiftableVarTraverser extends Traverser { - var blockId = 0 - var valDefBlockId = Map[Symbol, (ValDef, Int)]() - val liftable = collection.mutable.Set[ValDef]() - + /** + * Analyze the contents of an `async` block in order to: + * - Report unsupported `await` calls under nested templates, functions, by-name arguments. + * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based + * on whether or not they are accessed only from a single state. + */ + private[async] class AsyncAnalyzer extends Traverser { + private var chunkId = 0 + private def nextChunk() = chunkId += 1 + private var valDefChunkId = Map[Symbol, (ValDef, Int)]() - def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { - val badAwaits = tree collect { - case rt: RefTree if rt.symbol == Async_await => rt - } - badAwaits foreach { - tree => - c.error(tree.pos, s"await must not be used under a $whyUnsupported.") - } - } + val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]() override def traverse(tree: Tree) = { tree match { @@ -416,22 +414,13 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val case _: Function => reportUnsupportedAwait(tree, "nested anonymous function") case If(cond, thenp, elsep) if tree exists isAwait => - traverse(cond) - blockId += 1 - traverse(thenp) - blockId += 1 - traverse(elsep) - blockId += 1 + traverseChunks(List(cond, thenp, elsep)) case Match(selector, cases) if tree exists isAwait => - traverse(selector) - blockId += 1 - cases foreach { - c => traverse(c); blockId += 1 - } + traverseChunks(selector :: cases) case Apply(fun, args) if isAwait(fun) => traverseTrees(args) traverse(fun) - blockId += 1 + nextChunk() case Apply(fun, args) => val isInByName = isByName(fun) for ((arg, index) <- args.zipWithIndex) { @@ -441,28 +430,45 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val traverse(fun) case vd: ValDef => super.traverse(tree) - valDefBlockId += (vd.symbol ->(vd, blockId)) - if (vd.rhs.symbol == Async_await) liftable += vd + valDefChunkId += (vd.symbol ->(vd, chunkId)) + if (isAwait(vd.rhs)) valDefsToLift += vd case as: Assign => - if (as.rhs.symbol == Async_await) liftable += valDefBlockId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block."))._1 - + if (isAwait(as.rhs)) { + // TODO test the orElse case, try to remove the restriction. + val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block.")) + valDefsToLift += vd + } super.traverse(tree) case rt: RefTree => - valDefBlockId.get(rt.symbol) match { - case Some((vd, defBlockId)) if defBlockId != blockId => - liftable += vd + valDefChunkId.get(rt.symbol) match { + case Some((vd, defChunkId)) if defChunkId != chunkId => + valDefsToLift += vd case _ => } super.traverse(tree) case _ => super.traverse(tree) } } + + private def traverseChunks(trees: List[Tree]) { + trees.foreach {t => traverse(t); nextChunk()} + } + + private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { + val badAwaits = tree collect { + case rt: RefTree if isAwait(rt) => rt + } + badAwaits foreach { + tree => + c.error(tree.pos, s"await must not be used under a $whyUnsupported.") + } + } } /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ private def methodSym(apply: c.Expr[Any]): Symbol = { - val tree2: Tree = c.typeCheck(apply.tree) // TODO why is this needed? + val tree2: Tree = c.typeCheck(apply.tree) tree2.collect { case s: SymTree if s.symbol.isMethod => s.symbol }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}")) |