aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/ExprBuilder.scala
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2012-11-22 17:50:50 +0100
committerJason Zaugg <jzaugg@gmail.com>2012-11-22 17:50:50 +0100
commit087d1e4e138eccf4b2d420298affb4289632bf73 (patch)
treefd0fc1c034f4cbc2d92fa7958c6b03c59e23aa92 /src/main/scala/scala/async/ExprBuilder.scala
parent1c91fec998d09e31c2c52760452af1771a092182 (diff)
downloadscala-async-087d1e4e138eccf4b2d420298affb4289632bf73.tar.gz
scala-async-087d1e4e138eccf4b2d420298affb4289632bf73.tar.bz2
scala-async-087d1e4e138eccf4b2d420298affb4289632bf73.zip
Support match as an expression.
- corrects detection of await calls in the ANF transform. - Split AsyncAnalyzer into two parts. Unsupported await detection must happen prior to the async transform to prevent the ANF lifting out by-name arguments to vals and hence changing the semantics.
Diffstat (limited to 'src/main/scala/scala/async/ExprBuilder.scala')
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala184
1 files changed, 74 insertions, 110 deletions
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index 7a9c98d..735db76 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -22,14 +22,14 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
def suffixedName(prefix: String) = newTermName(suffix(prefix))
- val state = suffixedName("state")
- val result = suffixedName("result")
- val resume = suffixedName("resume")
+ val state = suffixedName("state")
+ val result = suffixedName("result")
+ val resume = suffixedName("resume")
val execContext = suffixedName("execContext")
// TODO do we need to freshen any of these?
- val x1 = newTermName("x$1")
- val tr = newTermName("tr")
+ val x1 = newTermName("x$1")
+ val tr = newTermName("tr")
val onCompleteHandler = suffixedName("onCompleteHandler")
def fresh(name: TermName) = newTermName(c.fresh("" + name + "$"))
@@ -60,7 +60,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) {
val body: c.Tree = stats match {
case stat :: Nil => stat
- case _ => Block(stats: _*)
+ case _ => Block(stats: _*)
}
val varDefs: List[(TermName, Type)] = Nil
@@ -78,7 +78,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
)
val updateState = mkStateTree(nextState)
Some(mkHandlerCase(state, List(tryGetTree, updateState, mkResumeApply)))
- case _ =>
+ case _ =>
None
}
}
@@ -106,7 +106,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
abstract class AsyncStateWithAwait(stats: List[c.Tree], state: Int, nextState: Int)
extends AsyncState(stats, state, nextState) {
- val awaitable: c.Tree
+ val awaitable : c.Tree
val resultName: TermName
val resultType: Type
@@ -154,7 +154,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
override def transform(tree: Tree) = tree match {
case Ident(_) if nameMap.keySet contains tree.symbol =>
Ident(nameMap(tree.symbol))
- case _ =>
+ case _ =>
super.transform(tree)
}
}
@@ -178,7 +178,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
}
else
new AsyncStateWithAwait(stats.toList, state, nextState) {
- val awaitable = self.awaitable
+ val awaitable = self.awaitable
val resultName = self.resultName
val resultType = self.resultType
override val varDefs = self.varDefs.toList
@@ -263,18 +263,18 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename)
// current state builder
- private var currState = startState
+ private var currState = startState
/* TODO Fall back to CPS plug-in if tree contains an `await` call. */
def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists {
case Apply(fun, _) if isAwait(fun) => true
- case _ => false
+ case _ => false
}) c.abort(tree.pos, "await unsupported in this position") //throw new FallbackToCpsException
def builderForBranch(tree: c.Tree, state: Int, nextState: Int): AsyncBlockBuilder = {
val (branchStats, branchExpr) = tree match {
case Block(s, e) => (s, e)
- case _ => (List(tree), c.literalUnit.tree)
+ case _ => (List(tree), c.literalUnit.tree)
}
new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, toRename)
}
@@ -326,7 +326,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
for ((cas, num) <- cases.zipWithIndex) {
val (casStats, casExpr) = cas match {
case CaseDef(_, _, Block(s, e)) => (s, e)
- case CaseDef(_, _, rhs) => (List(rhs), c.literalUnit.tree)
+ case CaseDef(_, _, rhs) => (List(rhs), c.literalUnit.tree)
}
val builder = new AsyncBlockBuilder(casStats, casExpr, caseStates(num), afterMatchState, toRename)
asyncStates ++= builder.asyncStates
@@ -362,147 +362,111 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
asyncStates.toList match {
case s :: Nil =>
List(caseForLastState)
- case _ =>
+ case _ =>
val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState()
initCases :+ caseForLastState
}
}
}
- private val Boolean_ShortCircuits: Set[Symbol] = {
- import definitions.BooleanClass
- def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName)
- val Boolean_&& = BooleanTermMember("&&")
- val Boolean_|| = BooleanTermMember("||")
- Set(Boolean_&&, Boolean_||)
+ /**
+ * Analyze the contents of an `async` block in order to:
+ * - Report unsupported `await` calls under nested templates, functions, by-name arguments.
+ * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based
+ * on whether or not they are accessed only from a single state.
+ */
+ def reportUnsupportedAwaits(tree: Tree) {
+ new UnsupportedAwaitAnalyzer().traverse(tree)
}
- def isByName(fun: Tree): (Int => Boolean) = {
- if (Boolean_ShortCircuits contains fun.symbol) i => true
- else fun.tpe match {
- case MethodType(params, _) =>
- val isByNameParams = params.map(_.asTerm.isByNameParam)
- (i: Int) => isByNameParams.applyOrElse(i, (_: Int) => false)
- case _ => Map()
+ private class UnsupportedAwaitAnalyzer extends super.AsyncTraverser {
+ override def nestedClass(classDef: ClassDef) {
+ val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class"
+ reportUnsupportedAwait(classDef, s"nested $kind")
}
- }
- private def isAwait(fun: Tree) = {
- fun.symbol == defn.Async_await
+ override def nestedModule(module: ModuleDef) {
+ reportUnsupportedAwait(module, "nested object")
+ }
+
+ override def byNameArgument(arg: Tree) {
+ reportUnsupportedAwait(arg, "by-name argument")
+ }
+
+ override def function(function: Function) {
+ reportUnsupportedAwait(function, "nested function")
+ }
+
+ private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) {
+ val badAwaits = tree collect {
+ case rt: RefTree if isAwait(rt) => rt
+ }
+ badAwaits foreach {
+ tree =>
+ c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
+ }
+ }
}
/**
* Analyze the contents of an `async` block in order to:
- * - Report unsupported `await` calls under nested templates, functions, by-name arguments.
- * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based
- * on whether or not they are accessed only from a single state.
+ * - Report unsupported `await` calls under nested templates, functions, by-name arguments.
+ * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based
+ * on whether or not they are accessed only from a single state.
*/
- private[async] class AsyncAnalyzer extends Traverser {
+ def valDefsUsedInSubsequentStates(tree: Tree): List[ValDef] = {
+ val analyzer = new AsyncDefinitionUseAnalyzer
+ analyzer.traverse(tree)
+ analyzer.valDefsToLift.toList
+ }
+
+ private class AsyncDefinitionUseAnalyzer extends super.AsyncTraverser {
private var chunkId = 0
+
private def nextChunk() = chunkId += 1
+
private var valDefChunkId = Map[Symbol, (ValDef, Int)]()
val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]()
override def traverse(tree: Tree) = {
tree match {
- case cd: ClassDef =>
- val kind = if (cd.symbol.asClass.isTrait) "trait" else "class"
- reportUnsupportedAwait(tree, s"nested $kind")
- case md: ModuleDef =>
- reportUnsupportedAwait(tree, "nested object")
- case _: Function =>
- reportUnsupportedAwait(tree, "nested anonymous function")
case If(cond, thenp, elsep) if tree exists isAwait =>
traverseChunks(List(cond, thenp, elsep))
case Match(selector, cases) if tree exists isAwait =>
traverseChunks(selector :: cases)
- case Apply(fun, args) if isAwait(fun) =>
- traverseTrees(args)
- traverse(fun)
+ case Apply(fun, args) if isAwait(fun) =>
+ super.traverse(tree)
nextChunk()
- case Apply(fun, args) =>
- val isInByName = isByName(fun)
- for ((arg, index) <- args.zipWithIndex) {
- if (!isInByName(index)) traverse(arg)
- else reportUnsupportedAwait(arg, "by-name argument")
- }
- traverse(fun)
- case vd: ValDef =>
+ case vd: ValDef =>
super.traverse(tree)
valDefChunkId += (vd.symbol ->(vd, chunkId))
if (isAwait(vd.rhs)) valDefsToLift += vd
- case as: Assign =>
+ case as: Assign =>
if (isAwait(as.rhs)) {
// TODO test the orElse case, try to remove the restriction.
- val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block."))
- valDefsToLift += vd
+ if (as.symbol != null) {
+ // synthetic added by the ANF transfor
+ val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block. " + as.symbol))
+ valDefsToLift += vd
+ }
}
super.traverse(tree)
- case rt: RefTree =>
+ case rt: RefTree =>
valDefChunkId.get(rt.symbol) match {
case Some((vd, defChunkId)) if defChunkId != chunkId =>
valDefsToLift += vd
- case _ =>
+ case _ =>
}
super.traverse(tree)
- case _ => super.traverse(tree)
+ case _ => super.traverse(tree)
}
}
private def traverseChunks(trees: List[Tree]) {
- trees.foreach {t => traverse(t); nextChunk()}
- }
-
- private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) {
- val badAwaits = tree collect {
- case rt: RefTree if isAwait(rt) => rt
+ trees.foreach {
+ t => traverse(t); nextChunk()
}
- badAwaits foreach {
- tree =>
- c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
- }
- }
- }
-
-
- /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */
- private def methodSym(apply: c.Expr[Any]): Symbol = {
- val tree2: Tree = c.typeCheck(apply.tree)
- tree2.collect {
- case s: SymTree if s.symbol.isMethod => s.symbol
- }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}"))
- }
-
- private[async] object defn {
- def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = {
- c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree)))
- }
-
- def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify(self.splice.contains(elem.splice))
-
- def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify {
- self.splice.apply(arg.splice)
- }
-
- def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify {
- self.splice == other.splice
- }
-
- def mkTry_get[A](self: Expr[util.Try[A]]) = reify {
- self.splice.get
- }
-
- val Try_get = methodSym(reify((null: scala.util.Try[Any]).get))
-
- val TryClass = c.mirror.staticClass("scala.util.Try")
- val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe))
- val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal")
-
- val Async_await = {
- val asyncMod = c.mirror.staticModule("scala.async.Async")
- val tpe = asyncMod.moduleClass.asType.toType
- tpe.member(c.universe.newTermName("await"))
}
}
}