diff options
-rw-r--r-- | src/main/scala/scala/async/internal/AsyncTransform.scala | 10 | ||||
-rw-r--r-- | src/main/scala/scala/async/internal/LiveVariables.scala | 33 |
2 files changed, 20 insertions, 23 deletions
diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala index 476c405..5c84f66 100644 --- a/src/main/scala/scala/async/internal/AsyncTransform.scala +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -66,15 +66,11 @@ trait AsyncTransform { val assignsOf = fieldsToNullOut(asyncBlock.asyncStates, liftedFields) for ((state, flds) <- assignsOf) { - val asyncState = asyncBlock.asyncStates.find(_.state == state).get val assigns = flds.map { fld => val fieldSym = fld.symbol - Assign( - gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), - gen.mkZero(fieldSym.info) - ) - }.toList - // prepend those assigns + Assign(gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), gen.mkZero(fieldSym.info)) + } + val asyncState = asyncBlock.asyncStates.find(_.state == state).get asyncState.stats = assigns ++ asyncState.stats } diff --git a/src/main/scala/scala/async/internal/LiveVariables.scala b/src/main/scala/scala/async/internal/LiveVariables.scala index 8963ebb..48f03cd 100644 --- a/src/main/scala/scala/async/internal/LiveVariables.scala +++ b/src/main/scala/scala/async/internal/LiveVariables.scala @@ -7,22 +7,22 @@ trait LiveVariables { import global._ /** - * Returns for a given state the set of fields (as trees) that should be nulled out + * Returns for a given state a list of fields (as trees) that should be nulled out * upon resuming that state (at the beginning of `resume`). */ - def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Int, Set[Tree]] = { + def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Int, List[Tree]] = { // live variables analysis: // the result map indicates in which states a given field should be nulled out val liveVarsMap: Map[Tree, Set[Int]] = liveVars(asyncStates, liftables) - var assignsOf = Map[Int, Set[Tree]]() + var assignsOf = Map[Int, List[Tree]]() for ((fld, where) <- liveVarsMap; state <- where) assignsOf get state match { case None => - assignsOf += (state -> Set[Tree](fld)) + assignsOf += (state -> List(fld)) case Some(trees) if !trees.exists(_.symbol == fld.symbol) => - assignsOf += (state -> (trees + fld)) + assignsOf += (state -> (fld +: trees)) case _ => /* do nothing */ } @@ -86,7 +86,7 @@ trait LiveVariables { */ def isPred(state1: Int, state2: Int, seen: Set[Int] = Set()): Boolean = if (seen(state1)) false // breaks cycles in the CFG - else cfg.get(state1) match { + else cfg get state1 match { case Some(nextStates) => nextStates.contains(state2) || nextStates.exists(isPred(_, state2, seen + state1)) case None => @@ -154,13 +154,15 @@ trait LiveVariables { AsyncUtils.vprintln(s"LVexit at state #${as.state}: ${LVexit(as.state).mkString(", ")}") } - def lastUsagesOf(field: Tree, at: AsyncState, avoid: Set[AsyncState]): Set[Int] = LVentry get at.state match { - case Some(fields) if fields.exists(_ == field.symbol) => - Set(at.state) - case _ => - val preds = asyncStates.filter(_.nextStates.contains(at.state)).toSet - preds.flatMap(p => lastUsagesOf(field, p, avoid + at)) - } + def lastUsagesOf(field: Tree, at: AsyncState, avoid: Set[AsyncState]): Set[Int] = + if (avoid(at)) Set() + else LVentry get at.state match { + case Some(fields) if fields.exists(_ == field.symbol) => + Set(at.state) + case _ => + val preds = asyncStates.filter(_.nextStates.contains(at.state)).toSet + preds.flatMap(p => lastUsagesOf(field, p, avoid + at)) + } val lastUsages: Map[Tree, Set[Int]] = liftables.map(fld => (fld -> lastUsagesOf(fld, finalState, Set()))).toMap @@ -171,9 +173,8 @@ trait LiveVariables { val nullOutAt: Map[Tree, Set[Int]] = for ((fld, lastStates) <- lastUsages) yield { val killAt = lastStates.flatMap { s => - if (s == finalState.state) { - Set[Int]() - } else { + if (s == finalState.state) Set() + else { val lastAsyncState = asyncStates.find(_.state == s).get val succNums = lastAsyncState.nextStates // all successor states that are not indirect predecessors |