aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorphaller <hallerp@gmail.com>2012-10-26 16:51:29 +0200
committerphaller <hallerp@gmail.com>2012-10-26 16:51:29 +0200
commit0aa0110cdbb303531436d580c7b2c588c7dd1057 (patch)
tree6d66c0a3e750c5485f9a2d52f063e0a1fcfe4ad5
parenta3978cd531915920597889845c40096395d5b8d8 (diff)
downloadscala-async-0aa0110cdbb303531436d580c7b2c588c7dd1057.tar.gz
scala-async-0aa0110cdbb303531436d580c7b2c588c7dd1057.tar.bz2
scala-async-0aa0110cdbb303531436d580c7b2c588c7dd1057.zip
Introduce immutable AsyncState class
- Refactor AsyncStateBuilder to extend collection.mutable.Builder - Reset attributes of duplicated trees only once inside the builder
-rw-r--r--src/async/library/scala/async/Async.scala180
1 files changed, 103 insertions, 77 deletions
diff --git a/src/async/library/scala/async/Async.scala b/src/async/library/scala/async/Async.scala
index e480607..d6ebc47 100644
--- a/src/async/library/scala/async/Async.scala
+++ b/src/async/library/scala/async/Async.scala
@@ -9,7 +9,7 @@ import scala.reflect.runtime.universe
import scala.concurrent.{ Future, Promise }
import scala.util.control.NonFatal
-import scala.collection.mutable.ListBuffer
+import scala.collection.mutable.{ ListBuffer, Builder }
/*
* @author Philipp Haller
@@ -43,7 +43,12 @@ class ExprBuilder[C <: Context with Singleton](val c: C) {
val handlerTree = mkHandlerTree(num, rhsTree)
c.Expr(handlerTree).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]]
}
-
+
+ def mkIncrStateTree(): c.Tree =
+ Assign(
+ Ident(newTermName("state")),
+ Apply(Select(Ident(newTermName("state")), newTermName("$plus")), List(Literal(Constant(1)))))
+
def mkHandlerTree(num: Int, rhs: c.Tree): c.Tree = {
val partFunIdent = Ident(c.mirror.staticClass("scala.PartialFunction"))
val intIdent = Ident(definitions.IntClass)
@@ -78,42 +83,22 @@ class ExprBuilder[C <: Context with Singleton](val c: C) {
)
}
- /*
- * Builder for a single state of an async method.
- */
- class AsyncStateBuilder {
- /* Statements preceding an await call. */
- private val stats = ListBuffer[c.Tree]()
-
- /* Argument of an await call. */
- var awaitable: c.Tree = null
+ class AsyncState(stats: List[c.Tree]) {
+ val body: c.Tree =
+ if (stats.size == 1) stats.head
+ else Block(stats: _*)
- /* Result name of an await call. */
- var resultName: c.universe.TermName = null
-
- /* Result type of an await call. */
- var resultType: c.universe.Type = null
-
- def += (stat: c.Tree): Unit =
- stats += stat
-
- /* Result needs to be created as a var at the beginning of the transformed method body, so that
- * it is visible in subsequent states of the state machine.
- *
- * @param awaitArg the argument of await
- * @param awaitResultName the name of the variable that the result of await is assigned to
- * @param awaitResultType the type of the result of await
- */
- def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: Tree): Unit = {
- awaitable = c.resetAllAttrs(awaitArg.duplicate)
- resultName = awaitResultName
- resultType = awaitResultType.tpe
- }
+ def mkHandlerTreeForState(num: Int): c.Tree =
+ mkHandlerTree(num, Block((stats :+ mkIncrStateTree()): _*))
- override def toString: String = {
- val statsBeforeAwait = stats.mkString("\n")
- s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName"
- }
+ def varDefForResult: Option[c.Tree] =
+ None
+ }
+
+ abstract class AsyncStateWithAwait(stats: List[c.Tree]) extends AsyncState(stats) {
+ val awaitable: c.Tree
+ val resultName: c.universe.TermName
+ val resultType: c.universe.Type
/* Make an `onComplete` invocation:
*
@@ -153,23 +138,18 @@ class ExprBuilder[C <: Context with Singleton](val c: C) {
* resume()
* }
*/
- def mkOnCompleteTreeIncrState: c.Tree = {
+ def mkOnCompleteIncrStateTree: c.Tree = {
val tryGetTree =
Assign(
Ident(resultName.toString),
Select(Ident("tr"), c.universe.newTermName("get"))
)
- val incrementStateTree =
- Assign(
- Ident(newTermName("state")),
- Apply(Select(Ident(newTermName("state")), newTermName("$plus")), List(Literal(Constant(1))))
- )
val handlerTree =
Match(
EmptyTree,
List(
CaseDef(Bind(c.universe.newTermName("tr"), Ident("_")), EmptyTree,
- Block(tryGetTree, incrementStateTree, Apply(Ident("resume"), List())) // rhs of case
+ Block(tryGetTree, mkIncrStateTree(), Apply(Ident("resume"), List())) // rhs of case
)
)
)
@@ -193,8 +173,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) {
*/
def mkHandlerForState(num: Int): c.Expr[PartialFunction[Int, Unit]] = {
assert(awaitable != null)
- val nakedStats = stats.map(stat => c.resetAllAttrs(stat.duplicate))
- builder.mkHandler(num, c.Expr[Unit](Block((nakedStats :+ mkOnCompleteTree): _*)))
+ builder.mkHandler(num, c.Expr[Unit](Block((stats :+ mkOnCompleteTree): _*)))
}
/* Make a partial function literal handling case #num:
@@ -210,30 +189,81 @@ class ExprBuilder[C <: Context with Singleton](val c: C) {
* }
* }
*/
- def mkHandlerTreeForState(num: Int): c.Tree = {
+ override def mkHandlerTreeForState(num: Int): c.Tree = {
assert(awaitable != null)
- val nakedStats = stats.map(stat => c.resetAllAttrs(stat.duplicate))
- builder.mkHandlerTree(num, Block((nakedStats :+ mkOnCompleteTreeIncrState): _*))
- }
-
- def lastExprTree: c.Tree = {
- assert(awaitable == null)
- if (stats.size == 1)
- c.resetAllAttrs(stats(0).duplicate)
- else {
- val nakedStats = stats.map(stat => c.resetAllAttrs(stat.duplicate))
- Block(nakedStats: _*)
- }
+ builder.mkHandlerTree(num, Block((stats :+ mkOnCompleteIncrStateTree): _*))
}
//TODO: complete for other primitive types, how to handle value classes?
- def varDefForResult: c.Tree = {
+ override def varDefForResult: Option[c.Tree] = {
val rhs =
if (resultType <:< definitions.IntTpe) Literal(Constant(0))
else if (resultType <:< definitions.LongTpe) Literal(Constant(0L))
else if (resultType <:< definitions.BooleanTpe) Literal(Constant(false))
else Literal(Constant(null))
- ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs)
+ Some(
+ ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs)
+ )
+ }
+ }
+
+ /*
+ * Builder for a single state of an async method.
+ */
+ class AsyncStateBuilder extends Builder[c.Tree, AsyncState] {
+ self =>
+
+ /* Statements preceding an await call. */
+ private val stats = ListBuffer[c.Tree]()
+
+ /* Argument of an await call. */
+ var awaitable: c.Tree = null
+
+ /* Result name of an await call. */
+ var resultName: c.universe.TermName = null
+
+ /* Result type of an await call. */
+ var resultType: c.universe.Type = null
+
+ def += (stat: c.Tree): this.type = {
+ stats += c.resetAllAttrs(stat.duplicate)
+ this
+ }
+
+ def result(): AsyncState =
+ if (awaitable == null)
+ new AsyncState(stats.toList)
+ else
+ new AsyncStateWithAwait(stats.toList) {
+ val awaitable = self.awaitable
+ val resultName = self.resultName
+ val resultType = self.resultType
+ }
+
+ def clear(): Unit = {
+ stats.clear()
+ awaitable = null
+ resultName = null
+ resultType = null
+ }
+
+ /* Result needs to be created as a var at the beginning of the transformed method body, so that
+ * it is visible in subsequent states of the state machine.
+ *
+ * @param awaitArg the argument of await
+ * @param awaitResultName the name of the variable that the result of await is assigned to
+ * @param awaitResultType the type of the result of await
+ */
+ def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: Tree): this.type = {
+ awaitable = c.resetAllAttrs(awaitArg.duplicate)
+ resultName = awaitResultName
+ resultType = awaitResultType.tpe
+ this
+ }
+
+ override def toString: String = {
+ val statsBeforeAwait = stats.mkString("\n")
+ s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName"
}
}
@@ -257,24 +287,20 @@ object Async extends AsyncUtils {
body.tree match {
case Block(stats, expr) =>
- val asyncStates = ListBuffer[builder.AsyncStateBuilder]()
+ val asyncStates = ListBuffer[builder.AsyncState]()
var stateBuilder = new builder.AsyncStateBuilder // current state builder
- for (stat <- stats) {
- stat match {
- // the val name = await(..) pattern
- case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == awaitMethod =>
- stateBuilder.complete(args(0), name, tpt)
- asyncStates += stateBuilder
- stateBuilder = new builder.AsyncStateBuilder
-
- case _ =>
- stateBuilder += stat
- }
+ for (stat <- stats) stat match {
+ // the val name = await(..) pattern
+ case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == awaitMethod =>
+ asyncStates += stateBuilder.complete(args(0), name, tpt).result // complete with await
+ stateBuilder = new builder.AsyncStateBuilder
+
+ case _ =>
+ stateBuilder += stat
}
// complete last state builder (representing the expressions after the last await)
- stateBuilder += expr
- asyncStates += stateBuilder
+ asyncStates += (stateBuilder += expr).result
vprintln("states of current method:")
asyncStates foreach vprintln
@@ -288,7 +314,7 @@ object Async extends AsyncUtils {
var handlerExpr = c.Expr(handlerTree).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]]
var i = 1
- while (asyncStates(i).awaitable != null) {
+ while (asyncStates(i).isInstanceOf[builder.AsyncStateWithAwait]) {
//val handlerForNextState = asyncStates(i).mkHandlerForState(i+1)
val handlerTreeForNextState = asyncStates(i).mkHandlerTreeForState(i)
@@ -310,14 +336,14 @@ object Async extends AsyncUtils {
val localVarDefs = ListBuffer[c.Tree]()
for (state <- asyncStates.init) // exclude last state (doesn't have await result)
- localVarDefs += state.varDefForResult
+ localVarDefs ++= state.varDefForResult.toList
// pad up to 5 var defs
if (localVarDefs.size < 5)
for (_ <- localVarDefs.size until 5) localVarDefs += EmptyTree
val handlerForLastState: c.Expr[PartialFunction[Int, Unit]] = {
val tree = Apply(Select(Ident("result"), c.universe.newTermName("success")),
- List(asyncStates(indexOfLastState).lastExprTree))
+ List(asyncStates(indexOfLastState).body))
//builder.mkHandler(indexOfLastState + 1, c.Expr[Unit](tree))
builder.mkHandler(indexOfLastState, c.Expr[Unit](tree))
}