From 5a2acf110233669e1cf1124ac54b28a526d37858 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Fri, 23 Nov 2012 12:08:59 +0100 Subject: Ensure unique names for definitions in the async block. - transform the provided tree using reflect.internal.Symbols#Symbol.name_= and treeCopy.{Ident, Select}. - not sure if this is possible within the public Symbol API. - move checking for unsupported nested module/class to AsyncAnalysis. - make block merging selective (only do so if there are nested await calls.) --- src/main/scala/scala/async/AnfTransform.scala | 61 ++++++++++++++++++++- src/main/scala/scala/async/Async.scala | 8 ++- src/main/scala/scala/async/AsyncAnalysis.scala | 18 +++++-- src/main/scala/scala/async/ExprBuilder.scala | 10 +--- src/test/scala/scala/async/TreeInterrogation.scala | 2 +- .../scala/async/run/anf/AnfTransformSpec.scala | 63 +++++++++++----------- 6 files changed, 113 insertions(+), 49 deletions(-) diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 24f37e7..ba8d36c 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -7,6 +7,64 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { import c.universe._ + def uniqueNames(tree: Tree): Tree = { + new UniqueNames(tree).transform(tree) + } + + /** Assigns unique names to all definitions in a tree, and adjusts references to use the new name. + * Only modifies names that appear more than once in the tree. + * + * This step is needed to allow us to safely merge blocks during the `inline` transform below. + */ + final class UniqueNames(tree: Tree) extends Transformer { + val repeatedNames: Set[Name] = tree.collect { + case dt: DefTree => dt.symbol.name + }.groupBy(x => x).filter(_._2.size > 1).keySet + + /** Stepping outside of the public Macro API to call [[scala.reflect.internal.Symbols# S y m b o l.n a m e _ =]] */ + val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] + + val renamed = collection.mutable.Set[Symbol]() + + override def transform(tree: Tree): Tree = { + tree match { + case defTree: DefTree if repeatedNames(defTree.symbol.name) => + val trans = super.transform(defTree) + val origName = defTree.symbol.name + val sym = defTree.symbol.asInstanceOf[symtab.Symbol] + val fresh = c.fresh("" + sym.name + "$") + sym.name = defTree.symbol.name match { + case _: TermName => symtab.newTermName(fresh) + case _: TypeName => symtab.newTypeName(fresh) + } + renamed += trans.symbol + val newName = trans.symbol.name + trans match { + case ValDef(mods, name, tpt, rhs) => + treeCopy.ValDef(trans, mods, newName, tpt, rhs) + case DefDef(mods, name, tparams, vparamss, tpt, rhs) => + treeCopy.DefDef(trans, mods, newName, tparams, vparamss, tpt, rhs) + case TypeDef(mods, name, tparams, rhs) => + treeCopy.TypeDef(tree, mods, newName, tparams, transform(rhs)) + // If we were to allow local classes / objects, we would need to rename here. + case ClassDef(mods, name, tparams, impl) => + treeCopy.ClassDef(tree, mods, newName, tparams, transform(impl).asInstanceOf[Template]) + case ModuleDef(mods, name, impl) => + treeCopy.ModuleDef(tree, mods, newName, transform(impl).asInstanceOf[Template]) + case x => super.transform(x) + } + case Ident(name) => + if (renamed(tree.symbol)) treeCopy.Ident(tree, tree.symbol.name) + else tree + case Select(fun, name) => + if (renamed(tree.symbol)) { + treeCopy.Select(tree, transform(fun), tree.symbol.name) + } else super.transform(tree) + case _ => super.transform(tree) + } + } + } + object inline { def transformToList(tree: Tree): List[Tree] = { val stats :+ expr = anf.transformToList(tree) @@ -86,7 +144,7 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { val simpleArgs = argLists map (_.last) funStats ++ allArgStats :+ Apply(simpleFun, simpleArgs).setSymbol(tree.symbol) - case Block(stats, expr) => // TODO figure out why adding a guard `if containsAwait` breaks LocalClasses0Spec. + case Block(stats, expr) if containsAwait => inline.transformToList(stats :+ expr) case ValDef(mods, name, tpt, rhs) if containsAwait => @@ -123,5 +181,4 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { } } } - } diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index 546445a..0b77c78 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -79,8 +79,12 @@ abstract class AsyncBase { // - if/match only used in statement position. val anfTree: Block = { val transform = new AnfTransform[c.type](c) - val stats1 :+ expr1 = transform.anf.transformToList(body.tree) - c.typeCheck(Block(stats1, expr1)).asInstanceOf[Block] + val unique = transform.uniqueNames(body.tree) + val stats1 :+ expr1 = transform.anf.transformToList(unique) + + val block = Block(stats1, expr1) + + c.typeCheck(block).asInstanceOf[Block] } // Analyze the block to find locals that will be accessed from multiple diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index 1b00620..c7db44d 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -32,11 +32,17 @@ private[async] final class AsyncAnalysis[C <: Context](override val c: C) extend private class UnsupportedAwaitAnalyzer extends super.AsyncTraverser { override def nestedClass(classDef: ClassDef) { val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" - reportUnsupportedAwait(classDef, s"nested $kind") + if (!reportUnsupportedAwait(classDef, s"nested $kind")) { + // do not allow local class definitions, because of SI-5467 (specific to case classes, though) + c.error(classDef.pos, s"Local class ${classDef.name.decoded} illegal within `async` block") + } } override def nestedModule(module: ModuleDef) { - reportUnsupportedAwait(module, "nested object") + if (!reportUnsupportedAwait(module, "nested object")) { + // local object definitions lead to spurious type errors (because of resetAllAttrs?) + c.error(module.pos, s"Local object ${module.name.decoded} illegal within `async` block") + } } override def byNameArgument(arg: Tree) { @@ -47,14 +53,18 @@ private[async] final class AsyncAnalysis[C <: Context](override val c: C) extend reportUnsupportedAwait(function, "nested function") } - private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { - val badAwaits = tree collect { + /** + * @return true, if the tree contained an unsupported await. + */ + private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = { + val badAwaits: List[RefTree] = tree collect { case rt: RefTree if isAwait(rt) => rt } badAwaits foreach { tree => c.error(tree.pos, s"await must not be used under a $whyUnsupported.") } + badAwaits.nonEmpty } } diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 573af16..b7cfc8f 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -32,7 +32,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val val tr = newTermName("tr") val onCompleteHandler = suffixedName("onCompleteHandler") - def fresh(name: TermName) = newTermName(c.fresh("" + name + "$")) + def fresh(name: TermName) = if (name.toString.contains("$")) name else newTermName(c.fresh("" + name + "$")) } private[async] lazy val futureSystemOps = futureSystem.mkOps(c) @@ -335,14 +335,6 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val currState = afterMatchState stateBuilder = new builder.AsyncStateBuilder(currState, toRename) - case ClassDef(_, name, _, _) => - // do not allow local class definitions, because of SI-5467 (specific to case classes, though) - c.error(stat.pos, s"Local class ${name.decoded} illegal within `async` block") - - case ModuleDef(_, name, _) => - // local object definitions lead to spurious type errors (because of resetAllAttrs?) - c.error(stat.pos, s"Local object ${name.decoded} illegal within `async` block") - case _ => checkForUnsupportedAwait(stat) stateBuilder += stat diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index 02f4b43..9e68005 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -32,6 +32,6 @@ class TreeInterrogation { val varDefs = tree1.collect { case ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) => name } - varDefs.map(_.decoded).toSet mustBe(Set("state$async", "onCompleteHandler$async", "await$1$1", "await$2$1")) + varDefs.map(_.decoded).toSet mustBe(Set("state$async", "onCompleteHandler$async", "await$1", "await$2")) } } diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala index 1d6e09a..872e44d 100644 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -112,37 +112,38 @@ class AnfTransformSpec { State.result mustBe (14) } - // TODO JZ -// @Test -// def `inlining block produces duplicate definition`() { -// import scala.async.AsyncId -// -// AsyncId.async { -// val f = 12 -// val x = AsyncId.await(f) -// -// { -// val x = 42 -// println(x) -// } -// -// x -// } -// } -// @Test -// def `inlining block in tail position produces duplicate definition`() { -// import scala.async.AsyncId -// -// AsyncId.async { -// val f = 12 -// val x = AsyncId.await(f) -// -// { -// val x = 42 // TODO should we rename the symbols when we collapse them into the same scope? -// x -// } -// } mustBe (42) -// } + @Test + def `inlining block does not produce duplicate definition`() { + import scala.async.AsyncId + + AsyncId.async { + val f = 12 + val x = AsyncId.await(f) + + { + type X = Int + val x: X = 42 + println(x) + } + type X = Int + x: X + } + } + + @Test + def `inlining block in tail position does not produce duplicate definition`() { + import scala.async.AsyncId + + AsyncId.async { + val f = 12 + val x = AsyncId.await(f) + + { + val x = 42 + x + } + } mustBe (42) + } @Test def `match as expression 1`() { -- cgit v1.2.3 From 848d98a89e4ff25f2d1373021bef23059ff04ab2 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Fri, 23 Nov 2012 12:15:05 +0100 Subject: Convert null check to an assert. We seem to be symful now. --- src/main/scala/scala/async/AsyncAnalysis.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index c7db44d..180e006 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -92,12 +92,11 @@ private[async] final class AsyncAnalysis[C <: Context](override val c: C) extend if (isAwait(vd.rhs)) valDefsToLift += vd case as: Assign => if (isAwait(as.rhs)) { + assert(as.symbol != null, "internal error: null symbol for Assign tree:" + as) + // TODO test the orElse case, try to remove the restriction. - if (as.symbol != null) { - // synthetic added by the ANF transfor - val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block. " + as.symbol)) - valDefsToLift += vd - } + val (vd, defBlockId) = valDefChunkId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block. " + as.symbol)) + valDefsToLift += vd } super.traverse(tree) case rt: RefTree => -- cgit v1.2.3