aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/AnfTransform.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/scala/async/AnfTransform.scala')
-rw-r--r--src/main/scala/scala/async/AnfTransform.scala68
1 files changed, 53 insertions, 15 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala
index f52bdad..a2d21f6 100644
--- a/src/main/scala/scala/async/AnfTransform.scala
+++ b/src/main/scala/scala/async/AnfTransform.scala
@@ -10,7 +10,9 @@ import scala.reflect.macros.Context
private[async] final case class AnfTransform[C <: Context](c: C) {
import c.universe._
+
val utils = TransformUtils[c.type](c)
+
import utils._
def apply(tree: Tree): List[Tree] = {
@@ -29,9 +31,21 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
* This step is needed to allow us to safely merge blocks during the `inline` transform below.
*/
private final class UniqueNames(tree: Tree) extends Transformer {
- val repeatedNames: Set[Name] = tree.collect {
- case dt: DefTree => dt.symbol.name
- }.groupBy(x => x).filter(_._2.size > 1).keySet
+ val repeatedNames: Set[Symbol] = {
+ class DuplicateNameTraverser extends AsyncTraverser {
+ val result = collection.mutable.Buffer[Symbol]()
+
+ override def traverse(tree: Tree) {
+ tree match {
+ case dt: DefTree => result += dt.symbol
+ case _ => super.traverse(tree)
+ }
+ }
+ }
+ val dupNameTraverser = new DuplicateNameTraverser
+ dupNameTraverser.traverse(tree)
+ dupNameTraverser.result.groupBy(x => x.name).filter(_._2.size > 1).values.flatten.toSet[Symbol]
+ }
/** Stepping outside of the public Macro API to call [[scala.reflect.internal.Symbols.Symbol.name_=]] */
val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable]
@@ -40,7 +54,7 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
override def transform(tree: Tree): Tree = {
tree match {
- case defTree: DefTree if repeatedNames(defTree.symbol.name) =>
+ case defTree: DefTree if repeatedNames(defTree.symbol) =>
val trans = super.transform(defTree)
val origName = defTree.symbol.name
val sym = defTree.symbol.asInstanceOf[symtab.Symbol]
@@ -54,6 +68,8 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
trans match {
case ValDef(mods, name, tpt, rhs) =>
treeCopy.ValDef(trans, mods, newName, tpt, rhs)
+ case Bind(name, body) =>
+ treeCopy.Bind(trans, newName, body)
case DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
treeCopy.DefDef(trans, mods, newName, tparams, vparamss, tpt, rhs)
case TypeDef(mods, name, tparams, rhs) =>
@@ -79,12 +95,16 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
private object trace {
private var indent = -1
+
def indentString = " " * indent
+
def apply[T](prefix: String, args: Any)(t: => T): T = {
indent += 1
- def oneLine(s: Any) = s.toString.replaceAll("""\n""", "\\\\n").take(127)
+ def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127)
try {
- AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})")
+ AsyncUtils.trace(s"${
+ indentString
+ }$prefix(${oneLine(args)})")
val result = t
AsyncUtils.trace(s"${indentString}= ${oneLine(result)}")
result
@@ -139,10 +159,7 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
}
}
- def transformToList(trees: List[Tree]): List[Tree] = trees match {
- case fst :: rest => transformToList(fst) ++ transformToList(rest)
- case Nil => Nil
- }
+ def transformToList(trees: List[Tree]): List[Tree] = trees flatMap transformToList
def transformToBlock(tree: Tree): Block = transformToList(tree) match {
case stats :+ expr => Block(stats, expr)
@@ -194,20 +211,40 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
stats :+ attachCopy(tree)(Assign(lhs, expr))
case If(cond, thenp, elsep) if containsAwait =>
- val stats :+ expr = inline.transformToList(cond)
+ val condStats :+ condExpr = inline.transformToList(cond)
val thenBlock = inline.transformToBlock(thenp)
val elseBlock = inline.transformToBlock(elsep)
- stats :+
- c.typeCheck(attachCopy(tree)(If(expr, thenBlock, elseBlock)))
+ // Typechecking with `condExpr` as the condition fails if the condition
+ // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems
+ // we rely on this call to `typeCheck` descending into the branches.
+ // But, we can get away with typechecking a throwaway `If` tree with the
+ // original scrutinee and the new branches, and setting that type on
+ // the real `If` tree.
+ val ifType = c.typeCheck(If(cond, thenBlock, elseBlock)).tpe
+ condStats :+
+ attachCopy(tree)(If(condExpr, thenBlock, elseBlock)).setType(ifType)
case Match(scrut, cases) if containsAwait =>
val scrutStats :+ scrutExpr = inline.transformToList(scrut)
val caseDefs = cases map {
case CaseDef(pat, guard, body) =>
+ // extract local variables for all names bound in `pat`, and rewrite `body`
+ // to refer to these.
+ // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`.
val block = inline.transformToBlock(body)
- attachCopy(tree)(CaseDef(pat, guard, block))
+ val (valDefs, mappings) = (pat collect {
+ case b@Bind(name, _) =>
+ val newName = newTermName(utils.name.fresh(name.toTermName + utils.name.bindSuffix))
+ val vd = ValDef(NoMods, newName, TypeTree(), Ident(b.symbol))
+ (vd, (b.symbol, newName))
+ }).unzip
+ val Block(stats1, expr1) = utils.substituteNames(block, mappings.toMap).asInstanceOf[Block]
+ attachCopy(tree)(CaseDef(pat, guard, Block(valDefs ++ stats1, expr1)))
}
- scrutStats :+ c.typeCheck(attachCopy(tree)(Match(scrutExpr, caseDefs)))
+ // Refer to comments the translation of `If` above.
+ val matchType = c.typeCheck(Match(scrut, caseDefs)).tpe
+ val typedMatch = attachCopy(tree)(Match(scrutExpr, caseDefs)).setType(tree.tpe)
+ scrutStats :+ typedMatch
case LabelDef(name, params, rhs) if containsAwait =>
List(LabelDef(name, params, Block(inline.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
@@ -221,4 +258,5 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
}
}
}
+
}