From 302a07c64e4cd3db91a654dcbc893ade0837ba8c Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Mon, 11 Nov 2013 23:17:11 +0100 Subject: Don't aggressively null out captured vars Once they escape, we leave the references in the state machines fields untouched. --- src/main/scala/scala/async/internal/AsyncId.scala | 3 + .../scala/scala/async/internal/LiveVariables.scala | 60 +++++++++-- .../scala/async/internal/TransformUtils.scala | 2 +- src/test/scala/scala/async/TreeInterrogation.scala | 38 ++++++- .../scala/async/run/live/LiveVariablesSpec.scala | 120 ++++++++++++++++++++- 5 files changed, 205 insertions(+), 18 deletions(-) diff --git a/src/main/scala/scala/async/internal/AsyncId.scala b/src/main/scala/scala/async/internal/AsyncId.scala index c123675..a794f93 100644 --- a/src/main/scala/scala/async/internal/AsyncId.scala +++ b/src/main/scala/scala/async/internal/AsyncId.scala @@ -27,6 +27,9 @@ object AsyncTestLV extends AsyncBase { def asyncIdImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit) var log: List[(String, Any)] = List() + def assertNulledOut(a: Any): Unit = assert(log.exists(_._2 == a), AsyncTestLV.log) + def assertNotNulledOut(a: Any): Unit = assert(!log.exists(_._2 == a), AsyncTestLV.log) + def clear() = log = Nil def apply(name: String, v: Any): Unit = log ::= (name -> v) diff --git a/src/main/scala/scala/async/internal/LiveVariables.scala b/src/main/scala/scala/async/internal/LiveVariables.scala index 4d8c479..8753b3d 100644 --- a/src/main/scala/scala/async/internal/LiveVariables.scala +++ b/src/main/scala/scala/async/internal/LiveVariables.scala @@ -68,19 +68,53 @@ trait LiveVariables { * @param as a state of an `async` expression * @return a set of lifted fields that are used within state `as` */ - def fieldsUsedIn(as: AsyncState): Set[Symbol] = { - class FindUseTraverser extends Traverser { + def fieldsUsedIn(as: AsyncState): ReferencedFields = { + class FindUseTraverser extends AsyncTraverser { var usedFields = Set[Symbol]() - override def traverse(tree: Tree) = tree match { - case Ident(_) if liftedSyms(tree.symbol) => - usedFields += tree.symbol - case _ => - super.traverse(tree) + var capturedFields = Set[Symbol]() + private def capturing[A](body: => A): A = { + val saved = capturing + try { + capturing = true + body + } finally capturing = saved } + private def capturingCheck(tree: Tree) = capturing(tree foreach check) + private var capturing: Boolean = false + private def check(tree: Tree) { + tree match { + case Ident(_) if liftedSyms(tree.symbol) => + if (capturing) + capturedFields += tree.symbol + else + usedFields += tree.symbol + case _ => + } + } + override def traverse(tree: Tree) = { + check(tree) + super.traverse(tree) + } + + override def nestedClass(classDef: ClassDef): Unit = capturingCheck(classDef) + + override def nestedModule(module: ModuleDef): Unit = capturingCheck(module) + + override def nestedMethod(defdef: DefDef): Unit = capturingCheck(defdef) + + override def byNameArgument(arg: Tree): Unit = capturingCheck(arg) + + override def function(function: Function): Unit = capturingCheck(function) + + override def patMatFunction(tree: Match): Unit = capturingCheck(tree) } + val findUses = new FindUseTraverser findUses.traverse(Block(as.stats: _*)) - findUses.usedFields + ReferencedFields(findUses.usedFields, findUses.capturedFields) + } + case class ReferencedFields(used: Set[Symbol], captured: Set[Symbol]) { + override def toString = s"used: ${used.mkString(",")}\ncaptured: ${captured.mkString(",")}" } /* Build the control-flow graph. @@ -104,7 +138,7 @@ trait LiveVariables { val finalState = asyncStates.find(as => !asyncStates.exists(other => isPred(as.state, other.state))).get for (as <- asyncStates) - AsyncUtils.vprintln(s"fields used in state #${as.state}: ${fieldsUsedIn(as).mkString(", ")}") + AsyncUtils.vprintln(s"fields used in state #${as.state}: ${fieldsUsedIn(as)}") /* Backwards data-flow analysis. Computes live variables information at entry and exit * of each async state. @@ -130,13 +164,16 @@ trait LiveVariables { var currStates = List(finalState) // start at final state var pred = List[AsyncState]() // current predecessor states var hasChanged = true // if something has changed we need to continue iterating + var captured: Set[Symbol] = Set() while (hasChanged) { hasChanged = false for (cs <- currStates) { val LVentryOld = LVentry(cs.state) - val LVentryNew = LVexit(cs.state) ++ fieldsUsedIn(cs) + val referenced = fieldsUsedIn(cs) + captured ++= referenced.captured + val LVentryNew = LVexit(cs.state) ++ referenced.used if (!LVentryNew.sameElements(LVentryOld)) { LVentry = LVentry + (cs.state -> LVentryNew) hasChanged = true @@ -164,6 +201,9 @@ trait LiveVariables { def lastUsagesOf(field: Tree, at: AsyncState, avoid: Set[AsyncState]): Set[Int] = if (avoid(at)) Set() + else if (captured(field.symbol)) { + Set() + } else LVentry get at.state match { case Some(fields) if fields.exists(_ == field.symbol) => Set(at.state) diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala index 9722610..92c9a4f 100644 --- a/src/main/scala/scala/async/internal/TransformUtils.scala +++ b/src/main/scala/scala/async/internal/TransformUtils.scala @@ -166,7 +166,7 @@ private[async] trait TransformUtils { def nestedModule(module: ModuleDef) { } - def nestedMethod(module: DefDef) { + def nestedMethod(defdef: DefDef) { } def byNameArgument(arg: Tree) { diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index 524e1a2..c8fe2d6 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -66,13 +66,43 @@ object TreeInterrogation extends App { withDebug { val cm = reflect.runtime.currentMirror val tb = mkToolbox("-cp ${toolboxClasspath} -Xprint:typer -uniqid") - import scala.async.Async._ + import scala.async.internal.AsyncTestLV._ val tree = tb.parse( - """ import _root_.scala.async.internal.AsyncId.{async, await} + """ + | import scala.async.internal.AsyncTestLV._ + | import scala.async.internal.AsyncTestLV + | + | case class MCell[T](var v: T) + | val f = async { MCell(1) } + | + | def m1(x: MCell[Int], y: Int): Int = + | async { x.v + y } + | case class Cell[T](v: T) + | | async { - | implicit def view(a: Int): String = "" - | await(0).length + | // state #1 + | val a: MCell[Int] = await(f) // await$13$1 + | // state #2 + | var y = MCell(0) + | + | while (a.v < 10) { + | // state #4 + | a.v = a.v + 1 + | y = MCell(await(a).v + 1) // await$14$1 + | // state #7 + | } + | + | // state #3 + | assert(AsyncTestLV.log.exists(entry => entry._1 == "await$14$1")) + | + | val b = await(m1(a, y.v)) // await$15$1 + | // state #8 + | assert(AsyncTestLV.log.exists(_ == ("a$1" -> MCell(10)))) + | assert(AsyncTestLV.log.exists(_ == ("y$1" -> MCell(11)))) + | b | } + | + | | """.stripMargin) println(tree) val tree1 = tb.typeCheck(tree.duplicate) diff --git a/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala b/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala index be62ed8..7d62f80 100644 --- a/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala +++ b/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala @@ -19,6 +19,7 @@ case class MCell[T](var v: T) class LiveVariablesSpec { + AsyncTestLV.clear() @Test def `zero out fields of reference type`() { @@ -35,7 +36,7 @@ class LiveVariablesSpec { // a == Cell(1) val b: Cell[Int] = await(m1(a)) // await$2$1 // b == Cell(2) - assert(AsyncTestLV.log.exists(_ == ("await$1$1" -> Cell(1)))) + assert(AsyncTestLV.log.exists(_ == ("await$1$1" -> Cell(1))), AsyncTestLV.log) val res = await(m2(b)) // await$3$1 assert(AsyncTestLV.log.exists(_ == ("await$2$1" -> Cell(2)))) res @@ -141,12 +142,125 @@ class LiveVariablesSpec { val b = await(m1(a, y.v)) // await$15$1 // state #8 - assert(AsyncTestLV.log.exists(_ == ("a$1" -> MCell(10)))) + assert(AsyncTestLV.log.exists(_ == ("a$1" -> MCell(10))), AsyncTestLV.log) assert(AsyncTestLV.log.exists(_ == ("y$1" -> MCell(11)))) b } - assert(m3() == 21) + assert(m3() == 21, m3()) } + @Test + def `don't zero captured fields captured lambda`() { + val f = async { + val x = "x" + val y = "y" + await(0) + y.reverse + val f = () => assert(x != null) + await(0) + f + } + AsyncTestLV.assertNotNulledOut("x") + AsyncTestLV.assertNulledOut("y") + f() + } + + @Test + def `don't zero captured fields captured by-name`() { + def func0[A](a: => A): () => A = () => a + val f = async { + val x = "x" + val y = "y" + await(0) + y.reverse + val f = func0(assert(x != null)) + await(0) + f + } + AsyncTestLV.assertNotNulledOut("x") + AsyncTestLV.assertNulledOut("y") + f() + } + + @Test + def `don't zero captured fields nested class`() { + def func0[A](a: => A): () => A = () => a + val f = async { + val x = "x" + val y = "y" + await(0) + y.reverse + val f = new Function0[Unit] { + def apply = assert(x != null) + } + await(0) + f + } + AsyncTestLV.assertNotNulledOut("x") + AsyncTestLV.assertNulledOut("y") + f() + } + + @Test + def `don't zero captured fields nested object`() { + def func0[A](a: => A): () => A = () => a + val f = async { + val x = "x" + val y = "y" + await(0) + y.reverse + object f extends Function0[Unit] { + def apply = assert(x != null) + } + await(0) + f + } + AsyncTestLV.assertNotNulledOut("x") + AsyncTestLV.assertNulledOut("y") + f() + } + + @Test + def `don't zero captured fields nested def`() { + val f = async { + val x = "x" + val y = "y" + await(0) + y.reverse + def xx = x + val f = xx _ + await(0) + f + } + AsyncTestLV.assertNotNulledOut("x") + AsyncTestLV.assertNulledOut("y") + f() + } + + @Test + def `capture bug`() { + sealed trait Base + case class B1() extends Base + case class B2() extends Base + val outer = List[(Base, Int)]((B1(), 8)) + + def getMore(b: Base) = 4 + + def baz = async { + outer.head match { + case (a @ B1(), r) => { + val ents = await(getMore(a)) + + { () => + println(a) + assert(a ne null) + } + } + case (b @ B2(), x) => + () => ??? + } + } + baz() + } } -- cgit v1.2.3