diff options
author | Jason Zaugg <jzaugg@gmail.com> | 2012-11-25 09:52:02 +0100 |
---|---|---|
committer | Jason Zaugg <jzaugg@gmail.com> | 2012-11-26 16:08:48 +0100 |
commit | a5cab2959067bc7f9d3884064fbf7bf7ec0b7285 (patch) | |
tree | 5acbcb90963da8e8528f1df299615ac6129fc4a6 /src | |
parent | f039ac8d61cc5ac43c7ea3683f60fe0a5ad15479 (diff) | |
download | scala-async-a5cab2959067bc7f9d3884064fbf7bf7ec0b7285.tar.gz scala-async-a5cab2959067bc7f9d3884064fbf7bf7ec0b7285.tar.bz2 scala-async-a5cab2959067bc7f9d3884064fbf7bf7ec0b7285.zip |
Extract vals for all names bound in a pattern.
These gives us something to lift to vars to be accessed
from multiple states of the state machine.
Fixes #35
Diffstat (limited to 'src')
-rw-r--r-- | src/main/scala/scala/async/AnfTransform.scala | 20 | ||||
-rw-r--r-- | src/main/scala/scala/async/AsyncAnalysis.scala | 29 | ||||
-rw-r--r-- | src/main/scala/scala/async/ExprBuilder.scala | 16 | ||||
-rw-r--r-- | src/main/scala/scala/async/TransformUtils.scala | 7 | ||||
-rw-r--r-- | src/test/scala/scala/async/TreeInterrogation.scala | 52 | ||||
-rw-r--r-- | src/test/scala/scala/async/run/match0/Match0.scala | 31 |
6 files changed, 106 insertions, 49 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 64bde3e..5080ecf 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -54,6 +54,8 @@ private[async] final case class AnfTransform[C <: Context](c: C) { trans match { case ValDef(mods, name, tpt, rhs) => treeCopy.ValDef(trans, mods, newName, tpt, rhs) + case Bind(name, body) => + treeCopy.Bind(trans, newName, body) case DefDef(mods, name, tparams, vparamss, tpt, rhs) => treeCopy.DefDef(trans, mods, newName, tparams, vparamss, tpt, rhs) case TypeDef(mods, name, tparams, rhs) => @@ -82,9 +84,11 @@ private[async] final case class AnfTransform[C <: Context](c: C) { def indentString = " " * indent def apply[T](prefix: String, args: Any)(t: => T): T = { indent += 1 - def oneLine(s: Any) = s.toString.replaceAll("""\n""", "\\\\n").take(127) + def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127) try { - AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") + AsyncUtils.trace(s"${ + indentString + }$prefix(${oneLine(args)})") val result = t AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") result @@ -201,8 +205,18 @@ private[async] final case class AnfTransform[C <: Context](c: C) { val scrutStats :+ scrutExpr = inline.transformToList(scrut) val caseDefs = cases map { case CaseDef(pat, guard, body) => + // extract local variables for all names bound in `pat`, and rewrite `body` + // to refer to these. + // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`. val block = inline.transformToBlock(body) - attachCopy(tree)(CaseDef(pat, guard, block)) + val (valDefs, mappings) = (pat collect { + case b@Bind(name, _) => + val newName = newTermName(utils.name.fresh(name.toTermName + utils.name.bindSuffix)) + val vd = ValDef(NoMods, newName, TypeTree(), Ident(b.symbol)) + (vd, (b.symbol, newName)) + }).unzip + val Block(stats1, expr1) = utils.substituteNames(block, mappings.toMap).asInstanceOf[Block] + attachCopy(tree)(CaseDef(pat, guard, Block(valDefs ++ stats1, expr1))) } scrutStats :+ c.typeCheck(attachCopy(tree)(Match(scrutExpr, caseDefs))) diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index 645d9f5..f0d4511 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -11,6 +11,7 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { import c.universe._ val utils = TransformUtils[c.type](c) + import utils._ /** @@ -67,15 +68,15 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { override def traverse(tree: Tree) { def containsAwait = tree exists isAwait tree match { - case Try(_, _, _) if containsAwait => + case Try(_, _, _) if containsAwait => reportUnsupportedAwait(tree, "try/catch") super.traverse(tree) case If(cond, _, _) if containsAwait => reportUnsupportedAwait(cond, "condition") super.traverse(tree) - case Return(_) => + case Return(_) => c.abort(tree.pos, "return is illegal within a async block") - case _ => + case _ => super.traverse(tree) } } @@ -92,7 +93,7 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { c.error(tree.pos, s"await must not be used under a $whyUnsupported.") } badAwaits.nonEmpty - } + } } private class AsyncDefinitionUseAnalyzer extends AsyncTraverser { @@ -106,36 +107,37 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { override def traverse(tree: Tree) = { tree match { - case If(cond, thenp, elsep) if tree exists isAwait => + case If(cond, thenp, elsep) if tree exists isAwait => traverseChunks(List(cond, thenp, elsep)) - case Match(selector, cases) if tree exists isAwait => + case Match(selector, cases) if tree exists isAwait => traverseChunks(selector :: cases) case LabelDef(name, params, rhs) if rhs exists isAwait => traverseChunks(rhs :: Nil) - case Apply(fun, args) if isAwait(fun) => + case Apply(fun, args) if isAwait(fun) => super.traverse(tree) nextChunk() - case vd: ValDef => + case vd: ValDef => super.traverse(tree) valDefChunkId += (vd.symbol ->(vd, chunkId)) - if (isAwait(vd.rhs)) valDefsToLift += vd - case as: Assign => + val isPatternBinder = vd.name.toString.contains(name.bindSuffix) + if (isAwait(vd.rhs) || isPatternBinder) valDefsToLift += vd + case as: Assign => if (isAwait(as.rhs)) { - assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol) + assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol) // TODO test the orElse case, try to remove the restriction. val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}")) valDefsToLift += vd } super.traverse(tree) - case rt: RefTree => + case rt: RefTree => valDefChunkId.get(rt.symbol) match { case Some((vd, defChunkId)) if defChunkId != chunkId => valDefsToLift += vd case _ => } super.traverse(tree) - case _ => super.traverse(tree) + case _ => super.traverse(tree) } } @@ -145,4 +147,5 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { } } } + } diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index d9faad5..cc2cde5 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -146,7 +146,12 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], caseStates: List[Int]): AsyncState = { // 1. build list of changed cases val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { - case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(caseStates(num)), mkResumeApply)) + case CaseDef(pat, guard, rhs) => + val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal).map { + case ValDef(_, name, _, rhs) => Assign(Ident(name), rhs) + case t => sys.error(s"Unexpected tree. Expected ValDef, found: $t") + } + CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num)), mkResumeApply)) } // 2. insert changed match tree at the end of the current state this += Match(renameReset(scrutTree), newCases) @@ -237,7 +242,9 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: stateBuilder.resultWithMatch(scrutinee, cases, caseStates) for ((cas, num) <- cases.zipWithIndex) { - val builder = nestedBlockBuilder(cas.body, caseStates(num), afterMatchState) + val (stats, expr) = statsAndExpr(cas.body) + val stats1 = stats.dropWhile(isSyntheticBindVal) + val builder = nestedBlockBuilder(Block(stats1, expr), caseStates(num), afterMatchState) asyncStates ++= builder.asyncStates } @@ -346,6 +353,11 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: } } + private def isSyntheticBindVal(tree: Tree) = tree match { + case vd@ValDef(_, lname, _, Ident(rname)) => lname.toString.contains(name.bindSuffix) + case _ => false + } + private final case class Awaitable(expr: Tree, resultName: TermName, resultType: Type) private def resetDuplicate(tree: Tree) = c.resetAllAttrs(tree.duplicate) diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index c5bbba1..c684ea7 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -27,9 +27,10 @@ private[async] final case class TransformUtils[C <: Context](c: C) { val tr = newTermName("tr") val onCompleteHandler = suffixedName("onCompleteHandler") - val matchRes = "matchres" - val ifRes = "ifres" - val await = "await" + val matchRes = "matchres" + val ifRes = "ifres" + val await = "await" + val bindSuffix = "$bind" def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index dd239a3..f005b8a 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -38,33 +38,29 @@ class TreeInterrogation { } varDefs.map(_.decoded).toSet mustBe (Set("state$async", "onCompleteHandler$async", "await$1", "await$2")) } +} - //@Test - def sandbox() { - sys.props("scala.async.debug") = true.toString - sys.props("scala.async.trace") = false.toString +object TreeInterrogation extends App { + sys.props("scala.async.debug") = true.toString + sys.props("scala.async.trace") = true.toString - val cm = reflect.runtime.currentMirror - val tb = mkToolbox("-cp target/scala-2.10/classes") - val tree = tb.parse( - """ import _root_.scala.async.AsyncId._ - | async { - | var sum = 0 - | var i = 0 - | while (i < 5) { - | var j = 0 - | while (j < 5) { - | sum += await(i) * await(j) - | j += 1 - | } - | i += 1 - | } - | sum - | } - | """.stripMargin) - println(tree) - val tree1 = tb.typeCheck(tree.duplicate) - println(cm.universe.show(tree1)) - println(tb.eval(tree)) - } -} + val cm = reflect.runtime.currentMirror + val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:all") + val tree = tb.parse( + """ import _root_.scala.async.AsyncId._ + | async { + | val x = 1 + | Option(x) match { + | case op @ Some(x) => + | assert(op != null) + | println((op, x)) + | x + await(x) + | case None => await(0) + | } + | } + | """.stripMargin) + println(tree) + val tree1 = tb.typeCheck(tree.duplicate) + println(cm.universe.show(tree1)) + println(tb.eval(tree)) +}
\ No newline at end of file diff --git a/src/test/scala/scala/async/run/match0/Match0.scala b/src/test/scala/scala/async/run/match0/Match0.scala index f550a69..5237629 100644 --- a/src/test/scala/scala/async/run/match0/Match0.scala +++ b/src/test/scala/scala/async/run/match0/Match0.scala @@ -69,4 +69,35 @@ class MatchSpec { val res = Await.result(fut, 2 seconds) res mustBe (5) } + + @Test def `support await in a match expression with binds`() { + val result = AsyncId.async { + val x = 1 + Option(x) match { + case op @ Some(x) => + assert(op == Some(1)) + x + AsyncId.await(x) + case None => AsyncId.await(0) + } + } + result mustBe (2) + } + + @Test def `support await referring to pattern matching vals`() { + import AsyncId.{async, await} + val result = async { + val x = 1 + val opt = Some("") + await(0) + val o @ Some(y) = opt + + { + val o @ Some(y) = Some(".") + } + + await(0) + await((o, y.isEmpty)) + } + result mustBe ((Some(""), true)) + } } |