From 23403a5ba6e7d045231d57572813859f6d344377 Mon Sep 17 00:00:00 2001 From: phaller Date: Mon, 29 Oct 2012 14:41:42 +0100 Subject: WIP: support await in if-else expressions --- src/async/library/scala/async/Async.scala | 276 +++++++++++++++++++++---- src/async/library/scala/async/AsyncUtils.scala | 2 +- 2 files changed, 239 insertions(+), 39 deletions(-) (limited to 'src/async/library/scala') diff --git a/src/async/library/scala/async/Async.scala b/src/async/library/scala/async/Async.scala index 92d6fcd..53ed062 100644 --- a/src/async/library/scala/async/Async.scala +++ b/src/async/library/scala/async/Async.scala @@ -20,6 +20,8 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { import c.universe._ import Flag._ + private val awaitMethod = awaitSym(c) + /* Make a partial function literal handling case #num: * * { @@ -49,6 +51,20 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { Ident(newTermName("state")), Apply(Select(Ident(newTermName("state")), newTermName("$plus")), List(Literal(Constant(1))))) + def mkStateTree(nextState: Int): c.Tree = + Assign( + Ident(newTermName("state")), + Literal(Constant(nextState))) + + def mkVarDefTree(resultType: c.universe.Type, resultName: c.universe.TermName): c.Tree = { + val rhs = + if (resultType <:< definitions.IntTpe) Literal(Constant(0)) + else if (resultType <:< definitions.LongTpe) Literal(Constant(0L)) + else if (resultType <:< definitions.BooleanTpe) Literal(Constant(false)) + else Literal(Constant(null)) + ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs) + } + def mkHandlerTree(num: Int, rhs: c.Tree): c.Tree = { val partFunIdent = Ident(c.mirror.staticClass("scala.PartialFunction")) val intIdent = Ident(definitions.IntClass) @@ -83,23 +99,50 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { ) } - class AsyncState(stats: List[c.Tree]) { + class AsyncState(stats: List[c.Tree], protected val state: Int, protected val nextState: Int) { val body: c.Tree = if (stats.size == 1) stats.head else Block(stats: _*) - def mkHandlerTreeForState(num: Int): c.Tree = - mkHandlerTree(num, Block((stats :+ mkIncrStateTree()): _*)) + val varDefs: List[(c.universe.TermName, c.universe.Type)] = List() + + def mkHandlerTreeForState(): c.Tree = + mkHandlerTree(state, Block((stats :+ mkStateTree(nextState)): _*)) + + def mkHandlerTreeForState(nextState: Int): c.Tree = + mkHandlerTree(state, Block((stats :+ mkStateTree(nextState)): _*)) def varDefForResult: Option[c.Tree] = None + + def allVarDefs: List[c.Tree] = + varDefForResult.toList ++ varDefs.map(p => mkVarDefTree(p._2, p._1)) + + override val toString: String = + s"AsyncState #$state, next = $nextState" } - abstract class AsyncStateWithAwait(stats: List[c.Tree]) extends AsyncState(stats) { + class AsyncStateWithIf(stats: List[c.Tree], state: Int) + extends AsyncState(stats, state, 0) { // nextState unused, since encoded in then and else branches + + override def mkHandlerTreeForState(): c.Tree = + mkHandlerTree(state, Block(stats: _*)) + + //TODO mkHandlerTreeForState(nextState: Int) + + override val toString: String = + s"AsyncStateWithIf #$state, next = $nextState" + } + + abstract class AsyncStateWithAwait(stats: List[c.Tree], state: Int, nextState: Int) + extends AsyncState(stats, state, nextState) { val awaitable: c.Tree val resultName: c.universe.TermName val resultType: c.universe.Type + override val toString: String = + s"AsyncStateWithAwait #$state, next = $nextState" + /* Make an `onComplete` invocation: * * awaitable.onComplete { @@ -159,6 +202,36 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { ) } + /* Make an `onComplete` invocation which sets the state to `nextState` upon resuming: + * + * awaitable.onComplete { + * case tr => + * resultName = tr.get + * state = `nextState` + * resume() + * } + */ + def mkOnCompleteStateTree(nextState: Int): c.Tree = { + val tryGetTree = + Assign( + Ident(resultName.toString), + Select(Ident("tr"), c.universe.newTermName("get")) + ) + val handlerTree = + Match( + EmptyTree, + List( + CaseDef(Bind(c.universe.newTermName("tr"), Ident("_")), EmptyTree, + Block(tryGetTree, mkStateTree(nextState), Apply(Ident("resume"), List())) // rhs of case + ) + ) + ) + Apply( + Select(awaitable, c.universe.newTermName("onComplete")), + List(handlerTree) + ) + } + /* Make a partial function literal handling case #num: * * { @@ -189,9 +262,14 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { * } * } */ - override def mkHandlerTreeForState(num: Int): c.Tree = { + override def mkHandlerTreeForState(): c.Tree = { + assert(awaitable != null) + mkHandlerTree(state, Block((stats :+ mkOnCompleteIncrStateTree): _*)) + } + + override def mkHandlerTreeForState(nextState: Int): c.Tree = { assert(awaitable != null) - builder.mkHandlerTree(num, Block((stats :+ mkOnCompleteIncrStateTree): _*)) + mkHandlerTree(state, Block((stats :+ mkOnCompleteStateTree(nextState)): _*)) } //TODO: complete for other primitive types, how to handle value classes? @@ -210,7 +288,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { /* * Builder for a single state of an async method. */ - class AsyncStateBuilder extends Builder[c.Tree, AsyncState] { + class AsyncStateBuilder(state: Int) extends Builder[c.Tree, AsyncState] { self => /* Statements preceding an await call. */ @@ -225,19 +303,32 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { /* Result type of an await call. */ var resultType: c.universe.Type = null + var nextState: Int = state + 1 + + private val varDefs = ListBuffer[(c.universe.TermName, c.universe.Type)]() + def += (stat: c.Tree): this.type = { stats += c.resetAllAttrs(stat.duplicate) this } + //TODO do not ignore `mods` + def addVarDef(mods: Any, name: c.universe.TermName, tpt: c.Tree): this.type = { + varDefs += (name -> tpt.tpe) + this + } + def result(): AsyncState = if (awaitable == null) - new AsyncState(stats.toList) + new AsyncState(stats.toList, state, nextState) { + override val varDefs = self.varDefs.toList + } else - new AsyncStateWithAwait(stats.toList) { + new AsyncStateWithAwait(stats.toList, state, nextState) { val awaitable = self.awaitable val resultName = self.resultName val resultType = self.resultType + override val varDefs = self.varDefs.toList } def clear(): Unit = { @@ -254,56 +345,164 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { * @param awaitResultName the name of the variable that the result of await is assigned to * @param awaitResultType the type of the result of await */ - def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: Tree): this.type = { + def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: Tree, nextState: Int = state + 1): this.type = { awaitable = c.resetAllAttrs(awaitArg.duplicate) resultName = awaitResultName resultType = awaitResultType.tpe + this.nextState = nextState this } + def complete(nextState: Int): this.type = { + this.nextState = nextState + this + } + + def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = { + // 1. build changed if-else tree + // 2. insert that tree at the end of the current state + val cond = c.resetAllAttrs(condTree.duplicate) + this += If(cond, mkStateTree(thenState), mkStateTree(elseState)) + new AsyncStateWithIf(stats.toList, state) { + override val varDefs = self.varDefs.toList + } + } + override def toString: String = { val statsBeforeAwait = stats.mkString("\n") s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName" } } - class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int) { + /* current issue: + def m2(y: Int): Future[Int] = async { + val f = m1(y) + if (y > 0) { + val x = await(f) + x + 2 + } else { + val x = await(f) + x - 2 + } + } + + */ + class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, budget: Int) { val asyncStates = ListBuffer[builder.AsyncState]() - private var stateBuilder = new builder.AsyncStateBuilder // current state builder - private val awaitMethod = awaitSym(c) + + private var stateBuilder = new builder.AsyncStateBuilder(startState) // current state builder + private var currState = startState + + private var remainingBudget = budget // populate asyncStates for (stat <- stats) stat match { // the val name = await(..) pattern case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == awaitMethod => asyncStates += stateBuilder.complete(args(0), name, tpt).result // complete with await - stateBuilder = new builder.AsyncStateBuilder - + currState += 1 + stateBuilder = new builder.AsyncStateBuilder(currState) + + case ValDef(mods, name, tpt, rhs) => + stateBuilder.addVarDef(mods, name, tpt) + stateBuilder += // instead of adding `stat` we add a simple assignment + Assign(Ident(name), c.resetAllAttrs(rhs.duplicate)) + + case If(cond, thenp, elsep) => + val ifBudget: Int = remainingBudget / 2 + remainingBudget -= ifBudget + println(s"ASYNC IF: ifBudget = $ifBudget") + // state that we continue with after if-else: currState + ifBudget + + val thenBudget: Int = ifBudget / 2 + val elseBudget = ifBudget - thenBudget + + asyncStates += + stateBuilder.resultWithIf(cond, currState + 1, currState + thenBudget) + + val thenBuilder = thenp match { + case Block(thenStats, thenExpr) => + new AsyncBlockBuilder(thenStats, thenExpr, currState + 1, currState + ifBudget, thenBudget) + case _ => + new AsyncBlockBuilder(List(thenp), Literal(Constant(())), currState + 1, currState + ifBudget, thenBudget) + } + println("ASYNC IF: states of thenp:") + for (s <- thenBuilder.asyncStates) + println(s.toString) + + // insert states of thenBuilder into asyncStates + asyncStates ++= thenBuilder.asyncStates + + val elseBuilder = elsep match { + case Block(elseStats, elseExpr) => + new AsyncBlockBuilder(elseStats, elseExpr, currState + thenBudget, currState + ifBudget, elseBudget) + case _ => + new AsyncBlockBuilder(List(elsep), Literal(Constant(())), currState + thenBudget, currState + ifBudget, elseBudget) + } + // insert states of elseBuilder into asyncStates + asyncStates ++= elseBuilder.asyncStates + + // create new state builder for state `currState + ifBudget` + currState = currState + ifBudget + stateBuilder = new builder.AsyncStateBuilder(currState) + case _ => stateBuilder += stat } // complete last state builder (representing the expressions after the last await) - asyncStates += (stateBuilder += expr).result + stateBuilder += expr + val lastState = stateBuilder.complete(endState).result + asyncStates += lastState /* Builds the handler expression for a sequence of async states. - * Also returns the index of the last state. */ - def mkHandlerExpr(): (c.Expr[PartialFunction[Int, Unit]], Int) = { - //var handlerExpr = asyncStates(0).mkHandlerForState(1) // state 0 but { case 1 => ... } - var handlerTree = asyncStates(0).mkHandlerTreeForState(0) - var handlerExpr = c.Expr(handlerTree).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] - - var i = 1 - for (asyncState <- asyncStates.tail.init) { - //val handlerForNextState = asyncStates(i).mkHandlerForState(i+1) - val handlerTreeForNextState = asyncState.mkHandlerTreeForState(i) + def mkHandlerExpr(): c.Expr[PartialFunction[Int, Unit]] = { + assert(asyncStates.size > 1) + + println(s"!!ASYNC mkHandlerExpr: asyncStates.size = ${asyncStates.size}") + println(s"!!ASYNC state 0: ${asyncStates(0)}") + + var handlerTree = + if (asyncStates.size > 2) asyncStates(0).mkHandlerTreeForState() + else asyncStates(0).mkHandlerTreeForState(endState) + var handlerExpr = + c.Expr(handlerTree).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] + + if (asyncStates.size == 2) + handlerExpr + else if (asyncStates.size == 3) { + // asyncStates(1) must continue with endState + val handlerTreeForLastState = asyncStates(1).mkHandlerTreeForState(endState) + val currentHandlerTreeNaked = c.resetAllAttrs(handlerExpr.tree.duplicate) + c.Expr( + Apply(Select(currentHandlerTreeNaked, newTermName("orElse")), + List(handlerTreeForLastState))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] + } else { // asyncStates.size > 3 + var i = startState + 1 + + println("!!ASYNC start for loop") + + // do not traverse first state: asyncStates.tail + // do not traverse last state: asyncStates.tail.init + // handle second to last state specially: asyncStates.tail.init.init + for (asyncState <- asyncStates.tail.init.init) { + println(s"!!ASYNC current asyncState: $asyncState") + val handlerTreeForNextState = asyncState.mkHandlerTreeForState() + val currentHandlerTreeNaked = c.resetAllAttrs(handlerExpr.tree.duplicate) + handlerExpr = c.Expr( + Apply(Select(currentHandlerTreeNaked, newTermName("orElse")), + List(handlerTreeForNextState))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] + i += 1 + } + + val lastState = asyncStates.tail.init.last + println(s"!!ASYNC current asyncState (forced to $endState): $lastState") + val handlerTreeForLastState = lastState.mkHandlerTreeForState(endState) val currentHandlerTreeNaked = c.resetAllAttrs(handlerExpr.tree.duplicate) - handlerExpr = c.Expr( - Apply(Select(currentHandlerTreeNaked, newTermName("orElse")), List(handlerTreeForNextState))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] - i += 1 + c.Expr( + Apply(Select(currentHandlerTreeNaked, newTermName("orElse")), + List(handlerTreeForLastState))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] } - // asyncStates(i) does not end with `await` (asyncStates(i).awaitable == null) - (handlerExpr, i) } } @@ -327,28 +526,29 @@ object Async extends AsyncUtils { body.tree match { case Block(stats, expr) => - val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000) + val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000) - vprintln("states of current method:") + vprintln(s"states of current method (${ asyncBlockBuilder.asyncStates }):") asyncBlockBuilder.asyncStates foreach vprintln - val (handlerExpr, indexOfLastState) = asyncBlockBuilder.mkHandlerExpr() + val handlerExpr = asyncBlockBuilder.mkHandlerExpr() - vprintln(s"GENERATED handler expr ($indexOfLastState):") + vprintln(s"GENERATED handler expr:") vprintln(handlerExpr) val localVarDefs = ListBuffer[c.Tree]() for (state <- asyncBlockBuilder.asyncStates.init) // exclude last state (doesn't have await result) - localVarDefs ++= state.varDefForResult.toList + localVarDefs ++= //state.varDefForResult.toList + state.allVarDefs // pad up to 5 var defs if (localVarDefs.size < 5) for (_ <- localVarDefs.size until 5) localVarDefs += EmptyTree val handlerForLastState: c.Expr[PartialFunction[Int, Unit]] = { val tree = Apply(Select(Ident("result"), c.universe.newTermName("success")), - List(asyncBlockBuilder.asyncStates(indexOfLastState).body)) + List(asyncBlockBuilder.asyncStates.last.body)) //builder.mkHandler(indexOfLastState + 1, c.Expr[Unit](tree)) - builder.mkHandler(indexOfLastState, c.Expr[Unit](tree)) + builder.mkHandler(1000, c.Expr[Unit](tree)) } vprintln("GENERATED handler for last state:") diff --git a/src/async/library/scala/async/AsyncUtils.scala b/src/async/library/scala/async/AsyncUtils.scala index 820541b..98330a5 100644 --- a/src/async/library/scala/async/AsyncUtils.scala +++ b/src/async/library/scala/async/AsyncUtils.scala @@ -10,7 +10,7 @@ import scala.reflect.macros.Context */ trait AsyncUtils { - val verbose = false + val verbose = true protected def vprintln(s: Any): Unit = if (verbose) println("[async] "+s) -- cgit v1.2.3