aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2012-11-25 13:12:03 +0100
committerJason Zaugg <jzaugg@gmail.com>2012-11-26 16:15:46 +0100
commite2b840b96a16f7d41dc43c3cf6d905e0db568629 (patch)
tree9d32f3a69a02aedbe2d16a9a1226beb60b3fb658
parent26038aebf1555b582dba35e8bfc3698f126705c5 (diff)
downloadscala-async-e2b840b96a16f7d41dc43c3cf6d905e0db568629.tar.gz
scala-async-e2b840b96a16f7d41dc43c3cf6d905e0db568629.tar.bz2
scala-async-e2b840b96a16f7d41dc43c3cf6d905e0db568629.zip
Lift local defs and functions.
Any vals referred to in the body of these must also be lifted. Fixes #36
-rw-r--r--src/main/scala/scala/async/Async.scala8
-rw-r--r--src/main/scala/scala/async/AsyncAnalysis.scala39
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala1
-rw-r--r--src/test/scala/scala/async/TreeInterrogation.scala4
-rw-r--r--src/test/scala/scala/async/run/nesteddef/NestedDef.scala40
5 files changed, 81 insertions, 11 deletions
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala
index ef506a5..f868f79 100644
--- a/src/main/scala/scala/async/Async.scala
+++ b/src/main/scala/scala/async/Async.scala
@@ -11,6 +11,7 @@ import scala.reflect.macros.Context
* @author Philipp Haller
*/
object Async extends AsyncBase {
+
import scala.concurrent.Future
lazy val futureSystem = ScalaConcurrentFutureSystem
@@ -87,9 +88,9 @@ abstract class AsyncBase {
// states of our generated state machine, e.g. a value assigned before
// an `await` and read afterwards.
val renameMap: Map[Symbol, TermName] = {
- anaylzer.valDefsUsedInSubsequentStates(anfTree).map {
+ anaylzer.defTreesUsedInSubsequentStates(anfTree).map {
vd =>
- (vd.symbol, name.fresh(vd.name))
+ (vd.symbol, name.fresh(vd.name.toTermName))
}.toMap
}
@@ -97,9 +98,12 @@ abstract class AsyncBase {
import asyncBlock.asyncStates
logDiagnostics(c)(anfTree, asyncStates.map(_.toString))
+ // Important to retain the original declaration order here!
val localVarTrees = anfTree.collect {
case vd@ValDef(_, _, tpt, _) if renameMap contains vd.symbol =>
utils.mkVarDefTree(tpt.tpe, renameMap(vd.symbol))
+ case dd@DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
+ DefDef(mods, renameMap(dd.symbol), tparams, vparamss, tpt, c.resetAllAttrs(utils.substituteNames(rhs, renameMap)))
}
val onCompleteHandler = asyncBlock.onCompleteHandler
diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala
index ecd5054..6e281e4 100644
--- a/src/main/scala/scala/async/AsyncAnalysis.scala
+++ b/src/main/scala/scala/async/AsyncAnalysis.scala
@@ -31,10 +31,11 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
*
* Must be called on the ANF transformed tree.
*/
- def valDefsUsedInSubsequentStates(tree: Tree): List[ValDef] = {
+ def defTreesUsedInSubsequentStates(tree: Tree): List[DefTree] = {
val analyzer = new AsyncDefinitionUseAnalyzer
analyzer.traverse(tree)
- analyzer.valDefsToLift.toList
+ val liftable: List[DefTree] = (analyzer.valDefsToLift ++ analyzer.nestedMethodsToLift).toList.distinct
+ liftable
}
private class UnsupportedAwaitAnalyzer extends AsyncTraverser {
@@ -68,12 +69,12 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
override def traverse(tree: Tree) {
def containsAwait = tree exists isAwait
tree match {
- case Try(_, _, _) if containsAwait =>
+ case Try(_, _, _) if containsAwait =>
reportUnsupportedAwait(tree, "try/catch")
super.traverse(tree)
- case Return(_) =>
+ case Return(_) =>
c.abort(tree.pos, "return is illegal within a async block")
- case _ =>
+ case _ =>
super.traverse(tree)
}
}
@@ -100,7 +101,33 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
private var valDefChunkId = Map[Symbol, (ValDef, Int)]()
- val valDefsToLift = mutable.Set[ValDef]()
+ val valDefsToLift : mutable.Set[ValDef] = collection.mutable.Set()
+ val nestedMethodsToLift: mutable.Set[DefDef] = collection.mutable.Set()
+
+ override def nestedMethod(defDef: DefDef) {
+ nestedMethodsToLift += defDef
+ defDef.rhs foreach {
+ case rt: RefTree =>
+ valDefChunkId.get(rt.symbol) match {
+ case Some((vd, defChunkId)) =>
+ valDefsToLift += vd // lift all vals referred to by nested methods.
+ case _ =>
+ }
+ case _ =>
+ }
+ }
+
+ override def function(function: Function) {
+ function foreach {
+ case rt: RefTree =>
+ valDefChunkId.get(rt.symbol) match {
+ case Some((vd, defChunkId)) =>
+ valDefsToLift += vd // lift all vals referred to by nested functions.
+ case _ =>
+ }
+ case _ =>
+ }
+ }
override def traverse(tree: Tree) = {
tree match {
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index cc2cde5..f8065f2 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -102,6 +102,7 @@ private[async] final case class ExprBuilder[C <: Context, FS <: FutureSystem](c:
assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat")
def addStat() = stats += renameReset(stat)
stat match {
+ case _: DefDef => // these have been lifted.
case Apply(fun, Nil) =>
labelDefStates get fun.symbol match {
case Some(nextState) => nextJumpState = Some(nextState)
diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala
index e3012c7..ca4a309 100644
--- a/src/test/scala/scala/async/TreeInterrogation.scala
+++ b/src/test/scala/scala/async/TreeInterrogation.scala
@@ -56,9 +56,7 @@ object TreeInterrogation extends App {
val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:all")
val tree = tb.parse(
""" import _root_.scala.async.AsyncId._
- | async {
- | await(0) match { case _ => 0 }
- | }
+ | async { val a = 0; val x = await(a) - 1; def foo(z: Any) = (a.toDouble, x.toDouble, z); foo(await(2)) }
| """.stripMargin)
println(tree)
val tree1 = tb.typeCheck(tree.duplicate)
diff --git a/src/test/scala/scala/async/run/nesteddef/NestedDef.scala b/src/test/scala/scala/async/run/nesteddef/NestedDef.scala
new file mode 100644
index 0000000..2baef0d
--- /dev/null
+++ b/src/test/scala/scala/async/run/nesteddef/NestedDef.scala
@@ -0,0 +1,40 @@
+package scala.async
+package run
+package nesteddef
+
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import org.junit.Test
+
+@RunWith(classOf[JUnit4])
+class NestedDef {
+
+ @Test
+ def nestedDef() {
+ import AsyncId._
+ val result = async {
+ val a = 0
+ val x = await(a) - 1
+ val local = 43
+ def bar(d: Double) = -d + a + local
+ def foo(z: Any) = (a.toDouble, bar(x).toDouble, z)
+ foo(await(2))
+ }
+ result mustBe (0d, 44d, 2)
+ }
+
+
+ @Test
+ def nestedFunction() {
+ import AsyncId._
+ val result = async {
+ val a = 0
+ val x = await(a) - 1
+ val local = 43
+ val bar = (d: Double) => -d + a + local
+ val foo = (z: Any) => (a.toDouble, bar(x).toDouble, z)
+ foo(await(2))
+ }
+ result mustBe (0d, 44d, 2)
+ }
+}