diff options
-rw-r--r-- | src/main/scala/scala/async/Async.scala | 11 | ||||
-rw-r--r-- | src/main/scala/scala/async/ExprBuilder.scala | 46 |
2 files changed, 14 insertions, 43 deletions
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index d69fd95..cdeadd8 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -84,16 +84,9 @@ abstract class AsyncBase extends AsyncUtils { asyncBlockBuilder.asyncStates foreach vprintln - val handlerCases: List[(CaseDef, Int)] = asyncBlockBuilder.mkCombinedHandlerCases() + val handlerCases: List[CaseDef] = asyncBlockBuilder.mkCombinedHandlerCases[T]() - val caseForLastState: (CaseDef, Int) = { - val lastState = asyncBlockBuilder.asyncStates.last - val lastStateBody = c.Expr[T](lastState.body) - val rhs = futureSystemOps.completeProm(c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Success(lastStateBody.splice))) - builder.mkHandlerCase(lastState.state, rhs.tree) -> lastState.state - } - - val combinedHander = c.Expr[PartialFunction[Int, Unit]](builder.mkHandlerTreeFor(handlerCases :+ caseForLastState)) + val combinedHander = c.Expr[Int => Unit](builder.mkFunction(handlerCases)) val localVarTrees = asyncBlockBuilder.asyncStates.init.flatMap(_.allVarDefs).toList diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 6e8827d..7d4bf35 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -83,35 +83,8 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy private def paramValDef(name: TermName, tpe: Type) = ValDef(Modifiers(PARAM), name, TypeTree(tpe), EmptyTree) - def mkHandlerTreeFor(cases: List[(CaseDef, Int)]): c.Tree = { - val partFunIdent = Ident(defn.PartialFunctionClass) - val intIdent = Ident(definitions.IntClass) - val unitIdent = Ident(definitions.UnitClass) - val (caseDefs, states) = cases.unzip - - val caseCheck = - defn.mkList_contains(defn.mkList_apply(states map c.literal))(c.Expr(Ident(name.x1))) - val handlerName = name.asyncHander - - Block(List( - // anonymous subclass of PartialFunction[Int, Unit] - // TODO subclass AbstractPartialFunction - ClassDef(Modifiers(FINAL), handlerName, Nil, Template(List(AppliedTypeTree(partFunIdent, List(intIdent, unitIdent))), - emptyValDef, List( - DefDef(NoMods, nme.CONSTRUCTOR, Nil, List(Nil), TypeTree(), - Block(List(Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), Nil)), c.literalUnit.tree)), - - DefDef(NoMods, name.isDefinedAt, Nil, List(List(paramValDef(name.x1, definitions.IntClass))), TypeTree(), - caseCheck.tree), - - DefDef(NoMods, name.apply, Nil, List(List(paramValDef(name.x1, definitions.IntClass))), TypeTree(), - Match(Ident(name.x1), caseDefs) // combine all cases into a single match - ) - )) - )), - Apply(Select(New(Ident(handlerName)), nme.CONSTRUCTOR), Nil) - ) - } + def mkFunction(cases: List[CaseDef]): c.Tree = + Function(List(paramValDef(name.x1, definitions.IntClass)), Match(Ident(name.x1), cases)) class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) { val body: c.Tree = stats match { @@ -442,11 +415,18 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy val lastState = stateBuilder.complete(endState).result asyncStates += lastState - def mkCombinedHandlerCases(): List[(CaseDef, Int)] = { + def mkCombinedHandlerCases[T](): List[CaseDef] = { assert(asyncStates.size > 1) + val initCases = for (state <- asyncStates.toList.init) yield state.mkHandlerCaseForState() - val cases = for (state <- asyncStates.toList) yield state.mkHandlerCaseForState() - cases zip asyncStates.init.map(_.state) + val caseForLastState: CaseDef = { + val lastState = asyncStates.last + val lastStateBody = c.Expr[T](lastState.body) + val rhs = futureSystemOps.completeProm(c.Expr[futureSystem.Prom[T]](Ident(name.result)), reify(scala.util.Success(lastStateBody.splice))) + mkHandlerCase(lastState.state, rhs.tree) + } + + initCases :+ caseForLastState } } @@ -483,9 +463,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](val c: C, val futureSy val Try_get = methodSym(reify((null: scala.util.Try[Any]).get)) - val PartialFunctionClass = c.mirror.staticClass("scala.PartialFunction") val TryClass = c.mirror.staticClass("scala.util.Try") val NonFatalClass = c.mirror.staticModule("scala.util.control.NonFatal") } - } |