From d5409fdfbc3302e1547c69d0d1fda390fcb14883 Mon Sep 17 00:00:00 2001 From: phaller Date: Mon, 5 Nov 2012 12:49:06 +0100 Subject: Combine cases of several states into a single partial function --- src/async/library/scala/async/Async.scala | 2 +- src/async/library/scala/async/ExprBuilder.scala | 52 ++++++++++++++++++------- 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/src/async/library/scala/async/Async.scala b/src/async/library/scala/async/Async.scala index 9466df5..2c81bc3 100644 --- a/src/async/library/scala/async/Async.scala +++ b/src/async/library/scala/async/Async.scala @@ -49,7 +49,7 @@ object Async extends AsyncUtils { vprintln(s"states of current method:") asyncBlockBuilder.asyncStates foreach vprintln - val handlerExpr = asyncBlockBuilder.mkHandlerExpr() + val handlerExpr = asyncBlockBuilder.mkCombinedHandlerExpr() vprintln(s"GENERATED handler expr:") vprintln(handlerExpr) diff --git a/src/async/library/scala/async/ExprBuilder.scala b/src/async/library/scala/async/ExprBuilder.scala index 4eb06e9..b7d6446 100644 --- a/src/async/library/scala/async/ExprBuilder.scala +++ b/src/async/library/scala/async/ExprBuilder.scala @@ -60,40 +60,46 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs) } - def mkHandlerTree(num: Int, rhs: c.Tree): c.Tree = { + def mkHandlerCase(num: Int, rhs: c.Tree): CaseDef = + CaseDef( + // pattern + Bind(newTermName("any"), Typed(Ident(nme.WILDCARD), Ident(definitions.IntClass))), + // guard + Apply(Select(Ident(newTermName("any")), newTermName("$eq$eq")), List(Literal(Constant(num)))), + rhs + ) + + def mkHandlerTreeFor(cases: List[(CaseDef, Int)]): c.Tree = { val partFunIdent = Ident(c.mirror.staticClass("scala.PartialFunction")) val intIdent = Ident(definitions.IntClass) val unitIdent = Ident(definitions.UnitClass) + val caseCheck = + Apply(Select(Apply(Select(Ident(newTermName("List")), newTermName("apply")), + cases.map(p => Literal(Constant(p._2)))), newTermName("contains")), List(Ident(newTermName("x$1")))) + Block(List( // anonymous subclass of PartialFunction[Int, Unit] ClassDef(Modifiers(FINAL), newTypeName("$anon"), List(), Template(List(AppliedTypeTree(partFunIdent, List(intIdent, unitIdent))), emptyValDef, List( - DefDef(Modifiers(), nme.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(Apply(Select(Super(This(tpnme.EMPTY), tpnme.EMPTY), nme.CONSTRUCTOR), List())), Literal(Constant(())))), - + DefDef(Modifiers(), newTermName("isDefinedAt"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(), - Apply(Select(Ident(newTermName("x$1")), newTermName("$eq$eq")), List(Literal(Constant(num))))), + caseCheck), DefDef(Modifiers(), newTermName("apply"), List(), List(List(ValDef(Modifiers(PARAM), newTermName("x$1"), intIdent, EmptyTree))), TypeTree(), - Match(Ident(newTermName("x$1")), List( - CaseDef( - // pattern - Bind(newTermName("any"), Typed(Ident(nme.WILDCARD), intIdent)), - // guard - Apply(Select(Ident(newTermName("any")), newTermName("$eq$eq")), List(Literal(Constant(num)))), - rhs - ) - )) + Match(Ident(newTermName("x$1")), cases.map(_._1)) // combine all cases into a single match ) - )) )), Apply(Select(New(Ident(newTypeName("$anon"))), nme.CONSTRUCTOR), List()) ) } + def mkHandlerTree(num: Int, rhs: c.Tree): c.Tree = + mkHandlerTreeFor(List(mkHandlerCase(num, rhs) -> num)) + class AsyncState(stats: List[c.Tree], val state: Int, val nextState: Int) { val body: c.Tree = if (stats.size == 1) stats.head @@ -101,6 +107,9 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { val varDefs: List[(c.universe.TermName, c.universe.Type)] = List() + def mkHandlerCaseForState(): CaseDef = + mkHandlerCase(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*)) + def mkHandlerTreeForState(): c.Tree = mkHandlerTree(state, Block((stats :+ mkStateTree(nextState) :+ Apply(Ident("resume"), List())): _*)) @@ -125,6 +134,9 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { //TODO mkHandlerTreeForState(nextState: Int) + override def mkHandlerCaseForState(): CaseDef = + mkHandlerCase(state, Block(stats: _*)) + override val toString: String = s"AsyncStateWithIf #$state, next = $nextState" } @@ -267,6 +279,11 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { mkHandlerTree(state, Block((stats :+ mkOnCompleteStateTree(nextState)): _*)) } + override def mkHandlerCaseForState(): CaseDef = { + assert(awaitable != null) + mkHandlerCase(state, Block((stats :+ mkOnCompleteIncrStateTree): _*)) + } + override def varDefForResult: Option[c.Tree] = { val rhs = if (resultType <:< definitions.IntTpe) Literal(Constant(0)) @@ -469,6 +486,13 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { val lastState = stateBuilder.complete(endState).result asyncStates += lastState + def mkCombinedHandlerExpr(): c.Expr[PartialFunction[Int, Unit]] = { + assert(asyncStates.size > 1) + + val cases = for (state <- asyncStates.toList) yield state.mkHandlerCaseForState() + c.Expr(mkHandlerTreeFor(cases zip asyncStates.init.map(_.state))).asInstanceOf[c.Expr[PartialFunction[Int, Unit]]] + } + /* Builds the handler expression for a sequence of async states. */ def mkHandlerExpr(): c.Expr[PartialFunction[Int, Unit]] = { -- cgit v1.2.3