diff options
Diffstat (limited to 'src/main/scala/scala/async/Async.scala')
-rw-r--r-- | src/main/scala/scala/async/Async.scala | 24 |
1 files changed, 15 insertions, 9 deletions
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index 4fe7c3f..09e002d 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -100,9 +100,9 @@ abstract class AsyncBase { // Important to retain the original declaration order here! val localVarTrees = anfTree.collect { - case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol => + case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol => utils.mkVarDefTree(tpt.tpe, renameMap(vd.symbol)) - case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) => + case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) if renameMap contains dd.symbol => DefDef(mods, renameMap(dd.symbol), tparams, vparamss, tpt, c.resetAllAttrs(utils.substituteNames(rhs, renameMap))) } @@ -113,10 +113,12 @@ abstract class AsyncBase { } val resumeFunTree = asyncBlock.resumeFunTree[T] - val stateMachine: ModuleDef = { + val stateMachineType = utils.applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType)) + + val stateMachine: ClassDef = { val body: List[Tree] = { val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) - val result = ValDef(NoMods, name.result, TypeTree(), futureSystemOps.createProm[T].tree) + val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree) val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree) val applyDefDef: DefDef = { val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) @@ -133,17 +135,16 @@ abstract class AsyncBase { List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef) } val template = { - val `Try[Any] => Unit` = utils.applied("scala.runtime.AbstractFunction1", List(defn.TryAnyType, definitions.UnitTpe)) - val `() => Unit` = utils.applied("scala.Function0", List(definitions.UnitTpe)) - Template(List(`Try[Any] => Unit`, `() => Unit`), emptyValDef, body) + Template(List(stateMachineType), emptyValDef, body) } - ModuleDef(NoMods, name.stateMachine, template) + ClassDef(NoMods, name.stateMachineT, Nil, template) } def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) val code = c.Expr[futureSystem.Fut[T]](Block(List[Tree]( stateMachine, + ValDef(NoMods, name.stateMachine, stateMachineType, New(Ident(name.stateMachineT), Nil)), futureSystemOps.future(c.Expr[Unit](Apply(selectStateMachine(name.apply), Nil))) (c.Expr[futureSystem.ExecContext](selectStateMachine(name.execContext))).tree ), @@ -152,7 +153,6 @@ abstract class AsyncBase { AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}") code - } def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) { @@ -169,3 +169,9 @@ abstract class AsyncBase { states foreach (s => AsyncUtils.vprintln(s)) } } + +/** Internal class used by the `async` macro; should not be manually extended by client code */ +abstract class StateMachine[Result, EC] extends (scala.util.Try[Any] => Unit) with (() => Unit) { + def result$async: Result + def execContext$async: EC +} |