aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2012-11-25 09:52:02 +0100
committerJason Zaugg <jzaugg@gmail.com>2012-11-26 16:08:48 +0100
commita5cab2959067bc7f9d3884064fbf7bf7ec0b7285 (patch)
tree5acbcb90963da8e8528f1df299615ac6129fc4a6
parentf039ac8d61cc5ac43c7ea3683f60fe0a5ad15479 (diff)
downloadscala-async-a5cab2959067bc7f9d3884064fbf7bf7ec0b7285.tar.gz
scala-async-a5cab2959067bc7f9d3884064fbf7bf7ec0b7285.tar.bz2
scala-async-a5cab2959067bc7f9d3884064fbf7bf7ec0b7285.zip
Extract vals for all names bound in a pattern.
These gives us something to lift to vars to be accessed from multiple states of the state machine. Fixes #35
-rw-r--r--.gitignore1
-rw-r--r--src/main/scala/scala/async/AnfTransform.scala20
-rw-r--r--src/main/scala/scala/async/AsyncAnalysis.scala29
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala16
-rw-r--r--src/main/scala/scala/async/TransformUtils.scala7
-rw-r--r--src/test/scala/scala/async/TreeInterrogation.scala52
-rw-r--r--src/test/scala/scala/async/run/match0/Match0.scala31
7 files changed, 107 insertions, 49 deletions
diff --git a/.gitignore b/.gitignore
index ad6e494..0c4d130 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,3 +2,4 @@ classes
target
.idea
.idea_modules
+*.icode
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala
index 64bde3e..5080ecf 100644
--- a/src/main/scala/scala/async/AnfTransform.scala
+++ b/src/main/scala/scala/async/AnfTransform.scala
@@ -54,6 +54,8 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
trans match {
case ValDef(mods, name, tpt, rhs) =>
treeCopy.ValDef(trans, mods, newName, tpt, rhs)
+ case Bind(name, body) =>
+ treeCopy.Bind(trans, newName, body)
case DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
treeCopy.DefDef(trans, mods, newName, tparams, vparamss, tpt, rhs)
case TypeDef(mods, name, tparams, rhs) =>
@@ -82,9 +84,11 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
def indentString = " " * indent
def apply[T](prefix: String, args: Any)(t: => T): T = {
indent += 1
- def oneLine(s: Any) = s.toString.replaceAll("""\n""", "\\\\n").take(127)
+ def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127)
try {
- AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})")
+ AsyncUtils.trace(s"${
+ indentString
+ }$prefix(${oneLine(args)})")
val result = t
AsyncUtils.trace(s"${indentString}= ${oneLine(result)}")
result
@@ -201,8 +205,18 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
val scrutStats :+ scrutExpr = inline.transformToList(scrut)
val caseDefs = cases map {
case CaseDef(pat, guard, body) =>
+ // extract local variables for all names bound in `pat`, and rewrite `body`
+ // to refer to these.
+ // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`.
val block = inline.transformToBlock(body)
- attachCopy(tree)(CaseDef(pat, guard, block))
+ val (valDefs, mappings) = (pat collect {
+ case b@Bind(name, _) =>
+ val newName = newTermName(utils.name.fresh(name.toTermName + utils.name.bindSuffix))
+ val vd = ValDef(NoMods, newName, TypeTree(), Ident(b.symbol))
+ (vd, (b.symbol, newName))
+ }).unzip
+ val Block(stats1, expr1) = utils.substituteNames(block, mappings.toMap).asInstanceOf[Block]
+ attachCopy(tree)(CaseDef(pat, guard, Block(valDefs ++ stats1, expr1)))
}
scrutStats :+ c.typeCheck(attachCopy(tree)(Match(scrutExpr, caseDefs)))
diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala
index 645d9f5..f0d4511 100644
--- a/src/main/scala/scala/async/AsyncAnalysis.scala
+++ b/src/main/scala/scala/async/AsyncAnalysis.scala
@@ -11,6 +11,7 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
import c.universe._
val utils = TransformUtils[c.type](c)
+
import utils._
/**
@@ -67,15 +68,15 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
override def traverse(tree: Tree) {
def containsAwait = tree exists isAwait
tree match {
- case Try(_, _, _) if containsAwait =>
+ case Try(_, _, _) if containsAwait =>
reportUnsupportedAwait(tree, "try/catch")
super.traverse(tree)
case If(cond, _, _) if containsAwait =>
reportUnsupportedAwait(cond, "condition")
super.traverse(tree)
- case Return(_) =>
+ case Return(_) =>
c.abort(tree.pos, "return is illegal within a async block")
- case _ =>
+ case _ =>
super.traverse(tree)
}
}
@@ -92,7 +93,7 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
}
badAwaits.nonEmpty
- }
+ }
}
private class AsyncDefinitionUseAnalyzer extends AsyncTraverser {
@@ -106,36 +107,37 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
override def traverse(tree: Tree) = {
tree match {
- case If(cond, thenp, elsep) if tree exists isAwait =>
+ case If(cond, thenp, elsep) if tree exists isAwait =>
traverseChunks(List(cond, thenp, elsep))
- case Match(selector, cases) if tree exists isAwait =>
+ case Match(selector, cases) if tree exists isAwait =>
traverseChunks(selector :: cases)
case LabelDef(name, params, rhs) if rhs exists isAwait =>
traverseChunks(rhs :: Nil)
- case Apply(fun, args) if isAwait(fun) =>
+ case Apply(fun, args) if isAwait(fun) =>
super.traverse(tree)
nextChunk()
- case vd: ValDef =>
+ case vd: ValDef =>
super.traverse(tree)
valDefChunkId += (vd.symbol ->(vd, chunkId))
- if (isAwait(vd.rhs)) valDefsToLift += vd
- case as: Assign =>
+ val isPatternBinder = vd.name.toString.contains(name.bindSuffix)
+ if (isAwait(vd.rhs) || isPatternBinder) valDefsToLift += vd
+ case as: Assign =>
if (isAwait(as.rhs)) {
- assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol)
+ assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol)
// TODO test the orElse case, try to remove the restriction.
val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}"))
valDefsToLift += vd
}
super.traverse(tree)
- case rt: RefTree =>
+ case rt: RefTree =>
valDefChunkId.get(rt.symbol) match {
case Some((vd, defChunkId)) if defChunkId != chunkId =>
valDefsToLift += vd
case _ =>
}
super.traverse(tree)
- case _ => super.traverse(tree)
+ case _ => super.traverse(tree)
}
}
@@ -145,4 +147,5 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
}
}
}
+
}
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index d9faad5..cc2cde5 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -146,7 +146,12 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], caseStates: List[Int]): AsyncState = {
// 1. build list of changed cases
val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match {
- case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(caseStates(num)), mkResumeApply))
+ case CaseDef(pat, guard, rhs) =>
+ val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal).map {
+ case ValDef(_, name, _, rhs) => Assign(Ident(name), rhs)
+ case t => sys.error(s"Unexpected tree. Expected ValDef, found: $t")
+ }
+ CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num)), mkResumeApply))
}
// 2. insert changed match tree at the end of the current state
this += Match(renameReset(scrutTree), newCases)
@@ -237,7 +242,9 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
stateBuilder.resultWithMatch(scrutinee, cases, caseStates)
for ((cas, num) <- cases.zipWithIndex) {
- val builder = nestedBlockBuilder(cas.body, caseStates(num), afterMatchState)
+ val (stats, expr) = statsAndExpr(cas.body)
+ val stats1 = stats.dropWhile(isSyntheticBindVal)
+ val builder = nestedBlockBuilder(Block(stats1, expr), caseStates(num), afterMatchState)
asyncStates ++= builder.asyncStates
}
@@ -346,6 +353,11 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
}
}
+ private def isSyntheticBindVal(tree: Tree) = tree match {
+ case vd@ValDef(_, lname, _, Ident(rname)) => lname.toString.contains(name.bindSuffix)
+ case _ => false
+ }
+
private final case class Awaitable(expr: Tree, resultName: TermName, resultType: Type)
private def resetDuplicate(tree: Tree) = c.resetAllAttrs(tree.duplicate)
diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala
index c5bbba1..c684ea7 100644
--- a/src/main/scala/scala/async/TransformUtils.scala
+++ b/src/main/scala/scala/async/TransformUtils.scala
@@ -27,9 +27,10 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
val tr = newTermName("tr")
val onCompleteHandler = suffixedName("onCompleteHandler")
- val matchRes = "matchres"
- val ifRes = "ifres"
- val await = "await"
+ val matchRes = "matchres"
+ val ifRes = "ifres"
+ val await = "await"
+ val bindSuffix = "$bind"
def fresh(name: TermName): TermName = newTermName(fresh(name.toString))
diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala
index dd239a3..f005b8a 100644
--- a/src/test/scala/scala/async/TreeInterrogation.scala
+++ b/src/test/scala/scala/async/TreeInterrogation.scala
@@ -38,33 +38,29 @@ class TreeInterrogation {
}
varDefs.map(_.decoded).toSet mustBe (Set("state$async", "onCompleteHandler$async", "await$1", "await$2"))
}
+}
- //@Test
- def sandbox() {
- sys.props("scala.async.debug") = true.toString
- sys.props("scala.async.trace") = false.toString
+object TreeInterrogation extends App {
+ sys.props("scala.async.debug") = true.toString
+ sys.props("scala.async.trace") = true.toString
- val cm = reflect.runtime.currentMirror
- val tb = mkToolbox("-cp target/scala-2.10/classes")
- val tree = tb.parse(
- """ import _root_.scala.async.AsyncId._
- | async {
- | var sum = 0
- | var i = 0
- | while (i < 5) {
- | var j = 0
- | while (j < 5) {
- | sum += await(i) * await(j)
- | j += 1
- | }
- | i += 1
- | }
- | sum
- | }
- | """.stripMargin)
- println(tree)
- val tree1 = tb.typeCheck(tree.duplicate)
- println(cm.universe.show(tree1))
- println(tb.eval(tree))
- }
-}
+ val cm = reflect.runtime.currentMirror
+ val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:all")
+ val tree = tb.parse(
+ """ import _root_.scala.async.AsyncId._
+ | async {
+ | val x = 1
+ | Option(x) match {
+ | case op @ Some(x) =>
+ | assert(op != null)
+ | println((op, x))
+ | x + await(x)
+ | case None => await(0)
+ | }
+ | }
+ | """.stripMargin)
+ println(tree)
+ val tree1 = tb.typeCheck(tree.duplicate)
+ println(cm.universe.show(tree1))
+ println(tb.eval(tree))
+} \ No newline at end of file
diff --git a/src/test/scala/scala/async/run/match0/Match0.scala b/src/test/scala/scala/async/run/match0/Match0.scala
index f550a69..5237629 100644
--- a/src/test/scala/scala/async/run/match0/Match0.scala
+++ b/src/test/scala/scala/async/run/match0/Match0.scala
@@ -69,4 +69,35 @@ class MatchSpec {
val res = Await.result(fut, 2 seconds)
res mustBe (5)
}
+
+ @Test def `support await in a match expression with binds`() {
+ val result = AsyncId.async {
+ val x = 1
+ Option(x) match {
+ case op @ Some(x) =>
+ assert(op == Some(1))
+ x + AsyncId.await(x)
+ case None => AsyncId.await(0)
+ }
+ }
+ result mustBe (2)
+ }
+
+ @Test def `support await referring to pattern matching vals`() {
+ import AsyncId.{async, await}
+ val result = async {
+ val x = 1
+ val opt = Some("")
+ await(0)
+ val o @ Some(y) = opt
+
+ {
+ val o @ Some(y) = Some(".")
+ }
+
+ await(0)
+ await((o, y.isEmpty))
+ }
+ result mustBe ((Some(""), true))
+ }
}