diff options
Diffstat (limited to 'src/main/scala/scala/async/Async.scala')
-rw-r--r-- | src/main/scala/scala/async/Async.scala | 99 |
1 files changed, 67 insertions, 32 deletions
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index ef506a5..4a770ed 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -11,6 +11,7 @@ import scala.reflect.macros.Context * @author Philipp Haller */ object Async extends AsyncBase { + import scala.concurrent.Future lazy val futureSystem = ScalaConcurrentFutureSystem @@ -65,7 +66,6 @@ abstract class AsyncBase { def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { import c.universe._ - val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem) val anaylzer = AsyncAnalysis[c.type](c) val utils = TransformUtils[c.type](c) import utils.{name, defn} @@ -87,54 +87,82 @@ abstract class AsyncBase { // states of our generated state machine, e.g. a value assigned before // an `await` and read afterwards. val renameMap: Map[Symbol, TermName] = { - anaylzer.valDefsUsedInSubsequentStates(anfTree).map { + anaylzer.defTreesUsedInSubsequentStates(anfTree).map { vd => - (vd.symbol, name.fresh(vd.name)) + (vd.symbol, name.fresh(vd.name.toTermName)) }.toMap } + val builder = ExprBuilder[c.type, futureSystem.type](c, self.futureSystem, anfTree) val asyncBlock: builder.AsyncBlock = builder.build(anfTree, renameMap) import asyncBlock.asyncStates logDiagnostics(c)(anfTree, asyncStates.map(_.toString)) + // Important to retain the original declaration order here! val localVarTrees = anfTree.collect { - case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol => + case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol => utils.mkVarDefTree(tpt.tpe, renameMap(vd.symbol)) + case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) if renameMap contains dd.symbol => + DefDef(mods, renameMap(dd.symbol), tparams, vparamss, tpt, c.resetAllAttrs(utils.substituteNames(rhs, renameMap))) } - val onCompleteHandler = asyncBlock.onCompleteHandler + val onCompleteHandler = { + Function( + List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree)), + asyncBlock.onCompleteHandler) + } val resumeFunTree = asyncBlock.resumeFunTree[T] - val prom: Expr[futureSystem.Prom[T]] = reify { - // Create the empty promise - val result$async = futureSystemOps.createProm[T].splice - // Initialize the state - var state$async = 0 - // Resolve the execution context - val execContext$async = futureSystemOps.execContext.splice - var onCompleteHandler$async: util.Try[Any] => Unit = null - - // Spawn a future to: - futureSystemOps.future[Unit] { - c.Expr[Unit](Block( - // define vars for all intermediate results that are accessed from multiple states - localVarTrees :+ - // define the resume() method - resumeFunTree :+ - // assign onComplete function. (The var breaks the circular dependency with resume)` - Assign(Ident(name.onCompleteHandler), onCompleteHandler), - // and get things started by calling resume() - Apply(Ident(name.resume), Nil))) - }(c.Expr[futureSystem.ExecContext](Ident(name.execContext))).splice - // Return the promise from this reify block... - result$async + val stateMachineType = utils.applied("scala.async.StateMachine", List(futureSystemOps.promType[T], futureSystemOps.execContextType)) + + lazy val stateMachine: ClassDef = { + val body: List[Tree] = { + val stateVar = ValDef(Modifiers(Flag.MUTABLE), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) + val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T]), futureSystemOps.createProm[T].tree) + val execContext = ValDef(NoMods, name.execContext, TypeTree(), futureSystemOps.execContext.tree) + val applyDefDef: DefDef = { + val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) + val applyBody = asyncBlock.onCompleteHandler + DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), applyBody) + } + val apply0DefDef: DefDef = { + // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. + // See SI-1247 for the the optimization that avoids creatio + val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(defn.TryAnyType), EmptyTree))) + val applyBody = asyncBlock.onCompleteHandler + DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.resume), Nil)) + } + List(utils.emptyConstructor, stateVar, result, execContext) ++ localVarTrees ++ List(resumeFunTree, applyDefDef, apply0DefDef) + } + val template = { + Template(List(stateMachineType), emptyValDef, body) + } + ClassDef(NoMods, name.stateMachineT, Nil, template) } - // ... and return its Future from the macro. - val result = futureSystemOps.promiseToFuture(prom) - AsyncUtils.vprintln(s"async state machine transform expands to:\n ${result.tree}") + def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) + + def spawn(tree: Tree): Tree = + futureSystemOps.future(c.Expr[Unit](tree))(c.Expr[futureSystem.ExecContext](selectStateMachine(name.execContext))).tree + + val code: c.Expr[futureSystem.Fut[T]] = { + val isSimple = asyncStates.size == 1 + val tree = + if (isSimple) + Block(Nil, spawn(body.tree)) // generate lean code for the simple case of `async { 1 + 1 }` + else { + Block(List[Tree]( + stateMachine, + ValDef(NoMods, name.stateMachine, stateMachineType, New(Ident(name.stateMachineT), Nil)), + spawn(Apply(selectStateMachine(name.apply), Nil)) + ), + futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree) + } + c.Expr[futureSystem.Fut[T]](tree) + } - result + AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code.tree}") + code } def logDiagnostics(c: Context)(anfTree: c.Tree, states: Seq[String]) { @@ -151,3 +179,10 @@ abstract class AsyncBase { states foreach (s => AsyncUtils.vprintln(s)) } } + +/** Internal class used by the `async` macro; should not be manually extended by client code */ +abstract class StateMachine[Result, EC] extends (scala.util.Try[Any] => Unit) with (() => Unit) { + def result$async: Result + + def execContext$async: EC +} |