diff options
Diffstat (limited to 'src/main/scala/scala/async/AsyncAnalysis.scala')
-rw-r--r-- | src/main/scala/scala/async/AsyncAnalysis.scala | 70 |
1 files changed, 49 insertions, 21 deletions
diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index 4f5bf8d..8bb5bcd 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -5,12 +5,13 @@ package scala.async import scala.reflect.macros.Context -import collection.mutable +import scala.collection.mutable -private[async] final case class AsyncAnalysis[C <: Context](val c: C) { +private[async] final case class AsyncAnalysis[C <: Context](c: C) { import c.universe._ val utils = TransformUtils[c.type](c) + import utils._ /** @@ -30,10 +31,11 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) { * * Must be called on the ANF transformed tree. */ - def valDefsUsedInSubsequentStates(tree: Tree): List[ValDef] = { + def defTreesUsedInSubsequentStates(tree: Tree): List[DefTree] = { val analyzer = new AsyncDefinitionUseAnalyzer analyzer.traverse(tree) - analyzer.valDefsToLift.toList + val liftable: List[DefTree] = (analyzer.valDefsToLift ++ analyzer.nestedMethodsToLift).toList.distinct + liftable } private class UnsupportedAwaitAnalyzer extends AsyncTraverser { @@ -41,7 +43,8 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) { val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" if (!reportUnsupportedAwait(classDef, s"nested $kind")) { // do not allow local class definitions, because of SI-5467 (specific to case classes, though) - c.error(classDef.pos, s"Local class ${classDef.name.decoded} illegal within `async` block") + if (classDef.symbol.asClass.isCaseClass) + c.error(classDef.pos, s"Local case class ${classDef.name.decoded} illegal within `async` block") } } @@ -70,12 +73,9 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) { case Try(_, _, _) if containsAwait => reportUnsupportedAwait(tree, "try/catch") super.traverse(tree) - case If(cond, _, _) if containsAwait => - reportUnsupportedAwait(cond, "condition") - super.traverse(tree) - case Return(_) => + case Return(_) => c.abort(tree.pos, "return is illegal within a async block") - case _ => + case _ => super.traverse(tree) } } @@ -92,7 +92,7 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) { c.error(tree.pos, s"await must not be used under a $whyUnsupported.") } badAwaits.nonEmpty - } + } } private class AsyncDefinitionUseAnalyzer extends AsyncTraverser { @@ -102,40 +102,67 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) { private var valDefChunkId = Map[Symbol, (ValDef, Int)]() - val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]() + val valDefsToLift : mutable.Set[ValDef] = collection.mutable.Set() + val nestedMethodsToLift: mutable.Set[DefDef] = collection.mutable.Set() + + override def nestedMethod(defDef: DefDef) { + nestedMethodsToLift += defDef + defDef.rhs foreach { + case rt: RefTree => + valDefChunkId.get(rt.symbol) match { + case Some((vd, defChunkId)) => + valDefsToLift += vd // lift all vals referred to by nested methods. + case _ => + } + case _ => + } + } + + override def function(function: Function) { + function foreach { + case rt: RefTree => + valDefChunkId.get(rt.symbol) match { + case Some((vd, defChunkId)) => + valDefsToLift += vd // lift all vals referred to by nested functions. + case _ => + } + case _ => + } + } override def traverse(tree: Tree) = { tree match { - case If(cond, thenp, elsep) if tree exists isAwait => + case If(cond, thenp, elsep) if tree exists isAwait => traverseChunks(List(cond, thenp, elsep)) - case Match(selector, cases) if tree exists isAwait => + case Match(selector, cases) if tree exists isAwait => traverseChunks(selector :: cases) case LabelDef(name, params, rhs) if rhs exists isAwait => traverseChunks(rhs :: Nil) - case Apply(fun, args) if isAwait(fun) => + case Apply(fun, args) if isAwait(fun) => super.traverse(tree) nextChunk() - case vd: ValDef => + case vd: ValDef => super.traverse(tree) valDefChunkId += (vd.symbol ->(vd, chunkId)) - if (isAwait(vd.rhs)) valDefsToLift += vd - case as: Assign => + val isPatternBinder = vd.name.toString.contains(name.bindSuffix) + if (isAwait(vd.rhs) || isPatternBinder) valDefsToLift += vd + case as: Assign => if (isAwait(as.rhs)) { - assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol) + assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol) // TODO test the orElse case, try to remove the restriction. val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}")) valDefsToLift += vd } super.traverse(tree) - case rt: RefTree => + case rt: RefTree => valDefChunkId.get(rt.symbol) match { case Some((vd, defChunkId)) if defChunkId != chunkId => valDefsToLift += vd case _ => } super.traverse(tree) - case _ => super.traverse(tree) + case _ => super.traverse(tree) } } @@ -145,4 +172,5 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) { } } } + } |