/* * Copyright (C) 2012 Typesafe Inc. */ package scala.async import scala.reflect.macros.Context import scala.collection.mutable private[async] final case class AsyncAnalysis[C <: Context](c: C, asyncBase: AsyncBase) { import c.universe._ val utils = TransformUtils[c.type](c) import utils._ /** * Analyze the contents of an `async` block in order to: * - Report unsupported `await` calls under nested templates, functions, by-name arguments. * * Must be called on the original tree, not on the ANF transformed tree. */ def reportUnsupportedAwaits(tree: Tree): Boolean = { val analyzer = new UnsupportedAwaitAnalyzer analyzer.traverse(tree) analyzer.hasUnsupportedAwaits } /** * Analyze the contents of an `async` block in order to: * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based * on whether or not they are accessed only from a single state. * * Must be called on the ANF transformed tree. */ def defTreesUsedInSubsequentStates(tree: Tree): List[DefTree] = { val analyzer = new AsyncDefinitionUseAnalyzer analyzer.traverse(tree) val liftable: List[DefTree] = (analyzer.valDefsToLift ++ analyzer.nestedMethodsToLift).toList.distinct liftable } private class UnsupportedAwaitAnalyzer extends AsyncTraverser { var hasUnsupportedAwaits = false override def nestedClass(classDef: ClassDef) { val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" if (!reportUnsupportedAwait(classDef, s"nested $kind")) { // do not allow local class definitions, because of SI-5467 (specific to case classes, though) if (classDef.symbol.asClass.isCaseClass) c.error(classDef.pos, s"Local case class ${classDef.name.decoded} illegal within `async` block") } } override def nestedModule(module: ModuleDef) { if (!reportUnsupportedAwait(module, "nested object")) { // local object definitions lead to spurious type errors (because of resetAllAttrs?) c.error(module.pos, s"Local object ${module.name.decoded} illegal within `async` block") } } override def nestedMethod(module: DefDef) { reportUnsupportedAwait(module, "nested method") } override def byNameArgument(arg: Tree) { reportUnsupportedAwait(arg, "by-name argument") } override def function(function: Function) { reportUnsupportedAwait(function, "nested function") } override def patMatFunction(tree: Match) { reportUnsupportedAwait(tree, "nested function") } override def traverse(tree: Tree) { def containsAwait = tree exists isAwait tree match { case Try(_, _, _) if containsAwait => reportUnsupportedAwait(tree, "try/catch") super.traverse(tree) case Return(_) => c.abort(tree.pos, "return is illegal within a async block") case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) => c.abort(tree.pos, "lazy vals are illegal within an async block") case _ => super.traverse(tree) } } /** * @return true, if the tree contained an unsupported await. */ private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = { val badAwaits: List[RefTree] = tree collect { case rt: RefTree if isAwait(rt) => rt } badAwaits foreach { tree => reportError(tree.pos, s"await must not be used under a $whyUnsupported.") } badAwaits.nonEmpty } private def reportError(pos: Position, msg: String) { hasUnsupportedAwaits = true if (!asyncBase.fallbackEnabled) c.error(pos, msg) } } private class AsyncDefinitionUseAnalyzer extends AsyncTraverser { private var chunkId = 0 private def nextChunk() = chunkId += 1 private var valDefChunkId = Map[Symbol, (ValDef, Int)]() val valDefsToLift : mutable.Set[ValDef] = collection.mutable.Set() val nestedMethodsToLift: mutable.Set[DefDef] = collection.mutable.Set() override def nestedMethod(defDef: DefDef) { nestedMethodsToLift += defDef markReferencedVals(defDef) } override def function(function: Function) { markReferencedVals(function) } override def patMatFunction(tree: Match) { markReferencedVals(tree) } private def markReferencedVals(tree: Tree) { tree foreach { case rt: RefTree => valDefChunkId.get(rt.symbol) match { case Some((vd, defChunkId)) => valDefsToLift += vd // lift all vals referred to by nested functions. case _ => } case _ => } } override def traverse(tree: Tree) = { tree match { case If(cond, thenp, elsep) if tree exists isAwait => traverseChunks(List(cond, thenp, elsep)) case Match(selector, cases) if tree exists isAwait => traverseChunks(selector :: cases) case LabelDef(name, params, rhs) if rhs exists isAwait => traverseChunks(rhs :: Nil) case Apply(fun, args) if isAwait(fun) => super.traverse(tree) nextChunk() case vd: ValDef => super.traverse(tree) valDefChunkId += (vd.symbol -> (vd -> chunkId)) val isPatternBinder = vd.name.toString.contains(name.bindSuffix) if (isAwait(vd.rhs) || isPatternBinder) valDefsToLift += vd case as: Assign => if (isAwait(as.rhs)) { assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol) // TODO test the orElse case, try to remove the restriction. val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}")) valDefsToLift += vd } super.traverse(tree) case rt: RefTree => valDefChunkId.get(rt.symbol) match { case Some((vd, defChunkId)) if defChunkId != chunkId => valDefsToLift += vd case _ => } super.traverse(tree) case _ => super.traverse(tree) } } private def traverseChunks(trees: List[Tree]) { trees.foreach { t => traverse(t); nextChunk() } } } }