diff options
-rw-r--r-- | src/main/scala/scala/async/Async.scala | 3 | ||||
-rw-r--r-- | src/main/scala/scala/async/ExprBuilder.scala | 268 |
2 files changed, 135 insertions, 136 deletions
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index 2c81bc3..0bf4362 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -74,13 +74,14 @@ object Async extends AsyncUtils { } } */ + val nonFatalModule = c.mirror.staticModule("scala.util.control.NonFatal") val resumeFunTree: c.Tree = DefDef(Modifiers(), newTermName("resume"), List(), List(List()), Ident(definitions.UnitClass), Try(Apply(Select( Apply(Select(handlerExpr.tree, newTermName("orElse")), List(handlerForLastState.tree)), newTermName("apply")), List(Ident(newTermName("state")))), List( CaseDef( - Apply(Select(Select(Select(Ident(newTermName("scala")), newTermName("util")), newTermName("control")), newTermName("NonFatal")), List(Bind(newTermName("t"), Ident(nme.WILDCARD)))), + Apply(Ident(nonFatalModule), List(Bind(newTermName("t"), Ident(nme.WILDCARD)))), EmptyTree, Block(List( Apply(Select(Ident(newTermName("result")), newTermName("failure")), List(Ident(newTermName("t"))))), diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index b7d6446..655c26f 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -4,19 +4,19 @@ package scala.async import scala.reflect.macros.Context -import scala.collection.mutable.{ ListBuffer, Builder } +import scala.collection.mutable.{ListBuffer, Builder} /* * @author Philipp Haller */ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { builder => - + import c.universe._ import Flag._ - + private val awaitMethod = awaitSym(c) - + /* Make a partial function literal handling case #num: * * { @@ -24,42 +24,46 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { * } */ def mkHandler(num: Int, rhs: c.Expr[Unit]): c.Expr[PartialFunction[Int, Unit]] = { -/* - val numLiteral = c.Expr[Int](Literal(Constant(num))) - - reify(new PartialFunction[Int, Unit] { - def isDefinedAt(`x$1`: Int) = - `x$1` == numLiteral.splice - def apply(`x$1`: Int) = `x$1` match { - case any: Int if any == numLiteral.splice => - rhs.splice - } - }) -*/ + /* + val numLiteral = c.Expr[Int](Literal(Constant(num))) + + reify(new PartialFunction[Int, Unit] { + def isDefinedAt(`x$1`: Int) = + `x$1` == numLiteral.splice + def apply(`x$1`: Int) = `x$1` match { + case any: Int if any == numLiteral.splice => + rhs.splice + } + }) + */ val rhsTree = c.resetAllAttrs(rhs.tree.duplicate) val handlerTree = mkHandlerTree(num, rhsTree) c.Expr(handlerTree).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] } - - def mkIncrStateTree(): c.Tree = + + def mkIncrStateTree(): c.Tree = { Assign( Ident(newTermName("state")), Apply(Select(Ident(newTermName("state")), newTermName("$plus")), List(Literal(Constant(1))))) - + } + def mkStateTree(nextState: Int): c.Tree = Assign( Ident(newTermName("state")), Literal(Constant(nextState))) - - def mkVarDefTree(resultType: c.universe.Type, resultName: c.universe.TermName): c.Tree = { - val rhs = - if (resultType <:< definitions.IntTpe) Literal(Constant(0)) - else if (resultType <:< definitions.LongTpe) Literal(Constant(0L)) - else if (resultType <:< definitions.BooleanTpe) Literal(Constant(false)) - else Literal(Constant(null)) - ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs) + + def defaultValue(tpe: Type): Literal = { + val defaultValue: Any = + if (tpe <:< definitions.BooleanTpe) false + else if (definitions.ScalaNumericValueClasses.exists(tpe <:< _.toType)) 0 + else null + Literal(Constant(defaultValue)) } - + + def mkVarDefTree(resultType: Type, resultName: TermName): c.Tree = { + ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType)) + } + def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = CaseDef( // pattern @@ -68,26 +72,27 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { Apply(Select(Ident(newTermName("any")), newTermName("$eq$eq")), List(Literal(Constant(num)))), rhs ) - + def mkHandlerTreeFor(cases: List[(CaseDef, Int)]): c.Tree = { val partFunIdent = Ident(c.mirror.staticClass("scala.PartialFunction")) val intIdent = Ident(definitions.IntClass) val unitIdent = Ident(definitions.UnitClass) - + val caseCheck = - Apply(Select(Apply(Select(Ident(newTermName("List")), newTermName("apply")), - cases.map(p => Literal(Constant(p._2)))), newTermName("contains")), List(Ident(newTermName("x$1")))) - + Apply(Select(Apply(Ident(definitions.List_apply), + cases.map(p => Literal(Constant(p._2)))), newTermName("contains")), List(Ident(newTermName("x$1")))) + Block(List( // anonymous subclass of PartialFunction[Int, Unit] + // TODO subclass AbstractPartialFunction ClassDef(Modifiers(FINAL), newTypeName("$anon"), List(), Template(List(AppliedTypeTree(partFunIdent, List(intIdent, unitIdent))), emptyValDef, List( DefDef(Modifiers(), nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), List())), Literal(Constant(())))), - + DefDef(Modifiers(), newTermName("isDefinedAt"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(), - caseCheck), - + caseCheck), + DefDef(Modifiers(), newTermName("apply"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(), Match(Ident(newTermName("x$1")), cases.map(_._1)) // combine all cases into a single match ) @@ -96,60 +101,61 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { Apply(Select(New(Ident(newTypeName("$anon"))), nme.CONSTRUCTOR), List()) ) } - + def mkHandlerTree(num: Int, rhs: c.Tree): c.Tree = mkHandlerTreeFor(List(mkHandlerCase(num, rhs) -> num)) - + class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) { val body: c.Tree = if (stats.size == 1) stats.head else Block(stats: _*) - - val varDefs: List[(c.universe.TermName, c.universe.Type)] = List() - + + val varDefs: List[(TermName, Type)] = List() + def mkHandlerCaseForState(): CaseDef = mkHandlerCase(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*)) - + def mkHandlerTreeForState(): c.Tree = mkHandlerTree(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*)) - + def mkHandlerTreeForState(nextState: Int): c.Tree = mkHandlerTree(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*)) - + def varDefForResult: Option[c.Tree] = None - + def allVarDefs: List[c.Tree] = varDefForResult.toList ++ varDefs.map(p => mkVarDefTree(p._2, p._1)) - + override val toString: String = s"AsyncState #$state, next = $nextState" } - + class AsyncStateWithIf(stats: List[c.Tree], state: Int) - extends AsyncState(stats, state, 0) { // nextState unused, since encoded in then and else branches - + extends AsyncState(stats, state, 0) { + // nextState unused, since encoded in then and else branches + override def mkHandlerTreeForState(): c.Tree = mkHandlerTree(state, Block(stats: _*)) - + //TODO mkHandlerTreeForState(nextState: Int) - + override def mkHandlerCaseForState(): CaseDef = mkHandlerCase(state, Block(stats: _*)) - + override val toString: String = s"AsyncStateWithIf #$state, next = $nextState" } - + abstract class AsyncStateWithAwait(stats: List[c.Tree], state: Int, nextState: Int) - extends AsyncState(stats, state, nextState) { + extends AsyncState(stats, state, nextState) { val awaitable: c.Tree - val resultName: c.universe.TermName - val resultType: c.universe.Type - + val resultName: TermName + val resultType: Type + override val toString: String = s"AsyncStateWithAwait #$state, next = $nextState" - + /* Make an `onComplete` invocation: * * awaitable.onComplete { @@ -162,23 +168,23 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { val assignTree = Assign( Ident(resultName.toString), - Select(Ident("tr"), c.universe.newTermName("get")) + Select(Ident("tr"), newTermName("get")) ) val handlerTree = Match( EmptyTree, List( - CaseDef(Bind(c.universe.newTermName("tr"), Ident("_")), EmptyTree, + CaseDef(Bind(newTermName("tr"), Ident("_")), EmptyTree, Block(assignTree, Apply(Ident("resume"), List())) // rhs of case ) ) ) Apply( - Select(awaitable, c.universe.newTermName("onComplete")), + Select(awaitable, newTermName("onComplete")), List(handlerTree) ) } - + /* Make an `onComplete` invocation which increments the state upon resuming: * * awaitable.onComplete { @@ -192,19 +198,19 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { val tryGetTree = Assign( Ident(resultName.toString), - Select(Ident("tr"), c.universe.newTermName("get")) + Select(Ident("tr"), newTermName("get")) ) val handlerTree = Match( EmptyTree, List( - CaseDef(Bind(c.universe.newTermName("tr"), Ident("_")), EmptyTree, + CaseDef(Bind(newTermName("tr"), Ident("_")), EmptyTree, Block(tryGetTree, mkIncrStateTree(), Apply(Ident("resume"), List())) // rhs of case ) ) ) Apply( - Select(awaitable, c.universe.newTermName("onComplete")), + Select(awaitable, newTermName("onComplete")), List(handlerTree) ) } @@ -222,19 +228,19 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { val tryGetTree = Assign( Ident(resultName.toString), - Select(Ident("tr"), c.universe.newTermName("get")) + Select(Ident("tr"), newTermName("get")) ) val handlerTree = Match( EmptyTree, List( - CaseDef(Bind(c.universe.newTermName("tr"), Ident("_")), EmptyTree, + CaseDef(Bind(newTermName("tr"), Ident("_")), EmptyTree, Block(tryGetTree, mkStateTree(nextState), Apply(Ident("resume"), List())) // rhs of case ) ) ) Apply( - Select(awaitable, c.universe.newTermName("onComplete")), + Select(awaitable, newTermName("onComplete")), List(handlerTree) ) } @@ -255,7 +261,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { assert(awaitable != null) builder.mkHandler(num, c.Expr[Unit](Block((stats :+ mkOnCompleteTree): _*))) } - + /* Make a partial function literal handling case #num: * * { @@ -273,56 +279,43 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { assert(awaitable != null) mkHandlerTree(state, Block((stats :+ mkOnCompleteIncrStateTree): _*)) } - + override def mkHandlerTreeForState(nextState: Int): c.Tree = { assert(awaitable != null) mkHandlerTree(state, Block((stats :+ mkOnCompleteStateTree(nextState)): _*)) } - + override def mkHandlerCaseForState(): CaseDef = { assert(awaitable != null) mkHandlerCase(state, Block((stats :+ mkOnCompleteIncrStateTree): _*)) } - - override def varDefForResult: Option[c.Tree] = { - val rhs = - if (resultType <:< definitions.IntTpe) Literal(Constant(0)) - else if (resultType <:< definitions.LongTpe) Literal(Constant(0L)) - else if (resultType <:< definitions.BooleanTpe) Literal(Constant(false)) - else if (resultType <:< definitions.FloatTpe) Literal(Constant(0.0f)) - else if (resultType <:< definitions.DoubleTpe) Literal(Constant(0.0d)) - else if (resultType <:< definitions.CharTpe) Literal(Constant(0.toChar)) - else if (resultType <:< definitions.ShortTpe) Literal(Constant(0.toShort)) - else if (resultType <:< definitions.ByteTpe) Literal(Constant(0.toByte)) - else Literal(Constant(null)) - Some( - ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs) - ) - } + + override def varDefForResult: Option[c.Tree] = + Some(mkVarDefTree(resultType, resultName)) } - + /* * Builder for a single state of an async method. */ class AsyncStateBuilder(state: Int, private var nameMap: Map[c.Symbol, c.Name]) extends Builder[c.Tree, AsyncState] { self => - + /* Statements preceding an await call. */ private val stats = ListBuffer[c.Tree]() - + /* Argument of an await call. */ var awaitable: c.Tree = null - + /* Result name of an await call. */ - var resultName: c.universe.TermName = null - + var resultName: TermName = null + /* Result type of an await call. */ - var resultType: c.universe.Type = null - + var resultType: Type = null + var nextState: Int = state + 1 - - private val varDefs = ListBuffer[(c.universe.TermName, c.universe.Type)]() - + + private val varDefs = ListBuffer[(TermName, Type)]() + private val renamer = new Transformer { override def transform(tree: Tree) = tree match { case Ident(_) if nameMap.keySet contains tree.symbol => @@ -331,20 +324,20 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { super.transform(tree) } } - - def += (stat: c.Tree): this.type = { + + def +=(stat: c.Tree): this.type = { stats += c.resetAllAttrs(renamer.transform(stat).duplicate) this } - + //TODO do not ignore `mods` - def addVarDef(mods: Any, name: c.universe.TermName, tpt: c.Tree, rhs: c.Tree, extNameMap: Map[c.Symbol, c.Name]): this.type = { + def addVarDef(mods: Any, name: TermName, tpt: c.Tree, rhs: c.Tree, extNameMap: Map[c.Symbol, c.Name]): this.type = { varDefs += (name -> tpt.tpe) nameMap ++= extNameMap // update name map this += Assign(Ident(name), c.resetAllAttrs(renamer.transform(rhs).duplicate)) this } - + def result(): AsyncState = if (awaitable == null) new AsyncState(stats.toList, state, nextState) { @@ -357,14 +350,14 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { val resultType = self.resultType override val varDefs = self.varDefs.toList } - + def clear(): Unit = { stats.clear() awaitable = null resultName = null resultType = null } - + /* Result needs to be created as a var at the beginning of the transformed method body, so that * it is visible in subsequent states of the state machine. * @@ -372,7 +365,8 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { * @param awaitResultName the name of the variable that the result of await is assigned to * @param awaitResultType the type of the result of await */ - def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: Tree, extNameMap: Map[c.Symbol, c.Name], nextState: Int = state + 1): this.type = { + def complete(awaitArg: c.Tree, awaitResultName: TermName, awaitResultType: Tree, + extNameMap: Map[c.Symbol, c.Name], nextState: Int = state + 1): this.type = { nameMap ++= extNameMap awaitable = c.resetAllAttrs(renamer.transform(awaitArg).duplicate) resultName = awaitResultName @@ -380,51 +374,53 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { this.nextState = nextState this } - + def complete(nextState: Int): this.type = { this.nextState = nextState this } - + def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = { // 1. build changed if-else tree // 2. insert that tree at the end of the current state val cond = c.resetAllAttrs(condTree.duplicate) this += If(cond, - Block(mkStateTree(thenState), Apply(Ident("resume"), List())), - Block(mkStateTree(elseState), Apply(Ident("resume"), List()))) + Block(mkStateTree(thenState), Apply(Ident("resume"), List())), + Block(mkStateTree(elseState), Apply(Ident("resume"), List()))) new AsyncStateWithIf(stats.toList, state) { override val varDefs = self.varDefs.toList } } - + override def toString: String = { val statsBeforeAwait = stats.mkString("\n") s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName" } } - class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, budget: Int, private var toRename: Map[c.Symbol, c.Name]) { + class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, + budget: Int, private var toRename: Map[c.Symbol, c.Name]) { val asyncStates = ListBuffer[builder.AsyncState]() - - private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename) // current state builder + + private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename) + // current state builder private var currState = startState - + private var remainingBudget = budget - + /* Fall back to CPS plug-in if tree contains an `await` call. */ def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists { case Apply(fun, _) if fun.symbol == awaitMethod => true case _ => false }) throw new FallbackToCpsException - + // populate asyncStates for (stat <- stats) stat match { // the val name = await(..) pattern case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == awaitMethod => val newName = newTermName(Async.freshString(name.toString())) toRename += (stat.symbol -> newName) - + asyncStates += stateBuilder.complete(args(0), newName, tpt, toRename).result // complete with await if (remainingBudget > 0) remainingBudget -= 1 @@ -432,29 +428,29 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { assert(false, "too many invocations of `await` in current method") currState += 1 stateBuilder = new builder.AsyncStateBuilder(currState, toRename) - + case ValDef(mods, name, tpt, rhs) => checkForUnsupportedAwait(rhs) - + val newName = newTermName(Async.freshString(name.toString())) toRename += (stat.symbol -> newName) // when adding assignment need to take `toRename` into account stateBuilder.addVarDef(mods, newName, tpt, rhs, toRename) - + case If(cond, thenp, elsep) => checkForUnsupportedAwait(cond) - + val ifBudget: Int = remainingBudget / 2 remainingBudget -= ifBudget //TODO test if budget > 0 // state that we continue with after if-else: currState + ifBudget - + val thenBudget: Int = ifBudget / 2 val elseBudget = ifBudget - thenBudget - + asyncStates += // the two Int arguments are the start state of the then branch and the else branch, respectively stateBuilder.resultWithIf(cond, currState + 1, currState + thenBudget) - + val thenBuilder = thenp match { case Block(thenStats, thenExpr) => new AsyncBlockBuilder(thenStats, thenExpr, currState + 1, currState + ifBudget, thenBudget, toRename) @@ -463,7 +459,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { } asyncStates ++= thenBuilder.asyncStates toRename ++= thenBuilder.toRename - + val elseBuilder = elsep match { case Block(elseStats, elseExpr) => new AsyncBlockBuilder(elseStats, elseExpr, currState + thenBudget, currState + ifBudget, elseBudget, toRename) @@ -472,11 +468,11 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { } asyncStates ++= elseBuilder.asyncStates toRename ++= elseBuilder.toRename - + // create new state builder for state `currState + ifBudget` currState = currState + ifBudget stateBuilder = new builder.AsyncStateBuilder(currState, toRename) - + case _ => checkForUnsupportedAwait(stat) stateBuilder += stat @@ -485,35 +481,37 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { stateBuilder += expr val lastState = stateBuilder.complete(endState).result asyncStates += lastState - + def mkCombinedHandlerExpr(): c.Expr[PartialFunction[Int, Unit]] = { assert(asyncStates.size > 1) - + val cases = for (state <- asyncStates.toList) yield state.mkHandlerCaseForState() c.Expr(mkHandlerTreeFor(cases zip asyncStates.init.map(_.state))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] } - + /* Builds the handler expression for a sequence of async states. */ def mkHandlerExpr(): c.Expr[PartialFunction[Int, Unit]] = { assert(asyncStates.size > 1) - + var handlerExpr = c.Expr(asyncStates(0).mkHandlerTreeForState()).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] - + if (asyncStates.size == 2) handlerExpr else { - for (asyncState <- asyncStates.tail.init) { // do not traverse first or last state + for (asyncState <- asyncStates.tail.init) { + // do not traverse first or last state val handlerTreeForNextState = asyncState.mkHandlerTreeForState() val currentHandlerTreeNaked = c.resetAllAttrs(handlerExpr.tree.duplicate) handlerExpr = c.Expr( Apply(Select(currentHandlerTreeNaked, newTermName("orElse")), - List(handlerTreeForNextState))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] + List(handlerTreeForNextState))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] } handlerExpr } } } + } |