diff options
-rw-r--r-- | src/async/library/scala/async/Async.scala | 216 |
1 files changed, 111 insertions, 105 deletions
diff --git a/src/async/library/scala/async/Async.scala b/src/async/library/scala/async/Async.scala index b72d088..424a67d 100644 --- a/src/async/library/scala/async/Async.scala +++ b/src/async/library/scala/async/Async.scala @@ -11,9 +11,14 @@ import scala.concurrent.{ Future, Promise } import scala.util.control.NonFatal import scala.collection.mutable.ListBuffer +/* + * @author Philipp Haller + */ class ExprBuilder[C <: Context with Singleton](val c: C) { - import c.universe.{ reify, Literal, Constant } - + builder => + + import c.universe._ + /* Make a partial function literal handling case #num: * * { @@ -32,8 +37,109 @@ class ExprBuilder[C <: Context with Singleton](val c: C) { } }) } + + class AsyncStateBuilder { + /* Statements preceding an await call. */ + private val stats = ListBuffer[c.Tree]() + + /* Argument of an await call. */ + var awaitable: c.Tree = null + + /* Result name of an await call. */ + var resultName: c.universe.TermName = null + + /* Result type of an await call. */ + var resultType: c.universe.Type = null + + def += (stat: c.Tree): Unit = + stats += stat + + /* Result needs to be created as a var at the beginning of the transformed method bodyso that + it is visible in subsequent states of the state machine. + */ + def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: Tree): Unit = { + awaitable = c.resetAllAttrs(awaitArg.duplicate) + resultName = awaitResultName + resultType = awaitResultType.tpe + } + + override def toString: String = { + val statsBeforeAwait = stats.mkString("\n") + s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName" + } + + /* Make an `onComplete` invocation: + * + * awaitable.onComplete { + * case tr => + * resultName = tr.get + * resume() + * } + */ + def mkOnCompleteTree: c.Tree = { + val assignTree = + Assign( + Ident(resultName.toString), + Select(Ident("tr"), c.universe.newTermName("get")) + ) + val handlerTree = + Match( + EmptyTree, + List( + CaseDef(Bind(c.universe.newTermName("tr"), Ident("_")), EmptyTree, + Block(assignTree, Apply(Ident("resume"), List())) // rhs of case + ) + ) + ) + Apply( + Select(awaitable, c.universe.newTermName("onComplete")), + List(handlerTree) + ) + } + + /* Make a partial function literal handling case #num: + * + * { + * case any if any == num => + * stats + * awaitable.onComplete { + * case tr => + * resultName = tr.get + * resume() + * } + * } + */ + def mkHandlerForState(num: Int): c.Expr[PartialFunction[Int, Unit]] = { + assert(awaitable != null) + val nakedStats = stats.map(stat => c.resetAllAttrs(stat.duplicate)) + val block = Block((nakedStats :+ mkOnCompleteTree): _*) + builder.mkHandler(num, c.Expr[Unit](block)) + } + + def lastExprTree: c.Tree = { + assert(awaitable == null) + if (stats.size == 1) + c.resetAllAttrs(stats(0).duplicate) + else { + val nakedStats = stats.map(stat => c.resetAllAttrs(stat.duplicate)) + Block(nakedStats: _*) + } + } + + //TODO: complete for other primitive types, how to handle value classes? + def varDefForResult: c.Tree = { + val rhs = + if (resultType <:< definitions.IntTpe) Literal(Constant(0)) + else if (resultType <:< definitions.LongTpe) Literal(Constant(0L)) + else if (resultType <:< definitions.BooleanTpe) Literal(Constant(false)) + else Literal(Constant(null)) + ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs) + } + } + } + /* * @author Philipp Haller */ @@ -48,110 +154,10 @@ object Async extends AsyncUtils { val builder = new ExprBuilder[c.type](c) - class AsyncStateBuilder { - /* Statements preceding an await call. */ - private val stats = ListBuffer[c.Tree]() - - /* Argument of an await call. */ - var awaitable: c.Tree = null - - /* Result name of an await call. */ - var resultName: c.universe.TermName = null - - /* Result type of an await call. */ - var resultType: c.universe.Type = null - - def += (stat: c.Tree): Unit = - stats += stat - - /* Result needs to be created as a var at the beginning of the transformed method body, so that - it is visible in subsequent states of the state machine. - */ - def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: c.Tree): Unit = { - awaitable = c.resetAllAttrs(awaitArg.duplicate) - resultName = awaitResultName - resultType = awaitResultType.tpe - } - - override def toString: String = { - val statsBeforeAwait = stats.mkString("\n") - s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName" - } - - /* Make an `onComplete` invocation: - * - * awaitable.onComplete { - * case tr => - * resultName = tr.get - * resume() - * } - */ - def mkOnCompleteTree: c.Tree = { - val assignTree = - Assign( - Ident(resultName.toString), - Select(Ident("tr"), c.universe.newTermName("get")) - ) - val handlerTree = - Match( - EmptyTree, - List( - CaseDef(Bind(c.universe.newTermName("tr"), Ident("_")), EmptyTree, - Block(assignTree, Apply(Ident("resume"), List())) // rhs of case - ) - ) - ) - Apply( - Select(awaitable, c.universe.newTermName("onComplete")), - List(handlerTree) - ) - } - - /* Make a partial function literal handling case #num: - * - * { - * case any if any == num => - * stats - * awaitable.onComplete { - * case tr => - * resultName = tr.get - * resume() - * } - * } - */ - def mkHandlerForState(num: Int): c.Expr[PartialFunction[Int, Unit]] = { - assert(awaitable != null) - val nakedStats = stats.map(stat => c.resetAllAttrs(stat.duplicate)) - val block = Block((nakedStats :+ mkOnCompleteTree): _*) - builder.mkHandler(num, c.Expr[Unit](block)) - } - - def lastExprTree: c.Tree = { - assert(awaitable == null) - if (stats.size == 1) - c.resetAllAttrs(stats(0).duplicate) - else { - val nakedStats = stats.map(stat => c.resetAllAttrs(stat.duplicate)) - Block(nakedStats: _*) - } - } - - //TODO: complete for other primitive types, how to handle value classes? - def varDefForResult: c.Tree = { - val rhs = - if (resultType <:< definitions.IntTpe) Literal(Constant(0)) - else if (resultType <:< definitions.LongTpe) Literal(Constant(0L)) - else if (resultType <:< definitions.BooleanTpe) Literal(Constant(false)) - else Literal(Constant(null)) - ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs) - } - } - - body.tree match { case Block(stats, expr) => - val asyncStates = ListBuffer[AsyncStateBuilder]() - var stateBuilder = new AsyncStateBuilder // current state builder + val asyncStates = ListBuffer[builder.AsyncStateBuilder]() + var stateBuilder = new builder.AsyncStateBuilder // current state builder val awaitMethod = awaitSym(c) for (stat <- stats) { @@ -160,7 +166,7 @@ object Async extends AsyncUtils { case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == awaitMethod => stateBuilder.complete(args(0), name, tpt) asyncStates += stateBuilder - stateBuilder = new AsyncStateBuilder + stateBuilder = new builder.AsyncStateBuilder case _ => stateBuilder += stat |