diff options
author | Jason Zaugg <jzaugg@gmail.com> | 2012-11-22 17:50:50 +0100 |
---|---|---|
committer | Jason Zaugg <jzaugg@gmail.com> | 2012-11-22 17:50:50 +0100 |
commit | 087d1e4e138eccf4b2d420298affb4289632bf73 (patch) | |
tree | fd0fc1c034f4cbc2d92fa7958c6b03c59e23aa92 /src/main/scala/scala/async/ExprBuilder.scala | |
parent | 1c91fec998d09e31c2c52760452af1771a092182 (diff) | |
download | scala-async-087d1e4e138eccf4b2d420298affb4289632bf73.tar.gz scala-async-087d1e4e138eccf4b2d420298affb4289632bf73.tar.bz2 scala-async-087d1e4e138eccf4b2d420298affb4289632bf73.zip |
Support match as an expression.
- corrects detection of await calls in the ANF transform.
- Split AsyncAnalyzer into two parts. Unsupported await
detection must happen prior to the async transform to
prevent the ANF lifting out by-name arguments to
vals and hence changing the semantics.
Diffstat (limited to 'src/main/scala/scala/async/ExprBuilder.scala')
-rw-r--r-- | src/main/scala/scala/async/ExprBuilder.scala | 184 |
1 files changed, 74 insertions, 110 deletions
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 7a9c98d..735db76 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -22,14 +22,14 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val def suffixedName(prefix: String) = newTermName(suffix(prefix)) - val state = suffixedName("state") - val result = suffixedName("result") - val resume = suffixedName("resume") + val state = suffixedName("state") + val result = suffixedName("result") + val resume = suffixedName("resume") val execContext = suffixedName("execContext") // TODO do we need to freshen any of these? - val x1 = newTermName("x$1") - val tr = newTermName("tr") + val x1 = newTermName("x$1") + val tr = newTermName("tr") val onCompleteHandler = suffixedName("onCompleteHandler") def fresh(name: TermName) = newTermName(c.fresh("" + name + "$")) @@ -60,7 +60,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) { val body: c.Tree = stats match { case stat :: Nil => stat - case _ => Block(stats: _*) + case _ => Block(stats: _*) } val varDefs: List[(TermName, Type)] = Nil @@ -78,7 +78,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val ) val updateState = mkStateTree(nextState) Some(mkHandlerCase(state, List(tryGetTree, updateState, mkResumeApply))) - case _ => + case _ => None } } @@ -106,7 +106,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val abstract class AsyncStateWithAwait(stats: List[c.Tree], state: Int, nextState: Int) extends AsyncState(stats, state, nextState) { - val awaitable: c.Tree + val awaitable : c.Tree val resultName: TermName val resultType: Type @@ -154,7 +154,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val override def transform(tree: Tree) = tree match { case Ident(_) if nameMap.keySet contains tree.symbol => Ident(nameMap(tree.symbol)) - case _ => + case _ => super.transform(tree) } } @@ -178,7 +178,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val } else new AsyncStateWithAwait(stats.toList, state, nextState) { - val awaitable = self.awaitable + val awaitable = self.awaitable val resultName = self.resultName val resultType = self.resultType override val varDefs = self.varDefs.toList @@ -263,18 +263,18 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename) // current state builder - private var currState = startState + private var currState = startState /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists { case Apply(fun, _) if isAwait(fun) => true - case _ => false + case _ => false }) c.abort(tree.pos, "await unsupported in this position") //throw new FallbackToCpsException def builderForBranch(tree: c.Tree, state: Int, nextState: Int): AsyncBlockBuilder = { val (branchStats, branchExpr) = tree match { case Block(s, e) => (s, e) - case _ => (List(tree), c.literalUnit.tree) + case _ => (List(tree), c.literalUnit.tree) } new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, toRename) } @@ -326,7 +326,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val for ((cas, num) <- cases.zipWithIndex) { val (casStats, casExpr) = cas match { case CaseDef(_, _, Block(s, e)) => (s, e) - case CaseDef(_, _, rhs) => (List(rhs), c.literalUnit.tree) + case CaseDef(_, _, rhs) => (List(rhs), c.literalUnit.tree) } val builder = new AsyncBlockBuilder(casStats, casExpr, caseStates(num), afterMatchState, toRename) asyncStates ++= builder.asyncStates @@ -362,147 +362,111 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val asyncStates.toList match { case s :: Nil => List(caseForLastState) - case _ => + case _ => val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState() initCases :+ caseForLastState } } } - private val Boolean_ShortCircuits: Set[Symbol] = { - import definitions.BooleanClass - def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName) - val Boolean_&& = BooleanTermMember("&&") - val Boolean_|| = BooleanTermMember("||") - Set(Boolean_&&, Boolean_||) + /** + * Analyze the contents of an `async` block in order to: + * - Report unsupported `await` calls under nested templates, functions, by-name arguments. + * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based + * on whether or not they are accessed only from a single state. + */ + def reportUnsupportedAwaits(tree: Tree) { + new UnsupportedAwaitAnalyzer().traverse(tree) } - def isByName(fun: Tree): (Int => Boolean) = { - if (Boolean_ShortCircuits contains fun.symbol) i => true - else fun.tpe match { - case MethodType(params, _) => - val isByNameParams = params.map(_.asTerm.isByNameParam) - (i: Int) => isByNameParams.applyOrElse(i, (_: Int) => false) - case _ => Map() + private class UnsupportedAwaitAnalyzer extends super.AsyncTraverser { + override def nestedClass(classDef: ClassDef) { + val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" + reportUnsupportedAwait(classDef, s"nested $kind") } - } - private def isAwait(fun: Tree) = { - fun.symbol == defn.Async_await + override def nestedModule(module: ModuleDef) { + reportUnsupportedAwait(module, "nested object") + } + + override def byNameArgument(arg: Tree) { + reportUnsupportedAwait(arg, "by-name argument") + } + + override def function(function: Function) { + reportUnsupportedAwait(function, "nested function") + } + + private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { + val badAwaits = tree collect { + case rt: RefTree if isAwait(rt) => rt + } + badAwaits foreach { + tree => + c.error(tree.pos, s"await must not be used under a $whyUnsupported.") + } + } } /** * Analyze the contents of an `async` block in order to: - * - Report unsupported `await` calls under nested templates, functions, by-name arguments. - * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based - * on whether or not they are accessed only from a single state. + * - Report unsupported `await` calls under nested templates, functions, by-name arguments. + * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based + * on whether or not they are accessed only from a single state. */ - private[async] class AsyncAnalyzer extends Traverser { + def valDefsUsedInSubsequentStates(tree: Tree): List[ValDef] = { + val analyzer = new AsyncDefinitionUseAnalyzer + analyzer.traverse(tree) + analyzer.valDefsToLift.toList + } + + private class AsyncDefinitionUseAnalyzer extends super.AsyncTraverser { private var chunkId = 0 + private def nextChunk() = chunkId += 1 + private var valDefChunkId = Map[Symbol, (ValDef, Int)]() val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]() override def traverse(tree: Tree) = { tree match { - case cd: ClassDef => - val kind = if (cd.symbol.asClass.isTrait) "trait" else "class" - reportUnsupportedAwait(tree, s"nested $kind") - case md: ModuleDef => - reportUnsupportedAwait(tree, "nested object") - case _: Function => - reportUnsupportedAwait(tree, "nested anonymous function") case If(cond, thenp, elsep) if tree exists isAwait => traverseChunks(List(cond, thenp, elsep)) case Match(selector, cases) if tree exists isAwait => traverseChunks(selector :: cases) - case Apply(fun, args) if isAwait(fun) => - traverseTrees(args) - traverse(fun) + case Apply(fun, args) if isAwait(fun) => + super.traverse(tree) nextChunk() - case Apply(fun, args) => - val isInByName = isByName(fun) - for ((arg, index) <- args.zipWithIndex) { - if (!isInByName(index)) traverse(arg) - else reportUnsupportedAwait(arg, "by-name argument") - } - traverse(fun) - case vd: ValDef => + case vd: ValDef => super.traverse(tree) valDefChunkId += (vd.symbol ->(vd, chunkId)) if (isAwait(vd.rhs)) valDefsToLift += vd - case as: Assign => + case as: Assign => if (isAwait(as.rhs)) { // TODO test the orElse case, try to remove the restriction. - val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block.")) - valDefsToLift += vd + if (as.symbol != null) { + // synthetic added by the ANF transfor + val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block. " + as.symbol)) + valDefsToLift += vd + } } super.traverse(tree) - case rt: RefTree => + case rt: RefTree => valDefChunkId.get(rt.symbol) match { case Some((vd, defChunkId)) if defChunkId != chunkId => valDefsToLift += vd - case _ => + case _ => } super.traverse(tree) - case _ => super.traverse(tree) + case _ => super.traverse(tree) } } private def traverseChunks(trees: List[Tree]) { - trees.foreach {t => traverse(t); nextChunk()} - } - - private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { - val badAwaits = tree collect { - case rt: RefTree if isAwait(rt) => rt + trees.foreach { + t => traverse(t); nextChunk() } - badAwaits foreach { - tree => - c.error(tree.pos, s"await must not be used under a $whyUnsupported.") - } - } - } - - - /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ - private def methodSym(apply: c.Expr[Any]): Symbol = { - val tree2: Tree = c.typeCheck(apply.tree) - tree2.collect { - case s: SymTree if s.symbol.isMethod => s.symbol - }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}")) - } - - private[async] object defn { - def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { - c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) - } - - def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify(self.splice.contains(elem.splice)) - - def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify { - self.splice.apply(arg.splice) - } - - def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify { - self.splice == other.splice - } - - def mkTry_get[A](self: Expr[util.Try[A]]) = reify { - self.splice.get - } - - val Try_get = methodSym(reify((null: scala.util.Try[Any]).get)) - - val TryClass = c.mirror.staticClass("scala.util.Try") - val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe)) - val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal") - - val Async_await = { - val asyncMod = c.mirror.staticModule("scala.async.Async") - val tpe = asyncMod.moduleClass.asType.toType - tpe.member(c.universe.newTermName("await")) } } } |