aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPhilipp Haller <hallerp@gmail.com>2012-11-23 02:42:16 -0800
committerPhilipp Haller <hallerp@gmail.com>2012-11-23 02:42:16 -0800
commit6bc377c3757a2b5cb4b479c3fea910bcce4d7b8f (patch)
tree78187ae3d1c143f187b2071545ae8abb81269b4f
parent1c91fec998d09e31c2c52760452af1771a092182 (diff)
parent8ff80d52047360f3236fcbc8e7849d388c4aa744 (diff)
downloadscala-async-6bc377c3757a2b5cb4b479c3fea910bcce4d7b8f.tar.gz
scala-async-6bc377c3757a2b5cb4b479c3fea910bcce4d7b8f.tar.bz2
scala-async-6bc377c3757a2b5cb4b479c3fea910bcce4d7b8f.zip
Merge pull request #27 from phaller/ticket/26-match
Ticket/26 match
-rw-r--r--src/main/scala/scala/async/AnfTransform.scala150
-rw-r--r--src/main/scala/scala/async/Async.scala7
-rw-r--r--src/main/scala/scala/async/AsyncAnalysis.scala110
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala167
-rw-r--r--src/main/scala/scala/async/TransformUtils.scala98
-rw-r--r--src/test/scala/scala/async/TreeInterrogation.scala4
-rw-r--r--src/test/scala/scala/async/neg/NakedAwait.scala3
-rw-r--r--src/test/scala/scala/async/run/anf/AnfTransformSpec.scala74
-rw-r--r--src/test/scala/scala/async/run/block1/block1.scala1
9 files changed, 369 insertions, 245 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala
index e1d7cd5..24f37e7 100644
--- a/src/main/scala/scala/async/AnfTransform.scala
+++ b/src/main/scala/scala/async/AnfTransform.scala
@@ -1,41 +1,49 @@
+
package scala.async
import scala.reflect.macros.Context
class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) {
+
import c.universe._
- import AsyncUtils._
object inline {
def transformToList(tree: Tree): List[Tree] = {
val stats :+ expr = anf.transformToList(tree)
expr match {
-
- case Apply(fun, args) if fun.toString.startsWith("scala.async.Async.await") =>
- val liftedName = c.fresh("await$")
- stats :+ ValDef(NoMods, liftedName, TypeTree(), expr) :+ Ident(liftedName)
+ case Apply(fun, args) if isAwait(fun) =>
+ val valDef = defineVal("await", expr)
+ stats :+ valDef :+ Ident(valDef.name)
case If(cond, thenp, elsep) =>
// if type of if-else is Unit don't introduce assignment,
// but add Unit value to bring it into form expected by async transform
if (expr.tpe =:= definitions.UnitTpe) {
stats :+ expr :+ Literal(Constant(()))
+ } else {
+ val varDef = defineVar("ifres", expr.tpe)
+ def branchWithAssign(orig: Tree) = orig match {
+ case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.name), thenExpr))
+ case _ => Assign(Ident(varDef.name), orig)
+ }
+ val ifWithAssign = If(cond, branchWithAssign(thenp), branchWithAssign(elsep))
+ stats :+ varDef :+ ifWithAssign :+ Ident(varDef.name)
+ }
+
+ case Match(scrut, cases) =>
+ // if type of match is Unit don't introduce assignment,
+ // but add Unit value to bring it into form expected by async transform
+ if (expr.tpe =:= definitions.UnitTpe) {
+ stats :+ expr :+ Literal(Constant(()))
}
else {
- val liftedName = c.fresh("ifres$")
- val varDef =
- ValDef(Modifiers(Flag.MUTABLE), liftedName, TypeTree(expr.tpe), defaultValue(expr.tpe))
- val thenWithAssign = thenp match {
- case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(liftedName), thenExpr))
- case _ => Assign(Ident(liftedName), thenp)
- }
- val elseWithAssign = elsep match {
- case Block(elseStats, elseExpr) => Block(elseStats, Assign(Ident(liftedName), elseExpr))
- case _ => Assign(Ident(liftedName), elsep)
+ val varDef = defineVar("matchres", expr.tpe)
+ val casesWithAssign = cases map {
+ case CaseDef(pat, guard, Block(caseStats, caseExpr)) => CaseDef(pat, guard, Block(caseStats, Assign(Ident(varDef.name), caseExpr)))
+ case CaseDef(pat, guard, body) => CaseDef(pat, guard, Assign(Ident(varDef.name), body))
}
- val ifWithAssign =
- If(cond, thenWithAssign, elseWithAssign)
- stats :+ varDef :+ ifWithAssign :+ Ident(liftedName)
+ val matchWithAssign = Match(scrut, casesWithAssign)
+ stats :+ varDef :+ matchWithAssign :+ Ident(varDef.name)
}
case _ =>
stats :+ expr
@@ -44,58 +52,76 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) {
def transformToList(trees: List[Tree]): List[Tree] = trees match {
case fst :: rest => transformToList(fst) ++ transformToList(rest)
- case Nil => Nil
+ case Nil => Nil
}
- }
-
- object anf {
- def transformToList(tree: Tree): List[Tree] = tree match {
- case Select(qual, sel) =>
- val stats :+ expr = inline.transformToList(qual)
- stats :+ Select(expr, sel)
- case Apply(fun, args) =>
- val funStats :+ simpleFun = inline.transformToList(fun)
- val argLists = args map inline.transformToList
- val allArgStats = argLists flatMap (_.init)
- val simpleArgs = argLists map (_.last)
- funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs)
-
- case Block(stats, expr) =>
- inline.transformToList(stats) ++ inline.transformToList(expr)
-
- case ValDef(mods, name, tpt, rhs) =>
- val stats :+ expr = inline.transformToList(rhs)
- stats :+ ValDef(mods, name, tpt, expr).setSymbol(tree.symbol)
-
- case Assign(name, rhs) =>
- val stats :+ expr = inline.transformToList(rhs)
- stats :+ Assign(name, expr)
-
- case If(cond, thenp, elsep) =>
- val stats :+ expr = inline.transformToList(cond)
- val thenStats :+ thenExpr = inline.transformToList(thenp)
- val elseStats :+ elseExpr = inline.transformToList(elsep)
- stats :+
- c.typeCheck(If(expr, Block(thenStats, thenExpr), Block(elseStats, elseExpr)))
+ def transformToBlock(tree: Tree): Block = transformToList(tree) match {
+ case stats :+ expr => Block(stats, expr)
+ }
- //TODO
- case Literal(_) | Ident(_) | This(_) | Match(_, _) | New(_) | Function(_, _) => List(tree)
+ def liftedName(prefix: String) = c.fresh(prefix + "$")
- case TypeApply(fun, targs) =>
- val funStats :+ simpleFun = inline.transformToList(fun)
- funStats :+ TypeApply(simpleFun, targs)
+ private def defineVar(prefix: String, tp: Type): ValDef =
+ ValDef(Modifiers(Flag.MUTABLE), liftedName(prefix), TypeTree(tp), defaultValue(tp))
- //TODO
- case DefDef(mods, name, tparams, vparamss, tpt, rhs) => List(tree)
+ private def defineVal(prefix: String, lhs: Tree): ValDef =
+ ValDef(NoMods, liftedName(prefix), TypeTree(), lhs)
+ }
- case ClassDef(mods, name, tparams, impl) => List(tree)
+ object anf {
+ def transformToList(tree: Tree): List[Tree] = {
+ def containsAwait = tree exists isAwait
+ tree match {
+ case Select(qual, sel) if containsAwait =>
+ val stats :+ expr = inline.transformToList(qual)
+ stats :+ Select(expr, sel).setSymbol(tree.symbol)
+
+ case Apply(fun, args) if containsAwait =>
+ // we an assume that no await call appears in a by-name argument position,
+ // this has already been checked.
+
+ val funStats :+ simpleFun = inline.transformToList(fun)
+ val argLists = args map inline.transformToList
+ val allArgStats = argLists flatMap (_.init)
+ val simpleArgs = argLists map (_.last)
+ funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs).setSymbol(tree.symbol)
+
+ case Block(stats, expr) => // TODO figure out why adding a guard `if containsAwait` breaks LocalClasses0Spec.
+ inline.transformToList(stats :+ expr)
+
+ case ValDef(mods, name, tpt, rhs) if containsAwait =>
+ if (rhs exists isAwait) {
+ val stats :+ expr = inline.transformToList(rhs)
+ stats :+ ValDef(mods, name, tpt, expr).setSymbol(tree.symbol)
+ } else List(tree)
+ case Assign(lhs, rhs) if containsAwait =>
+ val stats :+ expr = inline.transformToList(rhs)
+ stats :+ Assign(lhs, expr)
+
+ case If(cond, thenp, elsep) if containsAwait =>
+ val stats :+ expr = inline.transformToList(cond)
+ val thenBlock = inline.transformToBlock(thenp)
+ val elseBlock = inline.transformToBlock(elsep)
+ stats :+
+ c.typeCheck(If(expr, thenBlock, elseBlock))
+
+ case Match(scrut, cases) if containsAwait =>
+ val scrutStats :+ scrutExpr = inline.transformToList(scrut)
+ val caseDefs = cases map {
+ case CaseDef(pat, guard, body) =>
+ val block = inline.transformToBlock(body)
+ CaseDef(pat, guard, block)
+ }
+ scrutStats :+ c.typeCheck(Match(scrutExpr, caseDefs))
- case ModuleDef(mods, name, impl) => List(tree)
+ case TypeApply(fun, targs) if containsAwait =>
+ val funStats :+ simpleFun = inline.transformToList(fun)
+ funStats :+ TypeApply(simpleFun, targs).setSymbol(tree.symbol)
- case _ =>
- c.error(tree.pos, "Internal error while compiling `async` block")
- ???
+ case _ =>
+ List(tree)
+ }
}
}
+
}
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala
index bd766f2..546445a 100644
--- a/src/main/scala/scala/async/Async.scala
+++ b/src/main/scala/scala/async/Async.scala
@@ -66,11 +66,14 @@ abstract class AsyncBase {
import Flag._
val builder = new ExprBuilder[c.type, futureSystem.type](c, self.futureSystem)
+ val anaylzer = new AsyncAnalysis[c.type](c)
import builder.defn._
import builder.name
import builder.futureSystemOps
+ anaylzer.reportUnsupportedAwaits(body.tree)
+
// Transform to A-normal form:
// - no await calls in qualifiers or arguments,
// - if/match only used in statement position.
@@ -84,9 +87,7 @@ abstract class AsyncBase {
// states of our generated state machine, e.g. a value assigned before
// an `await` and read afterwards.
val renameMap: Map[Symbol, TermName] = {
- val analyzer = new builder.AsyncAnalyzer
- analyzer.traverse(anfTree)
- analyzer.valDefsToLift.map {
+ anaylzer.valDefsUsedInSubsequentStates(anfTree).map {
vd =>
(vd.symbol, builder.name.fresh(vd.name))
}.toMap
diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala
new file mode 100644
index 0000000..1b00620
--- /dev/null
+++ b/src/main/scala/scala/async/AsyncAnalysis.scala
@@ -0,0 +1,110 @@
+package scala.async
+
+import scala.reflect.macros.Context
+import collection.mutable
+
+private[async] final class AsyncAnalysis[C <: Context](override val c: C) extends TransformUtils(c) {
+ import c.universe._
+
+ /**
+ * Analyze the contents of an `async` block in order to:
+ * - Report unsupported `await` calls under nested templates, functions, by-name arguments.
+ *
+ * Must be called on the original tree, not on the ANF transformed tree.
+ */
+ def reportUnsupportedAwaits(tree: Tree) {
+ new UnsupportedAwaitAnalyzer().traverse(tree)
+ }
+
+ /**
+ * Analyze the contents of an `async` block in order to:
+ * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based
+ * on whether or not they are accessed only from a single state.
+ *
+ * Must be called on the ANF transformed tree.
+ */
+ def valDefsUsedInSubsequentStates(tree: Tree): List[ValDef] = {
+ val analyzer = new AsyncDefinitionUseAnalyzer
+ analyzer.traverse(tree)
+ analyzer.valDefsToLift.toList
+ }
+
+ private class UnsupportedAwaitAnalyzer extends super.AsyncTraverser {
+ override def nestedClass(classDef: ClassDef) {
+ val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class"
+ reportUnsupportedAwait(classDef, s"nested $kind")
+ }
+
+ override def nestedModule(module: ModuleDef) {
+ reportUnsupportedAwait(module, "nested object")
+ }
+
+ override def byNameArgument(arg: Tree) {
+ reportUnsupportedAwait(arg, "by-name argument")
+ }
+
+ override def function(function: Function) {
+ reportUnsupportedAwait(function, "nested function")
+ }
+
+ private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) {
+ val badAwaits = tree collect {
+ case rt: RefTree if isAwait(rt) => rt
+ }
+ badAwaits foreach {
+ tree =>
+ c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
+ }
+ }
+ }
+
+ private class AsyncDefinitionUseAnalyzer extends super.AsyncTraverser {
+ private var chunkId = 0
+
+ private def nextChunk() = chunkId += 1
+
+ private var valDefChunkId = Map[Symbol, (ValDef, Int)]()
+
+ val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]()
+
+ override def traverse(tree: Tree) = {
+ tree match {
+ case If(cond, thenp, elsep) if tree exists isAwait =>
+ traverseChunks(List(cond, thenp, elsep))
+ case Match(selector, cases) if tree exists isAwait =>
+ traverseChunks(selector :: cases)
+ case Apply(fun, args) if isAwait(fun) =>
+ super.traverse(tree)
+ nextChunk()
+ case vd: ValDef =>
+ super.traverse(tree)
+ valDefChunkId += (vd.symbol ->(vd, chunkId))
+ if (isAwait(vd.rhs)) valDefsToLift += vd
+ case as: Assign =>
+ if (isAwait(as.rhs)) {
+ // TODO test the orElse case, try to remove the restriction.
+ if (as.symbol != null) {
+ // synthetic added by the ANF transfor
+ val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block. " + as.symbol))
+ valDefsToLift += vd
+ }
+ }
+ super.traverse(tree)
+ case rt: RefTree =>
+ valDefChunkId.get(rt.symbol) match {
+ case Some((vd, defChunkId)) if defChunkId != chunkId =>
+ valDefsToLift += vd
+ case _ =>
+ }
+ super.traverse(tree)
+ case _ => super.traverse(tree)
+ }
+ }
+
+ private def traverseChunks(trees: List[Tree]) {
+ trees.foreach {
+ t => traverse(t); nextChunk()
+ }
+ }
+ }
+}
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index 7a9c98d..573af16 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -22,14 +22,14 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
def suffixedName(prefix: String) = newTermName(suffix(prefix))
- val state = suffixedName("state")
- val result = suffixedName("result")
- val resume = suffixedName("resume")
+ val state = suffixedName("state")
+ val result = suffixedName("result")
+ val resume = suffixedName("resume")
val execContext = suffixedName("execContext")
// TODO do we need to freshen any of these?
- val x1 = newTermName("x$1")
- val tr = newTermName("tr")
+ val x1 = newTermName("x$1")
+ val tr = newTermName("tr")
val onCompleteHandler = suffixedName("onCompleteHandler")
def fresh(name: TermName) = newTermName(c.fresh("" + name + "$"))
@@ -60,7 +60,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) {
val body: c.Tree = stats match {
case stat :: Nil => stat
- case _ => Block(stats: _*)
+ case _ => Block(stats: _*)
}
val varDefs: List[(TermName, Type)] = Nil
@@ -78,7 +78,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
)
val updateState = mkStateTree(nextState)
Some(mkHandlerCase(state, List(tryGetTree, updateState, mkResumeApply)))
- case _ =>
+ case _ =>
None
}
}
@@ -106,7 +106,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
abstract class AsyncStateWithAwait(stats: List[c.Tree], state: Int, nextState: Int)
extends AsyncState(stats, state, nextState) {
- val awaitable: c.Tree
+ val awaitable : c.Tree
val resultName: TermName
val resultType: Type
@@ -154,7 +154,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
override def transform(tree: Tree) = tree match {
case Ident(_) if nameMap.keySet contains tree.symbol =>
Ident(nameMap(tree.symbol))
- case _ =>
+ case _ =>
super.transform(tree)
}
}
@@ -178,7 +178,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
}
else
new AsyncStateWithAwait(stats.toList, state, nextState) {
- val awaitable = self.awaitable
+ val awaitable = self.awaitable
val resultName = self.resultName
val resultType = self.resultType
override val varDefs = self.varDefs.toList
@@ -263,18 +263,18 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename)
// current state builder
- private var currState = startState
+ private var currState = startState
/* TODO Fall back to CPS plug-in if tree contains an `await` call. */
def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists {
case Apply(fun, _) if isAwait(fun) => true
- case _ => false
+ case _ => false
}) c.abort(tree.pos, "await unsupported in this position") //throw new FallbackToCpsException
def builderForBranch(tree: c.Tree, state: Int, nextState: Int): AsyncBlockBuilder = {
val (branchStats, branchExpr) = tree match {
case Block(s, e) => (s, e)
- case _ => (List(tree), c.literalUnit.tree)
+ case _ => (List(tree), c.literalUnit.tree)
}
new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, toRename)
}
@@ -326,7 +326,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
for ((cas, num) <- cases.zipWithIndex) {
val (casStats, casExpr) = cas match {
case CaseDef(_, _, Block(s, e)) => (s, e)
- case CaseDef(_, _, rhs) => (List(rhs), c.literalUnit.tree)
+ case CaseDef(_, _, rhs) => (List(rhs), c.literalUnit.tree)
}
val builder = new AsyncBlockBuilder(casStats, casExpr, caseStates(num), afterMatchState, toRename)
asyncStates ++= builder.asyncStates
@@ -362,147 +362,10 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
asyncStates.toList match {
case s :: Nil =>
List(caseForLastState)
- case _ =>
+ case _ =>
val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState()
initCases :+ caseForLastState
}
}
}
-
- private val Boolean_ShortCircuits: Set[Symbol] = {
- import definitions.BooleanClass
- def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName)
- val Boolean_&& = BooleanTermMember("&&")
- val Boolean_|| = BooleanTermMember("||")
- Set(Boolean_&&, Boolean_||)
- }
-
- def isByName(fun: Tree): (Int => Boolean) = {
- if (Boolean_ShortCircuits contains fun.symbol) i => true
- else fun.tpe match {
- case MethodType(params, _) =>
- val isByNameParams = params.map(_.asTerm.isByNameParam)
- (i: Int) => isByNameParams.applyOrElse(i, (_: Int) => false)
- case _ => Map()
- }
- }
-
- private def isAwait(fun: Tree) = {
- fun.symbol == defn.Async_await
- }
-
- /**
- * Analyze the contents of an `async` block in order to:
- * - Report unsupported `await` calls under nested templates, functions, by-name arguments.
- * - Find which local `ValDef`-s need to be lifted to fields of the state machine, based
- * on whether or not they are accessed only from a single state.
- */
- private[async] class AsyncAnalyzer extends Traverser {
- private var chunkId = 0
- private def nextChunk() = chunkId += 1
- private var valDefChunkId = Map[Symbol, (ValDef, Int)]()
-
- val valDefsToLift: mutable.Set[ValDef] = collection.mutable.Set[ValDef]()
-
- override def traverse(tree: Tree) = {
- tree match {
- case cd: ClassDef =>
- val kind = if (cd.symbol.asClass.isTrait) "trait" else "class"
- reportUnsupportedAwait(tree, s"nested $kind")
- case md: ModuleDef =>
- reportUnsupportedAwait(tree, "nested object")
- case _: Function =>
- reportUnsupportedAwait(tree, "nested anonymous function")
- case If(cond, thenp, elsep) if tree exists isAwait =>
- traverseChunks(List(cond, thenp, elsep))
- case Match(selector, cases) if tree exists isAwait =>
- traverseChunks(selector :: cases)
- case Apply(fun, args) if isAwait(fun) =>
- traverseTrees(args)
- traverse(fun)
- nextChunk()
- case Apply(fun, args) =>
- val isInByName = isByName(fun)
- for ((arg, index) <- args.zipWithIndex) {
- if (!isInByName(index)) traverse(arg)
- else reportUnsupportedAwait(arg, "by-name argument")
- }
- traverse(fun)
- case vd: ValDef =>
- super.traverse(tree)
- valDefChunkId += (vd.symbol ->(vd, chunkId))
- if (isAwait(vd.rhs)) valDefsToLift += vd
- case as: Assign =>
- if (isAwait(as.rhs)) {
- // TODO test the orElse case, try to remove the restriction.
- val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block."))
- valDefsToLift += vd
- }
- super.traverse(tree)
- case rt: RefTree =>
- valDefChunkId.get(rt.symbol) match {
- case Some((vd, defChunkId)) if defChunkId != chunkId =>
- valDefsToLift += vd
- case _ =>
- }
- super.traverse(tree)
- case _ => super.traverse(tree)
- }
- }
-
- private def traverseChunks(trees: List[Tree]) {
- trees.foreach {t => traverse(t); nextChunk()}
- }
-
- private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) {
- val badAwaits = tree collect {
- case rt: RefTree if isAwait(rt) => rt
- }
- badAwaits foreach {
- tree =>
- c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
- }
- }
- }
-
-
- /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */
- private def methodSym(apply: c.Expr[Any]): Symbol = {
- val tree2: Tree = c.typeCheck(apply.tree)
- tree2.collect {
- case s: SymTree if s.symbol.isMethod => s.symbol
- }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}"))
- }
-
- private[async] object defn {
- def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = {
- c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree)))
- }
-
- def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify(self.splice.contains(elem.splice))
-
- def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify {
- self.splice.apply(arg.splice)
- }
-
- def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify {
- self.splice == other.splice
- }
-
- def mkTry_get[A](self: Expr[util.Try[A]]) = reify {
- self.splice.get
- }
-
- val Try_get = methodSym(reify((null: scala.util.Try[Any]).get))
-
- val TryClass = c.mirror.staticClass("scala.util.Try")
- val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe))
- val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal")
-
- val Async_await = {
- val asyncMod = c.mirror.staticModule("scala.async.Async")
- val tpe = asyncMod.moduleClass.asType.toType
- tpe.member(c.universe.newTermName("await"))
- }
- }
}
diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala
index d36c277..103c8d2 100644
--- a/src/main/scala/scala/async/TransformUtils.scala
+++ b/src/main/scala/scala/async/TransformUtils.scala
@@ -9,6 +9,7 @@ import scala.reflect.macros.Context
* Utilities used in both `ExprBuilder` and `AnfTransform`.
*/
class TransformUtils[C <: Context](val c: C) {
+
import c.universe._
protected def defaultValue(tpe: Type): Literal = {
@@ -19,4 +20,101 @@ class TransformUtils[C <: Context](val c: C) {
else null
Literal(Constant(defaultValue))
}
+
+ protected def isAwait(fun: Tree) =
+ fun.symbol == defn.Async_await
+
+ /** Descends into the regions of the tree that are subject to the
+ * translation to a state machine by `async`. When a nested template,
+ * function, or by-name argument is encountered, the descend stops,
+ * and `nestedClass` etc are invoked.
+ */
+ trait AsyncTraverser extends Traverser {
+ def nestedClass(classDef: ClassDef) {
+ }
+
+ def nestedModule(module: ModuleDef) {
+ }
+
+ def byNameArgument(arg: Tree) {
+ }
+
+ def function(function: Function) {
+ }
+
+ override def traverse(tree: Tree) {
+ tree match {
+ case cd: ClassDef => nestedClass(cd)
+ case md: ModuleDef => nestedModule(md)
+ case fun: Function => function(fun)
+ case Apply(fun, args) =>
+ val isInByName = isByName(fun)
+ for ((arg, index) <- args.zipWithIndex) {
+ if (!isInByName(index)) traverse(arg)
+ else byNameArgument(arg)
+ }
+ traverse(fun)
+ case _ => super.traverse(tree)
+ }
+ }
+ }
+
+ private lazy val Boolean_ShortCircuits: Set[Symbol] = {
+ import definitions.BooleanClass
+ def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName)
+ val Boolean_&& = BooleanTermMember("&&")
+ val Boolean_|| = BooleanTermMember("||")
+ Set(Boolean_&&, Boolean_||)
+ }
+
+ protected def isByName(fun: Tree): (Int => Boolean) = {
+ if (Boolean_ShortCircuits contains fun.symbol) i => true
+ else fun.tpe match {
+ case MethodType(params, _) =>
+ val isByNameParams = params.map(_.asTerm.isByNameParam)
+ (i: Int) => isByNameParams.applyOrElse(i, (_: Int) => false)
+ case _ => Map()
+ }
+ }
+
+ private[async] object defn {
+ def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = {
+ c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree)))
+ }
+
+ def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify(self.splice.contains(elem.splice))
+
+ def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify {
+ self.splice.apply(arg.splice)
+ }
+
+ def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify {
+ self.splice == other.splice
+ }
+
+ def mkTry_get[A](self: Expr[util.Try[A]]) = reify {
+ self.splice.get
+ }
+
+ val Try_get = methodSym(reify((null: scala.util.Try[Any]).get))
+
+ val TryClass = c.mirror.staticClass("scala.util.Try")
+ val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe))
+ val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal")
+
+ val Async_await = {
+ val asyncMod = c.mirror.staticClass("scala.async.AsyncBase")
+ val tpe = asyncMod.asType.toType
+ tpe.member(c.universe.newTermName("await")).ensuring(_ != NoSymbol)
+ }
+ }
+
+
+ /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */
+ private def methodSym(apply: c.Expr[Any]): Symbol = {
+ val tree2: Tree = c.typeCheck(apply.tree)
+ tree2.collect {
+ case s: SymTree if s.symbol.isMethod => s.symbol
+ }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}"))
+ }
}
diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala
index 1ed9be2..02f4b43 100644
--- a/src/test/scala/scala/async/TreeInterrogation.scala
+++ b/src/test/scala/scala/async/TreeInterrogation.scala
@@ -21,7 +21,7 @@ class TreeInterrogation {
| }""".stripMargin)
val tree1 = tb.typeCheck(tree)
- // println(cm.universe.showRaw(tree1))
+ //println(cm.universe.show(tree1))
import tb.mirror.universe._
val functions = tree1.collect {
@@ -32,6 +32,6 @@ class TreeInterrogation {
val varDefs = tree1.collect {
case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name
}
- varDefs.map(_.decoded).toSet mustBe(Set("state$async", "onCompleteHandler$async", "x$1", "z$1"))
+ varDefs.map(_.decoded).toSet mustBe(Set("state$async", "onCompleteHandler$async", "await$1$1", "await$2$1"))
}
}
diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala
index 66bc947..8b85977 100644
--- a/src/test/scala/scala/async/neg/NakedAwait.scala
+++ b/src/test/scala/scala/async/neg/NakedAwait.scala
@@ -17,7 +17,6 @@ class NakedAwait {
}
}
-
@Test
def `await not allowed in by-name argument`() {
expectError("await must not be used under a by-name argument.") {
@@ -81,7 +80,7 @@ class NakedAwait {
@Test
def nestedFunction() {
- expectError("await must not be used under a nested anonymous function.") {
+ expectError("await must not be used under a nested function.") {
"""
| import _root_.scala.async.AsyncId._
| async { () => { await(false) } }
diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
index 0abb937..1d6e09a 100644
--- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
+++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
@@ -112,35 +112,63 @@ class AnfTransformSpec {
State.result mustBe (14)
}
- @Test
- def `inlining block produces duplicate definition`() {
- import scala.async.AsyncId
-
- AsyncId.async {
- val f = 12
- val x = AsyncId.await(f)
+ // TODO JZ
+// @Test
+// def `inlining block produces duplicate definition`() {
+// import scala.async.AsyncId
+//
+// AsyncId.async {
+// val f = 12
+// val x = AsyncId.await(f)
+//
+// {
+// val x = 42
+// println(x)
+// }
+//
+// x
+// }
+// }
+// @Test
+// def `inlining block in tail position produces duplicate definition`() {
+// import scala.async.AsyncId
+//
+// AsyncId.async {
+// val f = 12
+// val x = AsyncId.await(f)
+//
+// {
+// val x = 42 // TODO should we rename the symbols when we collapse them into the same scope?
+// x
+// }
+// } mustBe (42)
+// }
- {
- val x = 42
- println(x)
+ @Test
+ def `match as expression 1`() {
+ import ExecutionContext.Implicits.global
+ val result = AsyncId.async {
+ val x = "" match {
+ case _ => AsyncId.await(1) + 1
}
-
x
}
+ result mustBe (2)
}
- @Test
- def `inlining block in tail position produces duplicate definition`() {
- import scala.async.AsyncId
- AsyncId.async {
- val f = 12
- val x = AsyncId.await(f)
-
- {
- val x = 42 // TODO should we rename the symbols when we collapse them into the same scope?
- x
+ @Test
+ def `match as expression 2`() {
+ import ExecutionContext.Implicits.global
+ val result = AsyncId.async {
+ val x = "" match {
+ case "" if false => AsyncId.await(1) + 1
+ case _ => 2 + AsyncId.await(1)
}
- } mustBe (42)
-
+ val y = x
+ "" match {
+ case _ => AsyncId.await(y) + 100
+ }
+ }
+ result mustBe (103)
}
}
diff --git a/src/test/scala/scala/async/run/block1/block1.scala b/src/test/scala/scala/async/run/block1/block1.scala
index a449805..0853498 100644
--- a/src/test/scala/scala/async/run/block1/block1.scala
+++ b/src/test/scala/scala/async/run/block1/block1.scala
@@ -27,7 +27,6 @@ class Test1Class {
val f1 = m1(y)
val f2 = m1(y + 2)
val x1 = await(f1)
- println("between two awaits")
val x2 = await(f2)
x1 + x2
}