diff options
Diffstat (limited to 'src/main/scala/scala/async/AnfTransform.scala')
-rw-r--r-- | src/main/scala/scala/async/AnfTransform.scala | 68 |
1 files changed, 53 insertions, 15 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index f52bdad..a2d21f6 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -10,7 +10,9 @@ import scala.reflect.macros.Context private[async] final case class AnfTransform[C <: Context](c: C) { import c.universe._ + val utils = TransformUtils[c.type](c) + import utils._ def apply(tree: Tree): List[Tree] = { @@ -29,9 +31,21 @@ private[async] final case class AnfTransform[C <: Context](c: C) { * This step is needed to allow us to safely merge blocks during the `inline` transform below. */ private final class UniqueNames(tree: Tree) extends Transformer { - val repeatedNames: Set[Name] = tree.collect { - case dt: DefTree => dt.symbol.name - }.groupBy(x => x).filter(_._2.size > 1).keySet + val repeatedNames: Set[Symbol] = { + class DuplicateNameTraverser extends AsyncTraverser { + val result = collection.mutable.Buffer[Symbol]() + + override def traverse(tree: Tree) { + tree match { + case dt: DefTree => result += dt.symbol + case _ => super.traverse(tree) + } + } + } + val dupNameTraverser = new DuplicateNameTraverser + dupNameTraverser.traverse(tree) + dupNameTraverser.result.groupBy(x => x.name).filter(_._2.size > 1).values.flatten.toSet[Symbol] + } /** Stepping outside of the public Macro API to call [[scala.reflect.internal.Symbols.Symbol.name_=]] */ val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] @@ -40,7 +54,7 @@ private[async] final case class AnfTransform[C <: Context](c: C) { override def transform(tree: Tree): Tree = { tree match { - case defTree: DefTree if repeatedNames(defTree.symbol.name) => + case defTree: DefTree if repeatedNames(defTree.symbol) => val trans = super.transform(defTree) val origName = defTree.symbol.name val sym = defTree.symbol.asInstanceOf[symtab.Symbol] @@ -54,6 +68,8 @@ private[async] final case class AnfTransform[C <: Context](c: C) { trans match { case ValDef(mods, name, tpt, rhs) => treeCopy.ValDef(trans, mods, newName, tpt, rhs) + case Bind(name, body) => + treeCopy.Bind(trans, newName, body) case DefDef(mods, name, tparams, vparamss, tpt, rhs) => treeCopy.DefDef(trans, mods, newName, tparams, vparamss, tpt, rhs) case TypeDef(mods, name, tparams, rhs) => @@ -79,12 +95,16 @@ private[async] final case class AnfTransform[C <: Context](c: C) { private object trace { private var indent = -1 + def indentString = " " * indent + def apply[T](prefix: String, args: Any)(t: => T): T = { indent += 1 - def oneLine(s: Any) = s.toString.replaceAll("""\n""", "\\\\n").take(127) + def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127) try { - AsyncUtils.trace(s"${indentString}$prefix(${oneLine(args)})") + AsyncUtils.trace(s"${ + indentString + }$prefix(${oneLine(args)})") val result = t AsyncUtils.trace(s"${indentString}= ${oneLine(result)}") result @@ -139,10 +159,7 @@ private[async] final case class AnfTransform[C <: Context](c: C) { } } - def transformToList(trees: List[Tree]): List[Tree] = trees match { - case fst :: rest => transformToList(fst) ++ transformToList(rest) - case Nil => Nil - } + def transformToList(trees: List[Tree]): List[Tree] = trees flatMap transformToList def transformToBlock(tree: Tree): Block = transformToList(tree) match { case stats :+ expr => Block(stats, expr) @@ -194,20 +211,40 @@ private[async] final case class AnfTransform[C <: Context](c: C) { stats :+ attachCopy(tree)(Assign(lhs, expr)) case If(cond, thenp, elsep) if containsAwait => - val stats :+ expr = inline.transformToList(cond) + val condStats :+ condExpr = inline.transformToList(cond) val thenBlock = inline.transformToBlock(thenp) val elseBlock = inline.transformToBlock(elsep) - stats :+ - c.typeCheck(attachCopy(tree)(If(expr, thenBlock, elseBlock))) + // Typechecking with `condExpr` as the condition fails if the condition + // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems + // we rely on this call to `typeCheck` descending into the branches. + // But, we can get away with typechecking a throwaway `If` tree with the + // original scrutinee and the new branches, and setting that type on + // the real `If` tree. + val ifType = c.typeCheck(If(cond, thenBlock, elseBlock)).tpe + condStats :+ + attachCopy(tree)(If(condExpr, thenBlock, elseBlock)).setType(ifType) case Match(scrut, cases) if containsAwait => val scrutStats :+ scrutExpr = inline.transformToList(scrut) val caseDefs = cases map { case CaseDef(pat, guard, body) => + // extract local variables for all names bound in `pat`, and rewrite `body` + // to refer to these. + // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`. val block = inline.transformToBlock(body) - attachCopy(tree)(CaseDef(pat, guard, block)) + val (valDefs, mappings) = (pat collect { + case b@Bind(name, _) => + val newName = newTermName(utils.name.fresh(name.toTermName + utils.name.bindSuffix)) + val vd = ValDef(NoMods, newName, TypeTree(), Ident(b.symbol)) + (vd, (b.symbol, newName)) + }).unzip + val Block(stats1, expr1) = utils.substituteNames(block, mappings.toMap).asInstanceOf[Block] + attachCopy(tree)(CaseDef(pat, guard, Block(valDefs ++ stats1, expr1))) } - scrutStats :+ c.typeCheck(attachCopy(tree)(Match(scrutExpr, caseDefs))) + // Refer to comments the translation of `If` above. + val matchType = c.typeCheck(Match(scrut, caseDefs)).tpe + val typedMatch = attachCopy(tree)(Match(scrutExpr, caseDefs)).setType(tree.tpe) + scrutStats :+ typedMatch case LabelDef(name, params, rhs) if containsAwait => List(LabelDef(name, params, Block(inline.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) @@ -221,4 +258,5 @@ private[async] final case class AnfTransform[C <: Context](c: C) { } } } + } |