diff options
Diffstat (limited to 'src/main/scala/scala/async/internal/ExprBuilder.scala')
-rw-r--r-- | src/main/scala/scala/async/internal/ExprBuilder.scala | 56 |
1 files changed, 38 insertions, 18 deletions
diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala index ce2345d..3ef9da5 100644 --- a/src/main/scala/scala/async/internal/ExprBuilder.scala +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -3,7 +3,6 @@ */ package scala.async.internal -import scala.reflect.macros.Context import scala.collection.mutable.ListBuffer import collection.mutable import language.existentials @@ -34,18 +33,17 @@ trait ExprBuilder { var stats: List[Tree] - def statsAnd(trees: List[Tree]): List[Tree] = { - val body = stats match { + def treesThenStats(trees: List[Tree]): List[Tree] = { + (stats match { case init :+ last if tpeOf(last) =:= definitions.NothingTpe => - adaptToUnit(init :+ Typed(last, TypeTree(definitions.AnyTpe))) + adaptToUnit((trees ::: init) :+ Typed(last, TypeTree(definitions.AnyTpe))) case _ => - adaptToUnit(stats) - } - Try(body, Nil, adaptToUnit(trees)) :: Nil + adaptToUnit(trees ::: stats) + }) :: Nil } final def allStats: List[Tree] = this match { - case a: AsyncStateWithAwait => statsAnd(a.awaitable.resultValDef :: Nil) + case a: AsyncStateWithAwait => treesThenStats(a.awaitable.resultValDef :: Nil) case _ => stats } @@ -63,7 +61,7 @@ trait ExprBuilder { List(nextState) def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = { - mkHandlerCase(state, statsAnd(mkStateTree(nextState, symLookup) :: Nil)) + mkHandlerCase(state, treesThenStats(mkStateTree(nextState, symLookup) :: Nil)) } override val toString: String = @@ -99,10 +97,10 @@ trait ExprBuilder { if (futureSystemOps.continueCompletedFutureOnSameThread) If(futureSystemOps.isCompleted(c.Expr[futureSystem.Fut[_]](awaitable.expr)).tree, adaptToUnit(ifIsFailureTree[T](futureSystemOps.getCompleted[Any](c.Expr[futureSystem.Fut[Any]](awaitable.expr)).tree) :: Nil), - Block(toList(callOnComplete), Return(literalUnit))) + Block(toList(callOnComplete), Return(literalUnit))) :: Nil else - Block(toList(callOnComplete), Return(literalUnit)) - mkHandlerCase(state, stats ++ List(mkStateTree(onCompleteState, symLookup), tryGetOrCallOnComplete)) + toList(callOnComplete) ::: Return(literalUnit) :: Nil + mkHandlerCase(state, stats ++ List(mkStateTree(onCompleteState, symLookup)) ++ tryGetOrCallOnComplete) } private def tryGetTree(tryReference: => Tree) = @@ -251,12 +249,17 @@ trait ExprBuilder { case LabelDef(name, _, _) => name.toString.startsWith("case") case _ => false } - val (before, _ :: after) = (stats :+ expr).span(_ ne t) - before.reverse.takeWhile(isPatternCaseLabelDef) ::: after.takeWhile(isPatternCaseLabelDef) + val span = (stats :+ expr).filterNot(isLiteralUnit).span(_ ne t) + span match { + case (before, _ :: after) => + before.reverse.takeWhile(isPatternCaseLabelDef) ::: after.takeWhile(isPatternCaseLabelDef) + case _ => + stats :+ expr + } } // populate asyncStates - for (stat <- (stats :+ expr)) stat match { + def add(stat: Tree): Unit = stat match { // the val name = await(..) pattern case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => val onCompleteState = nextState() @@ -315,10 +318,13 @@ trait ExprBuilder { asyncStates ++= builder.asyncStates currState = afterLabelState stateBuilder = new AsyncStateBuilder(currState, symLookup) + case b @ Block(stats, expr) => + (stats :+ expr) foreach (add) case _ => checkForUnsupportedAwait(stat) stateBuilder += stat } + for (stat <- (stats :+ expr)) add(stat) val lastState = stateBuilder.resultSimple(endState) asyncStates += lastState } @@ -357,8 +363,8 @@ trait ExprBuilder { val caseForLastState: CaseDef = { val lastState = asyncStates.last val lastStateBody = c.Expr[T](lastState.body) - val rhs = futureSystemOps.completeProm( - c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryySuccess[T](lastStateBody)) + val rhs = futureSystemOps.completeWithSuccess( + c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), lastStateBody) mkHandlerCase(lastState.state, Block(rhs.tree, Return(literalUnit))) } asyncStates.toList match { @@ -392,7 +398,10 @@ trait ExprBuilder { * } */ private def resumeFunTree[T: WeakTypeTag]: Tree = { - val body = Match(symLookup.memberRef(name.state), mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T])) + val stateMemberSymbol = symLookup.stateMachineMember(name.state) + val stateMemberRef = symLookup.memberRef(name.state) + val body = Match(stateMemberRef, mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ++ List(CaseDef(Ident(nme.WILDCARD), EmptyTree, Throw(Apply(Select(New(Ident(defn.IllegalStateExceptionClass)), termNames.CONSTRUCTOR), List()))))) + Try( body, List( @@ -462,13 +471,24 @@ trait ExprBuilder { private def tpeOf(t: Tree): Type = t match { case _ if t.tpe != null => t.tpe case Try(body, Nil, _) => tpeOf(body) + case Block(_, expr) => tpeOf(expr) + case Literal(Constant(value)) if value == () => definitions.UnitTpe + case Return(_) => definitions.NothingTpe case _ => NoType } private def adaptToUnit(rhs: List[Tree]): c.universe.Block = { rhs match { + case (rhs: Block) :: Nil if tpeOf(rhs) <:< definitions.UnitTpe => + rhs case init :+ last if tpeOf(last) <:< definitions.UnitTpe => Block(init, last) + case init :+ (last @ Literal(Constant(()))) => + Block(init, last) + case init :+ (last @ Block(_, Return(_) | Literal(Constant(())))) => + Block(init, last) + case init :+ (Block(stats, expr)) => + Block(init, Block(stats :+ expr, literalUnit)) case _ => Block(rhs, literalUnit) } |