aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPhilipp Haller <hallerp@gmail.com>2013-10-13 23:44:18 +0200
committerPhilipp Haller <hallerp@gmail.com>2013-10-22 14:40:01 +0200
commitd6c5aeb6f6effcac4a054f0290711aa64ae3c191 (patch)
treec661375d5d9299466936c9f3df6949b4346e1a85
parent9ecbb7a54ed0e9927a0efba23fa4e61d06be761e (diff)
downloadscala-async-d6c5aeb6f6effcac4a054f0290711aa64ae3c191.tar.gz
scala-async-d6c5aeb6f6effcac4a054f0290711aa64ae3c191.tar.bz2
scala-async-d6c5aeb6f6effcac4a054f0290711aa64ae3c191.zip
Liveness analysis to avoid memory retention issues
- Iterative, backwards data-flow analysis - Make sure fields captured by nested defs are never zeroed out. This is done elegantly by declaring such fields a being live at the exit of the final state; thus, they will never be zeroed out.
-rw-r--r--src/main/scala/scala/async/internal/AsyncMacro.scala2
-rw-r--r--src/main/scala/scala/async/internal/AsyncTransform.scala48
-rw-r--r--src/main/scala/scala/async/internal/ExprBuilder.scala24
-rw-r--r--src/main/scala/scala/async/internal/LiveVariables.scala192
4 files changed, 248 insertions, 18 deletions
diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala
index 8d93567..1c97ca7 100644
--- a/src/main/scala/scala/async/internal/AsyncMacro.scala
+++ b/src/main/scala/scala/async/internal/AsyncMacro.scala
@@ -20,7 +20,7 @@ object AsyncMacro {
private[async] trait AsyncMacro
extends TypingTransformers
with AnfTransform with TransformUtils with Lifter
- with ExprBuilder with AsyncTransform with AsyncAnalysis {
+ with ExprBuilder with AsyncTransform with AsyncAnalysis with LiveVariables {
val global: Global
val callSiteTyper: global.analyzer.Typer
diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala
index 43e4a9c..476c405 100644
--- a/src/main/scala/scala/async/internal/AsyncTransform.scala
+++ b/src/main/scala/scala/async/internal/AsyncTransform.scala
@@ -29,6 +29,7 @@ trait AsyncTransform {
val stateMachineType = applied("scala.async.StateMachine", List(futureSystemOps.promType[T](uncheckedBoundsResultTag), futureSystemOps.execContextType))
+ // Create `ClassDef` of state machine with empty method bodies for `resume` and `apply`.
val stateMachine: ClassDef = {
val body: List[Tree] = {
val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0)))
@@ -42,24 +43,44 @@ trait AsyncTransform {
}
List(emptyConstructor, stateVar, result, execContextValDef) ++ List(resumeFunTreeDummyBody, applyDefDefDummyBody, apply0DefDef)
}
- val template = {
- Template(List(stateMachineType), emptyValDef, body)
- }
+
+ val template = Template(List(stateMachineType), emptyValDef, body)
+
val t = ClassDef(NoMods, name.stateMachineT, Nil, template)
callSiteTyper.typedPos(macroPos)(Block(t :: Nil, Literal(Constant(()))))
t
}
+ val stateMachineClass = stateMachine.symbol
val asyncBlock: AsyncBlock = {
- val symLookup = new SymLookup(stateMachine.symbol, applyDefDefDummyBody.vparamss.head.head.symbol)
+ val symLookup = new SymLookup(stateMachineClass, applyDefDefDummyBody.vparamss.head.head.symbol)
buildAsyncBlock(anfTree, symLookup)
}
logDiagnostics(anfTree, asyncBlock.asyncStates.map(_.toString))
+ val liftedFields: List[Tree] = liftables(asyncBlock.asyncStates)
+
+ // live variables analysis
+ // the result map indicates in which states a given field should be nulled out
+ val assignsOf = fieldsToNullOut(asyncBlock.asyncStates, liftedFields)
+
+ for ((state, flds) <- assignsOf) {
+ val asyncState = asyncBlock.asyncStates.find(_.state == state).get
+ val assigns = flds.map { fld =>
+ val fieldSym = fld.symbol
+ Assign(
+ gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym),
+ gen.mkZero(fieldSym.info)
+ )
+ }.toList
+ // prepend those assigns
+ asyncState.stats = assigns ++ asyncState.stats
+ }
+
def startStateMachine: Tree = {
val stateMachineSpliced: Tree = spliceMethodBodies(
- liftables(asyncBlock.asyncStates),
+ liftedFields,
stateMachine,
atMacroPos(asyncBlock.onCompleteHandler[T]),
atMacroPos(asyncBlock.resumeFunTree[T].rhs)
@@ -96,9 +117,16 @@ trait AsyncTransform {
states foreach (s => AsyncUtils.vprintln(s))
}
- def spliceMethodBodies(liftables: List[Tree], tree: Tree, applyBody: Tree,
- resumeBody: Tree): Tree = {
-
+ /**
+ * Build final `ClassDef` tree of state machine class.
+ *
+ * @param liftables trees of definitions that are lifted to fields of the state machine class
+ * @param tree `ClassDef` tree of the state machine class
+ * @param applyBody tree of onComplete handler (`apply` method)
+ * @param resumeBody RHS of definition tree of `resume` method
+ * @return transformed `ClassDef` tree of the state machine class
+ */
+ def spliceMethodBodies(liftables: List[Tree], tree: ClassDef, applyBody: Tree, resumeBody: Tree): Tree = {
val liftedSyms = liftables.map(_.symbol).toSet
val stateMachineClass = tree.symbol
liftedSyms.foreach {
@@ -112,7 +140,7 @@ trait AsyncTransform {
// Replace the ValDefs in the splicee with Assigns to the corresponding lifted
// fields. Similarly, replace references to them with references to the field.
//
- // This transform will be only be run on the RHS of `def foo`.
+ // This transform will only be run on the RHS of `def foo`.
class UseFields extends MacroTypingTransformer {
override def transform(tree: Tree): Tree = tree match {
case _ if currentOwner == stateMachineClass =>
@@ -150,6 +178,7 @@ trait AsyncTransform {
}
val treeSubst = tree
+ /* Fixes up DefDef: use lifted fields in `body` */
def fixup(dd: DefDef, body: Tree, ctx: analyzer.Context): Tree = {
val spliceeAnfFixedOwnerSyms = body
val useField = new UseFields()
@@ -171,6 +200,7 @@ trait AsyncTransform {
(ctx: analyzer.Context) =>
val typedTree = fixup(dd, changeOwner(applyBody, callSiteTyper.context.owner, dd.symbol), ctx)
typedTree
+
case dd@DefDef(_, name.resume, _, _, _, _) if dd.symbol.owner == stateMachineClass =>
(ctx: analyzer.Context) =>
val changed = changeOwner(resumeBody, callSiteTyper.context.owner, dd.symbol)
diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala
index 438e59e..16e95dd 100644
--- a/src/main/scala/scala/async/internal/ExprBuilder.scala
+++ b/src/main/scala/scala/async/internal/ExprBuilder.scala
@@ -25,11 +25,13 @@ trait ExprBuilder {
trait AsyncState {
def state: Int
+ def nextStates: List[Int]
+
def mkHandlerCaseForState: CaseDef
def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None
- def stats: List[Tree]
+ var stats: List[Tree]
final def allStats: List[Tree] = this match {
case a: AsyncStateWithAwait => stats :+ a.awaitable.resultValDef
@@ -43,9 +45,12 @@ trait ExprBuilder {
}
/** A sequence of statements that concludes with a unconditional transition to `nextState` */
- final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup)
+ final class SimpleAsyncState(var stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup)
extends AsyncState {
+ def nextStates: List[Int] =
+ List(nextState)
+
def mkHandlerCaseForState: CaseDef =
mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup))
@@ -56,21 +61,24 @@ trait ExprBuilder {
/** A sequence of statements with a conditional transition to the next state, which will represent
* a branch of an `if` or a `match`.
*/
- final class AsyncStateWithoutAwait(val stats: List[Tree], val state: Int) extends AsyncState {
+ final class AsyncStateWithoutAwait(var stats: List[Tree], val state: Int, val nextStates: List[Int]) extends AsyncState {
override def mkHandlerCaseForState: CaseDef =
mkHandlerCase(state, stats)
override val toString: String =
- s"AsyncStateWithoutAwait #$state"
+ s"AsyncStateWithoutAwait #$state, nextStates = $nextStates"
}
/** A sequence of statements that concludes with an `await` call. The `onComplete`
* handler will unconditionally transition to `nextState`.
*/
- final class AsyncStateWithAwait(val stats: List[Tree], val state: Int, nextState: Int,
+ final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, nextState: Int,
val awaitable: Awaitable, symLookup: SymLookup)
extends AsyncState {
+ def nextStates: List[Int] =
+ List(nextState)
+
override def mkHandlerCaseForState: CaseDef = {
val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr),
Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree
@@ -147,7 +155,7 @@ trait ExprBuilder {
def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = {
def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup))
this += If(condTree, mkBranch(thenState), mkBranch(elseState))
- new AsyncStateWithoutAwait(stats.toList, state)
+ new AsyncStateWithoutAwait(stats.toList, state, List(thenState, elseState))
}
/**
@@ -169,12 +177,12 @@ trait ExprBuilder {
}
// 2. insert changed match tree at the end of the current state
this += Match(scrutTree, newCases)
- new AsyncStateWithoutAwait(stats.toList, state)
+ new AsyncStateWithoutAwait(stats.toList, state, caseStates)
}
def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = {
this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup))
- new AsyncStateWithoutAwait(stats.toList, state)
+ new AsyncStateWithoutAwait(stats.toList, state, List(startLabelState))
}
override def toString: String = {
diff --git a/src/main/scala/scala/async/internal/LiveVariables.scala b/src/main/scala/scala/async/internal/LiveVariables.scala
new file mode 100644
index 0000000..8963ebb
--- /dev/null
+++ b/src/main/scala/scala/async/internal/LiveVariables.scala
@@ -0,0 +1,192 @@
+package scala.async.internal
+
+import reflect.internal.Flags._
+
+trait LiveVariables {
+ self: AsyncMacro =>
+ import global._
+
+ /**
+ * Returns for a given state the set of fields (as trees) that should be nulled out
+ * upon resuming that state (at the beginning of `resume`).
+ */
+ def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Int, Set[Tree]] = {
+ // live variables analysis:
+ // the result map indicates in which states a given field should be nulled out
+ val liveVarsMap: Map[Tree, Set[Int]] = liveVars(asyncStates, liftables)
+
+ var assignsOf = Map[Int, Set[Tree]]()
+
+ for ((fld, where) <- liveVarsMap; state <- where)
+ assignsOf get state match {
+ case None =>
+ assignsOf += (state -> Set[Tree](fld))
+ case Some(trees) if !trees.exists(_.symbol == fld.symbol) =>
+ assignsOf += (state -> (trees + fld))
+ case _ =>
+ /* do nothing */
+ }
+
+ assignsOf
+ }
+
+ /**
+ * Live variables data-flow analysis.
+ *
+ * The goal is to find, for each lifted field, the last state where the field is used.
+ * In all direct successor states which are not (indirect) predecessors of that last state
+ * (possible through loops), the corresponding field should be nulled out (at the beginning of
+ * `resume`).
+ *
+ * @return a map which indicates for a given field (the key) the states in which it should be nulled out
+ */
+ def liveVars(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Tree, Set[Int]] = {
+ val liftedSyms: Set[Symbol] = // include only vars
+ liftables.filter {
+ case ValDef(mods, _, _, _) => mods.hasFlag(MUTABLE)
+ case _ => false
+ }.map(_.symbol).toSet
+
+ // determine which fields should be live also at the end (will not be nulled out)
+ val noNull: Set[Symbol] = liftedSyms.filter { sym =>
+ liftables.exists { tree =>
+ !liftedSyms.contains(tree.symbol) && tree.exists(_.symbol == sym)
+ }
+ }
+
+ /**
+ * Traverse statements of an `AsyncState`, collect `Ident`-s refering to lifted fields.
+ *
+ * @param as a state of an `async` expression
+ * @return a set of lifted fields that are used within state `as`
+ */
+ def fieldsUsedIn(as: AsyncState): Set[Symbol] = {
+ class FindUseTraverser extends Traverser {
+ var usedFields = Set[Symbol]()
+ override def traverse(tree: Tree) = tree match {
+ case Ident(_) if liftedSyms(tree.symbol) =>
+ usedFields += tree.symbol
+ case _ =>
+ super.traverse(tree)
+ }
+ }
+ val findUses = new FindUseTraverser
+ findUses.traverse(Block(as.stats: _*))
+ findUses.usedFields
+ }
+
+ /* Build the control-flow graph.
+ *
+ * A state `i` is contained in the list that is the value to which
+ * key `j` maps iff control can flow from state `j` to state `i`.
+ */
+ val cfg: Map[Int, List[Int]] = asyncStates.map(as => (as.state -> as.nextStates)).toMap
+
+ /** Tests if `state1` is a predecessor of `state2`.
+ */
+ def isPred(state1: Int, state2: Int, seen: Set[Int] = Set()): Boolean =
+ if (seen(state1)) false // breaks cycles in the CFG
+ else cfg.get(state1) match {
+ case Some(nextStates) =>
+ nextStates.contains(state2) || nextStates.exists(isPred(_, state2, seen + state1))
+ case None =>
+ false
+ }
+
+ val finalState = asyncStates.find(as => !asyncStates.exists(other => isPred(as.state, other.state))).get
+
+ for (as <- asyncStates)
+ AsyncUtils.vprintln(s"fields used in state #${as.state}: ${fieldsUsedIn(as).mkString(", ")}")
+
+ /* Backwards data-flow analysis. Computes live variables information at entry and exit
+ * of each async state.
+ *
+ * Compute using a simple fixed point iteration:
+ *
+ * 1. currStates = List(finalState)
+ * 2. for each cs \in currStates, compute LVentry(cs) from LVexit(cs) and used fields information for cs
+ * 3. record if LVentry(cs) has changed for some cs.
+ * 4. obtain predecessors pred of each cs \in currStates
+ * 5. for each p \in pred, compute LVexit(p) as union of the LVentry of its successors
+ * 6. currStates = pred
+ * 7. repeat if something has changed
+ */
+
+ var LVentry = Map[Int, Set[Symbol]]() withDefaultValue Set[Symbol]()
+ var LVexit = Map[Int, Set[Symbol]]() withDefaultValue Set[Symbol]()
+
+ // All fields are declared to be dead at the exit of the final async state, except for the ones
+ // that cannot be nulled out at all (those in noNull), because they have been captured by a nested def.
+ LVexit = LVexit + (finalState.state -> noNull)
+
+ var currStates = List(finalState) // start at final state
+ var pred = List[AsyncState]() // current predecessor states
+ var hasChanged = true // if something has changed we need to continue iterating
+
+ while (hasChanged) {
+ hasChanged = false
+
+ for (cs <- currStates) {
+ val LVentryOld = LVentry(cs.state)
+ val LVentryNew = LVexit(cs.state) ++ fieldsUsedIn(cs)
+ if (!LVentryNew.sameElements(LVentryOld)) {
+ LVentry = LVentry + (cs.state -> LVentryNew)
+ hasChanged = true
+ }
+ }
+
+ pred = currStates.flatMap(cs => asyncStates.filter(_.nextStates.contains(cs.state)))
+
+ for (p <- pred) {
+ val LVexitOld = LVexit(p.state)
+ val LVexitNew = p.nextStates.flatMap(succ => LVentry(succ)).toSet
+ if (!LVexitNew.sameElements(LVexitOld)) {
+ LVexit = LVexit + (p.state -> LVexitNew)
+ hasChanged = true
+ }
+ }
+
+ currStates = pred
+ }
+
+ for (as <- asyncStates) {
+ AsyncUtils.vprintln(s"LVentry at state #${as.state}: ${LVentry(as.state).mkString(", ")}")
+ AsyncUtils.vprintln(s"LVexit at state #${as.state}: ${LVexit(as.state).mkString(", ")}")
+ }
+
+ def lastUsagesOf(field: Tree, at: AsyncState, avoid: Set[AsyncState]): Set[Int] = LVentry get at.state match {
+ case Some(fields) if fields.exists(_ == field.symbol) =>
+ Set(at.state)
+ case _ =>
+ val preds = asyncStates.filter(_.nextStates.contains(at.state)).toSet
+ preds.flatMap(p => lastUsagesOf(field, p, avoid + at))
+ }
+
+ val lastUsages: Map[Tree, Set[Int]] =
+ liftables.map(fld => (fld -> lastUsagesOf(fld, finalState, Set()))).toMap
+
+ for ((fld, lastStates) <- lastUsages)
+ AsyncUtils.vprintln(s"field ${fld.symbol.name} is last used in states ${lastStates.mkString(", ")}")
+
+ val nullOutAt: Map[Tree, Set[Int]] =
+ for ((fld, lastStates) <- lastUsages) yield {
+ val killAt = lastStates.flatMap { s =>
+ if (s == finalState.state) {
+ Set[Int]()
+ } else {
+ val lastAsyncState = asyncStates.find(_.state == s).get
+ val succNums = lastAsyncState.nextStates
+ // all successor states that are not indirect predecessors
+ // filter out successor states where the field is live at the entry
+ succNums.filter(num => !isPred(num, s)).filterNot(num => LVentry(num).exists(_ == fld.symbol))
+ }
+ }
+ (fld, killAt)
+ }
+
+ for ((fld, killAt) <- nullOutAt)
+ AsyncUtils.vprintln(s"field ${fld.symbol.name} should be nulled out in states ${killAt.mkString(", ")}")
+
+ nullOutAt
+ }
+}