From 370be9ea41c582f033a2eeef05157e77a5077144 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Fri, 9 Nov 2012 14:29:35 +0100 Subject: Further cleanup in AST generation - centralize names - centralize more module/class lookup - reduce duplication - centralize use of resetAllAttrs - remove uses of asInstanceOf --- src/main/scala/scala/async/Async.scala | 22 ++-- src/main/scala/scala/async/ExprBuilder.scala | 166 +++++++++++++++------------ 2 files changed, 104 insertions(+), 84 deletions(-) (limited to 'src') diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index d64e04a..acd5128 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -65,13 +65,14 @@ abstract class AsyncBase extends AsyncUtils { @deprecated("`await` must be enclosed in an `async` block", "0.1") def await[T](awaitable: futureSystem.Fut[T]): T = ??? - def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]) = { + def asyncImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[futureSystem.Fut[T]] = { import c.universe._ import Flag._ - val builder = new ExprBuilder[c.type, self.FS](c, self.futureSystem) + val builder = new ExprBuilder[c.type, futureSystem.type](c, self.futureSystem) import builder.defn._ + import builder.name import builder.futureSystemOps val awaitMethod = awaitSym(c) @@ -91,7 +92,7 @@ abstract class AsyncBase extends AsyncUtils { val handlerForLastState: c.Expr[PartialFunction[Int, Unit]] = { val lastState = asyncBlockBuilder.asyncStates.last val lastStateBody = c.Expr[T](lastState.body) - builder.mkHandler(lastState.state, futureSystemOps.completeProm(c.Expr[futureSystem.Prom[T]](Ident("result")), reify(scala.util.Success(lastStateBody.splice)))) + builder.mkHandler(lastState.state, futureSystemOps.completeProm(c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Success(lastStateBody.splice)))) } vprintln("GENERATED handler for last state:") @@ -108,30 +109,31 @@ abstract class AsyncBase extends AsyncUtils { } } */ - val nonFatalModule = c.mirror.staticModule("scala.util.control.NonFatal") - val resumeFunTree: c.Tree = DefDef(Modifiers(), newTermName("resume"), List(), List(List()), Ident(definitions.UnitClass), + val nonFatalModule = builder.defn.NonFatalClass + val resumeFunTree: c.Tree = DefDef(Modifiers(), name.resume, List(), List(List()), Ident(definitions.UnitClass), Try( reify { val combinedHandler = mkPartialFunction_orElse(handlerExpr)(handlerForLastState).splice - combinedHandler.apply(c.Expr[Int](Ident(newTermName("state"))).splice) + combinedHandler.apply(c.Expr[Int](Ident(name.state)).splice) }.tree , List( CaseDef( - Apply(Ident(nonFatalModule), List(Bind(newTermName("t"), Ident(nme.WILDCARD)))), + Apply(Ident(nonFatalModule), List(Bind(name.tr, Ident(nme.WILDCARD)))), EmptyTree, Block(List({ - val t = c.Expr[Throwable](Ident(newTermName("t"))) - futureSystemOps.completeProm[T](c.Expr[futureSystem.Prom[T]](Ident(newTermName("result"))), reify(scala.util.Failure(t.splice))).tree + val t = c.Expr[Throwable](Ident(name.tr)) + futureSystemOps.completeProm[T](c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Failure(t.splice))).tree }), c.literalUnit.tree))), EmptyTree)) + val prom: Expr[futureSystem.Prom[T]] = reify { val result = futureSystemOps.createProm[T].splice var state = 0 futureSystemOps.future[Unit] { c.Expr[Unit](Block( localVarTrees :+ resumeFunTree, - Apply(Ident(newTermName("resume")), List()))) + Apply(Ident(name.resume), List()))) }(futureSystemOps.execContext).splice result } diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 4beaa34..65e98e0 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -19,11 +19,28 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy import Flag._ import defn._ - val execContextType = c.weakTypeOf[futureSystem.ExecContext] - val execContext = futureSystemOps.execContext + object name { + // TODO do we need to freshen any of these? + val resume = newTermName("resume") + val state = newTermName("state") + val result = newTermName("result") + val tr = newTermName("tr") + val any = newTermName("any") + val x1 = newTermName("x$1") + val apply = newTermName("apply") + val isDefinedAt = newTermName("isDefinedAt") + + val anon = newTypeName("$anon") + } + + private val execContext = futureSystemOps.execContext + + private def resetDuplicate(tree: Tree) = c.resetAllAttrs(tree.duplicate) private val awaitMethod = awaitSym(c) + private def mkResumeApply = Apply(Ident(name.resume), List()) + /* Make a partial function literal handling case #num: * * { @@ -43,22 +60,18 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy } }) */ - val rhsTree = c.resetAllAttrs(rhs.tree.duplicate) + val rhsTree = resetDuplicate(rhs.tree) val handlerTree = mkHandlerTree(num, rhsTree) - c.Expr(handlerTree).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] - } - - def mkIncrStateTree(): c.Tree = { - Assign( - Ident(newTermName("state")), - mkInt_+(c.Expr[Int](Ident(newTermName("state"))))(c.literal(1)).tree - ) + c.Expr(handlerTree) } def mkStateTree(nextState: Int): c.Tree = + mkStateTree(c.literal(nextState).tree) + + def mkStateTree(nextState: Tree): c.Tree = Assign( - Ident(newTermName("state")), - Literal(Constant(nextState))) + Ident(name.state), + nextState) def defaultValue(tpe: Type): Literal = { val defaultValue: Any = @@ -72,61 +85,68 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), defaultValue(resultType)) } + def mkHandlerCase(num: Int, rhs: List[c.Tree]): CaseDef = + mkHandlerCase(num, Block(rhs: _*)) + def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = CaseDef( // pattern - Bind(newTermName("any"), Typed(Ident(nme.WILDCARD), Ident(definitions.IntClass))), + Bind(name.any, Typed(Ident(nme.WILDCARD), Ident(definitions.IntClass))), // guard - mkAny_==(c.Expr(Ident(newTermName("any"))))(c.literal(num)).tree, + mkAny_==(c.Expr(Ident(name.any)))(c.literal(num)).tree, rhs ) def mkHandlerTreeFor(cases: List[(CaseDef, Int)]): c.Tree = { - val partFunIdent = Ident(c.mirror.staticClass("scala.PartialFunction")) + val partFunIdent = Ident(defn.PartialFunctionClass) val intIdent = Ident(definitions.IntClass) val unitIdent = Ident(definitions.UnitClass) val caseCheck = - defn.mkList_contains(defn.mkList_apply(cases.map(p => c.literal(p._2))))(c.Expr(Ident(newTermName("x$1")))) + defn.mkList_contains(defn.mkList_apply(cases.map(p => c.literal(p._2))))(c.Expr(Ident(name.x1))) Block(List( // anonymous subclass of PartialFunction[Int, Unit] // TODO subclass AbstractPartialFunction - ClassDef(Modifiers(FINAL), newTypeName("$anon"), List(), Template(List(AppliedTypeTree(partFunIdent, List(intIdent, unitIdent))), + ClassDef(Modifiers(FINAL), name.anon, List(), Template(List(AppliedTypeTree(partFunIdent, List(intIdent, unitIdent))), emptyValDef, List( DefDef(Modifiers(), nme.CONSTRUCTOR, List(), List(List()), TypeTree(), - Block(List(Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), List())), Literal(Constant(())))), + Block(List(Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), List())), c.literalUnit.tree)), - DefDef(Modifiers(), newTermName("isDefinedAt"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(), + DefDef(Modifiers(), name.isDefinedAt, List(), List(List(ValDef(Modifiers(PARAM), name.x1, intIdent, EmptyTree))), TypeTree(), caseCheck.tree), - DefDef(Modifiers(), newTermName("apply"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(), - Match(Ident(newTermName("x$1")), cases.map(_._1)) // combine all cases into a single match + DefDef(Modifiers(), name.apply, List(), List(List(ValDef(Modifiers(PARAM), name.x1, intIdent, EmptyTree))), TypeTree(), + Match(Ident(name.x1), cases.map(_._1)) // combine all cases into a single match ) )) )), - Apply(Select(New(Ident(newTypeName("$anon"))), nme.CONSTRUCTOR), List()) + Apply(Select(New(Ident(name.anon)), nme.CONSTRUCTOR), List()) ) } def mkHandlerTree(num: Int, rhs: c.Tree): c.Tree = mkHandlerTreeFor(List(mkHandlerCase(num, rhs) -> num)) + def mkHandlerTree(num: Int, rhs: List[c.Tree]): c.Tree = + mkHandlerTree(num, Block(rhs: _*)) + class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) { - val body: c.Tree = - if (stats.size == 1) stats.head - else Block(stats: _*) + val body: c.Tree = stats match { + case stat :: Nil => stat + case _ => Block(stats: _*) + } val varDefs: List[(TermName, Type)] = List() def mkHandlerCaseForState(): CaseDef = - mkHandlerCase(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*)) + mkHandlerCase(state, stats :+ mkStateTree(nextState) :+ mkResumeApply) def mkHandlerTreeForState(): c.Tree = - mkHandlerTree(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*)) + mkHandlerTree(state, stats :+ mkStateTree(nextState) :+ mkResumeApply) def mkHandlerTreeForState(nextState: Int): c.Tree = - mkHandlerTree(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*)) + mkHandlerTree(state, stats :+ mkStateTree(nextState) :+ mkResumeApply) def varDefForResult: Option[c.Tree] = None @@ -143,12 +163,12 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy // nextState unused, since encoded in then and else branches override def mkHandlerTreeForState(): c.Tree = - mkHandlerTree(state, Block(stats: _*)) + mkHandlerTree(state, stats) //TODO mkHandlerTreeForState(nextState: Int) override def mkHandlerCaseForState(): CaseDef = - mkHandlerCase(state, Block(stats: _*)) + mkHandlerCase(state, stats) override val toString: String = s"AsyncStateWithIf #$state, next = $nextState" @@ -160,6 +180,8 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy val resultName: TermName val resultType: Type + protected def tryType = appliedType(TryClass.toType, List(resultType)) + override val toString: String = s"AsyncStateWithAwait #$state, next = $nextState" @@ -174,15 +196,15 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy def mkOnCompleteTree: c.Tree = { val assignTree = Assign( - Ident(resultName.toString), - mkTry_get(c.Expr(Ident("tr"))).tree + Ident(resultName), + mkTry_get(c.Expr(Ident(name.tr))).tree ) val handlerTree = Match( EmptyTree, List( - CaseDef(Bind(newTermName("tr"), Ident("_")), EmptyTree, - Block(assignTree, Apply(Ident("resume"), List())) // rhs of case + CaseDef(Bind(name.tr, Ident("_")), EmptyTree, + Block(assignTree, mkResumeApply) // rhs of case ) ) ) @@ -198,20 +220,8 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy * resume() * } */ - def mkOnCompleteIncrStateTree: c.Tree = { - val tryGetTree = - Assign( - Ident(resultName.toString), - Select(Ident("tr"), Try_get) - ) - - val handlerTree = - Function(List(ValDef(Modifiers(PARAM), newTermName("tr"), TypeTree(tryType), EmptyTree)), Block(tryGetTree, mkIncrStateTree(), Apply(Ident("resume"), List()))) - - futureSystemOps.onComplete(c.Expr(awaitable), c.Expr(handlerTree), execContext).tree - } - - def tryType = appliedType(c.mirror.staticClass("scala.util.Try").toType, List(resultType)) + def mkOnCompleteIncrStateTree: c.Tree = + mkOnCompleteTree(mkInt_+(c.Expr[Int](Ident(name.state)))(c.literal(1)).tree) /* Make an `onComplete` invocation which sets the state to `nextState` upon resuming: * @@ -222,14 +232,20 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy * resume() * } */ - def mkOnCompleteStateTree(nextState: Int): c.Tree = { + def mkOnCompleteStateTree(nextState: Int): c.Tree = + mkOnCompleteTree(c.literal(nextState).tree) + + private def mkOnCompleteTree(nextState: Tree): c.Tree = { val tryGetTree = Assign( - Ident(resultName.toString), - Select(Ident("tr"), Try_get) + Ident(resultName), + Select(Ident(name.tr), Try_get) ) + + val updateState = mkStateTree(nextState) + val handlerTree = - Function(List(ValDef(Modifiers(PARAM), newTermName("tr"), TypeTree(tryType), EmptyTree)), Block(tryGetTree, mkStateTree(nextState), Apply(Ident("resume"), List()))) + Function(List(ValDef(Modifiers(PARAM), name.tr, TypeTree(tryType), EmptyTree)), Block(tryGetTree, updateState, mkResumeApply)) futureSystemOps.onComplete(c.Expr(awaitable), c.Expr(handlerTree), execContext).tree } @@ -240,7 +256,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy * case any if any == num => * stats * awaitable.onComplete { - * case tr => + * (try: Try[A]) => * resultName = tr.get * resume() * } @@ -266,17 +282,17 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy */ override def mkHandlerTreeForState(): c.Tree = { assert(awaitable != null) - mkHandlerTree(state, Block((stats :+ mkOnCompleteIncrStateTree): _*)) + mkHandlerTree(state, stats :+ mkOnCompleteIncrStateTree) } override def mkHandlerTreeForState(nextState: Int): c.Tree = { assert(awaitable != null) - mkHandlerTree(state, Block((stats :+ mkOnCompleteStateTree(nextState)): _*)) + mkHandlerTree(state, stats :+ mkOnCompleteStateTree(nextState)) } override def mkHandlerCaseForState(): CaseDef = { assert(awaitable != null) - mkHandlerCase(state, Block((stats :+ mkOnCompleteIncrStateTree): _*)) + mkHandlerCase(state, stats :+ mkOnCompleteIncrStateTree) } override def varDefForResult: Option[c.Tree] = @@ -315,7 +331,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy } def +=(stat: c.Tree): this.type = { - stats += c.resetAllAttrs(renamer.transform(stat).duplicate) + stats += resetDuplicate(renamer.transform(stat)) this } @@ -323,7 +339,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy def addVarDef(mods: Any, name: TermName, tpt: c.Tree, rhs: c.Tree, extNameMap: Map[c.Symbol, c.Name]): this.type = { varDefs += (name -> tpt.tpe) nameMap ++= extNameMap // update name map - this += Assign(Ident(name), c.resetAllAttrs(renamer.transform(rhs).duplicate)) + this += Assign(Ident(name), rhs) this } @@ -357,7 +373,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy def complete(awaitArg: c.Tree, awaitResultName: TermName, awaitResultType: Tree, extNameMap: Map[c.Symbol, c.Name], nextState: Int = state + 1): this.type = { nameMap ++= extNameMap - awaitable = c.resetAllAttrs(renamer.transform(awaitArg).duplicate) + awaitable = resetDuplicate(renamer.transform(awaitArg)) resultName = awaitResultName resultType = awaitResultType.tpe this.nextState = nextState @@ -372,10 +388,10 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy def resultWithIf(condTree: c.Tree, thenState: Int, elseState: Int): AsyncState = { // 1. build changed if-else tree // 2. insert that tree at the end of the current state - val cond = c.resetAllAttrs(condTree.duplicate) + val cond = resetDuplicate(condTree) this += If(cond, - Block(mkStateTree(thenState), Apply(Ident("resume"), List())), - Block(mkStateTree(elseState), Apply(Ident("resume"), List()))) + Block(mkStateTree(thenState), mkResumeApply), + Block(mkStateTree(elseState), mkResumeApply)) new AsyncStateWithoutAwait(stats.toList, state) { override val varDefs = self.varDefs.toList } @@ -395,10 +411,10 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy def resultWithMatch(scrutTree: c.Tree, cases: List[CaseDef], stateFirstCase: Int, perCasebudget: Int): AsyncState = { // 1. build list of changed cases val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { - case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(num * perCasebudget + stateFirstCase), Apply(Ident("resume"), List()))) + case CaseDef(pat, guard, rhs) => CaseDef(pat, guard, Block(mkStateTree(num * perCasebudget + stateFirstCase), mkResumeApply)) } // 2. insert changed match tree at the end of the current state - this += Match(c.resetAllAttrs(scrutTree.duplicate), newCases) + this += Match(resetDuplicate(scrutTree), newCases) new AsyncStateWithoutAwait(stats.toList, state) { override val varDefs = self.varDefs.toList } @@ -439,7 +455,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy def builderForBranch(tree: c.Tree, state: Int, nextState: Int, budget: Int, nameMap: Map[c.Symbol, c.Name]): AsyncBlockBuilder = { val (branchStats, branchExpr) = tree match { case Block(s, e) => (s, e) - case _ => (List(tree), Literal(Constant(()))) + case _ => (List(tree), c.literalUnit.tree) } new AsyncBlockBuilder(branchStats, branchExpr, state, nextState, budget, nameMap) } @@ -451,7 +467,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy val newName = c.fresh(name) toRename += (stat.symbol -> newName) - asyncStates += stateBuilder.complete(args(0), newName, tpt, toRename).result // complete with await + asyncStates += stateBuilder.complete(args.head, newName, tpt, toRename).result // complete with await if (remainingBudget > 0) remainingBudget -= 1 else @@ -508,7 +524,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy for ((cas, num) <- cases.zipWithIndex) { val (casStats, casExpr) = cas match { case CaseDef(_, _, Block(s, e)) => (s, e) - case CaseDef(_, _, rhs) => (List(rhs), Literal(Constant(()))) + case CaseDef(_, _, rhs) => (List(rhs), c.literalUnit.tree) } val builder = new AsyncBlockBuilder(casStats, casExpr, currState + (num * perCaseBudget) + 1, currState + matchBudget, perCaseBudget, toRename) asyncStates ++= builder.asyncStates @@ -532,9 +548,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy assert(asyncStates.size > 1) val cases = for (state <- asyncStates.toList) yield state.mkHandlerCaseForState() - reify { - c.Expr[PartialFunction[Int, Unit]](mkHandlerTreeFor(cases zip asyncStates.init.map(_.state))).splice: PartialFunction[Int, Unit] - } + c.Expr(mkHandlerTreeFor(cases zip asyncStates.init.map(_.state))) } /* Builds the handler expression for a sequence of async states. @@ -543,7 +557,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy assert(asyncStates.size > 1) var handlerExpr = - c.Expr(asyncStates(0).mkHandlerTreeForState()).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] + c.Expr[PartialFunction[Int, Unit]](asyncStates.head.mkHandlerTreeForState()) if (asyncStates.size == 2) handlerExpr @@ -551,7 +565,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy for (asyncState <- asyncStates.tail.init) { // do not traverse first or last state val handlerTreeForNextState = asyncState.mkHandlerTreeForState() - val currentHandlerTreeNaked = c.resetAllAttrs(handlerExpr.tree.duplicate) + val currentHandlerTreeNaked = resetDuplicate(handlerExpr.tree) handlerExpr = mkPartialFunction_orElse(c.Expr(currentHandlerTreeNaked))(c.Expr(handlerTreeForNextState)) } handlerExpr @@ -594,7 +608,11 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy self.splice.get } - val Try_get = methodSym(reify((null.asInstanceOf[scala.util.Try[Any]]).get)) + val Try_get = methodSym(reify((null: scala.util.Try[Any]).get)) + + val PartialFunctionClass = c.mirror.staticClass("scala.PartialFunction") + val TryClass = c.mirror.staticClass("scala.util.Try") + val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal") } } -- cgit v1.2.3