From 5a2acf110233669e1cf1124ac54b28a526d37858 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Fri, 23 Nov 2012 12:08:59 +0100 Subject: Ensure unique names for definitions in the async block. - transform the provided tree using reflect.internal.Symbols#Symbol.name_= and treeCopy.{Ident, Select}. - not sure if this is possible within the public Symbol API. - move checking for unsupported nested module/class to AsyncAnalysis. - make block merging selective (only do so if there are nested await calls.) --- src/main/scala/scala/async/AnfTransform.scala | 61 +++++++++++++++++++++++++- src/main/scala/scala/async/Async.scala | 8 +++- src/main/scala/scala/async/AsyncAnalysis.scala | 18 ++++++-- src/main/scala/scala/async/ExprBuilder.scala | 10 +---- 4 files changed, 80 insertions(+), 17 deletions(-) (limited to 'src/main') diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 24f37e7..ba8d36c 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -7,6 +7,64 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { import c.universe._ + def uniqueNames(tree: Tree): Tree = { + new UniqueNames(tree).transform(tree) + } + + /** Assigns unique names to all definitions in a tree, and adjusts references to use the new name. + * Only modifies names that appear more than once in the tree. + * + * This step is needed to allow us to safely merge blocks during the `inline` transform below. + */ + final class UniqueNames(tree: Tree) extends Transformer { + val repeatedNames: Set[Name] = tree.collect { + case dt: DefTree => dt.symbol.name + }.groupBy(x => x).filter(_._2.size > 1).keySet + + /** Stepping outside of the public Macro API to call [[scala.reflect.internal.Symbols# S y m b o l.n a m e _ =]] */ + val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] + + val renamed = collection.mutable.Set[Symbol]() + + override def transform(tree: Tree): Tree = { + tree match { + case defTree: DefTree if repeatedNames(defTree.symbol.name) => + val trans = super.transform(defTree) + val origName = defTree.symbol.name + val sym = defTree.symbol.asInstanceOf[symtab.Symbol] + val fresh = c.fresh("" + sym.name + "$") + sym.name = defTree.symbol.name match { + case _: TermName => symtab.newTermName(fresh) + case _: TypeName => symtab.newTypeName(fresh) + } + renamed += trans.symbol + val newName = trans.symbol.name + trans match { + case ValDef(mods, name, tpt, rhs) => + treeCopy.ValDef(trans, mods, newName, tpt, rhs) + case DefDef(mods, name, tparams, vparamss, tpt, rhs) => + treeCopy.DefDef(trans, mods, newName, tparams, vparamss, tpt, rhs) + case TypeDef(mods, name, tparams, rhs) => + treeCopy.TypeDef(tree, mods, newName, tparams, transform(rhs)) + // If we were to allow local classes / objects, we would need to rename here. + case ClassDef(mods, name, tparams, impl) => + treeCopy.ClassDef(tree, mods, newName, tparams, transform(impl).asInstanceOf[Template]) + case ModuleDef(mods, name, impl) => + treeCopy.ModuleDef(tree, mods, newName, transform(impl).asInstanceOf[Template]) + case x => super.transform(x) + } + case Ident(name) => + if (renamed(tree.symbol)) treeCopy.Ident(tree, tree.symbol.name) + else tree + case Select(fun, name) => + if (renamed(tree.symbol)) { + treeCopy.Select(tree, transform(fun), tree.symbol.name) + } else super.transform(tree) + case _ => super.transform(tree) + } + } + } + object inline { def transformToList(tree: Tree): List[Tree] = { val stats :+ expr = anf.transformToList(tree) @@ -86,7 +144,7 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { val simpleArgs = argLists map (_.last) funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs).setSymbol(tree.symbol) - case Block(stats, expr) => // TODO figure out why adding a guard `if containsAwait` breaks LocalClasses0Spec. + case Block(stats, expr) if containsAwait => inline.transformToList(stats :+ expr) case ValDef(mods, name, tpt, rhs) if containsAwait => @@ -123,5 +181,4 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { } } } - } diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index 546445a..0b77c78 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -79,8 +79,12 @@ abstract class AsyncBase { // - if/match only used in statement position. val anfTree: Block = { val transform = new AnfTransform[c.type](c) - val stats1 :+ expr1 = transform.anf.transformToList(body.tree) - c.typeCheck(Block(stats1, expr1)).asInstanceOf[Block] + val unique = transform.uniqueNames(body.tree) + val stats1 :+ expr1 = transform.anf.transformToList(unique) + + val block = Block(stats1, expr1) + + c.typeCheck(block).asInstanceOf[Block] } // Analyze the block to find locals that will be accessed from multiple diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index 1b00620..c7db44d 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -32,11 +32,17 @@ private[async] final class AsyncAnalysis[C <: Context](override val c: C) extend private class UnsupportedAwaitAnalyzer extends super.AsyncTraverser { override def nestedClass(classDef: ClassDef) { val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" - reportUnsupportedAwait(classDef, s"nested $kind") + if (!reportUnsupportedAwait(classDef, s"nested $kind")) { + // do not allow local class definitions, because of SI-5467 (specific to case classes, though) + c.error(classDef.pos, s"Local class ${classDef.name.decoded} illegal within `async` block") + } } override def nestedModule(module: ModuleDef) { - reportUnsupportedAwait(module, "nested object") + if (!reportUnsupportedAwait(module, "nested object")) { + // local object definitions lead to spurious type errors (because of resetAllAttrs?) + c.error(module.pos, s"Local object ${module.name.decoded} illegal within `async` block") + } } override def byNameArgument(arg: Tree) { @@ -47,14 +53,18 @@ private[async] final class AsyncAnalysis[C <: Context](override val c: C) extend reportUnsupportedAwait(function, "nested function") } - private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { - val badAwaits = tree collect { + /** + * @return true, if the tree contained an unsupported await. + */ + private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = { + val badAwaits: List[RefTree] = tree collect { case rt: RefTree if isAwait(rt) => rt } badAwaits foreach { tree => c.error(tree.pos, s"await must not be used under a $whyUnsupported.") } + badAwaits.nonEmpty } } diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 573af16..b7cfc8f 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -32,7 +32,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val val tr = newTermName("tr") val onCompleteHandler = suffixedName("onCompleteHandler") - def fresh(name: TermName) = newTermName(c.fresh("" + name + "$")) + def fresh(name: TermName) = if (name.toString.contains("$")) name else newTermName(c.fresh("" + name + "$")) } private[async] lazy val futureSystemOps = futureSystem.mkOps(c) @@ -335,14 +335,6 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val currState = afterMatchState stateBuilder = new builder.AsyncStateBuilder(currState, toRename) - case ClassDef(_, name, _, _) => - // do not allow local class definitions, because of SI-5467 (specific to case classes, though) - c.error(stat.pos, s"Local class ${name.decoded} illegal within `async` block") - - case ModuleDef(_, name, _) => - // local object definitions lead to spurious type errors (because of resetAllAttrs?) - c.error(stat.pos, s"Local object ${name.decoded} illegal within `async` block") - case _ => checkForUnsupportedAwait(stat) stateBuilder += stat -- cgit v1.2.3