From 4da04eee1893ead433a624f6b146d56aca46cb7e Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Sun, 25 Nov 2012 22:44:08 +0100 Subject: Preserve outer This() refs through resetAttrs. Adapt the compiler's standard ResetAttrs to keep This() nodes don't refer to a symbol defined in the current async block. --- src/main/scala/scala/async/Async.scala | 20 +++---- src/main/scala/scala/async/ExprBuilder.scala | 12 +++- src/main/scala/scala/async/TransformUtils.scala | 67 +++++++++++++++++++++- src/test/scala/scala/async/TreeInterrogation.scala | 10 +--- .../scala/async/run/anf/AnfTransformSpec.scala | 4 +- .../scala/scala/async/run/hygiene/Hygiene.scala | 47 +++++++++++---- 6 files changed, 120 insertions(+), 40 deletions(-) diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index 049dba0..4fe7c3f 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -66,7 +66,6 @@ abstract class AsyncBase { def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { import c.universe._ - val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem) val anaylzer = AsyncAnalysis[c.type](c) val utils = TransformUtils[c.type](c) import utils.{name, defn} @@ -94,6 +93,7 @@ abstract class AsyncBase { }.toMap } + val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem, anfTree) val asyncBlock: builder.AsyncBlock = builder.build(anfTree, renameMap) import asyncBlock.asyncStates logDiagnostics(c)(anfTree, asyncStates.map(_.toString)) @@ -115,7 +115,6 @@ abstract class AsyncBase { val stateMachine: ModuleDef = { val body: List[Tree] = { - val constr = DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), List())), c.literalUnit.tree)) val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) val result = ValDef(NoMods, name.result, TypeTree(), futureSystemOps.createProm[T].tree) val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree) @@ -126,16 +125,16 @@ abstract class AsyncBase { } val apply0DefDef: DefDef = { // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. - // See SI-1247 for the the optimization that avoids creation of another thunk class. + // See SI-1247 for the the optimization that avoids creatio val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) val applyBody = asyncBlock.onCompleteHandler DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil)) } - List(constr, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef) + List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef) } val template = { - val `Try[Any] => Unit` = AppliedTypeTree(Ident(c.mirror.staticClass("scala.runtime.AbstractFunction1")), List(TypeTree(defn.TryAnyType), TypeTree(definitions.UnitTpe))) - val `() => Unit` = AppliedTypeTree(Ident(c.mirror.staticClass("scala.Function0")), List(TypeTree(definitions.UnitTpe))) + val `Try[Any] => Unit` = utils.applied("scala.runtime.AbstractFunction1", List(defn.TryAnyType, definitions.UnitTpe)) + val `() => Unit` = utils.applied("scala.Function0", List(definitions.UnitTpe)) Template(List(`Try[Any] => Unit`, `() => Unit`), emptyValDef, body) } ModuleDef(NoMods, name.stateMachine, template) @@ -145,11 +144,10 @@ abstract class AsyncBase { val code = c.Expr[futureSystem.Fut[T]](Block(List[Tree]( stateMachine, - futureSystemOps.future( - c.Expr[Unit](Apply(selectStateMachine(name.apply), Nil))) - (c.Expr[futureSystem.ExecContext](selectStateMachine(name.execContext))).tree), - futureSystemOps.promiseToFuture( - c.Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree + futureSystemOps.future(c.Expr[Unit](Apply(selectStateMachine(name.apply), Nil))) + (c.Expr[futureSystem.ExecContext](selectStateMachine(name.execContext))).tree + ), + futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree )) AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}") diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 0655314..7b4ccb8 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -6,11 +6,12 @@ package scala.async import scala.reflect.macros.Context import scala.collection.mutable.ListBuffer import collection.mutable +import language.existentials /* * @author Philipp Haller */ -private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: C, futureSystem: FS) { +private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: C, futureSystem: FS, origTree: C#Tree) { builder => val utils = TransformUtils[c.type](c) @@ -96,7 +97,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: /** The state of the target of a LabelDef application (while loop jump) */ private var nextJumpState: Option[Int] = None - private def renameReset(tree: Tree) = resetDuplicate(substituteNames(tree, nameMap)) + private def renameReset(tree: Tree) = resetInternalAttrs(substituteNames(tree, nameMap)) def +=(stat: c.Tree): this.type = { assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat") @@ -320,6 +321,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: * } */ val onCompleteHandler: Tree = Match(Ident(name.state), initStates.flatMap(_.mkOnCompleteHandler).toList) + /** * def resume(): Unit = { * try { @@ -357,7 +359,11 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: private final case class Awaitable(expr: Tree, resultName: TermName, resultType: Type) - private def resetDuplicate(tree: Tree) = c.resetAllAttrs(tree.duplicate) + private val internalSyms = origTree.collect { + case dt: DefTree => dt.symbol + } + + private def resetInternalAttrs(tree: Tree) = utils.resetInternalAttrs(tree, internalSyms) private def mkResumeApply = Apply(Ident(name.resume), Nil) diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index c66f874..553211a 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -126,6 +126,14 @@ private[async] final case class TransformUtils[C <: Context](c: C) { ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType)) } + def emptyConstructor: DefDef = { + val emptySuperCall = Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil) + DefDef(NoMods, nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), c.literalUnit.tree)) + } + + def applied(className: String, types: List[Type]): AppliedTypeTree = + AppliedTypeTree(Ident(c.mirror.staticClass(className)), types.map(TypeTree(_))) + object defn { def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) @@ -145,8 +153,7 @@ private[async] final case class TransformUtils[C <: Context](c: C) { self.splice.get } - val Try_get = methodSym(reify((null: scala.util.Try[Any]).get)) - + val Try_get = methodSym(reify((null: scala.util.Try[Any]).get)) val TryClass = c.mirror.staticClass("scala.util.Try") val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe)) val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal") @@ -158,7 +165,6 @@ private[async] final case class TransformUtils[C <: Context](c: C) { } } - /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ private def methodSym(apply: c.Expr[Any]): Symbol = { val tree2: Tree = c.typeCheck(apply.tree) @@ -180,4 +186,59 @@ private[async] final case class TransformUtils[C <: Context](c: C) { tree } + def resetInternalAttrs(tree: Tree, internalSyms: List[Symbol]) = + new ResetInternalAttrs(internalSyms.toSet).transform(tree) + + /** + * Adaptation of [[scala.reflect.internal.Trees.ResetAttrs]] + * + * A transformer which resets symbol and tpe fields of all nodes in a given tree, + * with special treatment of: + * `TypeTree` nodes: are replaced by their original if it exists, otherwise tpe field is reset + * to empty if it started out empty or refers to local symbols (which are erased). + * `TypeApply` nodes: are deleted if type arguments end up reverted to empty + * + * `This` and `Ident` nodes referring to an external symbol are ''not'' reset. + */ + private final class ResetInternalAttrs(internalSyms: Set[Symbol]) extends Transformer { + + import language.existentials + + override def transform(tree: Tree): Tree = super.transform { + def isExternal = tree.symbol != NoSymbol && !internalSyms(tree.symbol) + + tree match { + case tpt: TypeTree => resetTypeTree(tpt) + case TypeApply(fn, args) + if args map transform exists (_.isEmpty) => transform(fn) + case EmptyTree => tree + case (_: Ident | _: This) if isExternal => tree // #35 Don't reset the symbol of Ident/This bound outside of the async block + case _ => resetTree(tree) + } + } + + private def resetTypeTree(tpt: TypeTree): Tree = { + if (tpt.original != null) + transform(tpt.original) + else if (tpt.tpe != null && tpt.asInstanceOf[symtab.TypeTree forSome {val symtab: reflect.internal.SymbolTable}].wasEmpty) { + val dupl = tpt.duplicate + dupl.tpe = null + dupl + } + else tpt + } + + private def resetTree(tree: Tree): Tree = { + val hasSymbol: Boolean = { + val reflectInternalTree = tree.asInstanceOf[symtab.Tree forSome {val symtab: reflect.internal.SymbolTable}] + reflectInternalTree.hasSymbol + } + val dupl = tree.duplicate + if (hasSymbol) + dupl.symbol = NoSymbol + dupl.tpe = null + dupl + } + } + } diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index 4bdb84d..14749ca 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -57,15 +57,7 @@ object TreeInterrogation extends App { val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:all") val tree = tb.parse( """ import _root_.scala.async.AsyncId._ - | object Test { - | def blerg = 1 - | def check() { - | async { - | assert(this.blerg == 1) - | assert(this == Test, this.getClass) - | } - | } - | } + | async { val a = 0; val x = await(a) - 1; def foo(z: Any) = (a.toDouble, x.toDouble, z); foo(await(2)) } | """.stripMargin) println(tree) val tree1 = tb.typeCheck(tree.duplicate) diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala index 41eeaa5..6dd4db7 100644 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -24,8 +24,8 @@ class AnfTestClass { } def m(y: Int): Future[Int] = async { - val f = base(y) - await(f) + val blerg = base(y) + await(blerg) } def m2(y: Int): Future[Int] = async { diff --git a/src/test/scala/scala/async/run/hygiene/Hygiene.scala b/src/test/scala/scala/async/run/hygiene/Hygiene.scala index 2aaf515..bb28d5b 100644 --- a/src/test/scala/scala/async/run/hygiene/Hygiene.scala +++ b/src/test/scala/scala/async/run/hygiene/Hygiene.scala @@ -88,16 +88,39 @@ class HygieneSpec { ext mustBe (14) } -// @Test def `this reference is maintained`() { -// object Test { -// def blerg = 1 -// def check() { -// AsyncId.async { -// assert(this.blerg == 1) -// assert(this == Test, this.getClass) -// } -// } -// } -// Test.check() -// } + trait T1 { + def blerg = 0 + } + + object O1 extends T1 { + override def blerg = 1 + + def check() { + val blerg = 3 + AsyncId.async { + assert(this == O1, this.getClass) + assert(this.blerg == 1) + assert(super.blerg == 0) + assert(super[T1].blerg == 0) + } + } + } + + @Test def `this reference is maintained`() { + O1.check() + } + + @Test def `this reference is maintained to local class`() { + object O2 { + def blerg = 2 + + def check() { + AsyncId.async { + assert(this.blerg == 2) + assert(this == O2, this.getClass) + } + } + } + O2.check() + } } -- cgit v1.2.3