aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/scala/async/ExprBuilder.scala
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2012-11-07 20:08:33 +0100
committerJason Zaugg <jzaugg@gmail.com>2012-11-09 15:44:16 +0100
commitd434c20cfb8623a243cd30f187907bb4b199dc99 (patch)
tree3d127564874feb2ffe1633018b9afcc9bedc8d9b /src/main/scala/scala/async/ExprBuilder.scala
parent7dbf0a0da4987e8fd5b223437d8f5316ff33616e (diff)
downloadscala-async-d434c20cfb8623a243cd30f187907bb4b199dc99.tar.gz
scala-async-d434c20cfb8623a243cd30f187907bb4b199dc99.tar.bz2
scala-async-d434c20cfb8623a243cd30f187907bb4b199dc99.zip
Abstract over the future implementation.
- Refactor the base macro implementation to be parameterized by a FutureSystem, which is defines the triple of types (Future, Promise, ExecutionContext) and the operations on those types (at the AST level) - Cleanup generation of ASTs, in particular, use reify more widely.
Diffstat (limited to 'src/main/scala/scala/async/ExprBuilder.scala')
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala157
1 files changed, 92 insertions, 65 deletions
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index c5c192d..4beaa34 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -5,15 +5,22 @@ package scala.async
import scala.reflect.macros.Context
import scala.collection.mutable.{ListBuffer, Builder}
+import concurrent.Future
/*
* @author Philipp Haller
*/
-class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
+final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSystem: FS) extends AsyncUtils {
builder =>
+ lazy val futureSystemOps = futureSystem.mkOps(c)
+
import c.universe._
import Flag._
+ import defn._
+
+ val execContextType = c.weakTypeOf[futureSystem.ExecContext]
+ val execContext = futureSystemOps.execContext
private val awaitMethod = awaitSym(c)
@@ -23,7 +30,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
* case any if any == num => rhs
* }
*/
- def mkHandler(num: Int, rhs: c.Expr[Unit]): c.Expr[PartialFunction[Int, Unit]] = {
+ def mkHandler(num: Int, rhs: c.Expr[Any]): c.Expr[PartialFunction[Int, Unit]] = {
/*
val numLiteral = c.Expr[Int](Literal(Constant(num)))
@@ -44,7 +51,8 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
def mkIncrStateTree(): c.Tree = {
Assign(
Ident(newTermName("state")),
- Apply(Select(Ident(newTermName("state")), newTermName("$plus")), List(Literal(Constant(1)))))
+ mkInt_+(c.Expr[Int](Ident(newTermName("state"))))(c.literal(1)).tree
+ )
}
def mkStateTree(nextState: Int): c.Tree =
@@ -69,7 +77,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
// pattern
Bind(newTermName("any"), Typed(Ident(nme.WILDCARD), Ident(definitions.IntClass))),
// guard
- Apply(Select(Ident(newTermName("any")), newTermName("$eq$eq")), List(Literal(Constant(num)))),
+ mkAny_==(c.Expr(Ident(newTermName("any"))))(c.literal(num)).tree,
rhs
)
@@ -79,8 +87,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
val unitIdent = Ident(definitions.UnitClass)
val caseCheck =
- Apply(Select(Apply(Ident(definitions.List_apply),
- cases.map(p => Literal(Constant(p._2)))), newTermName("contains")), List(Ident(newTermName("x$1"))))
+ defn.mkList_contains(defn.mkList_apply(cases.map(p => c.literal(p._2))))(c.Expr(Ident(newTermName("x$1"))))
Block(List(
// anonymous subclass of PartialFunction[Int, Unit]
@@ -91,7 +98,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
Block(List(Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), List())), Literal(Constant(())))),
DefDef(Modifiers(), newTermName("isDefinedAt"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(),
- caseCheck),
+ caseCheck.tree),
DefDef(Modifiers(), newTermName("apply"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(),
Match(Ident(newTermName("x$1")), cases.map(_._1)) // combine all cases into a single match
@@ -168,7 +175,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
val assignTree =
Assign(
Ident(resultName.toString),
- Select(Ident("tr"), newTermName("get"))
+ mkTry_get(c.Expr(Ident("tr"))).tree
)
val handlerTree =
Match(
@@ -179,10 +186,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
)
)
)
- Apply(
- Select(awaitable, newTermName("onComplete")),
- List(handlerTree)
- )
+ futureSystemOps.onComplete(c.Expr(awaitable), c.Expr(handlerTree), execContext).tree
}
/* Make an `onComplete` invocation which increments the state upon resuming:
@@ -198,23 +202,17 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
val tryGetTree =
Assign(
Ident(resultName.toString),
- Select(Ident("tr"), newTermName("get"))
+ Select(Ident("tr"), Try_get)
)
+
val handlerTree =
- Match(
- EmptyTree,
- List(
- CaseDef(Bind(newTermName("tr"), Ident("_")), EmptyTree,
- Block(tryGetTree, mkIncrStateTree(), Apply(Ident("resume"), List())) // rhs of case
- )
- )
- )
- Apply(
- Select(awaitable, newTermName("onComplete")),
- List(handlerTree)
- )
+ Function(List(ValDef(Modifiers(PARAM), newTermName("tr"), TypeTree(tryType), EmptyTree)), Block(tryGetTree, mkIncrStateTree(), Apply(Ident("resume"), List())))
+
+ futureSystemOps.onComplete(c.Expr(awaitable), c.Expr(handlerTree), execContext).tree
}
+ def tryType = appliedType(c.mirror.staticClass("scala.util.Try").toType, List(resultType))
+
/* Make an `onComplete` invocation which sets the state to `nextState` upon resuming:
*
* awaitable.onComplete {
@@ -228,21 +226,12 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
val tryGetTree =
Assign(
Ident(resultName.toString),
- Select(Ident("tr"), newTermName("get"))
+ Select(Ident("tr"), Try_get)
)
val handlerTree =
- Match(
- EmptyTree,
- List(
- CaseDef(Bind(newTermName("tr"), Ident("_")), EmptyTree,
- Block(tryGetTree, mkStateTree(nextState), Apply(Ident("resume"), List())) // rhs of case
- )
- )
- )
- Apply(
- Select(awaitable, newTermName("onComplete")),
- List(handlerTree)
- )
+ Function(List(ValDef(Modifiers(PARAM), newTermName("tr"), TypeTree(tryType), EmptyTree)), Block(tryGetTree, mkStateTree(nextState), Apply(Ident("resume"), List())))
+
+ futureSystemOps.onComplete(c.Expr(awaitable), c.Expr(handlerTree), execContext).tree
}
/* Make a partial function literal handling case #num:
@@ -391,12 +380,12 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
override val varDefs = self.varDefs.toList
}
}
-
+
/**
* Build `AsyncState` ending with a match expression.
- *
+ *
* The cases of the match simply resume at the state of their corresponding right-hand side.
- *
+ *
* @param scrutTree tree of the scrutinee
* @param cases list of case definitions
* @param stateFirstCase state of the right-hand side of the first case
@@ -414,7 +403,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
override val varDefs = self.varDefs.toList
}
}
-
+
override def toString: String = {
val statsBeforeAwait = stats.mkString("\n")
s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName"
@@ -423,7 +412,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
/**
* An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`).
- *
+ *
* @param stats a list of expressions
* @param expr the last expression of the block
* @param startState the start state
@@ -441,20 +430,20 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
private var remainingBudget = budget
- /* Fall back to CPS plug-in if tree contains an `await` call. */
+ /* TODO Fall back to CPS plug-in if tree contains an `await` call. */
def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists {
case Apply(fun, _) if fun.symbol == awaitMethod => true
case _ => false
- }) throw new FallbackToCpsException
-
+ }) c.abort(tree.pos, "await unsupported in this position") //throw new FallbackToCpsException
+
def builderForBranch(tree: c.Tree, state: Int, nextState: Int, budget: Int, nameMap: Map[c.Symbol, c.Name]): AsyncBlockBuilder = {
val (branchStats, branchExpr) = tree match {
case Block(s, e) => (s, e)
- case _ => (List(tree), Literal(Constant(())))
+ case _ => (List(tree), Literal(Constant(())))
}
new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, budget, nameMap)
}
-
+
// populate asyncStates
for (stat <- stats) stat match {
// the val name = await(..) pattern
@@ -491,44 +480,45 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
asyncStates +=
// the two Int arguments are the start state of the then branch and the else branch, respectively
stateBuilder.resultWithIf(cond, currState + 1, currState + thenBudget)
-
- List((thenp, currState + 1, thenBudget), (elsep, currState + thenBudget, elseBudget)) foreach { case (tree, state, branchBudget) =>
- val builder = builderForBranch(tree, state, currState + ifBudget, branchBudget, toRename)
- asyncStates ++= builder.asyncStates
- toRename ++= builder.toRename
+
+ List((thenp, currState + 1, thenBudget), (elsep, currState + thenBudget, elseBudget)) foreach {
+ case (tree, state, branchBudget) =>
+ val builder = builderForBranch(tree, state, currState + ifBudget, branchBudget, toRename)
+ asyncStates ++= builder.asyncStates
+ toRename ++= builder.toRename
}
-
+
// create new state builder for state `currState + ifBudget`
currState = currState + ifBudget
stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
-
+
case Match(scrutinee, cases) =>
vprintln("transforming match expr: " + stat)
checkForUnsupportedAwait(scrutinee)
-
+
val matchBudget: Int = remainingBudget / 2
remainingBudget -= matchBudget //TODO test if budget > 0
// state that we continue with after match: currState + matchBudget
-
+
val perCaseBudget: Int = matchBudget / cases.size
asyncStates +=
// the two Int arguments are the start state of the first case and the per-case state budget, respectively
stateBuilder.resultWithMatch(scrutinee, cases, currState + 1, perCaseBudget)
-
+
for ((cas, num) <- cases.zipWithIndex) {
val (casStats, casExpr) = cas match {
case CaseDef(_, _, Block(s, e)) => (s, e)
- case CaseDef(_, _, rhs) => (List(rhs), Literal(Constant(())))
+ case CaseDef(_, _, rhs) => (List(rhs), Literal(Constant(())))
}
val builder = new AsyncBlockBuilder(casStats, casExpr, currState + (num * perCaseBudget) + 1, currState + matchBudget, perCaseBudget, toRename)
asyncStates ++= builder.asyncStates
toRename ++= builder.toRename
}
-
+
// create new state builder for state `currState + matchBudget`
currState = currState + matchBudget
stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
-
+
case _ =>
checkForUnsupportedAwait(stat)
stateBuilder += stat
@@ -542,7 +532,9 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
assert(asyncStates.size > 1)
val cases = for (state <- asyncStates.toList) yield state.mkHandlerCaseForState()
- c.Expr(mkHandlerTreeFor(cases zip asyncStates.init.map(_.state))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]]
+ reify {
+ c.Expr[PartialFunction[Int, Unit]](mkHandlerTreeFor(cases zip asyncStates.init.map(_.state))).splice: PartialFunction[Int, Unit]
+ }
}
/* Builds the handler expression for a sequence of async states.
@@ -560,14 +552,49 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
// do not traverse first or last state
val handlerTreeForNextState = asyncState.mkHandlerTreeForState()
val currentHandlerTreeNaked = c.resetAllAttrs(handlerExpr.tree.duplicate)
- handlerExpr = c.Expr(
- Apply(Select(currentHandlerTreeNaked, newTermName("orElse")),
- List(handlerTreeForNextState))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]]
+ handlerExpr = mkPartialFunction_orElse(c.Expr(currentHandlerTreeNaked))(c.Expr(handlerTreeForNextState))
}
handlerExpr
}
}
+ }
+
+ /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */
+ def methodSym(apply: c.Expr[Any]): Symbol = {
+ val tree2: Tree = c.typeCheck(apply.tree) // TODO why is this needed?
+ tree2.collect {
+ case s: SymTree if s.symbol.isMethod => s.symbol
+ }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}"))
+ }
+
+ object defn {
+ def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = {
+ c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree)))
+ }
+
+ def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify(self.splice.contains(elem.splice))
+
+ def mkPartialFunction_orElse[A, B](self: Expr[PartialFunction[A, B]])(other: Expr[PartialFunction[A, B]]) = reify {
+ self.splice.orElse(other.splice)
+ }
+
+ def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify {
+ self.splice.apply(arg.splice)
+ }
+
+ def mkInt_+(self: Expr[Int])(other: Expr[Int]) = reify {
+ self.splice + other.splice
+ }
+
+ def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify {
+ self.splice == other.splice
+ }
+
+ def mkTry_get[A](self: Expr[util.Try[A]]) = reify {
+ self.splice.get
+ }
+ val Try_get = methodSym(reify((null.asInstanceOf[scala.util.Try[Any]]).get))
}
}