/* * Copyright (C) 2012 Typesafe Inc. */ package scala.async import scala.reflect.macros.Context import scala.collection.mutable.ListBuffer import collection.mutable /* * @author Philipp Haller */ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val futureSystem: FS) extends TransformUtils(c) { builder => import c.universe._ import defn._ lazy val futureSystemOps = futureSystem.mkOps(c) private def resetDuplicate(tree: Tree) = c.resetAllAttrs(tree.duplicate) private def mkResumeApply = Apply(Ident(name.resume), Nil) private def mkStateTree(nextState: Int): c.Tree = mkStateTree(c.literal(nextState).tree) private def mkStateTree(nextState: Tree): c.Tree = Assign(Ident(name.state), nextState) def mkVarDefTree(resultType: Type, resultName: TermName): c.Tree = { ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType)) } private def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef = mkHandlerCase(num, Block(rhs: _*)) private def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = { CaseDef(c.literal(num).tree, EmptyTree, rhs) } class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) { val body: c.Tree = stats match { case stat :: Nil => stat case _ => Block(stats: _*) } val varDefs: List[(TermName, Type)] = Nil def mkHandlerCaseForState(): CaseDef = mkHandlerCase(state, stats :+ mkStateTree(nextState) :+ mkResumeApply) def mkOnCompleteHandler(): Option[CaseDef] = { this match { case aw: AsyncStateWithAwait => val tryGetTree = Assign( Ident(aw.resultName), TypeApply(Select(Select(Ident(name.tr), Try_get), newTermName("asInstanceOf")), List(TypeTree(aw.resultType))) ) val updateState = mkStateTree(nextState) Some(mkHandlerCase(state, List(tryGetTree, updateState, mkResumeApply))) case _ => None } } override val toString: String = s"AsyncState #$state, next = $nextState" } class AsyncStateWithoutAwait(stats: List[c.Tree], state: Int) extends AsyncState(stats, state, 0) { // nextState unused, since encoded in then and else branches override def mkHandlerCaseForState(): CaseDef = mkHandlerCase(state, stats) override val toString: String = s"AsyncStateWithIf #$state, next = $nextState" } abstract class AsyncStateWithAwait(stats: List[c.Tree], state: Int, nextState: Int) extends AsyncState(stats, state, nextState) { val awaitable : c.Tree val resultName: TermName val resultType: Type protected def tryType = appliedType(TryClass.toType, List(resultType)) override val toString: String = s"AsyncStateWithAwait #$state, next = $nextState" private def mkOnCompleteTree: c.Tree = { futureSystemOps.onComplete(c.Expr(awaitable), c.Expr(Ident(name.onCompleteHandler)), c.Expr(Ident(name.execContext))).tree } override def mkHandlerCaseForState(): CaseDef = { assert(awaitable != null) mkHandlerCase(state, stats :+ mkOnCompleteTree) } } /* * Builder for a single state of an async method. */ class AsyncStateBuilder(state: Int, private val nameMap: Map[Symbol, c.Name]) { self => /* Statements preceding an await call. */ private val stats = ListBuffer[c.Tree]() /* Argument of an await call. */ var awaitable: c.Tree = null /* Result name of an await call. */ var resultName: TermName = null /* Result type of an await call. */ var resultType: Type = null var nextState : Int = -1 var nextJumpState: Option[Int] = None private val renamer = new Transformer { override def transform(tree: Tree) = tree match { case Ident(_) if nameMap.keySet contains tree.symbol => Ident(nameMap(tree.symbol)) case _ => super.transform(tree) } } def +=(stat: c.Tree): this.type = { assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat") def addStat() = stats += resetDuplicate(renamer.transform(stat)) stat match { case Apply(fun, Nil) => labelDefStates get fun.symbol match { case Some(nextState) => nextJumpState = Some(nextState) case None => addStat() } case _ => addStat() } this } //TODO do not ignore `mods` def addVarDef(mods: Any, name: TermName, tpt: c.Tree, rhs: c.Tree): this.type = { this += Assign(Ident(name), rhs) this } def result(): AsyncState = { val effectiveNestState = nextJumpState.getOrElse(nextState) if (awaitable == null) new AsyncState(stats.toList, state, effectiveNestState) else new AsyncStateWithAwait(stats.toList, state, effectiveNestState) { val awaitable = self.awaitable val resultName = self.resultName val resultType = self.resultType } } /* Result needs to be created as a var at the beginning of the transformed method body, so that * it is visible in subsequent states of the state machine. * * @param awaitArg the argument of await * @param awaitResultName the name of the variable that the result of await is assigned to * @param awaitResultType the type of the result of await */ def complete(awaitArg: c.Tree, awaitResultName: TermName, awaitResultType: Tree, nextState: Int): this.type = { val renamed = renamer.transform(awaitArg) awaitable = resetDuplicate(renamed) resultName = awaitResultName resultType = awaitResultType.tpe this.nextState = nextState this } def complete(nextState: Int): this.type = { this.nextState = nextState this } def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = { // 1. build changed if-else tree // 2. insert that tree at the end of the current state val cond = resetDuplicate(renamer.transform(condTree)) def mkBranch(state: Int) = Block(mkStateTree(state), mkResumeApply) this += If(cond, mkBranch(thenState), mkBranch(elseState)) new AsyncStateWithoutAwait(stats.toList, state) } /** * Build `AsyncState` ending with a match expression. * * The cases of the match simply resume at the state of their corresponding right-hand side. * * @param scrutTree tree of the scrutinee * @param cases list of case definitions * @param caseStates starting state of the right-hand side of the each case * @return an `AsyncState` representing the match expression */ def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], caseStates: List[Int]): AsyncState = { // 1. build list of changed cases val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(caseStates(num)), mkResumeApply)) } // 2. insert changed match tree at the end of the current state this += Match(resetDuplicate(scrutTree), newCases) new AsyncStateWithoutAwait(stats.toList, state) } def resultWithLabel(startLabelState: Int): AsyncState = { this += Block(mkStateTree(startLabelState), mkResumeApply) new AsyncStateWithoutAwait(stats.toList, state) } override def toString: String = { val statsBeforeAwait = stats.mkString("\n") s"ASYNC STATE:\n$statsBeforeAwait \nawaitable: $awaitable \nresult name: $resultName" } } val stateAssigner = new StateAssigner val labelDefStates = collection.mutable.Map[Symbol, Int]() /** * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`). * * @param stats a list of expressions * @param expr the last expression of the block * @param startState the start state * @param endState the state to continue with * @param toRename a `Map` for renaming the given key symbols to the mangled value names */ class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, private val toRename: Map[Symbol, c.Name]) { val asyncStates = ListBuffer[builder.AsyncState]() private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename) private var currState = startState /* TODO Fall back to CPS plug-in if tree contains an `await` call. */ def checkForUnsupportedAwait(tree: c.Tree) = if (tree exists { case Apply(fun, _) if isAwait(fun) => true case _ => false }) c.abort(tree.pos, "await must not be used in this position") //throw new FallbackToCpsException def builderForBranch(tree: c.Tree, state: Int, nextState: Int): AsyncBlockBuilder = { val (branchStats, branchExpr) = statsAndExpr(tree) new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, toRename) } // populate asyncStates for (stat <- stats) stat match { // the val name = await(..) pattern case ValDef(mods, name, tpt, Apply(fun, args)) if isAwait(fun) => val afterAwaitState = stateAssigner.nextState() asyncStates += stateBuilder.complete(args.head, toRename(stat.symbol).toTermName, tpt, afterAwaitState).result // complete with await currState = afterAwaitState stateBuilder = new builder.AsyncStateBuilder(currState, toRename) case ValDef(mods, name, tpt, rhs) if toRename contains stat.symbol => checkForUnsupportedAwait(rhs) stateBuilder += Assign(Ident(toRename(stat.symbol).toTermName), rhs) case If(cond, thenp, elsep) if stat exists isAwait => checkForUnsupportedAwait(cond) val thenStartState = stateAssigner.nextState() val elseStartState = stateAssigner.nextState() val afterIfState = stateAssigner.nextState() asyncStates += // the two Int arguments are the start state of the then branch and the else branch, respectively stateBuilder.resultWithIf(cond, thenStartState, elseStartState) List((thenp, thenStartState), (elsep, elseStartState)) foreach { case (tree, state) => val builder = builderForBranch(tree, state, afterIfState) asyncStates ++= builder.asyncStates } currState = afterIfState stateBuilder = new builder.AsyncStateBuilder(currState, toRename) case Match(scrutinee, cases) if stat exists isAwait => checkForUnsupportedAwait(scrutinee) val caseStates = cases.map(_ => stateAssigner.nextState()) val afterMatchState = stateAssigner.nextState() asyncStates += stateBuilder.resultWithMatch(scrutinee, cases, caseStates) for ((cas, num) <- cases.zipWithIndex) { val (casStats, casExpr) = statsAndExpr(cas.body) val builder = new AsyncBlockBuilder(casStats, casExpr, caseStates(num), afterMatchState, toRename) asyncStates ++= builder.asyncStates } currState = afterMatchState stateBuilder = new AsyncStateBuilder(currState, toRename) case ld@LabelDef(name, params, rhs) if rhs exists isAwait => val startLabelState = stateAssigner.nextState() val afterLabelState = stateAssigner.nextState() asyncStates += stateBuilder.resultWithLabel(startLabelState) val (stats, expr) = statsAndExpr(rhs) labelDefStates(ld.symbol) = startLabelState val builder = new AsyncBlockBuilder(stats, expr, startLabelState, afterLabelState, toRename) asyncStates ++= builder.asyncStates currState = afterLabelState stateBuilder = new AsyncStateBuilder(currState, toRename) case _ => checkForUnsupportedAwait(stat) stateBuilder += stat } // complete last state builder (representing the expressions after the last await) stateBuilder += expr val lastState = stateBuilder.complete(endState).result() asyncStates += lastState def mkCombinedHandlerCases[T](): List[CaseDef] = { val caseForLastState: CaseDef = { val lastState = asyncStates.last val lastStateBody = c.Expr[T](lastState.body) val rhs = futureSystemOps.completeProm(c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Success(lastStateBody.splice))) mkHandlerCase(lastState.state, rhs.tree) } asyncStates.toList match { case s :: Nil => List(caseForLastState) case _ => val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState() initCases :+ caseForLastState } } } }