diff options
Diffstat (limited to 'src/main/scala/scala/async/ExprBuilder.scala')
-rw-r--r-- | src/main/scala/scala/async/ExprBuilder.scala | 157 |
1 files changed, 92 insertions, 65 deletions
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index c5c192d..4beaa34 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -5,15 +5,22 @@ package scala.async import scala.reflect.macros.Context import scala.collection.mutable.{ListBuffer, Builder} +import concurrent.Future /* * @author Philipp Haller */ -class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { +final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSystem: FS) extends AsyncUtils { builder => + lazy val futureSystemOps = futureSystem.mkOps(c) + import c.universe._ import Flag._ + import defn._ + + val execContextType = c.weakTypeOf[futureSystem.ExecContext] + val execContext = futureSystemOps.execContext private val awaitMethod = awaitSym(c) @@ -23,7 +30,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { * case any if any == num => rhs * } */ - def mkHandler(num: Int, rhs: c.Expr[Unit]): c.Expr[PartialFunction[Int, Unit]] = { + def mkHandler(num: Int, rhs: c.Expr[Any]): c.Expr[PartialFunction[Int, Unit]] = { /* val numLiteral = c.Expr[Int](Literal(Constant(num))) @@ -44,7 +51,8 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { def mkIncrStateTree(): c.Tree = { Assign( Ident(newTermName("state")), - Apply(Select(Ident(newTermName("state")), newTermName("$plus")), List(Literal(Constant(1))))) + mkInt_+(c.Expr[Int](Ident(newTermName("state"))))(c.literal(1)).tree + ) } def mkStateTree(nextState: Int): c.Tree = @@ -69,7 +77,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { // pattern Bind(newTermName("any"), Typed(Ident(nme.WILDCARD), Ident(definitions.IntClass))), // guard - Apply(Select(Ident(newTermName("any")), newTermName("$eq$eq")), List(Literal(Constant(num)))), + mkAny_==(c.Expr(Ident(newTermName("any"))))(c.literal(num)).tree, rhs ) @@ -79,8 +87,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { val unitIdent = Ident(definitions.UnitClass) val caseCheck = - Apply(Select(Apply(Ident(definitions.List_apply), - cases.map(p => Literal(Constant(p._2)))), newTermName("contains")), List(Ident(newTermName("x$1")))) + defn.mkList_contains(defn.mkList_apply(cases.map(p => c.literal(p._2))))(c.Expr(Ident(newTermName("x$1")))) Block(List( // anonymous subclass of PartialFunction[Int, Unit] @@ -91,7 +98,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { Block(List(Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), List())), Literal(Constant(())))), DefDef(Modifiers(), newTermName("isDefinedAt"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(), - caseCheck), + caseCheck.tree), DefDef(Modifiers(), newTermName("apply"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(), Match(Ident(newTermName("x$1")), cases.map(_._1)) // combine all cases into a single match @@ -168,7 +175,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { val assignTree = Assign( Ident(resultName.toString), - Select(Ident("tr"), newTermName("get")) + mkTry_get(c.Expr(Ident("tr"))).tree ) val handlerTree = Match( @@ -179,10 +186,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { ) ) ) - Apply( - Select(awaitable, newTermName("onComplete")), - List(handlerTree) - ) + futureSystemOps.onComplete(c.Expr(awaitable), c.Expr(handlerTree), execContext).tree } /* Make an `onComplete` invocation which increments the state upon resuming: @@ -198,23 +202,17 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { val tryGetTree = Assign( Ident(resultName.toString), - Select(Ident("tr"), newTermName("get")) + Select(Ident("tr"), Try_get) ) + val handlerTree = - Match( - EmptyTree, - List( - CaseDef(Bind(newTermName("tr"), Ident("_")), EmptyTree, - Block(tryGetTree, mkIncrStateTree(), Apply(Ident("resume"), List())) // rhs of case - ) - ) - ) - Apply( - Select(awaitable, newTermName("onComplete")), - List(handlerTree) - ) + Function(List(ValDef(Modifiers(PARAM), newTermName("tr"), TypeTree(tryType), EmptyTree)), Block(tryGetTree, mkIncrStateTree(), Apply(Ident("resume"), List()))) + + futureSystemOps.onComplete(c.Expr(awaitable), c.Expr(handlerTree), execContext).tree } + def tryType = appliedType(c.mirror.staticClass("scala.util.Try").toType, List(resultType)) + /* Make an `onComplete` invocation which sets the state to `nextState` upon resuming: * * awaitable.onComplete { @@ -228,21 +226,12 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { val tryGetTree = Assign( Ident(resultName.toString), - Select(Ident("tr"), newTermName("get")) + Select(Ident("tr"), Try_get) ) val handlerTree = - Match( - EmptyTree, - List( - CaseDef(Bind(newTermName("tr"), Ident("_")), EmptyTree, - Block(tryGetTree, mkStateTree(nextState), Apply(Ident("resume"), List())) // rhs of case - ) - ) - ) - Apply( - Select(awaitable, newTermName("onComplete")), - List(handlerTree) - ) + Function(List(ValDef(Modifiers(PARAM), newTermName("tr"), TypeTree(tryType), EmptyTree)), Block(tryGetTree, mkStateTree(nextState), Apply(Ident("resume"), List()))) + + futureSystemOps.onComplete(c.Expr(awaitable), c.Expr(handlerTree), execContext).tree } /* Make a partial function literal handling case #num: @@ -391,12 +380,12 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { override val varDefs = self.varDefs.toList } } - + /** * Build `AsyncState` ending with a match expression. - * + * * The cases of the match simply resume at the state of their corresponding right-hand side. - * + * * @param scrutTree tree of the scrutinee * @param cases list of case definitions * @param stateFirstCase state of the right-hand side of the first case @@ -414,7 +403,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { override val varDefs = self.varDefs.toList } } - + override def toString: String = { val statsBeforeAwait = stats.mkString("\n") s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName" @@ -423,7 +412,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { /** * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`). - * + * * @param stats a list of expressions * @param expr the last expression of the block * @param startState the start state @@ -441,20 +430,20 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { private var remainingBudget = budget - /* Fall back to CPS plug-in if tree contains an `await` call. */ + /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists { case Apply(fun, _) if fun.symbol == awaitMethod => true case _ => false - }) throw new FallbackToCpsException - + }) c.abort(tree.pos, "await unsupported in this position") //throw new FallbackToCpsException + def builderForBranch(tree: c.Tree, state: Int, nextState: Int, budget: Int, nameMap: Map[c.Symbol, c.Name]): AsyncBlockBuilder = { val (branchStats, branchExpr) = tree match { case Block(s, e) => (s, e) - case _ => (List(tree), Literal(Constant(()))) + case _ => (List(tree), Literal(Constant(()))) } new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, budget, nameMap) } - + // populate asyncStates for (stat <- stats) stat match { // the val name = await(..) pattern @@ -491,44 +480,45 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { asyncStates += // the two Int arguments are the start state of the then branch and the else branch, respectively stateBuilder.resultWithIf(cond, currState + 1, currState + thenBudget) - - List((thenp, currState + 1, thenBudget), (elsep, currState + thenBudget, elseBudget)) foreach { case (tree, state, branchBudget) => - val builder = builderForBranch(tree, state, currState + ifBudget, branchBudget, toRename) - asyncStates ++= builder.asyncStates - toRename ++= builder.toRename + + List((thenp, currState + 1, thenBudget), (elsep, currState + thenBudget, elseBudget)) foreach { + case (tree, state, branchBudget) => + val builder = builderForBranch(tree, state, currState + ifBudget, branchBudget, toRename) + asyncStates ++= builder.asyncStates + toRename ++= builder.toRename } - + // create new state builder for state `currState + ifBudget` currState = currState + ifBudget stateBuilder = new builder.AsyncStateBuilder(currState, toRename) - + case Match(scrutinee, cases) => vprintln("transforming match expr: " + stat) checkForUnsupportedAwait(scrutinee) - + val matchBudget: Int = remainingBudget / 2 remainingBudget -= matchBudget //TODO test if budget > 0 // state that we continue with after match: currState + matchBudget - + val perCaseBudget: Int = matchBudget / cases.size asyncStates += // the two Int arguments are the start state of the first case and the per-case state budget, respectively stateBuilder.resultWithMatch(scrutinee, cases, currState + 1, perCaseBudget) - + for ((cas, num) <- cases.zipWithIndex) { val (casStats, casExpr) = cas match { case CaseDef(_, _, Block(s, e)) => (s, e) - case CaseDef(_, _, rhs) => (List(rhs), Literal(Constant(()))) + case CaseDef(_, _, rhs) => (List(rhs), Literal(Constant(()))) } val builder = new AsyncBlockBuilder(casStats, casExpr, currState + (num * perCaseBudget) + 1, currState + matchBudget, perCaseBudget, toRename) asyncStates ++= builder.asyncStates toRename ++= builder.toRename } - + // create new state builder for state `currState + matchBudget` currState = currState + matchBudget stateBuilder = new builder.AsyncStateBuilder(currState, toRename) - + case _ => checkForUnsupportedAwait(stat) stateBuilder += stat @@ -542,7 +532,9 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { assert(asyncStates.size > 1) val cases = for (state <- asyncStates.toList) yield state.mkHandlerCaseForState() - c.Expr(mkHandlerTreeFor(cases zip asyncStates.init.map(_.state))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] + reify { + c.Expr[PartialFunction[Int, Unit]](mkHandlerTreeFor(cases zip asyncStates.init.map(_.state))).splice: PartialFunction[Int, Unit] + } } /* Builds the handler expression for a sequence of async states. @@ -560,14 +552,49 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { // do not traverse first or last state val handlerTreeForNextState = asyncState.mkHandlerTreeForState() val currentHandlerTreeNaked = c.resetAllAttrs(handlerExpr.tree.duplicate) - handlerExpr = c.Expr( - Apply(Select(currentHandlerTreeNaked, newTermName("orElse")), - List(handlerTreeForNextState))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] + handlerExpr = mkPartialFunction_orElse(c.Expr(currentHandlerTreeNaked))(c.Expr(handlerTreeForNextState)) } handlerExpr } } + } + + /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ + def methodSym(apply: c.Expr[Any]): Symbol = { + val tree2: Tree = c.typeCheck(apply.tree) // TODO why is this needed? + tree2.collect { + case s: SymTree if s.symbol.isMethod => s.symbol + }.headOption.getOrElse(sys.error(s"Unable to find a method symbol in ${apply.tree}")) + } + + object defn { + def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { + c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) + } + + def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify(self.splice.contains(elem.splice)) + + def mkPartialFunction_orElse[A, B](self: Expr[PartialFunction[A, B]])(other: Expr[PartialFunction[A, B]]) = reify { + self.splice.orElse(other.splice) + } + + def mkFunction_apply[A, B](self: Expr[Function1[A, B]])(arg: Expr[A]) = reify { + self.splice.apply(arg.splice) + } + + def mkInt_+(self: Expr[Int])(other: Expr[Int]) = reify { + self.splice + other.splice + } + + def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify { + self.splice == other.splice + } + + def mkTry_get[A](self: Expr[util.Try[A]]) = reify { + self.splice.get + } + val Try_get = methodSym(reify((null.asInstanceOf[scala.util.Try[Any]]).get)) } } |