diff options
author | Jason Zaugg <jzaugg@gmail.com> | 2013-11-11 23:17:11 +0100 |
---|---|---|
committer | Jason Zaugg <jzaugg@gmail.com> | 2013-11-12 08:37:32 +0100 |
commit | 302a07c64e4cd3db91a654dcbc893ade0837ba8c (patch) | |
tree | 8eb296957c0d2bbf7e30600aab7a92bb25e26459 /src/main/scala/scala | |
parent | 490238d54d3476d681bfb0b7a04ac090e4e52d9f (diff) | |
download | scala-async-302a07c64e4cd3db91a654dcbc893ade0837ba8c.tar.gz scala-async-302a07c64e4cd3db91a654dcbc893ade0837ba8c.tar.bz2 scala-async-302a07c64e4cd3db91a654dcbc893ade0837ba8c.zip |
Don't aggressively null out captured vars
Once they escape, we leave the references in the state
machines fields untouched.
Diffstat (limited to 'src/main/scala/scala')
-rw-r--r-- | src/main/scala/scala/async/internal/AsyncId.scala | 3 | ||||
-rw-r--r-- | src/main/scala/scala/async/internal/LiveVariables.scala | 60 | ||||
-rw-r--r-- | src/main/scala/scala/async/internal/TransformUtils.scala | 2 |
3 files changed, 54 insertions, 11 deletions
diff --git a/src/main/scala/scala/async/internal/AsyncId.scala b/src/main/scala/scala/async/internal/AsyncId.scala index c123675..a794f93 100644 --- a/src/main/scala/scala/async/internal/AsyncId.scala +++ b/src/main/scala/scala/async/internal/AsyncId.scala @@ -27,6 +27,9 @@ object AsyncTestLV extends AsyncBase { def asyncIdImpl[T: c.WeakTypeTag](c: Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit) var log: List[(String, Any)] = List() + def assertNulledOut(a: Any): Unit = assert(log.exists(_._2 == a), AsyncTestLV.log) + def assertNotNulledOut(a: Any): Unit = assert(!log.exists(_._2 == a), AsyncTestLV.log) + def clear() = log = Nil def apply(name: String, v: Any): Unit = log ::= (name -> v) diff --git a/src/main/scala/scala/async/internal/LiveVariables.scala b/src/main/scala/scala/async/internal/LiveVariables.scala index 4d8c479..8753b3d 100644 --- a/src/main/scala/scala/async/internal/LiveVariables.scala +++ b/src/main/scala/scala/async/internal/LiveVariables.scala @@ -68,19 +68,53 @@ trait LiveVariables { * @param as a state of an `async` expression * @return a set of lifted fields that are used within state `as` */ - def fieldsUsedIn(as: AsyncState): Set[Symbol] = { - class FindUseTraverser extends Traverser { + def fieldsUsedIn(as: AsyncState): ReferencedFields = { + class FindUseTraverser extends AsyncTraverser { var usedFields = Set[Symbol]() - override def traverse(tree: Tree) = tree match { - case Ident(_) if liftedSyms(tree.symbol) => - usedFields += tree.symbol - case _ => - super.traverse(tree) + var capturedFields = Set[Symbol]() + private def capturing[A](body: => A): A = { + val saved = capturing + try { + capturing = true + body + } finally capturing = saved } + private def capturingCheck(tree: Tree) = capturing(tree foreach check) + private var capturing: Boolean = false + private def check(tree: Tree) { + tree match { + case Ident(_) if liftedSyms(tree.symbol) => + if (capturing) + capturedFields += tree.symbol + else + usedFields += tree.symbol + case _ => + } + } + override def traverse(tree: Tree) = { + check(tree) + super.traverse(tree) + } + + override def nestedClass(classDef: ClassDef): Unit = capturingCheck(classDef) + + override def nestedModule(module: ModuleDef): Unit = capturingCheck(module) + + override def nestedMethod(defdef: DefDef): Unit = capturingCheck(defdef) + + override def byNameArgument(arg: Tree): Unit = capturingCheck(arg) + + override def function(function: Function): Unit = capturingCheck(function) + + override def patMatFunction(tree: Match): Unit = capturingCheck(tree) } + val findUses = new FindUseTraverser findUses.traverse(Block(as.stats: _*)) - findUses.usedFields + ReferencedFields(findUses.usedFields, findUses.capturedFields) + } + case class ReferencedFields(used: Set[Symbol], captured: Set[Symbol]) { + override def toString = s"used: ${used.mkString(",")}\ncaptured: ${captured.mkString(",")}" } /* Build the control-flow graph. @@ -104,7 +138,7 @@ trait LiveVariables { val finalState = asyncStates.find(as => !asyncStates.exists(other => isPred(as.state, other.state))).get for (as <- asyncStates) - AsyncUtils.vprintln(s"fields used in state #${as.state}: ${fieldsUsedIn(as).mkString(", ")}") + AsyncUtils.vprintln(s"fields used in state #${as.state}: ${fieldsUsedIn(as)}") /* Backwards data-flow analysis. Computes live variables information at entry and exit * of each async state. @@ -130,13 +164,16 @@ trait LiveVariables { var currStates = List(finalState) // start at final state var pred = List[AsyncState]() // current predecessor states var hasChanged = true // if something has changed we need to continue iterating + var captured: Set[Symbol] = Set() while (hasChanged) { hasChanged = false for (cs <- currStates) { val LVentryOld = LVentry(cs.state) - val LVentryNew = LVexit(cs.state) ++ fieldsUsedIn(cs) + val referenced = fieldsUsedIn(cs) + captured ++= referenced.captured + val LVentryNew = LVexit(cs.state) ++ referenced.used if (!LVentryNew.sameElements(LVentryOld)) { LVentry = LVentry + (cs.state -> LVentryNew) hasChanged = true @@ -164,6 +201,9 @@ trait LiveVariables { def lastUsagesOf(field: Tree, at: AsyncState, avoid: Set[AsyncState]): Set[Int] = if (avoid(at)) Set() + else if (captured(field.symbol)) { + Set() + } else LVentry get at.state match { case Some(fields) if fields.exists(_ == field.symbol) => Set(at.state) diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala index 9722610..92c9a4f 100644 --- a/src/main/scala/scala/async/internal/TransformUtils.scala +++ b/src/main/scala/scala/async/internal/TransformUtils.scala @@ -166,7 +166,7 @@ private[async] trait TransformUtils { def nestedModule(module: ModuleDef) { } - def nestedMethod(module: DefDef) { + def nestedMethod(defdef: DefDef) { } def byNameArgument(arg: Tree) { |