From be275dcf295f0addf8d41c9a3b4cfe2acaadfaa4 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Mon, 26 Nov 2012 14:05:04 +0100 Subject: Rewrite the state machine to a class, rather than an object. To avoid suprises in tree retyping, the instance of this class is immediately upcase to StateMachine[Promise[T], ExecContext]. Allow nested non-case classes. These pop up when we use nested async calls. Only look for duplicate names in the subtrees traversed by AsyncTraverser. --- build.sbt | 2 +- src/main/scala/scala/async/AnfTransform.scala | 23 ++++++- src/main/scala/scala/async/Async.scala | 24 +++++--- src/main/scala/scala/async/AsyncAnalysis.scala | 3 +- src/main/scala/scala/async/FutureSystem.scala | 9 +++ src/main/scala/scala/async/TransformUtils.scala | 23 +++---- src/test/scala/scala/async/TreeInterrogation.scala | 14 ++++- .../scala/scala/async/neg/LocalClasses0Spec.scala | 12 ++-- .../scala/scala/async/run/hygiene/Hygiene.scala | 70 +++++----------------- 9 files changed, 92 insertions(+), 88 deletions(-) diff --git a/build.sbt b/build.sbt index 9b0a6bd..4a3f200 100644 --- a/build.sbt +++ b/build.sbt @@ -1,4 +1,4 @@ -scalaVersion := "2.10.0-RC1" +scalaVersion := "2.10.0-RC3" organization := "org.typesafe.async" diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 055676d..d216e44 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -10,7 +10,9 @@ import scala.reflect.macros.Context private[async] final case class AnfTransform[C <: Context](c: C) { import c.universe._ + val utils = TransformUtils[c.type](c) + import utils._ def apply(tree: Tree): List[Tree] = { @@ -29,9 +31,21 @@ private[async] final case class AnfTransform[C <: Context](c: C) { * This step is needed to allow us to safely merge blocks during the `inline` transform below. */ private final class UniqueNames(tree: Tree) extends Transformer { - val repeatedNames: Set[Name] = tree.collect { - case dt: DefTree => dt.symbol.name - }.groupBy(x => x).filter(_._2.size > 1).keySet + class DuplicateNameTraverser extends AsyncTraverser { + val result = collection.mutable.Buffer[Name]() + + override def traverse(tree: Tree) { + tree match { + case dt: DefTree => result += dt.symbol.name + case _ => super.traverse(tree) + } + } + } + val repeatedNames: Set[Name] = { + val dupNameTraverser = new DuplicateNameTraverser + dupNameTraverser.traverse(tree) + dupNameTraverser.result.groupBy(x => x).filter(_._2.size > 1).keySet + } /** Stepping outside of the public Macro API to call [[scala.reflect.internal.Symbols.Symbol.name_=]] */ val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] @@ -81,7 +95,9 @@ private[async] final case class AnfTransform[C <: Context](c: C) { private object trace { private var indent = -1 + def indentString = " " * indent + def apply[T](prefix: String, args: Any)(t: => T): T = { indent += 1 def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127) @@ -242,4 +258,5 @@ private[async] final case class AnfTransform[C <: Context](c: C) { } } } + } diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index 4fe7c3f..09e002d 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -100,9 +100,9 @@ abstract class AsyncBase { // Important to retain the original declaration order here! val localVarTrees = anfTree.collect { - case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol => + case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol => utils.mkVarDefTree(tpt.tpe, renameMap(vd.symbol)) - case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) => + case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) if renameMap contains dd.symbol => DefDef(mods, renameMap(dd.symbol), tparams, vparamss, tpt, c.resetAllAttrs(utils.substituteNames(rhs, renameMap))) } @@ -113,10 +113,12 @@ abstract class AsyncBase { } val resumeFunTree = asyncBlock.resumeFunTree[T] - val stateMachine: ModuleDef = { + val stateMachineType = utils.applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType)) + + val stateMachine: ClassDef = { val body: List[Tree] = { val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) - val result = ValDef(NoMods, name.result, TypeTree(), futureSystemOps.createProm[T].tree) + val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree) val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree) val applyDefDef: DefDef = { val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) @@ -133,17 +135,16 @@ abstract class AsyncBase { List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef) } val template = { - val `Try[Any] => Unit` = utils.applied("scala.runtime.AbstractFunction1", List(defn.TryAnyType, definitions.UnitTpe)) - val `() => Unit` = utils.applied("scala.Function0", List(definitions.UnitTpe)) - Template(List(`Try[Any] => Unit`, `() => Unit`), emptyValDef, body) + Template(List(stateMachineType), emptyValDef, body) } - ModuleDef(NoMods, name.stateMachine, template) + ClassDef(NoMods, name.stateMachineT, Nil, template) } def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) val code = c.Expr[futureSystem.Fut[T]](Block(List[Tree]( stateMachine, + ValDef(NoMods, name.stateMachine, stateMachineType, New(Ident(name.stateMachineT), Nil)), futureSystemOps.future(c.Expr[Unit](Apply(selectStateMachine(name.apply), Nil))) (c.Expr[futureSystem.ExecContext](selectStateMachine(name.execContext))).tree ), @@ -152,7 +153,6 @@ abstract class AsyncBase { AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}") code - } def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) { @@ -169,3 +169,9 @@ abstract class AsyncBase { states foreach (s => AsyncUtils.vprintln(s)) } } + +/** Internal class used by the `async` macro; should not be manually extended by client code */ +abstract class StateMachine[Result, EC] extends (scala.util.Try[Any] => Unit) with (() => Unit) { + def result$async: Result + def execContext$async: EC +} diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index 6e281e4..8bb5bcd 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -43,7 +43,8 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" if (!reportUnsupportedAwait(classDef, s"nested $kind")) { // do not allow local class definitions, because of SI-5467 (specific to case classes, though) - c.error(classDef.pos, s"Local class ${classDef.name.decoded} illegal within `async` block") + if (classDef.symbol.asClass.isCaseClass) + c.error(classDef.pos, s"Local case class ${classDef.name.decoded} illegal within `async` block") } } diff --git a/src/main/scala/scala/async/FutureSystem.scala b/src/main/scala/scala/async/FutureSystem.scala index 20bbea3..e9373b3 100644 --- a/src/main/scala/scala/async/FutureSystem.scala +++ b/src/main/scala/scala/async/FutureSystem.scala @@ -33,6 +33,9 @@ trait FutureSystem { /** Lookup the execution context, typically with an implicit search */ def execContext: Expr[ExecContext] + def promType[A: WeakTypeTag]: Type + def execContextType: Type + /** Create an empty promise */ def createProm[A: WeakTypeTag]: Expr[Prom[A]] @@ -71,6 +74,9 @@ object ScalaConcurrentFutureSystem extends FutureSystem { case context => context }) + def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Promise[A]] + def execContextType: Type = c.weakTypeOf[ExecutionContext] + def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { Promise[A]() } @@ -113,6 +119,9 @@ object IdentityFutureSystem extends FutureSystem { def execContext: Expr[ExecContext] = c.literalUnit + def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Prom[A]] + def execContextType: Type = c.weakTypeOf[Unit] + def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { new Prom(null.asInstanceOf[A]) } diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index 553211a..5b1fcbe 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -18,17 +18,18 @@ private[async] final case class TransformUtils[C <: Context](c: C) { def suffixedName(prefix: String) = newTermName(suffix(prefix)) - val state = suffixedName("state") - val result = suffixedName("result") - val resume = suffixedName("resume") - val execContext = suffixedName("execContext") - val stateMachine = suffixedName("stateMachine") - val apply = newTermName("apply") - val tr = newTermName("tr") - val matchRes = "matchres" - val ifRes = "ifres" - val await = "await" - val bindSuffix = "$bind" + val state = suffixedName("state") + val result = suffixedName("result") + val resume = suffixedName("resume") + val execContext = suffixedName("execContext") + val stateMachine = newTermName(fresh("stateMachine")) + val stateMachineT = stateMachine.toTypeName + val apply = newTermName("apply") + val tr = newTermName("tr") + val matchRes = "matchres" + val ifRes = "ifres" + val await = "await" + val bindSuffix = "$bind" def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index 14749ca..9a31337 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -57,11 +57,21 @@ object TreeInterrogation extends App { val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:all") val tree = tb.parse( """ import _root_.scala.async.AsyncId._ - | async { val a = 0; val x = await(a) - 1; def foo(z: Any) = (a.toDouble, x.toDouble, z); foo(await(2)) } + | val state = 23 + | val result: Any = "result" + | def resume(): Any = "resume" + | val res = async { + | val f1 = async { state + 2 } + | val x = await(f1) + | val y = await(async { result }) + | val z = await(async { resume() }) + | (x, y, z) + | } + | () | """.stripMargin) println(tree) val tree1 = tb.typeCheck(tree.duplicate) println(cm.universe.show(tree1)) println(tb.eval(tree)) } -} \ No newline at end of file +} diff --git a/src/test/scala/scala/async/neg/LocalClasses0Spec.scala b/src/test/scala/scala/async/neg/LocalClasses0Spec.scala index 06a0e71..2569303 100644 --- a/src/test/scala/scala/async/neg/LocalClasses0Spec.scala +++ b/src/test/scala/scala/async/neg/LocalClasses0Spec.scala @@ -18,7 +18,7 @@ class LocalClasses0Spec { @Test def `reject a local class`() { - expectError("Local class Person illegal within `async` block", "-cp target/scala-2.10/classes -deprecation -Xfatal-warnings") { + expectError("Local case class Person illegal within `async` block") { """ | import scala.concurrent.ExecutionContext.Implicits.global | import scala.async.Async._ @@ -32,7 +32,7 @@ class LocalClasses0Spec { @Test def `reject a local class 2`() { - expectError("Local class Person illegal within `async` block", "-cp target/scala-2.10/classes -deprecation -Xfatal-warnings") { + expectError("Local case class Person illegal within `async` block") { """ | import scala.concurrent.{Future, ExecutionContext} | import ExecutionContext.Implicits.global @@ -50,7 +50,7 @@ class LocalClasses0Spec { @Test def `reject a local class 3`() { - expectError("Local class Person illegal within `async` block", "-cp target/scala-2.10/classes -deprecation -Xfatal-warnings") { + expectError("Local case class Person illegal within `async` block") { """ | import scala.concurrent.{Future, ExecutionContext} | import ExecutionContext.Implicits.global @@ -68,7 +68,7 @@ class LocalClasses0Spec { @Test def `reject a local class with symbols in its name`() { - expectError("Local class :: illegal within `async` block", "-cp target/scala-2.10/classes -deprecation -Xfatal-warnings") { + expectError("Local case class :: illegal within `async` block") { """ | import scala.concurrent.{Future, ExecutionContext} | import ExecutionContext.Implicits.global @@ -86,7 +86,7 @@ class LocalClasses0Spec { @Test def `reject a nested local class`() { - expectError("Local class Person illegal within `async` block", "-cp target/scala-2.10/classes -deprecation -Xfatal-warnings") { + expectError("Local case class Person illegal within `async` block") { """ | import scala.concurrent.{Future, ExecutionContext} | import ExecutionContext.Implicits.global @@ -110,7 +110,7 @@ class LocalClasses0Spec { @Test def `reject a local singleton object`() { - expectError("Local object Person illegal within `async` block", "-cp target/scala-2.10/classes -deprecation -Xfatal-warnings") { + expectError("Local object Person illegal within `async` block") { """ | import scala.concurrent.ExecutionContext.Implicits.global | import scala.async.Async._ diff --git a/src/test/scala/scala/async/run/hygiene/Hygiene.scala b/src/test/scala/scala/async/run/hygiene/Hygiene.scala index bb28d5b..5306ecc 100644 --- a/src/test/scala/scala/async/run/hygiene/Hygiene.scala +++ b/src/test/scala/scala/async/run/hygiene/Hygiene.scala @@ -30,28 +30,6 @@ class HygieneSpec { res mustBe ((25, "result", "resume")) } -/* TODO: -[error] /Users/phaller/git/async/src/test/scala/scala/async/run/hygiene/Hygiene.scala:52: not found: value tr$1 -[error] val f1 = async { state + 2 } -[error] ^ - @Test - def `is hygenic`() { - val state = 23 - val result: Any = "result" - def resume(): Any = "resume" - val res = async { - val f1 = async { state + 2 } - val x = await(f1) - val y = await(async { result }) - val z = await(async { resume() }) - (x, y, z) - } - res._1 mustBe (25) - res._2 mustBe ("result") - res._3 mustBe ("resume") - } -*/ - @Test def `external var as result of await`() { var ext = 0 @@ -88,39 +66,21 @@ class HygieneSpec { ext mustBe (14) } - trait T1 { - def blerg = 0 - } - - object O1 extends T1 { - override def blerg = 1 - - def check() { - val blerg = 3 - AsyncId.async { - assert(this == O1, this.getClass) - assert(this.blerg == 1) - assert(super.blerg == 0) - assert(super[T1].blerg == 0) - } - } - } - - @Test def `this reference is maintained`() { - O1.check() - } - - @Test def `this reference is maintained to local class`() { - object O2 { - def blerg = 2 - - def check() { - AsyncId.async { - assert(this.blerg == 2) - assert(this == O2, this.getClass) - } - } + @Test + def `is hygenic nested`() { + val state = 23 + val result: Any = "result" + def resume(): Any = "resume" + import AsyncId.{await, async} + val res = async { + val f1 = async { state + 2 } + val x = await(f1) + val y = await(async { result }) + val z = await(async { resume() }) + (x, y, z) } - O2.check() + res._1 mustBe (25) + res._2 mustBe ("result") + res._3 mustBe ("resume") } } -- cgit v1.2.3