From e2b840b96a16f7d41dc43c3cf6d905e0db568629 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Sun, 25 Nov 2012 13:12:03 +0100 Subject: Lift local defs and functions. Any vals referred to in the body of these must also be lifted. Fixes #36 --- src/main/scala/scala/async/Async.scala | 8 +++-- src/main/scala/scala/async/AsyncAnalysis.scala | 39 +++++++++++++++++---- src/main/scala/scala/async/ExprBuilder.scala | 1 + src/test/scala/scala/async/TreeInterrogation.scala | 4 +-- .../scala/async/run/nesteddef/NestedDef.scala | 40 ++++++++++++++++++++++ 5 files changed, 81 insertions(+), 11 deletions(-) create mode 100644 src/test/scala/scala/async/run/nesteddef/NestedDef.scala diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index ef506a5..f868f79 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -11,6 +11,7 @@ import scala.reflect.macros.Context * @author Philipp Haller */ object Async extends AsyncBase { + import scala.concurrent.Future lazy val futureSystem = ScalaConcurrentFutureSystem @@ -87,9 +88,9 @@ abstract class AsyncBase { // states of our generated state machine, e.g. a value assigned before // an `await` and read afterwards. val renameMap: Map[Symbol, TermName] = { - anaylzer.valDefsUsedInSubsequentStates(anfTree).map { + anaylzer.defTreesUsedInSubsequentStates(anfTree).map { vd => - (vd.symbol, name.fresh(vd.name)) + (vd.symbol, name.fresh(vd.name.toTermName)) }.toMap } @@ -97,9 +98,12 @@ abstract class AsyncBase { import asyncBlock.asyncStates logDiagnostics(c)(anfTree, asyncStates.map(_.toString)) + // Important to retain the original declaration order here! val localVarTrees = anfTree.collect { case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol => utils.mkVarDefTree(tpt.tpe, renameMap(vd.symbol)) + case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) => + DefDef(mods, renameMap(dd.symbol), tparams, vparamss, tpt, c.resetAllAttrs(utils.substituteNames(rhs, renameMap))) } val onCompleteHandler = asyncBlock.onCompleteHandler diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index ecd5054..6e281e4 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -31,10 +31,11 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { * * Must be called on the ANF transformed tree. */ - def valDefsUsedInSubsequentStates(tree: Tree): List[ValDef] = { + def defTreesUsedInSubsequentStates(tree: Tree): List[DefTree] = { val analyzer = new AsyncDefinitionUseAnalyzer analyzer.traverse(tree) - analyzer.valDefsToLift.toList + val liftable: List[DefTree] = (analyzer.valDefsToLift ++ analyzer.nestedMethodsToLift).toList.distinct + liftable } private class UnsupportedAwaitAnalyzer extends AsyncTraverser { @@ -68,12 +69,12 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { override def traverse(tree: Tree) { def containsAwait = tree exists isAwait tree match { - case Try(_, _, _) if containsAwait => + case Try(_, _, _) if containsAwait => reportUnsupportedAwait(tree, "try/catch") super.traverse(tree) - case Return(_) => + case Return(_) => c.abort(tree.pos, "return is illegal within a async block") - case _ => + case _ => super.traverse(tree) } } @@ -100,7 +101,33 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { private var valDefChunkId = Map[Symbol, (ValDef, Int)]() - val valDefsToLift = mutable.Set[ValDef]() + val valDefsToLift : mutable.Set[ValDef] = collection.mutable.Set() + val nestedMethodsToLift: mutable.Set[DefDef] = collection.mutable.Set() + + override def nestedMethod(defDef: DefDef) { + nestedMethodsToLift += defDef + defDef.rhs foreach { + case rt: RefTree => + valDefChunkId.get(rt.symbol) match { + case Some((vd, defChunkId)) => + valDefsToLift += vd // lift all vals referred to by nested methods. + case _ => + } + case _ => + } + } + + override def function(function: Function) { + function foreach { + case rt: RefTree => + valDefChunkId.get(rt.symbol) match { + case Some((vd, defChunkId)) => + valDefsToLift += vd // lift all vals referred to by nested functions. + case _ => + } + case _ => + } + } override def traverse(tree: Tree) = { tree match { diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index cc2cde5..f8065f2 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -102,6 +102,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c: assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat") def addStat() = stats += renameReset(stat) stat match { + case _: DefDef => // these have been lifted. case Apply(fun, Nil) => labelDefStates get fun.symbol match { case Some(nextState) => nextJumpState = Some(nextState) diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index e3012c7..ca4a309 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -56,9 +56,7 @@ object TreeInterrogation extends App { val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:all") val tree = tb.parse( """ import _root_.scala.async.AsyncId._ - | async { - | await(0) match { case _ => 0 } - | } + | async { val a = 0; val x = await(a) - 1; def foo(z: Any) = (a.toDouble, x.toDouble, z); foo(await(2)) } | """.stripMargin) println(tree) val tree1 = tb.typeCheck(tree.duplicate) diff --git a/src/test/scala/scala/async/run/nesteddef/NestedDef.scala b/src/test/scala/scala/async/run/nesteddef/NestedDef.scala new file mode 100644 index 0000000..2baef0d --- /dev/null +++ b/src/test/scala/scala/async/run/nesteddef/NestedDef.scala @@ -0,0 +1,40 @@ +package scala.async +package run +package nesteddef + +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.junit.Test + +@RunWith(classOf[JUnit4]) +class NestedDef { + + @Test + def nestedDef() { + import AsyncId._ + val result = async { + val a = 0 + val x = await(a) - 1 + val local = 43 + def bar(d: Double) = -d + a + local + def foo(z: Any) = (a.toDouble, bar(x).toDouble, z) + foo(await(2)) + } + result mustBe (0d, 44d, 2) + } + + + @Test + def nestedFunction() { + import AsyncId._ + val result = async { + val a = 0 + val x = await(a) - 1 + val local = 43 + val bar = (d: Double) => -d + a + local + val foo = (z: Any) => (a.toDouble, bar(x).toDouble, z) + foo(await(2)) + } + result mustBe (0d, 44d, 2) + } +} -- cgit v1.2.3