aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/AsyncAnalysis.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/scala/async/AsyncAnalysis.scala')
-rw-r--r--src/main/scala/scala/async/AsyncAnalysis.scala39
1 files changed, 33 insertions, 6 deletions
diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala
index ecd5054..6e281e4 100644
--- a/src/main/scala/scala/async/AsyncAnalysis.scala
+++ b/src/main/scala/scala/async/AsyncAnalysis.scala
@@ -31,10 +31,11 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
*
* Must be called on the ANF transformed tree.
*/
- def valDefsUsedInSubsequentStates(tree: Tree): List[ValDef] = {
+ def defTreesUsedInSubsequentStates(tree: Tree): List[DefTree] = {
val analyzer = new AsyncDefinitionUseAnalyzer
analyzer.traverse(tree)
- analyzer.valDefsToLift.toList
+ val liftable: List[DefTree] = (analyzer.valDefsToLift ++ analyzer.nestedMethodsToLift).toList.distinct
+ liftable
}
private class UnsupportedAwaitAnalyzer extends AsyncTraverser {
@@ -68,12 +69,12 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
override def traverse(tree: Tree) {
def containsAwait = tree exists isAwait
tree match {
- case Try(_, _, _) if containsAwait =>
+ case Try(_, _, _) if containsAwait =>
reportUnsupportedAwait(tree, "try/catch")
super.traverse(tree)
- case Return(_) =>
+ case Return(_) =>
c.abort(tree.pos, "return is illegal within a async block")
- case _ =>
+ case _ =>
super.traverse(tree)
}
}
@@ -100,7 +101,33 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
private var valDefChunkId = Map[Symbol, (ValDef, Int)]()
- val valDefsToLift = mutable.Set[ValDef]()
+ val valDefsToLift : mutable.Set[ValDef] = collection.mutable.Set()
+ val nestedMethodsToLift: mutable.Set[DefDef] = collection.mutable.Set()
+
+ override def nestedMethod(defDef: DefDef) {
+ nestedMethodsToLift += defDef
+ defDef.rhs foreach {
+ case rt: RefTree =>
+ valDefChunkId.get(rt.symbol) match {
+ case Some((vd, defChunkId)) =>
+ valDefsToLift += vd // lift all vals referred to by nested methods.
+ case _ =>
+ }
+ case _ =>
+ }
+ }
+
+ override def function(function: Function) {
+ function foreach {
+ case rt: RefTree =>
+ valDefChunkId.get(rt.symbol) match {
+ case Some((vd, defChunkId)) =>
+ valDefsToLift += vd // lift all vals referred to by nested functions.
+ case _ =>
+ }
+ case _ =>
+ }
+ }
override def traverse(tree: Tree) = {
tree match {