diff options
author | Jason Zaugg <jzaugg@gmail.com> | 2012-11-22 18:01:49 +0100 |
---|---|---|
committer | Jason Zaugg <jzaugg@gmail.com> | 2012-11-22 18:01:49 +0100 |
commit | d6ce00d65ade8c31b61091d65fe21ad480c6b20c (patch) | |
tree | 423b0c6c3c764b35c50cdc973e7ec542318bb1ed /src/main/scala/scala/async/AsyncAnalysis.scala | |
parent | 93520f30d77af10c0b936da3f658ec644c7ecd4b (diff) | |
download | scala-async-d6ce00d65ade8c31b61091d65fe21ad480c6b20c.tar.gz scala-async-d6ce00d65ade8c31b61091d65fe21ad480c6b20c.tar.bz2 scala-async-d6ce00d65ade8c31b61091d65fe21ad480c6b20c.zip |
Refactor the analyzers to a seprarate file.
Diffstat (limited to 'src/main/scala/scala/async/AsyncAnalysis.scala')
-rw-r--r-- | src/main/scala/scala/async/AsyncAnalysis.scala | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala new file mode 100644 index 0000000..1b00620 --- /dev/null +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -0,0 +1,110 @@ +package scala.async + +import scala.reflect.macros.Context +import collection.mutable + +private[async] final class AsyncAnalysis[C <: Context](override val c: C) extends TransformUtils(c) { + import c.universe._ + + /** + * Analyze the contents of an `async` block in order to: + * - Report unsupported `await` calls under nested templates, functions, by-name arguments. + * + * Must be called on the original tree, not on the ANF transformed tree. + */ + def reportUnsupportedAwaits(tree: Tree) { + new UnsupportedAwaitAnalyzer().traverse(tree) + } + + /** + * Analyze the contents of an `async` block in order to: + * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based + * on whether or not they are accessed only from a single state. + * + * Must be called on the ANF transformed tree. + */ + def valDefsUsedInSubsequentStates(tree: Tree): List[ValDef] = { + val analyzer = new AsyncDefinitionUseAnalyzer + analyzer.traverse(tree) + analyzer.valDefsToLift.toList + } + + private class UnsupportedAwaitAnalyzer extends super.AsyncTraverser { + override def nestedClass(classDef: ClassDef) { + val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" + reportUnsupportedAwait(classDef, s"nested $kind") + } + + override def nestedModule(module: ModuleDef) { + reportUnsupportedAwait(module, "nested object") + } + + override def byNameArgument(arg: Tree) { + reportUnsupportedAwait(arg, "by-name argument") + } + + override def function(function: Function) { + reportUnsupportedAwait(function, "nested function") + } + + private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { + val badAwaits = tree collect { + case rt: RefTree if isAwait(rt) => rt + } + badAwaits foreach { + tree => + c.error(tree.pos, s"await must not be used under a $whyUnsupported.") + } + } + } + + private class AsyncDefinitionUseAnalyzer extends super.AsyncTraverser { + private var chunkId = 0 + + private def nextChunk() = chunkId += 1 + + private var valDefChunkId = Map[Symbol, (ValDef, Int)]() + + val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]() + + override def traverse(tree: Tree) = { + tree match { + case If(cond, thenp, elsep) if tree exists isAwait => + traverseChunks(List(cond, thenp, elsep)) + case Match(selector, cases) if tree exists isAwait => + traverseChunks(selector :: cases) + case Apply(fun, args) if isAwait(fun) => + super.traverse(tree) + nextChunk() + case vd: ValDef => + super.traverse(tree) + valDefChunkId += (vd.symbol ->(vd, chunkId)) + if (isAwait(vd.rhs)) valDefsToLift += vd + case as: Assign => + if (isAwait(as.rhs)) { + // TODO test the orElse case, try to remove the restriction. + if (as.symbol != null) { + // synthetic added by the ANF transfor + val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block. " + as.symbol)) + valDefsToLift += vd + } + } + super.traverse(tree) + case rt: RefTree => + valDefChunkId.get(rt.symbol) match { + case Some((vd, defChunkId)) if defChunkId != chunkId => + valDefsToLift += vd + case _ => + } + super.traverse(tree) + case _ => super.traverse(tree) + } + } + + private def traverseChunks(trees: List[Tree]) { + trees.foreach { + t => traverse(t); nextChunk() + } + } + } +} |