diff options
Diffstat (limited to 'src/main')
-rw-r--r-- | src/main/scala/scala/async/AnfTransform.scala | 23 | ||||
-rw-r--r-- | src/main/scala/scala/async/Async.scala | 24 | ||||
-rw-r--r-- | src/main/scala/scala/async/AsyncAnalysis.scala | 3 | ||||
-rw-r--r-- | src/main/scala/scala/async/FutureSystem.scala | 9 | ||||
-rw-r--r-- | src/main/scala/scala/async/TransformUtils.scala | 23 |
5 files changed, 58 insertions, 24 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 055676d..d216e44 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -10,7 +10,9 @@ import scala.reflect.macros.Context private[async] final case class AnfTransform[C <: Context](c: C) { import c.universe._ + val utils = TransformUtils[c.type](c) + import utils._ def apply(tree: Tree): List[Tree] = { @@ -29,9 +31,21 @@ private[async] final case class AnfTransform[C <: Context](c: C) { * This step is needed to allow us to safely merge blocks during the `inline` transform below. */ private final class UniqueNames(tree: Tree) extends Transformer { - val repeatedNames: Set[Name] = tree.collect { - case dt: DefTree => dt.symbol.name - }.groupBy(x => x).filter(_._2.size > 1).keySet + class DuplicateNameTraverser extends AsyncTraverser { + val result = collection.mutable.Buffer[Name]() + + override def traverse(tree: Tree) { + tree match { + case dt: DefTree => result += dt.symbol.name + case _ => super.traverse(tree) + } + } + } + val repeatedNames: Set[Name] = { + val dupNameTraverser = new DuplicateNameTraverser + dupNameTraverser.traverse(tree) + dupNameTraverser.result.groupBy(x => x).filter(_._2.size > 1).keySet + } /** Stepping outside of the public Macro API to call [[scala.reflect.internal.Symbols.Symbol.name_=]] */ val symtab = c.universe.asInstanceOf[reflect.internal.SymbolTable] @@ -81,7 +95,9 @@ private[async] final case class AnfTransform[C <: Context](c: C) { private object trace { private var indent = -1 + def indentString = " " * indent + def apply[T](prefix: String, args: Any)(t: => T): T = { indent += 1 def oneLine(s: Any) = s.toString.replaceAll( """\n""", "\\\\n").take(127) @@ -242,4 +258,5 @@ private[async] final case class AnfTransform[C <: Context](c: C) { } } } + } diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index 4fe7c3f..09e002d 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -100,9 +100,9 @@ abstract class AsyncBase { // Important to retain the original declaration order here! val localVarTrees = anfTree.collect { - case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol => + case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol => utils.mkVarDefTree(tpt.tpe, renameMap(vd.symbol)) - case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) => + case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) if renameMap contains dd.symbol => DefDef(mods, renameMap(dd.symbol), tparams, vparamss, tpt, c.resetAllAttrs(utils.substituteNames(rhs, renameMap))) } @@ -113,10 +113,12 @@ abstract class AsyncBase { } val resumeFunTree = asyncBlock.resumeFunTree[T] - val stateMachine: ModuleDef = { + val stateMachineType = utils.applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType)) + + val stateMachine: ClassDef = { val body: List[Tree] = { val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) - val result = ValDef(NoMods, name.result, TypeTree(), futureSystemOps.createProm[T].tree) + val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree) val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree) val applyDefDef: DefDef = { val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) @@ -133,17 +135,16 @@ abstract class AsyncBase { List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef) } val template = { - val `Try[Any] => Unit` = utils.applied("scala.runtime.AbstractFunction1", List(defn.TryAnyType, definitions.UnitTpe)) - val `() => Unit` = utils.applied("scala.Function0", List(definitions.UnitTpe)) - Template(List(`Try[Any] => Unit`, `() => Unit`), emptyValDef, body) + Template(List(stateMachineType), emptyValDef, body) } - ModuleDef(NoMods, name.stateMachine, template) + ClassDef(NoMods, name.stateMachineT, Nil, template) } def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) val code = c.Expr[futureSystem.Fut[T]](Block(List[Tree]( stateMachine, + ValDef(NoMods, name.stateMachine, stateMachineType, New(Ident(name.stateMachineT), Nil)), futureSystemOps.future(c.Expr[Unit](Apply(selectStateMachine(name.apply), Nil))) (c.Expr[futureSystem.ExecContext](selectStateMachine(name.execContext))).tree ), @@ -152,7 +153,6 @@ abstract class AsyncBase { AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}") code - } def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) { @@ -169,3 +169,9 @@ abstract class AsyncBase { states foreach (s => AsyncUtils.vprintln(s)) } } + +/** Internal class used by the `async` macro; should not be manually extended by client code */ +abstract class StateMachine[Result, EC] extends (scala.util.Try[Any] => Unit) with (() => Unit) { + def result$async: Result + def execContext$async: EC +} diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index 6e281e4..8bb5bcd 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -43,7 +43,8 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" if (!reportUnsupportedAwait(classDef, s"nested $kind")) { // do not allow local class definitions, because of SI-5467 (specific to case classes, though) - c.error(classDef.pos, s"Local class ${classDef.name.decoded} illegal within `async` block") + if (classDef.symbol.asClass.isCaseClass) + c.error(classDef.pos, s"Local case class ${classDef.name.decoded} illegal within `async` block") } } diff --git a/src/main/scala/scala/async/FutureSystem.scala b/src/main/scala/scala/async/FutureSystem.scala index 20bbea3..e9373b3 100644 --- a/src/main/scala/scala/async/FutureSystem.scala +++ b/src/main/scala/scala/async/FutureSystem.scala @@ -33,6 +33,9 @@ trait FutureSystem { /** Lookup the execution context, typically with an implicit search */ def execContext: Expr[ExecContext] + def promType[A: WeakTypeTag]: Type + def execContextType: Type + /** Create an empty promise */ def createProm[A: WeakTypeTag]: Expr[Prom[A]] @@ -71,6 +74,9 @@ object ScalaConcurrentFutureSystem extends FutureSystem { case context => context }) + def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Promise[A]] + def execContextType: Type = c.weakTypeOf[ExecutionContext] + def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { Promise[A]() } @@ -113,6 +119,9 @@ object IdentityFutureSystem extends FutureSystem { def execContext: Expr[ExecContext] = c.literalUnit + def promType[A: WeakTypeTag]: Type = c.weakTypeOf[Prom[A]] + def execContextType: Type = c.weakTypeOf[Unit] + def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { new Prom(null.asInstanceOf[A]) } diff --git a/src/main/scala/scala/async/TransformUtils.scala b/src/main/scala/scala/async/TransformUtils.scala index 553211a..5b1fcbe 100644 --- a/src/main/scala/scala/async/TransformUtils.scala +++ b/src/main/scala/scala/async/TransformUtils.scala @@ -18,17 +18,18 @@ private[async] final case class TransformUtils[C <: Context](c: C) { def suffixedName(prefix: String) = newTermName(suffix(prefix)) - val state = suffixedName("state") - val result = suffixedName("result") - val resume = suffixedName("resume") - val execContext = suffixedName("execContext") - val stateMachine = suffixedName("stateMachine") - val apply = newTermName("apply") - val tr = newTermName("tr") - val matchRes = "matchres" - val ifRes = "ifres" - val await = "await" - val bindSuffix = "$bind" + val state = suffixedName("state") + val result = suffixedName("result") + val resume = suffixedName("resume") + val execContext = suffixedName("execContext") + val stateMachine = newTermName(fresh("stateMachine")) + val stateMachineT = stateMachine.toTypeName + val apply = newTermName("apply") + val tr = newTermName("tr") + val matchRes = "matchres" + val ifRes = "ifres" + val await = "await" + val bindSuffix = "$bind" def fresh(name: TermName): TermName = newTermName(fresh(name.toString)) |