From d6ce00d65ade8c31b61091d65fe21ad480c6b20c Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Thu, 22 Nov 2012 18:01:49 +0100 Subject: Refactor the analyzers to a seprarate file. --- src/main/scala/scala/async/AnfTransform.scala | 1 + src/main/scala/scala/async/Async.scala | 5 +- src/main/scala/scala/async/AsyncAnalysis.scala | 110 +++++++++++++++++++++++++ src/main/scala/scala/async/ExprBuilder.scala | 101 ----------------------- 4 files changed, 114 insertions(+), 103 deletions(-) create mode 100644 src/main/scala/scala/async/AsyncAnalysis.scala diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index d06fb54..0146210 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -1,3 +1,4 @@ + package scala.async import scala.reflect.macros.Context diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index 645f3f7..546445a 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -66,12 +66,13 @@ abstract class AsyncBase { import Flag._ val builder = new ExprBuilder[c.type, futureSystem.type](c, self.futureSystem) + val anaylzer = new AsyncAnalysis[c.type](c) import builder.defn._ import builder.name import builder.futureSystemOps - builder.reportUnsupportedAwaits(body.tree) + anaylzer.reportUnsupportedAwaits(body.tree) // Transform to A-normal form: // - no await calls in qualifiers or arguments, @@ -86,7 +87,7 @@ abstract class AsyncBase { // states of our generated state machine, e.g. a value assigned before // an `await` and read afterwards. val renameMap: Map[Symbol, TermName] = { - builder.valDefsUsedInSubsequentStates(anfTree).map { + anaylzer.valDefsUsedInSubsequentStates(anfTree).map { vd => (vd.symbol, builder.name.fresh(vd.name)) }.toMap diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala new file mode 100644 index 0000000..1b00620 --- /dev/null +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -0,0 +1,110 @@ +package scala.async + +import scala.reflect.macros.Context +import collection.mutable + +private[async] final class AsyncAnalysis[C <: Context](override val c: C) extends TransformUtils(c) { + import c.universe._ + + /** + * Analyze the contents of an `async` block in order to: + * - Report unsupported `await` calls under nested templates, functions, by-name arguments. + * + * Must be called on the original tree, not on the ANF transformed tree. + */ + def reportUnsupportedAwaits(tree: Tree) { + new UnsupportedAwaitAnalyzer().traverse(tree) + } + + /** + * Analyze the contents of an `async` block in order to: + * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based + * on whether or not they are accessed only from a single state. + * + * Must be called on the ANF transformed tree. + */ + def valDefsUsedInSubsequentStates(tree: Tree): List[ValDef] = { + val analyzer = new AsyncDefinitionUseAnalyzer + analyzer.traverse(tree) + analyzer.valDefsToLift.toList + } + + private class UnsupportedAwaitAnalyzer extends super.AsyncTraverser { + override def nestedClass(classDef: ClassDef) { + val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" + reportUnsupportedAwait(classDef, s"nested $kind") + } + + override def nestedModule(module: ModuleDef) { + reportUnsupportedAwait(module, "nested object") + } + + override def byNameArgument(arg: Tree) { + reportUnsupportedAwait(arg, "by-name argument") + } + + override def function(function: Function) { + reportUnsupportedAwait(function, "nested function") + } + + private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { + val badAwaits = tree collect { + case rt: RefTree if isAwait(rt) => rt + } + badAwaits foreach { + tree => + c.error(tree.pos, s"await must not be used under a $whyUnsupported.") + } + } + } + + private class AsyncDefinitionUseAnalyzer extends super.AsyncTraverser { + private var chunkId = 0 + + private def nextChunk() = chunkId += 1 + + private var valDefChunkId = Map[Symbol, (ValDef, Int)]() + + val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]() + + override def traverse(tree: Tree) = { + tree match { + case If(cond, thenp, elsep) if tree exists isAwait => + traverseChunks(List(cond, thenp, elsep)) + case Match(selector, cases) if tree exists isAwait => + traverseChunks(selector :: cases) + case Apply(fun, args) if isAwait(fun) => + super.traverse(tree) + nextChunk() + case vd: ValDef => + super.traverse(tree) + valDefChunkId += (vd.symbol ->(vd, chunkId)) + if (isAwait(vd.rhs)) valDefsToLift += vd + case as: Assign => + if (isAwait(as.rhs)) { + // TODO test the orElse case, try to remove the restriction. + if (as.symbol != null) { + // synthetic added by the ANF transfor + val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block. " + as.symbol)) + valDefsToLift += vd + } + } + super.traverse(tree) + case rt: RefTree => + valDefChunkId.get(rt.symbol) match { + case Some((vd, defChunkId)) if defChunkId != chunkId => + valDefsToLift += vd + case _ => + } + super.traverse(tree) + case _ => super.traverse(tree) + } + } + + private def traverseChunks(trees: List[Tree]) { + trees.foreach { + t => traverse(t); nextChunk() + } + } + } +} diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 735db76..573af16 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -368,105 +368,4 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val } } } - - /** - * Analyze the contents of an `async` block in order to: - * - Report unsupported `await` calls under nested templates, functions, by-name arguments. - * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based - * on whether or not they are accessed only from a single state. - */ - def reportUnsupportedAwaits(tree: Tree) { - new UnsupportedAwaitAnalyzer().traverse(tree) - } - - private class UnsupportedAwaitAnalyzer extends super.AsyncTraverser { - override def nestedClass(classDef: ClassDef) { - val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" - reportUnsupportedAwait(classDef, s"nested $kind") - } - - override def nestedModule(module: ModuleDef) { - reportUnsupportedAwait(module, "nested object") - } - - override def byNameArgument(arg: Tree) { - reportUnsupportedAwait(arg, "by-name argument") - } - - override def function(function: Function) { - reportUnsupportedAwait(function, "nested function") - } - - private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { - val badAwaits = tree collect { - case rt: RefTree if isAwait(rt) => rt - } - badAwaits foreach { - tree => - c.error(tree.pos, s"await must not be used under a $whyUnsupported.") - } - } - } - - /** - * Analyze the contents of an `async` block in order to: - * - Report unsupported `await` calls under nested templates, functions, by-name arguments. - * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based - * on whether or not they are accessed only from a single state. - */ - def valDefsUsedInSubsequentStates(tree: Tree): List[ValDef] = { - val analyzer = new AsyncDefinitionUseAnalyzer - analyzer.traverse(tree) - analyzer.valDefsToLift.toList - } - - private class AsyncDefinitionUseAnalyzer extends super.AsyncTraverser { - private var chunkId = 0 - - private def nextChunk() = chunkId += 1 - - private var valDefChunkId = Map[Symbol, (ValDef, Int)]() - - val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]() - - override def traverse(tree: Tree) = { - tree match { - case If(cond, thenp, elsep) if tree exists isAwait => - traverseChunks(List(cond, thenp, elsep)) - case Match(selector, cases) if tree exists isAwait => - traverseChunks(selector :: cases) - case Apply(fun, args) if isAwait(fun) => - super.traverse(tree) - nextChunk() - case vd: ValDef => - super.traverse(tree) - valDefChunkId += (vd.symbol ->(vd, chunkId)) - if (isAwait(vd.rhs)) valDefsToLift += vd - case as: Assign => - if (isAwait(as.rhs)) { - // TODO test the orElse case, try to remove the restriction. - if (as.symbol != null) { - // synthetic added by the ANF transfor - val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block. " + as.symbol)) - valDefsToLift += vd - } - } - super.traverse(tree) - case rt: RefTree => - valDefChunkId.get(rt.symbol) match { - case Some((vd, defChunkId)) if defChunkId != chunkId => - valDefsToLift += vd - case _ => - } - super.traverse(tree) - case _ => super.traverse(tree) - } - } - - private def traverseChunks(trees: List[Tree]) { - trees.foreach { - t => traverse(t); nextChunk() - } - } - } } -- cgit v1.2.3