aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/AsyncAnalysis.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/scala/async/AsyncAnalysis.scala')
-rw-r--r--src/main/scala/scala/async/AsyncAnalysis.scala70
1 files changed, 49 insertions, 21 deletions
diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala
index 4f5bf8d..8bb5bcd 100644
--- a/src/main/scala/scala/async/AsyncAnalysis.scala
+++ b/src/main/scala/scala/async/AsyncAnalysis.scala
@@ -5,12 +5,13 @@
package scala.async
import scala.reflect.macros.Context
-import collection.mutable
+import scala.collection.mutable
-private[async] final case class AsyncAnalysis[C <: Context](val c: C) {
+private[async] final case class AsyncAnalysis[C <: Context](c: C) {
import c.universe._
val utils = TransformUtils[c.type](c)
+
import utils._
/**
@@ -30,10 +31,11 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) {
*
* Must be called on the ANF transformed tree.
*/
- def valDefsUsedInSubsequentStates(tree: Tree): List[ValDef] = {
+ def defTreesUsedInSubsequentStates(tree: Tree): List[DefTree] = {
val analyzer = new AsyncDefinitionUseAnalyzer
analyzer.traverse(tree)
- analyzer.valDefsToLift.toList
+ val liftable: List[DefTree] = (analyzer.valDefsToLift ++ analyzer.nestedMethodsToLift).toList.distinct
+ liftable
}
private class UnsupportedAwaitAnalyzer extends AsyncTraverser {
@@ -41,7 +43,8 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) {
val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class"
if (!reportUnsupportedAwait(classDef, s"nested $kind")) {
// do not allow local class definitions, because of SI-5467 (specific to case classes, though)
- c.error(classDef.pos, s"Local class ${classDef.name.decoded} illegal within `async` block")
+ if (classDef.symbol.asClass.isCaseClass)
+ c.error(classDef.pos, s"Local case class ${classDef.name.decoded} illegal within `async` block")
}
}
@@ -70,12 +73,9 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) {
case Try(_, _, _) if containsAwait =>
reportUnsupportedAwait(tree, "try/catch")
super.traverse(tree)
- case If(cond, _, _) if containsAwait =>
- reportUnsupportedAwait(cond, "condition")
- super.traverse(tree)
- case Return(_) =>
+ case Return(_) =>
c.abort(tree.pos, "return is illegal within a async block")
- case _ =>
+ case _ =>
super.traverse(tree)
}
}
@@ -92,7 +92,7 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) {
c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
}
badAwaits.nonEmpty
- }
+ }
}
private class AsyncDefinitionUseAnalyzer extends AsyncTraverser {
@@ -102,40 +102,67 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) {
private var valDefChunkId = Map[Symbol, (ValDef, Int)]()
- val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]()
+ val valDefsToLift : mutable.Set[ValDef] = collection.mutable.Set()
+ val nestedMethodsToLift: mutable.Set[DefDef] = collection.mutable.Set()
+
+ override def nestedMethod(defDef: DefDef) {
+ nestedMethodsToLift += defDef
+ defDef.rhs foreach {
+ case rt: RefTree =>
+ valDefChunkId.get(rt.symbol) match {
+ case Some((vd, defChunkId)) =>
+ valDefsToLift += vd // lift all vals referred to by nested methods.
+ case _ =>
+ }
+ case _ =>
+ }
+ }
+
+ override def function(function: Function) {
+ function foreach {
+ case rt: RefTree =>
+ valDefChunkId.get(rt.symbol) match {
+ case Some((vd, defChunkId)) =>
+ valDefsToLift += vd // lift all vals referred to by nested functions.
+ case _ =>
+ }
+ case _ =>
+ }
+ }
override def traverse(tree: Tree) = {
tree match {
- case If(cond, thenp, elsep) if tree exists isAwait =>
+ case If(cond, thenp, elsep) if tree exists isAwait =>
traverseChunks(List(cond, thenp, elsep))
- case Match(selector, cases) if tree exists isAwait =>
+ case Match(selector, cases) if tree exists isAwait =>
traverseChunks(selector :: cases)
case LabelDef(name, params, rhs) if rhs exists isAwait =>
traverseChunks(rhs :: Nil)
- case Apply(fun, args) if isAwait(fun) =>
+ case Apply(fun, args) if isAwait(fun) =>
super.traverse(tree)
nextChunk()
- case vd: ValDef =>
+ case vd: ValDef =>
super.traverse(tree)
valDefChunkId += (vd.symbol ->(vd, chunkId))
- if (isAwait(vd.rhs)) valDefsToLift += vd
- case as: Assign =>
+ val isPatternBinder = vd.name.toString.contains(name.bindSuffix)
+ if (isAwait(vd.rhs) || isPatternBinder) valDefsToLift += vd
+ case as: Assign =>
if (isAwait(as.rhs)) {
- assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol)
+ assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol)
// TODO test the orElse case, try to remove the restriction.
val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}"))
valDefsToLift += vd
}
super.traverse(tree)
- case rt: RefTree =>
+ case rt: RefTree =>
valDefChunkId.get(rt.symbol) match {
case Some((vd, defChunkId)) if defChunkId != chunkId =>
valDefsToLift += vd
case _ =>
}
super.traverse(tree)
- case _ => super.traverse(tree)
+ case _ => super.traverse(tree)
}
}
@@ -145,4 +172,5 @@ private[async] final case class AsyncAnalysis[C <: Context](val c: C) {
}
}
}
+
}