diff options
author | phaller <hallerp@gmail.com> | 2012-10-26 16:51:29 +0200 |
---|---|---|
committer | phaller <hallerp@gmail.com> | 2012-10-26 16:51:29 +0200 |
commit | 0aa0110cdbb303531436d580c7b2c588c7dd1057 (patch) | |
tree | 6d66c0a3e750c5485f9a2d52f063e0a1fcfe4ad5 /src | |
parent | a3978cd531915920597889845c40096395d5b8d8 (diff) | |
download | scala-async-0aa0110cdbb303531436d580c7b2c588c7dd1057.tar.gz scala-async-0aa0110cdbb303531436d580c7b2c588c7dd1057.tar.bz2 scala-async-0aa0110cdbb303531436d580c7b2c588c7dd1057.zip |
Introduce immutable AsyncState class
- Refactor AsyncStateBuilder to extend collection.mutable.Builder
- Reset attributes of duplicated trees only once inside the builder
Diffstat (limited to 'src')
-rw-r--r-- | src/async/library/scala/async/Async.scala | 180 |
1 files changed, 103 insertions, 77 deletions
diff --git a/src/async/library/scala/async/Async.scala b/src/async/library/scala/async/Async.scala index e480607..d6ebc47 100644 --- a/src/async/library/scala/async/Async.scala +++ b/src/async/library/scala/async/Async.scala @@ -9,7 +9,7 @@ import scala.reflect.runtime.universe import scala.concurrent.{ Future, Promise } import scala.util.control.NonFatal -import scala.collection.mutable.ListBuffer +import scala.collection.mutable.{ ListBuffer, Builder } /* * @author Philipp Haller @@ -43,7 +43,12 @@ class ExprBuilder[C <: Context with Singleton](val c: C) { val handlerTree = mkHandlerTree(num, rhsTree) c.Expr(handlerTree).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] } - + + def mkIncrStateTree(): c.Tree = + Assign( + Ident(newTermName("state")), + Apply(Select(Ident(newTermName("state")), newTermName("$plus")), List(Literal(Constant(1))))) + def mkHandlerTree(num: Int, rhs: c.Tree): c.Tree = { val partFunIdent = Ident(c.mirror.staticClass("scala.PartialFunction")) val intIdent = Ident(definitions.IntClass) @@ -78,42 +83,22 @@ class ExprBuilder[C <: Context with Singleton](val c: C) { ) } - /* - * Builder for a single state of an async method. - */ - class AsyncStateBuilder { - /* Statements preceding an await call. */ - private val stats = ListBuffer[c.Tree]() - - /* Argument of an await call. */ - var awaitable: c.Tree = null + class AsyncState(stats: List[c.Tree]) { + val body: c.Tree = + if (stats.size == 1) stats.head + else Block(stats: _*) - /* Result name of an await call. */ - var resultName: c.universe.TermName = null - - /* Result type of an await call. */ - var resultType: c.universe.Type = null - - def += (stat: c.Tree): Unit = - stats += stat - - /* Result needs to be created as a var at the beginning of the transformed method body, so that - * it is visible in subsequent states of the state machine. - * - * @param awaitArg the argument of await - * @param awaitResultName the name of the variable that the result of await is assigned to - * @param awaitResultType the type of the result of await - */ - def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: Tree): Unit = { - awaitable = c.resetAllAttrs(awaitArg.duplicate) - resultName = awaitResultName - resultType = awaitResultType.tpe - } + def mkHandlerTreeForState(num: Int): c.Tree = + mkHandlerTree(num, Block((stats :+ mkIncrStateTree()): _*)) - override def toString: String = { - val statsBeforeAwait = stats.mkString("\n") - s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName" - } + def varDefForResult: Option[c.Tree] = + None + } + + abstract class AsyncStateWithAwait(stats: List[c.Tree]) extends AsyncState(stats) { + val awaitable: c.Tree + val resultName: c.universe.TermName + val resultType: c.universe.Type /* Make an `onComplete` invocation: * @@ -153,23 +138,18 @@ class ExprBuilder[C <: Context with Singleton](val c: C) { * resume() * } */ - def mkOnCompleteTreeIncrState: c.Tree = { + def mkOnCompleteIncrStateTree: c.Tree = { val tryGetTree = Assign( Ident(resultName.toString), Select(Ident("tr"), c.universe.newTermName("get")) ) - val incrementStateTree = - Assign( - Ident(newTermName("state")), - Apply(Select(Ident(newTermName("state")), newTermName("$plus")), List(Literal(Constant(1)))) - ) val handlerTree = Match( EmptyTree, List( CaseDef(Bind(c.universe.newTermName("tr"), Ident("_")), EmptyTree, - Block(tryGetTree, incrementStateTree, Apply(Ident("resume"), List())) // rhs of case + Block(tryGetTree, mkIncrStateTree(), Apply(Ident("resume"), List())) // rhs of case ) ) ) @@ -193,8 +173,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) { */ def mkHandlerForState(num: Int): c.Expr[PartialFunction[Int, Unit]] = { assert(awaitable != null) - val nakedStats = stats.map(stat => c.resetAllAttrs(stat.duplicate)) - builder.mkHandler(num, c.Expr[Unit](Block((nakedStats :+ mkOnCompleteTree): _*))) + builder.mkHandler(num, c.Expr[Unit](Block((stats :+ mkOnCompleteTree): _*))) } /* Make a partial function literal handling case #num: @@ -210,30 +189,81 @@ class ExprBuilder[C <: Context with Singleton](val c: C) { * } * } */ - def mkHandlerTreeForState(num: Int): c.Tree = { + override def mkHandlerTreeForState(num: Int): c.Tree = { assert(awaitable != null) - val nakedStats = stats.map(stat => c.resetAllAttrs(stat.duplicate)) - builder.mkHandlerTree(num, Block((nakedStats :+ mkOnCompleteTreeIncrState): _*)) - } - - def lastExprTree: c.Tree = { - assert(awaitable == null) - if (stats.size == 1) - c.resetAllAttrs(stats(0).duplicate) - else { - val nakedStats = stats.map(stat => c.resetAllAttrs(stat.duplicate)) - Block(nakedStats: _*) - } + builder.mkHandlerTree(num, Block((stats :+ mkOnCompleteIncrStateTree): _*)) } //TODO: complete for other primitive types, how to handle value classes? - def varDefForResult: c.Tree = { + override def varDefForResult: Option[c.Tree] = { val rhs = if (resultType <:< definitions.IntTpe) Literal(Constant(0)) else if (resultType <:< definitions.LongTpe) Literal(Constant(0L)) else if (resultType <:< definitions.BooleanTpe) Literal(Constant(false)) else Literal(Constant(null)) - ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs) + Some( + ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs) + ) + } + } + + /* + * Builder for a single state of an async method. + */ + class AsyncStateBuilder extends Builder[c.Tree, AsyncState] { + self => + + /* Statements preceding an await call. */ + private val stats = ListBuffer[c.Tree]() + + /* Argument of an await call. */ + var awaitable: c.Tree = null + + /* Result name of an await call. */ + var resultName: c.universe.TermName = null + + /* Result type of an await call. */ + var resultType: c.universe.Type = null + + def += (stat: c.Tree): this.type = { + stats += c.resetAllAttrs(stat.duplicate) + this + } + + def result(): AsyncState = + if (awaitable == null) + new AsyncState(stats.toList) + else + new AsyncStateWithAwait(stats.toList) { + val awaitable = self.awaitable + val resultName = self.resultName + val resultType = self.resultType + } + + def clear(): Unit = { + stats.clear() + awaitable = null + resultName = null + resultType = null + } + + /* Result needs to be created as a var at the beginning of the transformed method body, so that + * it is visible in subsequent states of the state machine. + * + * @param awaitArg the argument of await + * @param awaitResultName the name of the variable that the result of await is assigned to + * @param awaitResultType the type of the result of await + */ + def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: Tree): this.type = { + awaitable = c.resetAllAttrs(awaitArg.duplicate) + resultName = awaitResultName + resultType = awaitResultType.tpe + this + } + + override def toString: String = { + val statsBeforeAwait = stats.mkString("\n") + s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName" } } @@ -257,24 +287,20 @@ object Async extends AsyncUtils { body.tree match { case Block(stats, expr) => - val asyncStates = ListBuffer[builder.AsyncStateBuilder]() + val asyncStates = ListBuffer[builder.AsyncState]() var stateBuilder = new builder.AsyncStateBuilder // current state builder - for (stat <- stats) { - stat match { - // the val name = await(..) pattern - case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == awaitMethod => - stateBuilder.complete(args(0), name, tpt) - asyncStates += stateBuilder - stateBuilder = new builder.AsyncStateBuilder - - case _ => - stateBuilder += stat - } + for (stat <- stats) stat match { + // the val name = await(..) pattern + case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == awaitMethod => + asyncStates += stateBuilder.complete(args(0), name, tpt).result // complete with await + stateBuilder = new builder.AsyncStateBuilder + + case _ => + stateBuilder += stat } // complete last state builder (representing the expressions after the last await) - stateBuilder += expr - asyncStates += stateBuilder + asyncStates += (stateBuilder += expr).result vprintln("states of current method:") asyncStates foreach vprintln @@ -288,7 +314,7 @@ object Async extends AsyncUtils { var handlerExpr = c.Expr(handlerTree).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] var i = 1 - while (asyncStates(i).awaitable != null) { + while (asyncStates(i).isInstanceOf[builder.AsyncStateWithAwait]) { //val handlerForNextState = asyncStates(i).mkHandlerForState(i+1) val handlerTreeForNextState = asyncStates(i).mkHandlerTreeForState(i) @@ -310,14 +336,14 @@ object Async extends AsyncUtils { val localVarDefs = ListBuffer[c.Tree]() for (state <- asyncStates.init) // exclude last state (doesn't have await result) - localVarDefs += state.varDefForResult + localVarDefs ++= state.varDefForResult.toList // pad up to 5 var defs if (localVarDefs.size < 5) for (_ <- localVarDefs.size until 5) localVarDefs += EmptyTree val handlerForLastState: c.Expr[PartialFunction[Int, Unit]] = { val tree = Apply(Select(Ident("result"), c.universe.newTermName("success")), - List(asyncStates(indexOfLastState).lastExprTree)) + List(asyncStates(indexOfLastState).body)) //builder.mkHandler(indexOfLastState + 1, c.Expr[Unit](tree)) builder.mkHandler(indexOfLastState, c.Expr[Unit](tree)) } |