diff options
Diffstat (limited to 'src/main/scala')
-rw-r--r-- | src/main/scala/scala/async/Async.scala | 250 | ||||
-rw-r--r-- | src/main/scala/scala/async/AsyncUtils.scala | 22 | ||||
-rw-r--r-- | src/main/scala/scala/async/ExprBuilder.scala | 444 | ||||
-rw-r--r-- | src/main/scala/scala/async/FutureSystem.scala | 135 |
4 files changed, 434 insertions, 417 deletions
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index d4b950e..4f7fa01 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -7,131 +7,161 @@ import language.experimental.macros import scala.reflect.macros.Context import scala.collection.mutable.ListBuffer -import scala.concurrent.{ Future, Promise, ExecutionContext, future } +import scala.concurrent.{Future, Promise, ExecutionContext, future} import ExecutionContext.Implicits.global import scala.util.control.NonFatal -import scala.util.continuations.{ shift, reset, cpsParam } +import AsyncUtils.vprintln -/* Extending `ControlThrowable`, by default, also avoids filling in the stack trace. */ -class FallbackToCpsException extends scala.util.control.ControlThrowable /* * @author Philipp Haller */ -object Async extends AsyncUtils { +object Async extends AsyncBase { + lazy val futureSystem = ScalaConcurrentFutureSystem + type FS = ScalaConcurrentFutureSystem.type - def async[T](body: T): Future[T] = macro asyncImpl[T] + def async[T](body: T) = macro asyncImpl[T] + override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = super.asyncImpl[T](c)(body) +} + +object AsyncId extends AsyncBase { + lazy val futureSystem = IdentityFutureSystem + type FS = IdentityFutureSystem.type + + def async[T](body: T) = macro asyncImpl[T] + + override def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = super.asyncImpl[T](c)(body) +} + +/** + * A base class for the `async` macro. Subclasses must provide: + * + * - Concrete types for a given future system + * - Tree manipulations to create and complete the equivalent of Future and Promise + * in that system. + * - The `async` macro declaration itself, and a forwarder for the macro implementation. + * (The latter is temporarily needed to workaround bug SI-6650 in the macro system) + * + * The default implementation, [[scala.async.Async]], binds the macro to `scala.concurrent._`. + */ +abstract class AsyncBase { + self => + + type FS <: FutureSystem + val futureSystem: FS + + /** + * A call to `await` must be nested in an enclosing `async` block. + * + * A call to await does not block the thread, rather it is a delimiter + * used by the enclosing `async` macro. Code following the `await` + * call. + * + * @param awaitable The future from which a value is awaited + * @tparam T The type of that value + * @return The value + */ // TODO Replace with `@compileTimeOnly when this is implemented SI-6539 @deprecated("`await` must be enclosed in an `async` block", "0.1") - def await[T](awaitable: Future[T]): T = ??? - - /* Fall back for `await` when it is called at an unsupported position. - */ - def awaitCps[T, U](awaitable: Future[T], p: Promise[U]): T @cpsParam[U, Unit] = - shift { - (k: (T => U)) => - awaitable onComplete { - case tr => p.success(k(tr.get)) - } - } - - def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[Future[T]] = { + def await[T](awaitable: futureSystem.Fut[T]): T = ??? + + def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { import c.universe._ import Flag._ - - val builder = new ExprBuilder[c.type](c) - val awaitMethod = awaitSym(c) - - try { - body.tree match { - case Block(stats, expr) => - val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000, Map()) - - vprintln(s"states of current method:") - asyncBlockBuilder.asyncStates foreach vprintln - - val handlerExpr = asyncBlockBuilder.mkCombinedHandlerExpr() - - vprintln(s"GENERATED handler expr:") - vprintln(handlerExpr) - - val handlerForLastState: c.Expr[PartialFunction[Int, Unit]] = { - val tree = Apply(Select(Ident("result"), newTermName("success")), - List(asyncBlockBuilder.asyncStates.last.body)) - builder.mkHandler(asyncBlockBuilder.asyncStates.last.state, c.Expr[Unit](tree)) - } - - vprintln("GENERATED handler for last state:") - vprintln(handlerForLastState) - - val localVarTrees = asyncBlockBuilder.asyncStates.init.flatMap(_.allVarDefs).toList - - /* - def resume(): Unit = { - try { - (handlerExpr.splice orElse handlerForLastState.splice)(state) - } catch { - case NonFatal(t) => result.failure(t) - } - } - */ - val nonFatalModule = c.mirror.staticModule("scala.util.control.NonFatal") - val resumeFunTree: c.Tree = DefDef(Modifiers(), newTermName("resume"), List(), List(List()), Ident(definitions.UnitClass), - Try(Apply(Select( - Apply(Select(handlerExpr.tree, newTermName("orElse")), List(handlerForLastState.tree)), - newTermName("apply")), List(Ident(newTermName("state")))), - List( - CaseDef( - Apply(Ident(nonFatalModule), List(Bind(newTermName("t"), Ident(nme.WILDCARD)))), - EmptyTree, - Block(List( - Apply(Select(Ident(newTermName("result")), newTermName("failure")), List(Ident(newTermName("t"))))), - Literal(Constant(()))))), EmptyTree)) - - reify { - val result = Promise[T]() - var state = 0 - future { - c.Expr(Block( - localVarTrees :+ resumeFunTree, - Apply(Ident(newTermName("resume")), List()))).splice - } - result.future - } - - case _ => - // issue error message - reify { - sys.error("expression not supported by async") - } - } - } catch { - case _: FallbackToCpsException => - // replace `await` invocations with `awaitCps` invocations - val awaitReplacer = new Transformer { - val awaitCpsMethod = awaitCpsSym(c) - override def transform(tree: Tree): Tree = tree match { - case Apply(fun @ TypeApply(_, List(futArgTpt)), args) if fun.symbol == awaitMethod => - val typeApp = treeCopy.TypeApply(fun, Ident(awaitCpsMethod), List(TypeTree(futArgTpt.tpe), TypeTree(body.tree.tpe))) - treeCopy.Apply(tree, typeApp, args.map(arg => c.resetAllAttrs(arg.duplicate)) :+ Ident(newTermName("p"))) - - case _ => - super.transform(tree) - } + + val builder = new ExprBuilder[c.type, futureSystem.type](c, self.futureSystem) + + import builder.defn._ + import builder.name + import builder.futureSystemOps + val (stats, expr) = body.tree match { + case Block(stats, expr) => (stats, expr) + case tree => (Nil, tree) + } + + val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000, Map()) + + asyncBlockBuilder.asyncStates foreach (s => vprintln(s)) + + val handlerCases: List[CaseDef] = asyncBlockBuilder.mkCombinedHandlerCases[T]() + + val initStates = asyncBlockBuilder.asyncStates.init + val localVarTrees = initStates.flatMap(_.allVarDefs).toList + + /* + lazy val onCompleteHandler = (tr: Try[Any]) => state match { + case 0 => { + x11 = tr.get.asInstanceOf[Double]; + state = 1; + resume() } - - val newBody = awaitReplacer.transform(body.tree) - - reify { - val p = Promise[T]() - future { - reset { - c.Expr(c.resetAllAttrs(newBody.duplicate)).asInstanceOf[c.Expr[T]].splice - } - } - p.future + ... + */ + val onCompleteHandler = { + val onCompleteHandlers = initStates.flatMap(_.mkOnCompleteHandler).toList + Function( + List(ValDef(Modifiers(PARAM), name.tr, TypeTree(TryAnyType), EmptyTree)), + Match(Ident(name.state), onCompleteHandlers)) + } + + /* + def resume(): Unit = { + try { + state match { + case 0 => { + f11 = exprReturningFuture + f11.onComplete(onCompleteHandler)(context) + } + ... + } + } catch { + case NonFatal(t) => result.failure(t) } + } + */ + val resumeFunTree: c.Tree = DefDef(Modifiers(), name.resume, Nil, List(Nil), Ident(definitions.UnitClass), + Try( + Match(Ident(name.state), handlerCases), + List( + CaseDef( + Apply(Ident(NonFatalClass), List(Bind(name.tr, Ident(nme.WILDCARD)))), + EmptyTree, + Block(List({ + val t = c.Expr[Throwable](Ident(name.tr)) + futureSystemOps.completeProm[T](c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Failure(t.splice))).tree + }), c.literalUnit.tree))), EmptyTree)) + + + val prom: Expr[futureSystem.Prom[T]] = reify { + // Create the empty promise + val result$async = futureSystemOps.createProm[T].splice + // Initialize the state + var state$async = 0 + // Resolve the execution context + val execContext$async = futureSystemOps.execContext.splice + var onCompleteHandler$async: util.Try[Any] => Unit = null + + // Spawn a future to: + futureSystemOps.future[Unit] { + c.Expr[Unit](Block( + // define vars for all intermediate results + localVarTrees :+ + // define the resume() method + resumeFunTree :+ + // assign onComplete function. (The var breaks the circular dependency with resume)` + Assign(Ident(name.onCompleteHandler), onCompleteHandler), + // and get things started by calling resume() + Apply(Ident(name.resume), Nil))) + }(c.Expr[futureSystem.ExecContext](Ident(name.execContext))).splice + // Return the promise from this reify block... + result$async } + // ... and return its Future from the macro. + val result = futureSystemOps.promiseToFuture(prom) + + vprintln(s"${c.macroApplication} \nexpands to:\n ${result.tree}") + + result } } diff --git a/src/main/scala/scala/async/AsyncUtils.scala b/src/main/scala/scala/async/AsyncUtils.scala index d288d34..77c155f 100644 --- a/src/main/scala/scala/async/AsyncUtils.scala +++ b/src/main/scala/scala/async/AsyncUtils.scala @@ -3,29 +3,13 @@ */ package scala.async -import scala.reflect.macros.Context - /* * @author Philipp Haller */ -trait AsyncUtils { +object AsyncUtils { - val verbose = false + private val verbose = false - protected def vprintln(s: Any): Unit = if (verbose) + private[async] def vprintln(s: => Any): Unit = if (verbose) println("[async] "+s) - - /* Symbol of the `Async.await` method in context `c`. - */ - protected def awaitSym(c: Context): c.universe.Symbol = { - val asyncMod = c.mirror.staticModule("scala.async.Async") - val tpe = asyncMod.moduleClass.asType.toType - tpe.member(c.universe.newTermName("await")) - } - - protected def awaitCpsSym(c: Context): c.universe.Symbol = { - val asyncMod = c.mirror.staticModule("scala.async.Async") - val tpe = asyncMod.moduleClass.asType.toType - tpe.member(c.universe.newTermName("awaitCps")) - } } diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 0dcc074..56274ec 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -4,53 +4,51 @@ package scala.async import scala.reflect.macros.Context -import scala.collection.mutable.{ListBuffer, Builder} +import scala.collection.mutable.ListBuffer +import concurrent.Future +import AsyncUtils.vprintln /* * @author Philipp Haller */ -class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { +final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSystem: FS) { builder => + lazy val futureSystemOps = futureSystem.mkOps(c) + import c.universe._ import Flag._ + import defn._ - private val awaitMethod = awaitSym(c) + object name { + def suffix(string: String) = string + "$async" - /* Make a partial function literal handling case #num: - * - * { - * case any if any == num => rhs - * } - */ - def mkHandler(num: Int, rhs: c.Expr[Unit]): c.Expr[PartialFunction[Int, Unit]] = { - /* - val numLiteral = c.Expr[Int](Literal(Constant(num))) - - reify(new PartialFunction[Int, Unit] { - def isDefinedAt(`x$1`: Int) = - `x$1` == numLiteral.splice - def apply(`x$1`: Int) = `x$1` match { - case any: Int if any == numLiteral.splice => - rhs.splice - } - }) - */ - val rhsTree = c.resetAllAttrs(rhs.tree.duplicate) - val handlerTree = mkHandlerTree(num, rhsTree) - c.Expr(handlerTree).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] - } + def suffixedName(prefix: String) = newTermName(suffix(prefix)) + + val state = suffixedName("state") + val result = suffixedName("result") + val resume = suffixedName("resume") + val execContext = suffixedName("execContext") + + // TODO do we need to freshen any of these? + val x1 = newTermName("x$1") + val tr = newTermName("tr") + val onCompleteHandler = suffixedName("onCompleteHandler") - def mkIncrStateTree(): c.Tree = { - Assign( - Ident(newTermName("state")), - Apply(Select(Ident(newTermName("state")), newTermName("$plus")), List(Literal(Constant(1))))) + def fresh(name: TermName) = newTermName(c.fresh("" + name + "$")) } + private val execContext = futureSystemOps.execContext + + private def resetDuplicate(tree: Tree) = c.resetAllAttrs(tree.duplicate) + + private def mkResumeApply = Apply(Ident(name.resume), Nil) + def mkStateTree(nextState: Int): c.Tree = - Assign( - Ident(newTermName("state")), - Literal(Constant(nextState))) + mkStateTree(c.literal(nextState).tree) + + def mkStateTree(nextState: Tree): c.Tree = + Assign(Ident(name.state), nextState) def defaultValue(tpe: Type): Literal = { val defaultValue: Any = @@ -64,62 +62,37 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType)) } - def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = - CaseDef( - // pattern - Bind(newTermName("any"), Typed(Ident(nme.WILDCARD), Ident(definitions.IntClass))), - // guard - Apply(Select(Ident(newTermName("any")), newTermName("$eq$eq")), List(Literal(Constant(num)))), - rhs - ) - - def mkHandlerTreeFor(cases: List[(CaseDef, Int)]): c.Tree = { - val partFunIdent = Ident(c.mirror.staticClass("scala.PartialFunction")) - val intIdent = Ident(definitions.IntClass) - val unitIdent = Ident(definitions.UnitClass) - - val caseCheck = - Apply(Select(Apply(Ident(definitions.List_apply), - cases.map(p => Literal(Constant(p._2)))), newTermName("contains")), List(Ident(newTermName("x$1")))) - - Block(List( - // anonymous subclass of PartialFunction[Int, Unit] - // TODO subclass AbstractPartialFunction - ClassDef(Modifiers(FINAL), newTypeName("$anon"), List(), Template(List(AppliedTypeTree(partFunIdent, List(intIdent, unitIdent))), - emptyValDef, List( - DefDef(Modifiers(), nme.CONSTRUCTOR, List(), List(List()), TypeTree(), - Block(List(Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), List())), Literal(Constant(())))), - - DefDef(Modifiers(), newTermName("isDefinedAt"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(), - caseCheck), - - DefDef(Modifiers(), newTermName("apply"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(), - Match(Ident(newTermName("x$1")), cases.map(_._1)) // combine all cases into a single match - ) - )) - )), - Apply(Select(New(Ident(newTypeName("$anon"))), nme.CONSTRUCTOR), List()) - ) - } + def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef = + mkHandlerCase(num, Block(rhs: _*)) - def mkHandlerTree(num: Int, rhs: c.Tree): c.Tree = - mkHandlerTreeFor(List(mkHandlerCase(num, rhs) -> num)) + def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = + CaseDef(c.literal(num).tree, EmptyTree, rhs) class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) { - val body: c.Tree = - if (stats.size == 1) stats.head - else Block(stats: _*) + val body: c.Tree = stats match { + case stat :: Nil => stat + case _ => Block(stats: _*) + } - val varDefs: List[(TermName, Type)] = List() + val varDefs: List[(TermName, Type)] = Nil def mkHandlerCaseForState(): CaseDef = - mkHandlerCase(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*)) - - def mkHandlerTreeForState(): c.Tree = - mkHandlerTree(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*)) - - def mkHandlerTreeForState(nextState: Int): c.Tree = - mkHandlerTree(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*)) + mkHandlerCase(state, stats :+ mkStateTree(nextState) :+ mkResumeApply) + + def mkOnCompleteHandler(): Option[CaseDef] = { + this match { + case aw: AsyncStateWithAwait => + val tryGetTree = + Assign( + Ident(aw.resultName), + TypeApply(Select(Select(Ident(name.tr), Try_get), newTermName("asInstanceOf")), List(TypeTree(aw.resultType))) + ) + val updateState = mkStateTree(nextState) // or increment? + Some(mkHandlerCase(state, List(tryGetTree, updateState, mkResumeApply))) + case _ => + None + } + } def varDefForResult: Option[c.Tree] = None @@ -135,13 +108,8 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { extends AsyncState(stats, state, 0) { // nextState unused, since encoded in then and else branches - override def mkHandlerTreeForState(): c.Tree = - mkHandlerTree(state, Block(stats: _*)) - - //TODO mkHandlerTreeForState(nextState: Int) - override def mkHandlerCaseForState(): CaseDef = - mkHandlerCase(state, Block(stats: _*)) + mkHandlerCase(state, stats) override val toString: String = s"AsyncStateWithIf #$state, next = $nextState" @@ -153,141 +121,18 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { val resultName: TermName val resultType: Type + protected def tryType = appliedType(TryClass.toType, List(resultType)) + override val toString: String = s"AsyncStateWithAwait #$state, next = $nextState" - /* Make an `onComplete` invocation: - * - * awaitable.onComplete { - * case tr => - * resultName = tr.get - * resume() - * } - */ - def mkOnCompleteTree: c.Tree = { - val assignTree = - Assign( - Ident(resultName.toString), - Select(Ident("tr"), newTermName("get")) - ) - val handlerTree = - Match( - EmptyTree, - List( - CaseDef(Bind(newTermName("tr"), Ident("_")), EmptyTree, - Block(assignTree, Apply(Ident("resume"), List())) // rhs of case - ) - ) - ) - Apply( - Select(awaitable, newTermName("onComplete")), - List(handlerTree) - ) - } - - /* Make an `onComplete` invocation which increments the state upon resuming: - * - * awaitable.onComplete { - * case tr => - * resultName = tr.get - * state += 1 - * resume() - * } - */ - def mkOnCompleteIncrStateTree: c.Tree = { - val tryGetTree = - Assign( - Ident(resultName.toString), - Select(Ident("tr"), newTermName("get")) - ) - val handlerTree = - Match( - EmptyTree, - List( - CaseDef(Bind(newTermName("tr"), Ident("_")), EmptyTree, - Block(tryGetTree, mkIncrStateTree(), Apply(Ident("resume"), List())) // rhs of case - ) - ) - ) - Apply( - Select(awaitable, newTermName("onComplete")), - List(handlerTree) - ) - } - - /* Make an `onComplete` invocation which sets the state to `nextState` upon resuming: - * - * awaitable.onComplete { - * case tr => - * resultName = tr.get - * state = `nextState` - * resume() - * } - */ - def mkOnCompleteStateTree(nextState: Int): c.Tree = { - val tryGetTree = - Assign( - Ident(resultName.toString), - Select(Ident("tr"), newTermName("get")) - ) - val handlerTree = - Match( - EmptyTree, - List( - CaseDef(Bind(newTermName("tr"), Ident("_")), EmptyTree, - Block(tryGetTree, mkStateTree(nextState), Apply(Ident("resume"), List())) // rhs of case - ) - ) - ) - Apply( - Select(awaitable, newTermName("onComplete")), - List(handlerTree) - ) - } - - /* Make a partial function literal handling case #num: - * - * { - * case any if any == num => - * stats - * awaitable.onComplete { - * case tr => - * resultName = tr.get - * resume() - * } - * } - */ - def mkHandlerForState(num: Int): c.Expr[PartialFunction[Int, Unit]] = { - assert(awaitable != null) - builder.mkHandler(num, c.Expr[Unit](Block((stats :+ mkOnCompleteTree): _*))) - } - - /* Make a partial function literal handling case #num: - * - * { - * case any if any == num => - * stats - * awaitable.onComplete { - * case tr => - * resultName = tr.get - * state += 1 - * resume() - * } - * } - */ - override def mkHandlerTreeForState(): c.Tree = { - assert(awaitable != null) - mkHandlerTree(state, Block((stats :+ mkOnCompleteIncrStateTree): _*)) - } - - override def mkHandlerTreeForState(nextState: Int): c.Tree = { - assert(awaitable != null) - mkHandlerTree(state, Block((stats :+ mkOnCompleteStateTree(nextState)): _*)) + private def mkOnCompleteTree: c.Tree = { + futureSystemOps.onComplete(c.Expr(awaitable), c.Expr(Ident(name.onCompleteHandler)), c.Expr(Ident(name.execContext))).tree } override def mkHandlerCaseForState(): CaseDef = { assert(awaitable != null) - mkHandlerCase(state, Block((stats :+ mkOnCompleteIncrStateTree): _*)) + mkHandlerCase(state, stats :+ mkOnCompleteTree) } override def varDefForResult: Option[c.Tree] = @@ -297,7 +142,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { /* * Builder for a single state of an async method. */ - class AsyncStateBuilder(state: Int, private var nameMap: Map[c.Symbol, c.Name]) extends Builder[c.Tree, AsyncState] { + class AsyncStateBuilder(state: Int, private var nameMap: Map[c.Symbol, c.Name]) { self => /* Statements preceding an await call. */ @@ -326,7 +171,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { } def +=(stat: c.Tree): this.type = { - stats += c.resetAllAttrs(renamer.transform(stat).duplicate) + stats += resetDuplicate(renamer.transform(stat)) this } @@ -334,7 +179,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { def addVarDef(mods: Any, name: TermName, tpt: c.Tree, rhs: c.Tree, extNameMap: Map[c.Symbol, c.Name]): this.type = { varDefs += (name -> tpt.tpe) nameMap ++= extNameMap // update name map - this += Assign(Ident(name), c.resetAllAttrs(renamer.transform(rhs).duplicate)) + this += Assign(Ident(name), rhs) this } @@ -351,13 +196,6 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { override val varDefs = self.varDefs.toList } - def clear(): Unit = { - stats.clear() - awaitable = null - resultName = null - resultType = null - } - /* Result needs to be created as a var at the beginning of the transformed method body, so that * it is visible in subsequent states of the state machine. * @@ -368,7 +206,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { def complete(awaitArg: c.Tree, awaitResultName: TermName, awaitResultType: Tree, extNameMap: Map[c.Symbol, c.Name], nextState: Int = state + 1): this.type = { nameMap ++= extNameMap - awaitable = c.resetAllAttrs(renamer.transform(awaitArg).duplicate) + awaitable = resetDuplicate(renamer.transform(awaitArg)) resultName = awaitResultName resultType = awaitResultType.tpe this.nextState = nextState @@ -383,20 +221,20 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = { // 1. build changed if-else tree // 2. insert that tree at the end of the current state - val cond = c.resetAllAttrs(condTree.duplicate) + val cond = resetDuplicate(condTree) this += If(cond, - Block(mkStateTree(thenState), Apply(Ident("resume"), List())), - Block(mkStateTree(elseState), Apply(Ident("resume"), List()))) + Block(mkStateTree(thenState), mkResumeApply), + Block(mkStateTree(elseState), mkResumeApply)) new AsyncStateWithoutAwait(stats.toList, state) { override val varDefs = self.varDefs.toList } } - + /** * Build `AsyncState` ending with a match expression. - * + * * The cases of the match simply resume at the state of their corresponding right-hand side. - * + * * @param scrutTree tree of the scrutinee * @param cases list of case definitions * @param stateFirstCase state of the right-hand side of the first case @@ -406,15 +244,15 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], stateFirstCase: Int, perCasebudget: Int): AsyncState = { // 1. build list of changed cases val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { - case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(num * perCasebudget + stateFirstCase), Apply(Ident("resume"), List()))) + case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(num * perCasebudget + stateFirstCase), mkResumeApply)) } // 2. insert changed match tree at the end of the current state - this += Match(c.resetAllAttrs(scrutTree.duplicate), newCases) + this += Match(resetDuplicate(scrutTree), newCases) new AsyncStateWithoutAwait(stats.toList, state) { override val varDefs = self.varDefs.toList } } - + override def toString: String = { val statsBeforeAwait = stats.mkString("\n") s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName" @@ -423,7 +261,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { /** * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`). - * + * * @param stats a list of expressions * @param expr the last expression of the block * @param startState the start state @@ -441,28 +279,28 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { private var remainingBudget = budget - /* Fall back to CPS plug-in if tree contains an `await` call. */ + /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists { - case Apply(fun, _) if fun.symbol == awaitMethod => true + case Apply(fun, _) if fun.symbol == Async_await => true case _ => false - }) throw new FallbackToCpsException - + }) c.abort(tree.pos, "await unsupported in this position") //throw new FallbackToCpsException + def builderForBranch(tree: c.Tree, state: Int, nextState: Int, budget: Int, nameMap: Map[c.Symbol, c.Name]): AsyncBlockBuilder = { val (branchStats, branchExpr) = tree match { case Block(s, e) => (s, e) - case _ => (List(tree), Literal(Constant(()))) + case _ => (List(tree), c.literalUnit.tree) } new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, budget, nameMap) } - + // populate asyncStates for (stat <- stats) stat match { // the val name = await(..) pattern - case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == awaitMethod => - val newName = c.fresh(name) + case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == Async_await => + val newName = builder.name.fresh(name) toRename += (stat.symbol -> newName) - asyncStates += stateBuilder.complete(args(0), newName, tpt, toRename).result // complete with await + asyncStates += stateBuilder.complete(args.head, newName, tpt, toRename).result // complete with await if (remainingBudget > 0) remainingBudget -= 1 else @@ -473,7 +311,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { case ValDef(mods, name, tpt, rhs) => checkForUnsupportedAwait(rhs) - val newName = c.fresh(name) + val newName = builder.name.fresh(name) toRename += (stat.symbol -> newName) // when adding assignment need to take `toRename` into account stateBuilder.addVarDef(mods, newName, tpt, rhs, toRename) @@ -491,52 +329,53 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { asyncStates += // the two Int arguments are the start state of the then branch and the else branch, respectively stateBuilder.resultWithIf(cond, currState + 1, currState + thenBudget) - - List((thenp, currState + 1, thenBudget), (elsep, currState + thenBudget, elseBudget)) foreach { case (tree, state, branchBudget) => - val builder = builderForBranch(tree, state, currState + ifBudget, branchBudget, toRename) - asyncStates ++= builder.asyncStates - toRename ++= builder.toRename + + List((thenp, currState + 1, thenBudget), (elsep, currState + thenBudget, elseBudget)) foreach { + case (tree, state, branchBudget) => + val builder = builderForBranch(tree, state, currState + ifBudget, branchBudget, toRename) + asyncStates ++= builder.asyncStates + toRename ++= builder.toRename } - + // create new state builder for state `currState + ifBudget` currState = currState + ifBudget stateBuilder = new builder.AsyncStateBuilder(currState, toRename) - + case Match(scrutinee, cases) => vprintln("transforming match expr: " + stat) checkForUnsupportedAwait(scrutinee) - + val matchBudget: Int = remainingBudget / 2 remainingBudget -= matchBudget //TODO test if budget > 0 // state that we continue with after match: currState + matchBudget - + val perCaseBudget: Int = matchBudget / cases.size asyncStates += // the two Int arguments are the start state of the first case and the per-case state budget, respectively stateBuilder.resultWithMatch(scrutinee, cases, currState + 1, perCaseBudget) - + for ((cas, num) <- cases.zipWithIndex) { val (casStats, casExpr) = cas match { case CaseDef(_, _, Block(s, e)) => (s, e) - case CaseDef(_, _, rhs) => (List(rhs), Literal(Constant(()))) + case CaseDef(_, _, rhs) => (List(rhs), c.literalUnit.tree) } val builder = new AsyncBlockBuilder(casStats, casExpr, currState + (num * perCaseBudget) + 1, currState + matchBudget, perCaseBudget, toRename) asyncStates ++= builder.asyncStates toRename ++= builder.toRename } - + // create new state builder for state `currState + matchBudget` currState = currState + matchBudget stateBuilder = new builder.AsyncStateBuilder(currState, toRename) - + case ClassDef(_, name, _, _) => // do not allow local class definitions, because of SI-5467 (specific to case classes, though) c.error(stat.pos, s"Local class ${name.decoded} illegal within `async` block") - + case ModuleDef(_, name, _) => // local object definitions lead to spurious type errors (because of resetAllAttrs?) c.error(stat.pos, s"Local object ${name.decoded} illegal within `async` block") - + case _ => checkForUnsupportedAwait(stat) stateBuilder += stat @@ -546,36 +385,65 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { val lastState = stateBuilder.complete(endState).result asyncStates += lastState - def mkCombinedHandlerExpr(): c.Expr[PartialFunction[Int, Unit]] = { - assert(asyncStates.size > 1) + def mkCombinedHandlerCases[T](): List[CaseDef] = { + val caseForLastState: CaseDef = { + val lastState = asyncStates.last + val lastStateBody = c.Expr[T](lastState.body) + val rhs = futureSystemOps.completeProm(c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Success(lastStateBody.splice))) + mkHandlerCase(lastState.state, rhs.tree) + } + asyncStates.toList match { + case s :: Nil => + List(caseForLastState) + case _ => + val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState() + initCases :+ caseForLastState + } + } + } - val cases = for (state <- asyncStates.toList) yield state.mkHandlerCaseForState() - c.Expr(mkHandlerTreeFor(cases zip asyncStates.init.map(_.state))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] + /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ + def methodSym(apply: c.Expr[Any]): Symbol = { + val tree2: Tree = c.typeCheck(apply.tree) // TODO why is this needed? + tree2.collect { + case s: SymTree if s.symbol.isMethod => s.symbol + }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}")) + } + + object defn { + def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { + c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) } - /* Builds the handler expression for a sequence of async states. - */ - def mkHandlerExpr(): c.Expr[PartialFunction[Int, Unit]] = { - assert(asyncStates.size > 1) - - var handlerExpr = - c.Expr(asyncStates(0).mkHandlerTreeForState()).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] - - if (asyncStates.size == 2) - handlerExpr - else { - for (asyncState <- asyncStates.tail.init) { - // do not traverse first or last state - val handlerTreeForNextState = asyncState.mkHandlerTreeForState() - val currentHandlerTreeNaked = c.resetAllAttrs(handlerExpr.tree.duplicate) - handlerExpr = c.Expr( - Apply(Select(currentHandlerTreeNaked, newTermName("orElse")), - List(handlerTreeForNextState))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] - } - handlerExpr - } + def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify(self.splice.contains(elem.splice)) + + def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify { + self.splice.apply(arg.splice) + } + + def mkInt_+(self: Expr[Int])(other: Expr[Int]) = reify { + self.splice + other.splice + } + + def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify { + self.splice == other.splice } + def mkTry_get[A](self: Expr[util.Try[A]]) = reify { + self.splice.get + } + + val Try_get = methodSym(reify((null: scala.util.Try[Any]).get)) + + val TryClass = c.mirror.staticClass("scala.util.Try") + val TryAnyType = appliedType(TryClass.toType, List(definitions.AnyTpe)) + val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal") + + val Async_await = { + val asyncMod = c.mirror.staticModule("scala.async.Async") + val tpe = asyncMod.moduleClass.asType.toType + tpe.member(c.universe.newTermName("await")) + } } } diff --git a/src/main/scala/scala/async/FutureSystem.scala b/src/main/scala/scala/async/FutureSystem.scala new file mode 100644 index 0000000..738de34 --- /dev/null +++ b/src/main/scala/scala/async/FutureSystem.scala @@ -0,0 +1,135 @@ +/** + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ +package scala.async + +import reflect.macros.Context + +/** + * An abstraction over a future system. + * + * Used by the macro implementations in [[scala.async.AsyncBase]] to + * customize the code generation. + * + * The API mirrors that of `scala.concurrent.Future`, see the instance + * [[scala.async.ScalaConcurrentFutureSystem]] for an example of how + * to implement this. + */ +trait FutureSystem { + /** A container to receive the final value of the computation */ + type Prom[A] + /** A (potentially in-progress) computation */ + type Fut[A] + /** An execution context, required to create or register an on completion callback on a Future. */ + type ExecContext + + trait Ops { + val context: reflect.macros.Context + + import context.universe._ + + /** Lookup the execution context, typically with an implicit search */ + def execContext: Expr[ExecContext] + + /** Create an empty promise */ + def createProm[A: WeakTypeTag]: Expr[Prom[A]] + + /** Extract a future from the given promise. */ + def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]): Expr[Fut[A]] + + /** Construct a future to asynchrously compute the given expression */ + def future[A: WeakTypeTag](a: Expr[A])(execContext: Expr[ExecContext]): Expr[Fut[A]] + + /** Register an call back to run on completion of the given future */ + def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], + execContext: Expr[ExecContext]): Expr[Unit] + + /** Complete a promise with a value */ + def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] + } + + def mkOps(c: Context): Ops {val context: c.type} +} + +object ScalaConcurrentFutureSystem extends FutureSystem { + + import scala.concurrent._ + + type Prom[A] = Promise[A] + type Fut[A] = Future[A] + type ExecContext = ExecutionContext + + def mkOps(c: Context): Ops {val context: c.type} = new Ops { + val context: c.type = c + + import context.universe._ + + def execContext: Expr[ExecContext] = c.Expr(c.inferImplicitValue(c.weakTypeOf[ExecutionContext]) match { + case EmptyTree => c.abort(c.macroApplication.pos, "Unable to resolve implicit ExecutionContext") + case context => context + }) + + def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { + Promise[A]() + } + + def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify { + prom.splice.future + } + + def future[A: WeakTypeTag](a: Expr[A])(execContext: Expr[ExecContext]) = reify { + Future(a.splice)(execContext.splice) + } + + def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], + execContext: Expr[ExecContext]): Expr[Unit] = reify { + future.splice.onComplete(fun.splice)(execContext.splice) + } + + def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { + prom.splice.complete(value.splice) + context.literalUnit.splice + } + } +} + +/** + * A trivial implentation of [[scala.async.FutureSystem]] that performs computations + * on the current thread. Useful for testing. + */ +object IdentityFutureSystem extends FutureSystem { + + class Prom[A](var a: A) + + type Fut[A] = A + type ExecContext = Unit + + def mkOps(c: Context): Ops {val context: c.type} = new Ops { + val context: c.type = c + + import context.universe._ + + def execContext: Expr[ExecContext] = c.literalUnit + + def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { + new Prom(null.asInstanceOf[A]) + } + + def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify { + prom.splice.a + } + + def future[A: WeakTypeTag](t: Expr[A])(execContext: Expr[ExecContext]) = t + + def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], + execContext: Expr[ExecContext]): Expr[Unit] = reify { + fun.splice.apply(util.Success(future.splice)) + context.literalUnit.splice + } + + def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { + prom.splice.a = value.splice.get + context.literalUnit.splice + } + } +} |