From d6c5aeb6f6effcac4a054f0290711aa64ae3c191 Mon Sep 17 00:00:00 2001 From: Philipp Haller Date: Sun, 13 Oct 2013 23:44:18 +0200 Subject: Liveness analysis to avoid memory retention issues - Iterative, backwards data-flow analysis - Make sure fields captured by nested defs are never zeroed out. This is done elegantly by declaring such fields a being live at the exit of the final state; thus, they will never be zeroed out. --- .../scala/scala/async/internal/AsyncMacro.scala | 2 +- .../scala/async/internal/AsyncTransform.scala | 48 +++++- .../scala/scala/async/internal/ExprBuilder.scala | 24 ++- .../scala/scala/async/internal/LiveVariables.scala | 192 +++++++++++++++++++++ 4 files changed, 248 insertions(+), 18 deletions(-) create mode 100644 src/main/scala/scala/async/internal/LiveVariables.scala diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala index 8d93567..1c97ca7 100644 --- a/src/main/scala/scala/async/internal/AsyncMacro.scala +++ b/src/main/scala/scala/async/internal/AsyncMacro.scala @@ -20,7 +20,7 @@ object AsyncMacro { private[async] trait AsyncMacro extends TypingTransformers with AnfTransform with TransformUtils with Lifter - with ExprBuilder with AsyncTransform with AsyncAnalysis { + with ExprBuilder with AsyncTransform with AsyncAnalysis with LiveVariables { val global: Global val callSiteTyper: global.analyzer.Typer diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala index 43e4a9c..476c405 100644 --- a/src/main/scala/scala/async/internal/AsyncTransform.scala +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -29,6 +29,7 @@ trait AsyncTransform { val stateMachineType = applied("scala.async.StateMachine", List(futureSystemOps.promType[T](uncheckedBoundsResultTag), futureSystemOps.execContextType)) + // Create `ClassDef` of state machine with empty method bodies for `resume` and `apply`. val stateMachine: ClassDef = { val body: List[Tree] = { val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(0))) @@ -42,24 +43,44 @@ trait AsyncTransform { } List(emptyConstructor, stateVar, result, execContextValDef) ++ List(resumeFunTreeDummyBody, applyDefDefDummyBody, apply0DefDef) } - val template = { - Template(List(stateMachineType), emptyValDef, body) - } + + val template = Template(List(stateMachineType), emptyValDef, body) + val t = ClassDef(NoMods, name.stateMachineT, Nil, template) callSiteTyper.typedPos(macroPos)(Block(t :: Nil, Literal(Constant(())))) t } + val stateMachineClass = stateMachine.symbol val asyncBlock: AsyncBlock = { - val symLookup = new SymLookup(stateMachine.symbol, applyDefDefDummyBody.vparamss.head.head.symbol) + val symLookup = new SymLookup(stateMachineClass, applyDefDefDummyBody.vparamss.head.head.symbol) buildAsyncBlock(anfTree, symLookup) } logDiagnostics(anfTree, asyncBlock.asyncStates.map(_.toString)) + val liftedFields: List[Tree] = liftables(asyncBlock.asyncStates) + + // live variables analysis + // the result map indicates in which states a given field should be nulled out + val assignsOf = fieldsToNullOut(asyncBlock.asyncStates, liftedFields) + + for ((state, flds) <- assignsOf) { + val asyncState = asyncBlock.asyncStates.find(_.state == state).get + val assigns = flds.map { fld => + val fieldSym = fld.symbol + Assign( + gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), + gen.mkZero(fieldSym.info) + ) + }.toList + // prepend those assigns + asyncState.stats = assigns ++ asyncState.stats + } + def startStateMachine: Tree = { val stateMachineSpliced: Tree = spliceMethodBodies( - liftables(asyncBlock.asyncStates), + liftedFields, stateMachine, atMacroPos(asyncBlock.onCompleteHandler[T]), atMacroPos(asyncBlock.resumeFunTree[T].rhs) @@ -96,9 +117,16 @@ trait AsyncTransform { states foreach (s => AsyncUtils.vprintln(s)) } - def spliceMethodBodies(liftables: List[Tree], tree: Tree, applyBody: Tree, - resumeBody: Tree): Tree = { - + /** + * Build final `ClassDef` tree of state machine class. + * + * @param liftables trees of definitions that are lifted to fields of the state machine class + * @param tree `ClassDef` tree of the state machine class + * @param applyBody tree of onComplete handler (`apply` method) + * @param resumeBody RHS of definition tree of `resume` method + * @return transformed `ClassDef` tree of the state machine class + */ + def spliceMethodBodies(liftables: List[Tree], tree: ClassDef, applyBody: Tree, resumeBody: Tree): Tree = { val liftedSyms = liftables.map(_.symbol).toSet val stateMachineClass = tree.symbol liftedSyms.foreach { @@ -112,7 +140,7 @@ trait AsyncTransform { // Replace the ValDefs in the splicee with Assigns to the corresponding lifted // fields. Similarly, replace references to them with references to the field. // - // This transform will be only be run on the RHS of `def foo`. + // This transform will only be run on the RHS of `def foo`. class UseFields extends MacroTypingTransformer { override def transform(tree: Tree): Tree = tree match { case _ if currentOwner == stateMachineClass => @@ -150,6 +178,7 @@ trait AsyncTransform { } val treeSubst = tree + /* Fixes up DefDef: use lifted fields in `body` */ def fixup(dd: DefDef, body: Tree, ctx: analyzer.Context): Tree = { val spliceeAnfFixedOwnerSyms = body val useField = new UseFields() @@ -171,6 +200,7 @@ trait AsyncTransform { (ctx: analyzer.Context) => val typedTree = fixup(dd, changeOwner(applyBody, callSiteTyper.context.owner, dd.symbol), ctx) typedTree + case dd@DefDef(_, name.resume, _, _, _, _) if dd.symbol.owner == stateMachineClass => (ctx: analyzer.Context) => val changed = changeOwner(resumeBody, callSiteTyper.context.owner, dd.symbol) diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala index 438e59e..16e95dd 100644 --- a/src/main/scala/scala/async/internal/ExprBuilder.scala +++ b/src/main/scala/scala/async/internal/ExprBuilder.scala @@ -25,11 +25,13 @@ trait ExprBuilder { trait AsyncState { def state: Int + def nextStates: List[Int] + def mkHandlerCaseForState: CaseDef def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None - def stats: List[Tree] + var stats: List[Tree] final def allStats: List[Tree] = this match { case a: AsyncStateWithAwait => stats :+ a.awaitable.resultValDef @@ -43,9 +45,12 @@ trait ExprBuilder { } /** A sequence of statements that concludes with a unconditional transition to `nextState` */ - final class SimpleAsyncState(val stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup) + final class SimpleAsyncState(var stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup) extends AsyncState { + def nextStates: List[Int] = + List(nextState) + def mkHandlerCaseForState: CaseDef = mkHandlerCase(state, stats :+ mkStateTree(nextState, symLookup) :+ mkResumeApply(symLookup)) @@ -56,21 +61,24 @@ trait ExprBuilder { /** A sequence of statements with a conditional transition to the next state, which will represent * a branch of an `if` or a `match`. */ - final class AsyncStateWithoutAwait(val stats: List[Tree], val state: Int) extends AsyncState { + final class AsyncStateWithoutAwait(var stats: List[Tree], val state: Int, val nextStates: List[Int]) extends AsyncState { override def mkHandlerCaseForState: CaseDef = mkHandlerCase(state, stats) override val toString: String = - s"AsyncStateWithoutAwait #$state" + s"AsyncStateWithoutAwait #$state, nextStates = $nextStates" } /** A sequence of statements that concludes with an `await` call. The `onComplete` * handler will unconditionally transition to `nextState`. */ - final class AsyncStateWithAwait(val stats: List[Tree], val state: Int, nextState: Int, + final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, nextState: Int, val awaitable: Awaitable, symLookup: SymLookup) extends AsyncState { + def nextStates: List[Int] = + List(nextState) + override def mkHandlerCaseForState: CaseDef = { val callOnComplete = futureSystemOps.onComplete(Expr(awaitable.expr), Expr(This(tpnme.EMPTY)), Expr(Ident(name.execContext))).tree @@ -147,7 +155,7 @@ trait ExprBuilder { def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { def mkBranch(state: Int) = Block(mkStateTree(state, symLookup) :: Nil, mkResumeApply(symLookup)) this += If(condTree, mkBranch(thenState), mkBranch(elseState)) - new AsyncStateWithoutAwait(stats.toList, state) + new AsyncStateWithoutAwait(stats.toList, state, List(thenState, elseState)) } /** @@ -169,12 +177,12 @@ trait ExprBuilder { } // 2. insert changed match tree at the end of the current state this += Match(scrutTree, newCases) - new AsyncStateWithoutAwait(stats.toList, state) + new AsyncStateWithoutAwait(stats.toList, state, caseStates) } def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = { this += Block(mkStateTree(startLabelState, symLookup) :: Nil, mkResumeApply(symLookup)) - new AsyncStateWithoutAwait(stats.toList, state) + new AsyncStateWithoutAwait(stats.toList, state, List(startLabelState)) } override def toString: String = { diff --git a/src/main/scala/scala/async/internal/LiveVariables.scala b/src/main/scala/scala/async/internal/LiveVariables.scala new file mode 100644 index 0000000..8963ebb --- /dev/null +++ b/src/main/scala/scala/async/internal/LiveVariables.scala @@ -0,0 +1,192 @@ +package scala.async.internal + +import reflect.internal.Flags._ + +trait LiveVariables { + self: AsyncMacro => + import global._ + + /** + * Returns for a given state the set of fields (as trees) that should be nulled out + * upon resuming that state (at the beginning of `resume`). + */ + def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Int, Set[Tree]] = { + // live variables analysis: + // the result map indicates in which states a given field should be nulled out + val liveVarsMap: Map[Tree, Set[Int]] = liveVars(asyncStates, liftables) + + var assignsOf = Map[Int, Set[Tree]]() + + for ((fld, where) <- liveVarsMap; state <- where) + assignsOf get state match { + case None => + assignsOf += (state -> Set[Tree](fld)) + case Some(trees) if !trees.exists(_.symbol == fld.symbol) => + assignsOf += (state -> (trees + fld)) + case _ => + /* do nothing */ + } + + assignsOf + } + + /** + * Live variables data-flow analysis. + * + * The goal is to find, for each lifted field, the last state where the field is used. + * In all direct successor states which are not (indirect) predecessors of that last state + * (possible through loops), the corresponding field should be nulled out (at the beginning of + * `resume`). + * + * @return a map which indicates for a given field (the key) the states in which it should be nulled out + */ + def liveVars(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Tree, Set[Int]] = { + val liftedSyms: Set[Symbol] = // include only vars + liftables.filter { + case ValDef(mods, _, _, _) => mods.hasFlag(MUTABLE) + case _ => false + }.map(_.symbol).toSet + + // determine which fields should be live also at the end (will not be nulled out) + val noNull: Set[Symbol] = liftedSyms.filter { sym => + liftables.exists { tree => + !liftedSyms.contains(tree.symbol) && tree.exists(_.symbol == sym) + } + } + + /** + * Traverse statements of an `AsyncState`, collect `Ident`-s refering to lifted fields. + * + * @param as a state of an `async` expression + * @return a set of lifted fields that are used within state `as` + */ + def fieldsUsedIn(as: AsyncState): Set[Symbol] = { + class FindUseTraverser extends Traverser { + var usedFields = Set[Symbol]() + override def traverse(tree: Tree) = tree match { + case Ident(_) if liftedSyms(tree.symbol) => + usedFields += tree.symbol + case _ => + super.traverse(tree) + } + } + val findUses = new FindUseTraverser + findUses.traverse(Block(as.stats: _*)) + findUses.usedFields + } + + /* Build the control-flow graph. + * + * A state `i` is contained in the list that is the value to which + * key `j` maps iff control can flow from state `j` to state `i`. + */ + val cfg: Map[Int, List[Int]] = asyncStates.map(as => (as.state -> as.nextStates)).toMap + + /** Tests if `state1` is a predecessor of `state2`. + */ + def isPred(state1: Int, state2: Int, seen: Set[Int] = Set()): Boolean = + if (seen(state1)) false // breaks cycles in the CFG + else cfg.get(state1) match { + case Some(nextStates) => + nextStates.contains(state2) || nextStates.exists(isPred(_, state2, seen + state1)) + case None => + false + } + + val finalState = asyncStates.find(as => !asyncStates.exists(other => isPred(as.state, other.state))).get + + for (as <- asyncStates) + AsyncUtils.vprintln(s"fields used in state #${as.state}: ${fieldsUsedIn(as).mkString(", ")}") + + /* Backwards data-flow analysis. Computes live variables information at entry and exit + * of each async state. + * + * Compute using a simple fixed point iteration: + * + * 1. currStates = List(finalState) + * 2. for each cs \in currStates, compute LVentry(cs) from LVexit(cs) and used fields information for cs + * 3. record if LVentry(cs) has changed for some cs. + * 4. obtain predecessors pred of each cs \in currStates + * 5. for each p \in pred, compute LVexit(p) as union of the LVentry of its successors + * 6. currStates = pred + * 7. repeat if something has changed + */ + + var LVentry = Map[Int, Set[Symbol]]() withDefaultValue Set[Symbol]() + var LVexit = Map[Int, Set[Symbol]]() withDefaultValue Set[Symbol]() + + // All fields are declared to be dead at the exit of the final async state, except for the ones + // that cannot be nulled out at all (those in noNull), because they have been captured by a nested def. + LVexit = LVexit + (finalState.state -> noNull) + + var currStates = List(finalState) // start at final state + var pred = List[AsyncState]() // current predecessor states + var hasChanged = true // if something has changed we need to continue iterating + + while (hasChanged) { + hasChanged = false + + for (cs <- currStates) { + val LVentryOld = LVentry(cs.state) + val LVentryNew = LVexit(cs.state) ++ fieldsUsedIn(cs) + if (!LVentryNew.sameElements(LVentryOld)) { + LVentry = LVentry + (cs.state -> LVentryNew) + hasChanged = true + } + } + + pred = currStates.flatMap(cs => asyncStates.filter(_.nextStates.contains(cs.state))) + + for (p <- pred) { + val LVexitOld = LVexit(p.state) + val LVexitNew = p.nextStates.flatMap(succ => LVentry(succ)).toSet + if (!LVexitNew.sameElements(LVexitOld)) { + LVexit = LVexit + (p.state -> LVexitNew) + hasChanged = true + } + } + + currStates = pred + } + + for (as <- asyncStates) { + AsyncUtils.vprintln(s"LVentry at state #${as.state}: ${LVentry(as.state).mkString(", ")}") + AsyncUtils.vprintln(s"LVexit at state #${as.state}: ${LVexit(as.state).mkString(", ")}") + } + + def lastUsagesOf(field: Tree, at: AsyncState, avoid: Set[AsyncState]): Set[Int] = LVentry get at.state match { + case Some(fields) if fields.exists(_ == field.symbol) => + Set(at.state) + case _ => + val preds = asyncStates.filter(_.nextStates.contains(at.state)).toSet + preds.flatMap(p => lastUsagesOf(field, p, avoid + at)) + } + + val lastUsages: Map[Tree, Set[Int]] = + liftables.map(fld => (fld -> lastUsagesOf(fld, finalState, Set()))).toMap + + for ((fld, lastStates) <- lastUsages) + AsyncUtils.vprintln(s"field ${fld.symbol.name} is last used in states ${lastStates.mkString(", ")}") + + val nullOutAt: Map[Tree, Set[Int]] = + for ((fld, lastStates) <- lastUsages) yield { + val killAt = lastStates.flatMap { s => + if (s == finalState.state) { + Set[Int]() + } else { + val lastAsyncState = asyncStates.find(_.state == s).get + val succNums = lastAsyncState.nextStates + // all successor states that are not indirect predecessors + // filter out successor states where the field is live at the entry + succNums.filter(num => !isPred(num, s)).filterNot(num => LVentry(num).exists(_ == fld.symbol)) + } + } + (fld, killAt) + } + + for ((fld, killAt) <- nullOutAt) + AsyncUtils.vprintln(s"field ${fld.symbol.name} should be nulled out in states ${killAt.mkString(", ")}") + + nullOutAt + } +} -- cgit v1.2.3 From 480fffd487d53cfdb943a2287788af2bad409b88 Mon Sep 17 00:00:00 2001 From: Philipp Haller Date: Fri, 18 Oct 2013 12:32:35 +0200 Subject: Fix looping issue when computing last usages of fields - A missing condition could cause an infinite loop - Various clean-ups --- .../scala/async/internal/AsyncTransform.scala | 10 ++----- .../scala/scala/async/internal/LiveVariables.scala | 33 +++++++++++----------- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala index 476c405..5c84f66 100644 --- a/src/main/scala/scala/async/internal/AsyncTransform.scala +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -66,15 +66,11 @@ trait AsyncTransform { val assignsOf = fieldsToNullOut(asyncBlock.asyncStates, liftedFields) for ((state, flds) <- assignsOf) { - val asyncState = asyncBlock.asyncStates.find(_.state == state).get val assigns = flds.map { fld => val fieldSym = fld.symbol - Assign( - gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), - gen.mkZero(fieldSym.info) - ) - }.toList - // prepend those assigns + Assign(gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), gen.mkZero(fieldSym.info)) + } + val asyncState = asyncBlock.asyncStates.find(_.state == state).get asyncState.stats = assigns ++ asyncState.stats } diff --git a/src/main/scala/scala/async/internal/LiveVariables.scala b/src/main/scala/scala/async/internal/LiveVariables.scala index 8963ebb..48f03cd 100644 --- a/src/main/scala/scala/async/internal/LiveVariables.scala +++ b/src/main/scala/scala/async/internal/LiveVariables.scala @@ -7,22 +7,22 @@ trait LiveVariables { import global._ /** - * Returns for a given state the set of fields (as trees) that should be nulled out + * Returns for a given state a list of fields (as trees) that should be nulled out * upon resuming that state (at the beginning of `resume`). */ - def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Int, Set[Tree]] = { + def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Int, List[Tree]] = { // live variables analysis: // the result map indicates in which states a given field should be nulled out val liveVarsMap: Map[Tree, Set[Int]] = liveVars(asyncStates, liftables) - var assignsOf = Map[Int, Set[Tree]]() + var assignsOf = Map[Int, List[Tree]]() for ((fld, where) <- liveVarsMap; state <- where) assignsOf get state match { case None => - assignsOf += (state -> Set[Tree](fld)) + assignsOf += (state -> List(fld)) case Some(trees) if !trees.exists(_.symbol == fld.symbol) => - assignsOf += (state -> (trees + fld)) + assignsOf += (state -> (fld +: trees)) case _ => /* do nothing */ } @@ -86,7 +86,7 @@ trait LiveVariables { */ def isPred(state1: Int, state2: Int, seen: Set[Int] = Set()): Boolean = if (seen(state1)) false // breaks cycles in the CFG - else cfg.get(state1) match { + else cfg get state1 match { case Some(nextStates) => nextStates.contains(state2) || nextStates.exists(isPred(_, state2, seen + state1)) case None => @@ -154,13 +154,15 @@ trait LiveVariables { AsyncUtils.vprintln(s"LVexit at state #${as.state}: ${LVexit(as.state).mkString(", ")}") } - def lastUsagesOf(field: Tree, at: AsyncState, avoid: Set[AsyncState]): Set[Int] = LVentry get at.state match { - case Some(fields) if fields.exists(_ == field.symbol) => - Set(at.state) - case _ => - val preds = asyncStates.filter(_.nextStates.contains(at.state)).toSet - preds.flatMap(p => lastUsagesOf(field, p, avoid + at)) - } + def lastUsagesOf(field: Tree, at: AsyncState, avoid: Set[AsyncState]): Set[Int] = + if (avoid(at)) Set() + else LVentry get at.state match { + case Some(fields) if fields.exists(_ == field.symbol) => + Set(at.state) + case _ => + val preds = asyncStates.filter(_.nextStates.contains(at.state)).toSet + preds.flatMap(p => lastUsagesOf(field, p, avoid + at)) + } val lastUsages: Map[Tree, Set[Int]] = liftables.map(fld => (fld -> lastUsagesOf(fld, finalState, Set()))).toMap @@ -171,9 +173,8 @@ trait LiveVariables { val nullOutAt: Map[Tree, Set[Int]] = for ((fld, lastStates) <- lastUsages) yield { val killAt = lastStates.flatMap { s => - if (s == finalState.state) { - Set[Int]() - } else { + if (s == finalState.state) Set() + else { val lastAsyncState = asyncStates.find(_.state == s).get val succNums = lastAsyncState.nextStates // all successor states that are not indirect predecessors -- cgit v1.2.3 From 62f22d41cfabc7d0d87c5afef64c1c9015e2cf5e Mon Sep 17 00:00:00 2001 From: Philipp Haller Date: Sat, 19 Oct 2013 12:45:00 +0200 Subject: Enables testing the resetting of lifted local variables - Adds a hook that lets a derived macro insert additional code when zero-ing out a lifted field. - Adds a variant of the `AsyncId` macro that logs zeroed-out fields. - Adds a test using this mechanism --- .../continuations/AsyncBaseWithCPSFallback.scala | 2 +- .../scala/scala/async/internal/AsyncBase.scala | 6 +++- src/main/scala/scala/async/internal/AsyncId.scala | 18 ++++++++++ .../scala/scala/async/internal/AsyncMacro.scala | 13 ++++--- .../scala/async/internal/AsyncTransform.scala | 10 +++++- .../scala/async/run/live/LiveVariablesSpec.scala | 40 ++++++++++++++++++++++ 6 files changed, 81 insertions(+), 8 deletions(-) create mode 100644 src/test/scala/scala/async/run/live/LiveVariablesSpec.scala diff --git a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala index 1a6ac87..20f5cce 100644 --- a/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala +++ b/src/main/scala/scala/async/continuations/AsyncBaseWithCPSFallback.scala @@ -92,7 +92,7 @@ trait AsyncBaseWithCPSFallback extends internal.AsyncBase { (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { AsyncUtils.vprintln("AsyncBaseWithCPSFallback.asyncImpl") - val asyncMacro = AsyncMacro(c, futureSystem) + val asyncMacro = AsyncMacro(c, this) if (!asyncMacro.reportUnsupportedAwaits(body.tree.asInstanceOf[asyncMacro.global.Tree], report = fallbackEnabled)) super.asyncImpl[T](c)(body)(execContext) // no unsupported awaits diff --git a/src/main/scala/scala/async/internal/AsyncBase.scala b/src/main/scala/scala/async/internal/AsyncBase.scala index 3690c2d..e44c27f 100644 --- a/src/main/scala/scala/async/internal/AsyncBase.scala +++ b/src/main/scala/scala/async/internal/AsyncBase.scala @@ -6,6 +6,7 @@ package scala.async.internal import scala.reflect.internal.annotations.compileTimeOnly import scala.reflect.macros.Context +import scala.reflect.api.Universe /** * A base class for the `async` macro. Subclasses must provide: @@ -45,7 +46,7 @@ abstract class AsyncBase { (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { import c.universe._ - val asyncMacro = AsyncMacro(c, futureSystem) + val asyncMacro = AsyncMacro(c, self) val code = asyncMacro.asyncTransform[T]( body.tree.asInstanceOf[asyncMacro.global.Tree], @@ -59,4 +60,7 @@ abstract class AsyncBase { AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}") c.Expr[futureSystem.Fut[T]](code) } + + protected[async] def nullOut(u: Universe)(name: u.Expr[String], v: u.Expr[Any]): u.Expr[Unit] = + u.reify { () } } diff --git a/src/main/scala/scala/async/internal/AsyncId.scala b/src/main/scala/scala/async/internal/AsyncId.scala index b9d82e2..7f7807f 100644 --- a/src/main/scala/scala/async/internal/AsyncId.scala +++ b/src/main/scala/scala/async/internal/AsyncId.scala @@ -6,6 +6,7 @@ package scala.async.internal import language.experimental.macros import scala.reflect.macros.Context +import scala.reflect.api.Universe import scala.reflect.internal.SymbolTable object AsyncId extends AsyncBase { @@ -17,6 +18,23 @@ object AsyncId extends AsyncBase { def asyncIdImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit) } +object AsyncTestLV extends AsyncBase { + lazy val futureSystem = IdentityFutureSystem + type FS = IdentityFutureSystem.type + + def async[T](body: T) = macro asyncIdImpl[T] + + def asyncIdImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit) + + var log: List[(String, Any)] = List() + + def apply(name: String, v: Any): Unit = + log ::= (name -> v) + + protected[async] override def nullOut(u: Universe)(name: u.Expr[String], v: u.Expr[Any]): u.Expr[Unit] = + u.reify { scala.async.internal.AsyncTestLV(name.splice, v.splice) } +} + /** * A trivial implementation of [[FutureSystem]] that performs computations * on the current thread. Useful for testing. diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala index 1c97ca7..ee49923 100644 --- a/src/main/scala/scala/async/internal/AsyncMacro.scala +++ b/src/main/scala/scala/async/internal/AsyncMacro.scala @@ -4,15 +4,18 @@ import scala.tools.nsc.Global import scala.tools.nsc.transform.TypingTransformers object AsyncMacro { - def apply(c: reflect.macros.Context, futureSystem0: FutureSystem): AsyncMacro = { + def apply(c: reflect.macros.Context, base: AsyncBase): AsyncMacro = { import language.reflectiveCalls val powerContext = c.asInstanceOf[c.type { val universe: Global; val callsiteTyper: universe.analyzer.Typer }] new AsyncMacro { - val global: powerContext.universe.type = powerContext.universe + val global: powerContext.universe.type = powerContext.universe val callSiteTyper: global.analyzer.Typer = powerContext.callsiteTyper - val futureSystem: futureSystem0.type = futureSystem0 - val futureSystemOps: futureSystem.Ops {val universe: global.type} = futureSystem0.mkOps(global) - val macroApplication: global.Tree = c.macroApplication.asInstanceOf[global.Tree] + val macroApplication: global.Tree = c.macroApplication.asInstanceOf[global.Tree] + // This member is required by `AsyncTransform`: + val asyncBase: AsyncBase = base + // These members are required by `ExprBuilder`: + val futureSystem: FutureSystem = base.futureSystem + val futureSystemOps: futureSystem.Ops {val universe: global.type} = futureSystem.mkOps(global) } } } diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala index 5c84f66..18caea4 100644 --- a/src/main/scala/scala/async/internal/AsyncTransform.scala +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -5,6 +5,8 @@ trait AsyncTransform { import global._ + val asyncBase: AsyncBase + def asyncTransform[T](body: Tree, execContext: Tree, cpsFallbackEnabled: Boolean) (resultType: WeakTypeTag[T]): Tree = { @@ -68,7 +70,13 @@ trait AsyncTransform { for ((state, flds) <- assignsOf) { val assigns = flds.map { fld => val fieldSym = fld.symbol - Assign(gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), gen.mkZero(fieldSym.info)) + val zero = gen.mkZero(fieldSym.info) + Block( + List( + asyncBase.nullOut(global)(Expr[String](Literal(Constant(fieldSym.name.toString))), Expr[Any](zero)).tree + ), + Assign(gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), zero) + ) } val asyncState = asyncBlock.asyncStates.find(_.state == state).get asyncState.stats = assigns ++ asyncState.stats diff --git a/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala b/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala new file mode 100644 index 0000000..2cecffa --- /dev/null +++ b/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2012-2013 Typesafe Inc. + */ + +package scala.async +package run +package live + +import org.junit.Test + +import internal.AsyncTestLV +import AsyncTestLV._ + +class LiveVariablesSpec { + + @Test + def liveVars1() { + val f = async { 1 } + + def m1(x: Int): Int = + async { x + 1 } + + def m2(x: Int): String = + async { x.toString } + + def m3() = async { + val a = await(f) // await$1$1 + // a == 1 + val b = await(m1(a)) // await$2$1 + // b == 2 + assert(AsyncTestLV.log.exists(_ == ("await$1$1" -> 0))) + val res = await(m2(b)) // await$3$1 + assert(AsyncTestLV.log.exists(_ == ("await$2$1" -> 0))) + res + } + + assert(m3() == "2") + } + +} -- cgit v1.2.3 From 01b11f71fadf60c8dbf2f5f38f32ec82c437feb0 Mon Sep 17 00:00:00 2001 From: Philipp Haller Date: Sat, 19 Oct 2013 16:20:26 +0200 Subject: Avoid zero-ing out dead fields of primitive value class type - Zero out fields of type Any - Zero out fields of value class type --- .../scala/async/internal/AsyncTransform.scala | 5 +- .../scala/scala/async/internal/LiveVariables.scala | 3 +- .../scala/async/run/live/LiveVariablesSpec.scala | 132 +++++++++++++++++++-- 3 files changed, 126 insertions(+), 14 deletions(-) diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala index 18caea4..27d95a4 100644 --- a/src/main/scala/scala/async/internal/AsyncTransform.scala +++ b/src/main/scala/scala/async/internal/AsyncTransform.scala @@ -70,12 +70,11 @@ trait AsyncTransform { for ((state, flds) <- assignsOf) { val assigns = flds.map { fld => val fieldSym = fld.symbol - val zero = gen.mkZero(fieldSym.info) Block( List( - asyncBase.nullOut(global)(Expr[String](Literal(Constant(fieldSym.name.toString))), Expr[Any](zero)).tree + asyncBase.nullOut(global)(Expr[String](Literal(Constant(fieldSym.name.toString))), Expr[Any](Ident(fieldSym))).tree ), - Assign(gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), zero) + Assign(gen.mkAttributedStableRef(fieldSym.owner.thisType, fieldSym), gen.mkZero(fieldSym.info)) ) } val asyncState = asyncBlock.asyncStates.find(_.state == state).get diff --git a/src/main/scala/scala/async/internal/LiveVariables.scala b/src/main/scala/scala/async/internal/LiveVariables.scala index 48f03cd..0f95bca 100644 --- a/src/main/scala/scala/async/internal/LiveVariables.scala +++ b/src/main/scala/scala/async/internal/LiveVariables.scala @@ -49,10 +49,11 @@ trait LiveVariables { // determine which fields should be live also at the end (will not be nulled out) val noNull: Set[Symbol] = liftedSyms.filter { sym => - liftables.exists { tree => + sym.tpe.typeSymbol.isPrimitiveValueClass || liftables.exists { tree => !liftedSyms.contains(tree.symbol) && tree.exists(_.symbol == sym) } } + AsyncUtils.vprintln(s"fields never zero-ed out: ${noNull.mkString(", ")}") /** * Traverse statements of an `AsyncState`, collect `Ident`-s refering to lifted fields. diff --git a/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala b/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala index 2cecffa..be62ed8 100644 --- a/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala +++ b/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala @@ -11,30 +11,142 @@ import org.junit.Test import internal.AsyncTestLV import AsyncTestLV._ +case class Cell[T](v: T) + +class Meter(val len: Long) extends AnyVal + +case class MCell[T](var v: T) + + class LiveVariablesSpec { @Test - def liveVars1() { + def `zero out fields of reference type`() { + val f = async { Cell(1) } + + def m1(x: Cell[Int]): Cell[Int] = + async { Cell(x.v + 1) } + + def m2(x: Cell[Int]): String = + async { x.v.toString } + + def m3() = async { + val a: Cell[Int] = await(f) // await$1$1 + // a == Cell(1) + val b: Cell[Int] = await(m1(a)) // await$2$1 + // b == Cell(2) + assert(AsyncTestLV.log.exists(_ == ("await$1$1" -> Cell(1)))) + val res = await(m2(b)) // await$3$1 + assert(AsyncTestLV.log.exists(_ == ("await$2$1" -> Cell(2)))) + res + } + + assert(m3() == "2") + } + + @Test + def `zero out fields of type Any`() { + val f = async { Cell(1) } + + def m1(x: Cell[Int]): Cell[Int] = + async { Cell(x.v + 1) } + + def m2(x: Any): String = + async { x.toString } + + def m3() = async { + val a: Cell[Int] = await(f) // await$4$1 + // a == Cell(1) + val b: Any = await(m1(a)) // await$5$1 + // b == Cell(2) + assert(AsyncTestLV.log.exists(_ == ("await$4$1" -> Cell(1)))) + val res = await(m2(b)) // await$6$1 + assert(AsyncTestLV.log.exists(_ == ("await$5$1" -> Cell(2)))) + res + } + + assert(m3() == "Cell(2)") + } + + @Test + def `do not zero out fields of primitive type`() { val f = async { 1 } - def m1(x: Int): Int = - async { x + 1 } + def m1(x: Int): Cell[Int] = + async { Cell(x + 1) } - def m2(x: Int): String = + def m2(x: Any): String = async { x.toString } def m3() = async { - val a = await(f) // await$1$1 + val a: Int = await(f) // await$7$1 // a == 1 - val b = await(m1(a)) // await$2$1 - // b == 2 - assert(AsyncTestLV.log.exists(_ == ("await$1$1" -> 0))) - val res = await(m2(b)) // await$3$1 - assert(AsyncTestLV.log.exists(_ == ("await$2$1" -> 0))) + val b: Any = await(m1(a)) // await$8$1 + // b == Cell(2) + assert(!AsyncTestLV.log.exists(p => p._1 == "await$7$1")) + val res = await(m2(b)) // await$9$1 + assert(AsyncTestLV.log.exists(_ == ("await$8$1" -> Cell(2)))) + res + } + + assert(m3() == "Cell(2)") + } + + @Test + def `zero out fields of value class type`() { + val f = async { Cell(1) } + + def m1(x: Cell[Int]): Meter = + async { new Meter(x.v + 1) } + + def m2(x: Any): String = + async { x.toString } + + def m3() = async { + val a: Cell[Int] = await(f) // await$10$1 + // a == Cell(1) + val b: Meter = await(m1(a)) // await$11$1 + // b == Meter(2) + assert(AsyncTestLV.log.exists(_ == ("await$10$1" -> Cell(1)))) + val res = await(m2(b.len)) // await$12$1 + assert(AsyncTestLV.log.exists(entry => entry._1 == "await$11$1" && entry._2.asInstanceOf[Meter].len == 2L)) res } assert(m3() == "2") } + @Test + def `zero out fields after use in loop`() { + val f = async { MCell(1) } + + def m1(x: MCell[Int], y: Int): Int = + async { x.v + y } + + def m3() = async { + // state #1 + val a: MCell[Int] = await(f) // await$13$1 + // state #2 + var y = MCell(0) + + while (a.v < 10) { + // state #4 + a.v = a.v + 1 + y = MCell(await(a).v + 1) // await$14$1 + // state #7 + } + + // state #3 + assert(AsyncTestLV.log.exists(entry => entry._1 == "await$14$1")) + + val b = await(m1(a, y.v)) // await$15$1 + // state #8 + assert(AsyncTestLV.log.exists(_ == ("a$1" -> MCell(10)))) + assert(AsyncTestLV.log.exists(_ == ("y$1" -> MCell(11)))) + b + } + + assert(m3() == 21) + } + } -- cgit v1.2.3 From 9ac022c1329e06da1f4dd45abc24e6d482bfbdaf Mon Sep 17 00:00:00 2001 From: Philipp Haller Date: Tue, 22 Oct 2013 15:15:42 +0200 Subject: Add more doc comments --- src/main/scala/scala/async/internal/LiveVariables.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/main/scala/scala/async/internal/LiveVariables.scala b/src/main/scala/scala/async/internal/LiveVariables.scala index 0f95bca..4d8c479 100644 --- a/src/main/scala/scala/async/internal/LiveVariables.scala +++ b/src/main/scala/scala/async/internal/LiveVariables.scala @@ -9,6 +9,11 @@ trait LiveVariables { /** * Returns for a given state a list of fields (as trees) that should be nulled out * upon resuming that state (at the beginning of `resume`). + * + * @param asyncStates the states of an `async` block + * @param liftables the lifted fields + * @return a map mapping a state to the fields that should be nulled out + * upon resuming that state */ def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Int, List[Tree]] = { // live variables analysis: @@ -38,7 +43,9 @@ trait LiveVariables { * (possible through loops), the corresponding field should be nulled out (at the beginning of * `resume`). * - * @return a map which indicates for a given field (the key) the states in which it should be nulled out + * @param asyncStates the states of an `async` block + * @param liftables the lifted fields + * @return a map which indicates for a given field (the key) the states in which it should be nulled out */ def liveVars(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Tree, Set[Int]] = { val liftedSyms: Set[Symbol] = // include only vars -- cgit v1.2.3