diff options
author | Jason Zaugg <jzaugg@gmail.com> | 2012-11-25 09:52:02 +0100 |
---|---|---|
committer | Jason Zaugg <jzaugg@gmail.com> | 2012-11-26 16:08:48 +0100 |
commit | a5cab2959067bc7f9d3884064fbf7bf7ec0b7285 (patch) | |
tree | 5acbcb90963da8e8528f1df299615ac6129fc4a6 /src/main | |
parent | f039ac8d61cc5ac43c7ea3683f60fe0a5ad15479 (diff) | |
download | scala-async-a5cab2959067bc7f9d3884064fbf7bf7ec0b7285.tar.gz scala-async-a5cab2959067bc7f9d3884064fbf7bf7ec0b7285.tar.bz2 scala-async-a5cab2959067bc7f9d3884064fbf7bf7ec0b7285.zip |
Extract vals for all names bound in a pattern.
These gives us something to lift to vars to be accessed
from multiple states of the state machine.
Fixes #35
Diffstat (limited to 'src/main')
-rw-r--r-- | src/main/scala/scala/async/AnfTransform.scala | 20 | ||||
-rw-r--r-- | src/main/scala/scala/async/AsyncAnalysis.scala | 29 | ||||
-rw-r--r-- | src/main/scala/scala/async/ExprBuilder.scala | 16 | ||||
-rw-r--r-- | src/main/scala/scala/async/TransformUtils.scala | 7 |
4 files changed, 51 insertions, 21 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 64bde3e..5080ecf 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -54,6 +54,8 @@ private[async] final case class AnfTransform[C <: Context](c: C) { trans match { case ValDef(mods, name, tpt, rhs) => treeCopy.ValDef(trans, mods, newName, tpt, rhs) + case Bind(name, body) => + treeCopy.Bind(trans, newName, body) case DefDef(mods, name, tparams, vparamss, tpt, rhs) => treeCopy.DefDef(trans, mods, newName, tparams, vparamss, tpt, rhs) case TypeDef(mods, name, tparams, rhs) => @@ -82,9 +84,11 @@ private[async] final case class AnfTransform[C <: Context](c: C) { def indentString = " " * indent def apply[T](prefix: String, args: Any)(t: => T): T = { indent += 1 - def oneLine(s: Any) = s.toString.replaceAll("""\n""", "\\\\n").take(127) + def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127) try { - AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") + AsyncUtils.trace(s"${ + indentString + }$prefix(${oneLine(args)})") val result = t AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") result @@ -201,8 +205,18 @@ private[async] final case class AnfTransform[C <: Context](c: C) { val scrutStats :+ scrutExpr = inline.transformToList(scrut) val caseDefs = cases map { case CaseDef(pat, guard, body) => + // extract local variables for all names bound in `pat`, and rewrite `body` + // to refer to these. + // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`. val block = inline.transformToBlock(body) - attachCopy(tree)(CaseDef(pat, guard, block)) + val (valDefs, mappings) = (pat collect { + case b@Bind(name, _) => + val newName = newTermName(utils.name.fresh(name.toTermName + utils.name.bindSuffix)) + val vd = ValDef(NoMods, newName, TypeTree(), Ident(b.symbol)) + (vd, (b.symbol, newName)) + }).unzip + val Block(stats1, expr1) = utils.substituteNames(block, mappings.toMap).asInstanceOf[Block] + attachCopy(tree)(CaseDef(pat, guard, Block(valDefs ++ stats1, expr1))) } scrutStats :+ c.typeCheck(attachCopy(tree)(Match(scrutExpr, caseDefs))) diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index 645d9f5..f0d4511 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -11,6 +11,7 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { import c.universe._ val utils = TransformUtils[c.type](c) + import utils._ /** @@ -67,15 +68,15 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { override def traverse(tree: Tree) { def containsAwait = tree exists isAwait tree match { - case Try(_, _, _) if containsAwait => + case Try(_, _, _) if containsAwait => reportUnsupportedAwait(tree, "try/catch") super.traverse(tree) case If(cond, _, _) if containsAwait => reportUnsupportedAwait(cond, "condition") super.traverse(tree) - case Return(_) => + case Return(_) => c.abort(tree.pos, "return is illegal within a async block") - case _ => + case _ => super.traverse(tree) } } @@ -92,7 +93,7 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { c.error(tree.pos, s"await must not be used under a $whyUnsupported.") } badAwaits.nonEmpty - } + } } private class AsyncDefinitionUseAnalyzer extends AsyncTraverser { @@ -106,36 +107,37 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { override def traverse(tree: Tree) = { tree match { - case If(cond, thenp, elsep) if tree exists isAwait => + case If(cond, thenp, elsep) if tree exists isAwait => traverseChunks(List(cond, thenp, elsep)) - case Match(selector, cases) if tree exists isAwait => + case Match(selector, cases) if tree exists isAwait => traverseChunks(selector :: cases) case LabelDef(name, params, rhs) if rhs exists isAwait => traverseChunks(rhs :: Nil) - case Apply(fun, args) if isAwait(fun) => + case Apply(fun, args) if isAwait(fun) => super.traverse(tree) nextChunk() - case vd: ValDef => + case vd: ValDef => super.traverse(tree) valDefChunkId += (vd.symbol ->(vd, chunkId)) - if (isAwait(vd.rhs)) valDefsToLift += vd - case as: Assign => + val isPatternBinder = vd.name.toString.contains(name.bindSuffix) + if (isAwait(vd.rhs) || isPatternBinder) valDefsToLift += vd + case as: Assign => if (isAwait(as.rhs)) { - assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol) + assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol) // TODO test the orElse case, try to remove the restriction. val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}")) valDefsToLift += vd } super.traverse(tree) - case rt: RefTree => + case rt: RefTree => valDefChunkId.get(rt.symbol) match { case Some((vd, defChunkId)) if defChunkId != chunkId => valDefsToLift += vd case _ => } super.traverse(tree) - case _ => super.traverse(tree) + case _ => super.traverse(tree) } } @@ -145,4 +147,5 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { } } } + } diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index d9faad5..cc2cde5 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -146,7 +146,12 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], caseStates: List[Int]): AsyncState = { // 1. build list of changed cases val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { - case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(caseStates(num)), mkResumeApply)) + case CaseDef(pat, guard, rhs) => + val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal).map { + case ValDef(_, name, _, rhs) => Assign(Ident(name), rhs) + case t => sys.error(s"Unexpected tree. Expected ValDef, found: $t") + } + CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num)), mkResumeApply)) } // 2. insert changed match tree at the end of the current state this += Match(renameReset(scrutTree), newCases) @@ -237,7 +242,9 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: stateBuilder.resultWithMatch(scrutinee, cases, caseStates) for ((cas, num) <- cases.zipWithIndex) { - val builder = nestedBlockBuilder(cas.body, caseStates(num), afterMatchState) + val (stats, expr) = statsAndExpr(cas.body) + val stats1 = stats.dropWhile(isSyntheticBindVal) + val builder = nestedBlockBuilder(Block(stats1, expr), caseStates(num), afterMatchState) asyncStates ++= builder.asyncStates } @@ -346,6 +353,11 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: } } + private def isSyntheticBindVal(tree: Tree) = tree match { + case vd@ValDef(_, lname, _, Ident(rname)) => lname.toString.contains(name.bindSuffix) + case _ => false + } + private final case class Awaitable(expr: Tree, resultName: TermName, resultType: Type) private def resetDuplicate(tree: Tree) = c.resetAllAttrs(tree.duplicate) diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index c5bbba1..c684ea7 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -27,9 +27,10 @@ private[async] final case class TransformUtils[C <: Context](c: C) { val tr = newTermName("tr") val onCompleteHandler = suffixedName("onCompleteHandler") - val matchRes = "matchres" - val ifRes = "ifres" - val await = "await" + val matchRes = "matchres" + val ifRes = "ifres" + val await = "await" + val bindSuffix = "$bind" def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) |