From d6c5aeb6f6effcac4a054f0290711aa64ae3c191 Mon Sep 17 00:00:00 2001 From: Philipp Haller Date: Sun, 13 Oct 2013 23:44:18 +0200 Subject: Liveness analysis to avoid memory retention issues - Iterative, backwards data-flow analysis - Make sure fields captured by nested defs are never zeroed out. This is done elegantly by declaring such fields a being live at the exit of the final state; thus, they will never be zeroed out. --- .../scala/scala/async/internal/AsyncMacro.scala | 2 +- .../scala/async/internal/AsyncTransform.scala | 48 +++++- .../scala/scala/async/internal/ExprBuilder.scala | 24 ++- .../scala/scala/async/internal/LiveVariables.scala | 192 +++++++++++++++++++++ 4 files changed, 248 insertions(+), 18 deletions(-) create mode 100644 src/main/scala/scala/async/internal/LiveVariables.scala diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala index 8d93567..1c97ca7 100644 --- a/src/main/scala/scala/async/internal/AsyncMacro.scala +++ b/src/main/scala/scala/async/internal/AsyncMacro.scala @@ -20,7 +20,7 @@ object AsyncMacro { private[async] trait AsyncMacro extends TypingTransformers with AnfTransform with TransformUtils with Lifter - with ExprBuilder with AsyncTransform with AsyncAnalysis { + with ExprBuilder with AsyncTransform with AsyncAnalysis with LiveVariables { val global: Global val callSiteTyper: global.analyzer.Typer diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala index 43e4a9c..476c405 100644 --- a/src/main/scala/scala/async/internal/AsyncTransform.scala +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -29,6 +29,7 @@ trait AsyncTransform { val stateMachineType = applied("scala.async.StateMachine", List(futureSystemOps.promType[T](uncheckedBoundsResultTag), futureSystemOps.execContextType)) + // Create `ClassDef` of state machine with empty method bodies for `resume` and `apply`. val stateMachine: ClassDef = { val body: List[Tree] = { val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) @@ -42,24 +43,44 @@ trait AsyncTransform { } List(emptyConstructor, stateVar, result, execContextValDef) ++ List(resumeFunTreeDummyBody, applyDefDefDummyBody, apply0DefDef) } - val template = { - Template(List(stateMachineType), emptyValDef, body) - } + + val template = Template(List(stateMachineType), emptyValDef, body) + val t = ClassDef(NoMods, name.stateMachineT, Nil, template) callSiteTyper.typedPos(macroPos)(Block(t :: Nil, Literal(Constant(())))) t } + val stateMachineClass = stateMachine.symbol val asyncBlock: AsyncBlock = { - val symLookup = new SymLookup(stateMachine.symbol, applyDefDefDummyBody.vparamss.head.head.symbol) + val symLookup = new SymLookup(stateMachineClass, applyDefDefDummyBody.vparamss.head.head.symbol) buildAsyncBlock(anfTree, symLookup) } logDiagnostics(anfTree, asyncBlock.asyncStates.map(_.toString)) + val liftedFields: List[Tree] = liftables(asyncBlock.asyncStates) + + // live variables analysis + // the result map indicates in which states a given field should be nulled out + val assignsOf = fieldsToNullOut(asyncBlock.asyncStates, liftedFields) + + for ((state, flds) <- assignsOf) { + val asyncState = asyncBlock.asyncStates.find(_.state == state).get + val assigns = flds.map { fld => + val fieldSym = fld.symbol + Assign( + gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), + gen.mkZero(fieldSym.info) + ) + }.toList + // prepend those assigns + asyncState.stats = assigns ++ asyncState.stats + } + def startStateMachine: Tree = { val stateMachineSpliced: Tree = spliceMethodBodies( - liftables(asyncBlock.asyncStates), + liftedFields, stateMachine, atMacroPos(asyncBlock.onCompleteHandler[T]), atMacroPos(asyncBlock.resumeFunTree[T].rhs) @@ -96,9 +117,16 @@ trait AsyncTransform { states foreach (s => AsyncUtils.vprintln(s)) } - def spliceMethodBodies(liftables: List[Tree], tree: Tree, applyBody: Tree, - resumeBody: Tree): Tree = { - + /** + * Build final `ClassDef` tree of state machine class. + * + * @param liftables trees of definitions that are lifted to fields of the state machine class + * @param tree `ClassDef` tree of the state machine class + * @param applyBody tree of onComplete handler (`apply` method) + * @param resumeBody RHS of definition tree of `resume` method + * @return transformed `ClassDef` tree of the state machine class + */ + def spliceMethodBodies(liftables: List[Tree], tree: ClassDef, applyBody: Tree, resumeBody: Tree): Tree = { val liftedSyms = liftables.map(_.symbol).toSet val stateMachineClass = tree.symbol liftedSyms.foreach { @@ -112,7 +140,7 @@ trait AsyncTransform { // Replace the ValDefs in the splicee with Assigns to the corresponding lifted // fields. Similarly, replace references to them with references to the field. // - // This transform will be only be run on the RHS of `def foo`. + // This transform will only be run on the RHS of `def foo`. class UseFields extends MacroTypingTransformer { override def transform(tree: Tree): Tree = tree match { case _ if currentOwner == stateMachineClass => @@ -150,6 +178,7 @@ trait AsyncTransform { } val treeSubst = tree + /* Fixes up DefDef: use lifted fields in `body` */ def fixup(dd: DefDef, body: Tree, ctx: analyzer.Context): Tree = { val spliceeAnfFixedOwnerSyms = body val useField = new UseFields() @@ -171,6 +200,7 @@ trait AsyncTransform { (ctx: analyzer.Context) => val typedTree = fixup(dd, changeOwner(applyBody, callSiteTyper.context.owner, dd.symbol), ctx) typedTree + case dd@DefDef(_, name.resume, _, _, _, _) if dd.symbol.owner == stateMachineClass => (ctx: analyzer.Context) => val changed = changeOwner(resumeBody, callSiteTyper.context.owner, dd.symbol) diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala index 438e59e..16e95dd 100644 --- a/src/main/scala/scala/async/internal/ExprBuilder.scala +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -25,11 +25,13 @@ trait ExprBuilder { trait AsyncState { def state: Int + def nextStates: List[Int] + def mkHandlerCaseForState: CaseDef def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None - def stats: List[Tree] + var stats: List[Tree] final def allStats: List[Tree] = this match { case a: AsyncStateWithAwait => stats :+ a.awaitable.resultValDef @@ -43,9 +45,12 @@ trait ExprBuilder { } /** A sequence of statements that concludes with a unconditional transition to `nextState` */ - final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup) + final class SimpleAsyncState(var stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup) extends AsyncState { + def nextStates: List[Int] = + List(nextState) + def mkHandlerCaseForState: CaseDef = mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup)) @@ -56,21 +61,24 @@ trait ExprBuilder { /** A sequence of statements with a conditional transition to the next state, which will represent * a branch of an `if` or a `match`. */ - final class AsyncStateWithoutAwait(val stats: List[Tree], val state: Int) extends AsyncState { + final class AsyncStateWithoutAwait(var stats: List[Tree], val state: Int, val nextStates: List[Int]) extends AsyncState { override def mkHandlerCaseForState: CaseDef = mkHandlerCase(state, stats) override val toString: String = - s"AsyncStateWithoutAwait #$state" + s"AsyncStateWithoutAwait #$state, nextStates = $nextStates" } /** A sequence of statements that concludes with an `await` call. The `onComplete` * handler will unconditionally transition to `nextState`. */ - final class AsyncStateWithAwait(val stats: List[Tree], val state: Int, nextState: Int, + final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, nextState: Int, val awaitable: Awaitable, symLookup: SymLookup) extends AsyncState { + def nextStates: List[Int] = + List(nextState) + override def mkHandlerCaseForState: CaseDef = { val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr), Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree @@ -147,7 +155,7 @@ trait ExprBuilder { def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup)) this += If(condTree, mkBranch(thenState), mkBranch(elseState)) - new AsyncStateWithoutAwait(stats.toList, state) + new AsyncStateWithoutAwait(stats.toList, state, List(thenState, elseState)) } /** @@ -169,12 +177,12 @@ trait ExprBuilder { } // 2. insert changed match tree at the end of the current state this += Match(scrutTree, newCases) - new AsyncStateWithoutAwait(stats.toList, state) + new AsyncStateWithoutAwait(stats.toList, state, caseStates) } def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = { this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup)) - new AsyncStateWithoutAwait(stats.toList, state) + new AsyncStateWithoutAwait(stats.toList, state, List(startLabelState)) } override def toString: String = { diff --git a/src/main/scala/scala/async/internal/LiveVariables.scala b/src/main/scala/scala/async/internal/LiveVariables.scala new file mode 100644 index 0000000..8963ebb --- /dev/null +++ b/src/main/scala/scala/async/internal/LiveVariables.scala @@ -0,0 +1,192 @@ +package scala.async.internal + +import reflect.internal.Flags._ + +trait LiveVariables { + self: AsyncMacro => + import global._ + + /** + * Returns for a given state the set of fields (as trees) that should be nulled out + * upon resuming that state (at the beginning of `resume`). + */ + def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Int, Set[Tree]] = { + // live variables analysis: + // the result map indicates in which states a given field should be nulled out + val liveVarsMap: Map[Tree, Set[Int]] = liveVars(asyncStates, liftables) + + var assignsOf = Map[Int, Set[Tree]]() + + for ((fld, where) <- liveVarsMap; state <- where) + assignsOf get state match { + case None => + assignsOf += (state -> Set[Tree](fld)) + case Some(trees) if !trees.exists(_.symbol == fld.symbol) => + assignsOf += (state -> (trees + fld)) + case _ => + /* do nothing */ + } + + assignsOf + } + + /** + * Live variables data-flow analysis. + * + * The goal is to find, for each lifted field, the last state where the field is used. + * In all direct successor states which are not (indirect) predecessors of that last state + * (possible through loops), the corresponding field should be nulled out (at the beginning of + * `resume`). + * + * @return a map which indicates for a given field (the key) the states in which it should be nulled out + */ + def liveVars(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Tree, Set[Int]] = { + val liftedSyms: Set[Symbol] = // include only vars + liftables.filter { + case ValDef(mods, _, _, _) => mods.hasFlag(MUTABLE) + case _ => false + }.map(_.symbol).toSet + + // determine which fields should be live also at the end (will not be nulled out) + val noNull: Set[Symbol] = liftedSyms.filter { sym => + liftables.exists { tree => + !liftedSyms.contains(tree.symbol) && tree.exists(_.symbol == sym) + } + } + + /** + * Traverse statements of an `AsyncState`, collect `Ident`-s refering to lifted fields. + * + * @param as a state of an `async` expression + * @return a set of lifted fields that are used within state `as` + */ + def fieldsUsedIn(as: AsyncState): Set[Symbol] = { + class FindUseTraverser extends Traverser { + var usedFields = Set[Symbol]() + override def traverse(tree: Tree) = tree match { + case Ident(_) if liftedSyms(tree.symbol) => + usedFields += tree.symbol + case _ => + super.traverse(tree) + } + } + val findUses = new FindUseTraverser + findUses.traverse(Block(as.stats: _*)) + findUses.usedFields + } + + /* Build the control-flow graph. + * + * A state `i` is contained in the list that is the value to which + * key `j` maps iff control can flow from state `j` to state `i`. + */ + val cfg: Map[Int, List[Int]] = asyncStates.map(as => (as.state -> as.nextStates)).toMap + + /** Tests if `state1` is a predecessor of `state2`. + */ + def isPred(state1: Int, state2: Int, seen: Set[Int] = Set()): Boolean = + if (seen(state1)) false // breaks cycles in the CFG + else cfg.get(state1) match { + case Some(nextStates) => + nextStates.contains(state2) || nextStates.exists(isPred(_, state2, seen + state1)) + case None => + false + } + + val finalState = asyncStates.find(as => !asyncStates.exists(other => isPred(as.state, other.state))).get + + for (as <- asyncStates) + AsyncUtils.vprintln(s"fields used in state #${as.state}: ${fieldsUsedIn(as).mkString(", ")}") + + /* Backwards data-flow analysis. Computes live variables information at entry and exit + * of each async state. + * + * Compute using a simple fixed point iteration: + * + * 1. currStates = List(finalState) + * 2. for each cs \in currStates, compute LVentry(cs) from LVexit(cs) and used fields information for cs + * 3. record if LVentry(cs) has changed for some cs. + * 4. obtain predecessors pred of each cs \in currStates + * 5. for each p \in pred, compute LVexit(p) as union of the LVentry of its successors + * 6. currStates = pred + * 7. repeat if something has changed + */ + + var LVentry = Map[Int, Set[Symbol]]() withDefaultValue Set[Symbol]() + var LVexit = Map[Int, Set[Symbol]]() withDefaultValue Set[Symbol]() + + // All fields are declared to be dead at the exit of the final async state, except for the ones + // that cannot be nulled out at all (those in noNull), because they have been captured by a nested def. + LVexit = LVexit + (finalState.state -> noNull) + + var currStates = List(finalState) // start at final state + var pred = List[AsyncState]() // current predecessor states + var hasChanged = true // if something has changed we need to continue iterating + + while (hasChanged) { + hasChanged = false + + for (cs <- currStates) { + val LVentryOld = LVentry(cs.state) + val LVentryNew = LVexit(cs.state) ++ fieldsUsedIn(cs) + if (!LVentryNew.sameElements(LVentryOld)) { + LVentry = LVentry + (cs.state -> LVentryNew) + hasChanged = true + } + } + + pred = currStates.flatMap(cs => asyncStates.filter(_.nextStates.contains(cs.state))) + + for (p <- pred) { + val LVexitOld = LVexit(p.state) + val LVexitNew = p.nextStates.flatMap(succ => LVentry(succ)).toSet + if (!LVexitNew.sameElements(LVexitOld)) { + LVexit = LVexit + (p.state -> LVexitNew) + hasChanged = true + } + } + + currStates = pred + } + + for (as <- asyncStates) { + AsyncUtils.vprintln(s"LVentry at state #${as.state}: ${LVentry(as.state).mkString(", ")}") + AsyncUtils.vprintln(s"LVexit at state #${as.state}: ${LVexit(as.state).mkString(", ")}") + } + + def lastUsagesOf(field: Tree, at: AsyncState, avoid: Set[AsyncState]): Set[Int] = LVentry get at.state match { + case Some(fields) if fields.exists(_ == field.symbol) => + Set(at.state) + case _ => + val preds = asyncStates.filter(_.nextStates.contains(at.state)).toSet + preds.flatMap(p => lastUsagesOf(field, p, avoid + at)) + } + + val lastUsages: Map[Tree, Set[Int]] = + liftables.map(fld => (fld -> lastUsagesOf(fld, finalState, Set()))).toMap + + for ((fld, lastStates) <- lastUsages) + AsyncUtils.vprintln(s"field ${fld.symbol.name} is last used in states ${lastStates.mkString(", ")}") + + val nullOutAt: Map[Tree, Set[Int]] = + for ((fld, lastStates) <- lastUsages) yield { + val killAt = lastStates.flatMap { s => + if (s == finalState.state) { + Set[Int]() + } else { + val lastAsyncState = asyncStates.find(_.state == s).get + val succNums = lastAsyncState.nextStates + // all successor states that are not indirect predecessors + // filter out successor states where the field is live at the entry + succNums.filter(num => !isPred(num, s)).filterNot(num => LVentry(num).exists(_ == fld.symbol)) + } + } + (fld, killAt) + } + + for ((fld, killAt) <- nullOutAt) + AsyncUtils.vprintln(s"field ${fld.symbol.name} should be nulled out in states ${killAt.mkString(", ")}") + + nullOutAt + } +} -- cgit v1.2.3