aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/scala/async/AnfTransform.scala20
-rw-r--r--src/main/scala/scala/async/AsyncAnalysis.scala29
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala16
-rw-r--r--src/main/scala/scala/async/TransformUtils.scala7
4 files changed, 51 insertions, 21 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala
index 64bde3e..5080ecf 100644
--- a/src/main/scala/scala/async/AnfTransform.scala
+++ b/src/main/scala/scala/async/AnfTransform.scala
@@ -54,6 +54,8 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
trans match {
case ValDef(mods, name, tpt, rhs) =>
treeCopy.ValDef(trans, mods, newName, tpt, rhs)
+ case Bind(name, body) =>
+ treeCopy.Bind(trans, newName, body)
case DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
treeCopy.DefDef(trans, mods, newName, tparams, vparamss, tpt, rhs)
case TypeDef(mods, name, tparams, rhs) =>
@@ -82,9 +84,11 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
def indentString = " " * indent
def apply[T](prefix: String, args: Any)(t: => T): T = {
indent += 1
- def oneLine(s: Any) = s.toString.replaceAll("""\n""", "\\\\n").take(127)
+ def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127)
try {
- AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})")
+ AsyncUtils.trace(s"${
+ indentString
+ }$prefix(${oneLine(args)})")
val result = t
AsyncUtils.trace(s"${indentString}= ${oneLine(result)}")
result
@@ -201,8 +205,18 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
val scrutStats :+ scrutExpr = inline.transformToList(scrut)
val caseDefs = cases map {
case CaseDef(pat, guard, body) =>
+ // extract local variables for all names bound in `pat`, and rewrite `body`
+ // to refer to these.
+ // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`.
val block = inline.transformToBlock(body)
- attachCopy(tree)(CaseDef(pat, guard, block))
+ val (valDefs, mappings) = (pat collect {
+ case b@Bind(name, _) =>
+ val newName = newTermName(utils.name.fresh(name.toTermName + utils.name.bindSuffix))
+ val vd = ValDef(NoMods, newName, TypeTree(), Ident(b.symbol))
+ (vd, (b.symbol, newName))
+ }).unzip
+ val Block(stats1, expr1) = utils.substituteNames(block, mappings.toMap).asInstanceOf[Block]
+ attachCopy(tree)(CaseDef(pat, guard, Block(valDefs ++ stats1, expr1)))
}
scrutStats :+ c.typeCheck(attachCopy(tree)(Match(scrutExpr, caseDefs)))
diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala
index 645d9f5..f0d4511 100644
--- a/src/main/scala/scala/async/AsyncAnalysis.scala
+++ b/src/main/scala/scala/async/AsyncAnalysis.scala
@@ -11,6 +11,7 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
import c.universe._
val utils = TransformUtils[c.type](c)
+
import utils._
/**
@@ -67,15 +68,15 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
override def traverse(tree: Tree) {
def containsAwait = tree exists isAwait
tree match {
- case Try(_, _, _) if containsAwait =>
+ case Try(_, _, _) if containsAwait =>
reportUnsupportedAwait(tree, "try/catch")
super.traverse(tree)
case If(cond, _, _) if containsAwait =>
reportUnsupportedAwait(cond, "condition")
super.traverse(tree)
- case Return(_) =>
+ case Return(_) =>
c.abort(tree.pos, "return is illegal within a async block")
- case _ =>
+ case _ =>
super.traverse(tree)
}
}
@@ -92,7 +93,7 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
}
badAwaits.nonEmpty
- }
+ }
}
private class AsyncDefinitionUseAnalyzer extends AsyncTraverser {
@@ -106,36 +107,37 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
override def traverse(tree: Tree) = {
tree match {
- case If(cond, thenp, elsep) if tree exists isAwait =>
+ case If(cond, thenp, elsep) if tree exists isAwait =>
traverseChunks(List(cond, thenp, elsep))
- case Match(selector, cases) if tree exists isAwait =>
+ case Match(selector, cases) if tree exists isAwait =>
traverseChunks(selector :: cases)
case LabelDef(name, params, rhs) if rhs exists isAwait =>
traverseChunks(rhs :: Nil)
- case Apply(fun, args) if isAwait(fun) =>
+ case Apply(fun, args) if isAwait(fun) =>
super.traverse(tree)
nextChunk()
- case vd: ValDef =>
+ case vd: ValDef =>
super.traverse(tree)
valDefChunkId += (vd.symbol ->(vd, chunkId))
- if (isAwait(vd.rhs)) valDefsToLift += vd
- case as: Assign =>
+ val isPatternBinder = vd.name.toString.contains(name.bindSuffix)
+ if (isAwait(vd.rhs) || isPatternBinder) valDefsToLift += vd
+ case as: Assign =>
if (isAwait(as.rhs)) {
- assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol)
+ assert(as.lhs.symbol != null, "internal error: null symbol for Assign tree:" + as + " " + as.lhs.symbol)
// TODO test the orElse case, try to remove the restriction.
val (vd, defBlockId) = valDefChunkId.getOrElse(as.lhs.symbol, c.abort(as.pos, s"await may only be assigned to a var/val defined in the async block. ${as.lhs} ${as.lhs.symbol}"))
valDefsToLift += vd
}
super.traverse(tree)
- case rt: RefTree =>
+ case rt: RefTree =>
valDefChunkId.get(rt.symbol) match {
case Some((vd, defChunkId)) if defChunkId != chunkId =>
valDefsToLift += vd
case _ =>
}
super.traverse(tree)
- case _ => super.traverse(tree)
+ case _ => super.traverse(tree)
}
}
@@ -145,4 +147,5 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
}
}
}
+
}
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index d9faad5..cc2cde5 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -146,7 +146,12 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], caseStates: List[Int]): AsyncState = {
// 1. build list of changed cases
val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match {
- case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(caseStates(num)), mkResumeApply))
+ case CaseDef(pat, guard, rhs) =>
+ val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal).map {
+ case ValDef(_, name, _, rhs) => Assign(Ident(name), rhs)
+ case t => sys.error(s"Unexpected tree. Expected ValDef, found: $t")
+ }
+ CaseDef(pat, guard, Block(bindAssigns :+ mkStateTree(caseStates(num)), mkResumeApply))
}
// 2. insert changed match tree at the end of the current state
this += Match(renameReset(scrutTree), newCases)
@@ -237,7 +242,9 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
stateBuilder.resultWithMatch(scrutinee, cases, caseStates)
for ((cas, num) <- cases.zipWithIndex) {
- val builder = nestedBlockBuilder(cas.body, caseStates(num), afterMatchState)
+ val (stats, expr) = statsAndExpr(cas.body)
+ val stats1 = stats.dropWhile(isSyntheticBindVal)
+ val builder = nestedBlockBuilder(Block(stats1, expr), caseStates(num), afterMatchState)
asyncStates ++= builder.asyncStates
}
@@ -346,6 +353,11 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
}
}
+ private def isSyntheticBindVal(tree: Tree) = tree match {
+ case vd@ValDef(_, lname, _, Ident(rname)) => lname.toString.contains(name.bindSuffix)
+ case _ => false
+ }
+
private final case class Awaitable(expr: Tree, resultName: TermName, resultType: Type)
private def resetDuplicate(tree: Tree) = c.resetAllAttrs(tree.duplicate)
diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala
index c5bbba1..c684ea7 100644
--- a/src/main/scala/scala/async/TransformUtils.scala
+++ b/src/main/scala/scala/async/TransformUtils.scala
@@ -27,9 +27,10 @@ private[async] final case class TransformUtils[C <: Context](c: C) {
val tr = newTermName("tr")
val onCompleteHandler = suffixedName("onCompleteHandler")
- val matchRes = "matchres"
- val ifRes = "ifres"
- val await = "await"
+ val matchRes = "matchres"
+ val ifRes = "ifres"
+ val await = "await"
+ val bindSuffix = "$bind"
def fresh(name: TermName): TermName = newTermName(fresh(name.toString))