aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorphaller <hallerp@gmail.com>2012-11-03 00:06:53 +0100
committerphaller <hallerp@gmail.com>2012-11-03 00:06:53 +0100
commitf2e7ed60de323ec6d274ee8ce08141a26d0ff0ee (patch)
treebabef0dacdfa3aebb411bde5121b31c0a16e93ac
parentf22998d343c01951254fc2020731986fb3219ff0 (diff)
downloadscala-async-f2e7ed60de323ec6d274ee8ce08141a26d0ff0ee.tar.gz
scala-async-f2e7ed60de323ec6d274ee8ce08141a26d0ff0ee.tar.bz2
scala-async-f2e7ed60de323ec6d274ee8ce08141a26d0ff0ee.zip
Name-mangle lifted local vars
-rw-r--r--src/async/library/scala/async/Async.scala2
-rw-r--r--src/async/library/scala/async/AsyncUtils.scala4
-rw-r--r--src/async/library/scala/async/ExprBuilder.scala52
-rw-r--r--test/files/run/if-else2/MinimalScalaTest.scala102
-rw-r--r--test/files/run/if-else2/if-else2.scala50
-rw-r--r--test/files/run/if-else3/MinimalScalaTest.scala102
-rw-r--r--test/files/run/if-else3/if-else3.scala52
7 files changed, 346 insertions, 18 deletions
diff --git a/src/async/library/scala/async/Async.scala b/src/async/library/scala/async/Async.scala
index 1d9b2d9..d7f4101 100644
--- a/src/async/library/scala/async/Async.scala
+++ b/src/async/library/scala/async/Async.scala
@@ -44,7 +44,7 @@ object Async extends AsyncUtils {
try {
body.tree match {
case Block(stats, expr) =>
- val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000)
+ val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000, Map())
vprintln(s"states of current method (${asyncBlockBuilder.asyncStates}):")
asyncBlockBuilder.asyncStates foreach vprintln
diff --git a/src/async/library/scala/async/AsyncUtils.scala b/src/async/library/scala/async/AsyncUtils.scala
index adc8c87..19e9d92 100644
--- a/src/async/library/scala/async/AsyncUtils.scala
+++ b/src/async/library/scala/async/AsyncUtils.scala
@@ -28,5 +28,9 @@ trait AsyncUtils {
val tpe = asyncMod.moduleClass.asType.toType
tpe.member(c.universe.newTermName("awaitCps"))
}
+
+ private var cnt = 0
+ protected[async] def freshString(prefix: String): String =
+ prefix + "$async$" + { cnt += 1; cnt }
}
diff --git a/src/async/library/scala/async/ExprBuilder.scala b/src/async/library/scala/async/ExprBuilder.scala
index 4d068b5..a47cd1a 100644
--- a/src/async/library/scala/async/ExprBuilder.scala
+++ b/src/async/library/scala/async/ExprBuilder.scala
@@ -287,7 +287,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
/*
* Builder for a single state of an async method.
*/
- class AsyncStateBuilder(state: Int) extends Builder[c.Tree, AsyncState] {
+ class AsyncStateBuilder(state: Int, private var nameMap: Map[c.Symbol, c.Name]) extends Builder[c.Tree, AsyncState] {
self =>
/* Statements preceding an await call. */
@@ -306,14 +306,25 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
private val varDefs = ListBuffer[(c.universe.TermName, c.universe.Type)]()
+ private val renamer = new Transformer {
+ override def transform(tree: Tree) = tree match {
+ case Ident(_) if nameMap.keySet contains tree.symbol =>
+ Ident(nameMap(tree.symbol))
+ case _ =>
+ super.transform(tree)
+ }
+ }
+
def += (stat: c.Tree): this.type = {
- stats += c.resetAllAttrs(stat.duplicate)
+ stats += c.resetAllAttrs(renamer.transform(stat).duplicate)
this
}
//TODO do not ignore `mods`
- def addVarDef(mods: Any, name: c.universe.TermName, tpt: c.Tree): this.type = {
+ def addVarDef(mods: Any, name: c.universe.TermName, tpt: c.Tree, rhs: c.Tree, extNameMap: Map[c.Symbol, c.Name]): this.type = {
varDefs += (name -> tpt.tpe)
+ nameMap ++= extNameMap // update name map
+ this += Assign(Ident(name), c.resetAllAttrs(renamer.transform(rhs).duplicate))
this
}
@@ -344,8 +355,9 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
* @param awaitResultName the name of the variable that the result of await is assigned to
* @param awaitResultType the type of the result of await
*/
- def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: Tree, nextState: Int = state + 1): this.type = {
- awaitable = c.resetAllAttrs(awaitArg.duplicate)
+ def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: Tree, extNameMap: Map[c.Symbol, c.Name], nextState: Int = state + 1): this.type = {
+ nameMap ++= extNameMap
+ awaitable = c.resetAllAttrs(renamer.transform(awaitArg).duplicate)
resultName = awaitResultName
resultType = awaitResultType.tpe
this.nextState = nextState
@@ -375,10 +387,10 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
}
}
- class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, budget: Int) {
+ class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, budget: Int, private var toRename: Map[c.Symbol, c.Name]) {
val asyncStates = ListBuffer[builder.AsyncState]()
- private var stateBuilder = new builder.AsyncStateBuilder(startState) // current state builder
+ private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename) // current state builder
private var currState = startState
private var remainingBudget = budget
@@ -393,20 +405,24 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
for (stat <- stats) stat match {
// the val name = await(..) pattern
case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == awaitMethod =>
- asyncStates += stateBuilder.complete(args(0), name, tpt).result // complete with await
+ val newName = newTermName(Async.freshString(name.toString()))
+ toRename += (stat.symbol -> newName)
+
+ asyncStates += stateBuilder.complete(args(0), newName, tpt, toRename).result // complete with await
if (remainingBudget > 0)
remainingBudget -= 1
else
assert(false, "too many invocations of `await` in current method")
currState += 1
- stateBuilder = new builder.AsyncStateBuilder(currState)
+ stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
case ValDef(mods, name, tpt, rhs) =>
checkForUnsupportedAwait(rhs)
- stateBuilder.addVarDef(mods, name, tpt)
- stateBuilder += // instead of adding `stat` we add a simple assignment
- Assign(Ident(name), c.resetAllAttrs(rhs.duplicate))
+ val newName = newTermName(Async.freshString(name.toString()))
+ toRename += (stat.symbol -> newName)
+ // when adding assignment need to take `toRename` into account
+ stateBuilder.addVarDef(mods, newName, tpt, rhs, toRename)
case If(cond, thenp, elsep) =>
checkForUnsupportedAwait(cond)
@@ -425,9 +441,9 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
val thenBuilder = thenp match {
case Block(thenStats, thenExpr) =>
- new AsyncBlockBuilder(thenStats, thenExpr, currState + 1, currState + ifBudget, thenBudget)
+ new AsyncBlockBuilder(thenStats, thenExpr, currState + 1, currState + ifBudget, thenBudget, toRename)
case _ =>
- new AsyncBlockBuilder(List(thenp), Literal(Constant(())), currState + 1, currState + ifBudget, thenBudget)
+ new AsyncBlockBuilder(List(thenp), Literal(Constant(())), currState + 1, currState + ifBudget, thenBudget, toRename)
}
vprintln("ASYNC IF: thenBuilder: "+thenBuilder)
vprintln("ASYNC IF: states of thenp:")
@@ -436,12 +452,13 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
// insert states of thenBuilder into asyncStates
asyncStates ++= thenBuilder.asyncStates
+ toRename ++= thenBuilder.toRename
val elseBuilder = elsep match {
case Block(elseStats, elseExpr) =>
- new AsyncBlockBuilder(elseStats, elseExpr, currState + thenBudget, currState + ifBudget, elseBudget)
+ new AsyncBlockBuilder(elseStats, elseExpr, currState + thenBudget, currState + ifBudget, elseBudget, toRename)
case _ =>
- new AsyncBlockBuilder(List(elsep), Literal(Constant(())), currState + thenBudget, currState + ifBudget, elseBudget)
+ new AsyncBlockBuilder(List(elsep), Literal(Constant(())), currState + thenBudget, currState + ifBudget, elseBudget, toRename)
}
vprintln("ASYNC IF: elseBuilder: "+elseBuilder)
vprintln("ASYNC IF: states of elsep:")
@@ -450,10 +467,11 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
// insert states of elseBuilder into asyncStates
asyncStates ++= elseBuilder.asyncStates
+ toRename ++= elseBuilder.toRename
// create new state builder for state `currState + ifBudget`
currState = currState + ifBudget
- stateBuilder = new builder.AsyncStateBuilder(currState)
+ stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
case _ =>
checkForUnsupportedAwait(stat)
diff --git a/test/files/run/if-else2/MinimalScalaTest.scala b/test/files/run/if-else2/MinimalScalaTest.scala
new file mode 100644
index 0000000..91de1fc
--- /dev/null
+++ b/test/files/run/if-else2/MinimalScalaTest.scala
@@ -0,0 +1,102 @@
+import language.reflectiveCalls
+import language.postfixOps
+import language.implicitConversions
+
+import scala.reflect.{ ClassTag, classTag }
+
+import scala.collection.mutable
+import scala.concurrent.{ Future, Awaitable, CanAwait }
+import java.util.concurrent.{ TimeoutException, CountDownLatch, TimeUnit }
+import scala.concurrent.duration.Duration
+
+
+
+trait Output {
+ val buffer = new StringBuilder
+
+ def bufferPrintln(a: Any): Unit = buffer.synchronized {
+ buffer.append(a.toString + "\n")
+ }
+}
+
+
+trait MinimalScalaTest extends Output {
+
+ val throwables = mutable.ArrayBuffer[Throwable]()
+
+ def check() {
+ if (throwables.nonEmpty) println(buffer.toString)
+ }
+
+ implicit def stringops(s: String) = new {
+
+ def should[U](snippets: =>U): U = {
+ bufferPrintln(s + " should:")
+ snippets
+ }
+
+ def in[U](snippet: =>U): Unit = {
+ try {
+ bufferPrintln("- " + s)
+ snippet
+ bufferPrintln("[OK] Test passed.")
+ } catch {
+ case e: Throwable =>
+ bufferPrintln("[FAILED] " + e)
+ bufferPrintln(e.getStackTrace().mkString("\n"))
+ throwables += e
+ }
+ }
+
+ }
+
+ implicit def objectops(obj: Any) = new {
+
+ def mustBe(other: Any) = assert(obj == other, obj + " is not " + other)
+ def mustEqual(other: Any) = mustBe(other)
+
+ }
+
+ def intercept[T <: Throwable: ClassTag](body: =>Any): T = {
+ try {
+ body
+ throw new Exception("Exception of type %s was not thrown".format(classTag[T]))
+ } catch {
+ case t: Throwable =>
+ if (classTag[T].runtimeClass != t.getClass) throw t
+ else t.asInstanceOf[T]
+ }
+ }
+
+ def checkType[T: ClassTag, S](in: Future[T], refclasstag: ClassTag[S]): Boolean = classTag[T] == refclasstag
+}
+
+
+object TestLatch {
+ val DefaultTimeout = Duration(5, TimeUnit.SECONDS)
+
+ def apply(count: Int = 1) = new TestLatch(count)
+}
+
+
+class TestLatch(count: Int = 1) extends Awaitable[Unit] {
+ private var latch = new CountDownLatch(count)
+
+ def countDown() = latch.countDown()
+ def isOpen: Boolean = latch.getCount == 0
+ def open() = while (!isOpen) countDown()
+ def reset() = latch = new CountDownLatch(count)
+
+ @throws(classOf[TimeoutException])
+ def ready(atMost: Duration)(implicit permit: CanAwait) = {
+ val opened = latch.await(atMost.toNanos, TimeUnit.NANOSECONDS)
+ if (!opened) throw new TimeoutException("Timeout of %s." format (atMost.toString))
+ this
+ }
+
+ @throws(classOf[Exception])
+ def result(atMost: Duration)(implicit permit: CanAwait): Unit = {
+ ready(atMost)
+ }
+
+}
diff --git a/test/files/run/if-else2/if-else2.scala b/test/files/run/if-else2/if-else2.scala
new file mode 100644
index 0000000..262308c
--- /dev/null
+++ b/test/files/run/if-else2/if-else2.scala
@@ -0,0 +1,50 @@
+/**
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+import language.{ reflectiveCalls, postfixOps }
+import scala.concurrent.{ Future, ExecutionContext, future, Await }
+import scala.concurrent.duration._
+import scala.async.Async.{ async, await }
+
+object Test extends App {
+
+ IfElse2Spec.check()
+
+}
+
+class TestIfElse2Class {
+ import ExecutionContext.Implicits.global
+
+ def base(x: Int): Future[Int] = future {
+ Thread.sleep(1000)
+ x + 2
+ }
+
+ def m(y: Int): Future[Int] = async {
+ val f = base(y)
+ var z = 0
+ if (y > 0) {
+ val x = await(f)
+ z = x + 2
+ } else {
+ val x = await(f)
+ z = x - 2
+ }
+ z
+ }
+}
+
+
+object IfElse2Spec extends MinimalScalaTest {
+
+ "An async method" should {
+ "allow variables of the same name in different blocks" in {
+ val o = new TestIfElse2Class
+ val fut = o.m(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe(14)
+ }
+ }
+
+}
diff --git a/test/files/run/if-else3/MinimalScalaTest.scala b/test/files/run/if-else3/MinimalScalaTest.scala
new file mode 100644
index 0000000..91de1fc
--- /dev/null
+++ b/test/files/run/if-else3/MinimalScalaTest.scala
@@ -0,0 +1,102 @@
+import language.reflectiveCalls
+import language.postfixOps
+import language.implicitConversions
+
+import scala.reflect.{ ClassTag, classTag }
+
+import scala.collection.mutable
+import scala.concurrent.{ Future, Awaitable, CanAwait }
+import java.util.concurrent.{ TimeoutException, CountDownLatch, TimeUnit }
+import scala.concurrent.duration.Duration
+
+
+
+trait Output {
+ val buffer = new StringBuilder
+
+ def bufferPrintln(a: Any): Unit = buffer.synchronized {
+ buffer.append(a.toString + "\n")
+ }
+}
+
+
+trait MinimalScalaTest extends Output {
+
+ val throwables = mutable.ArrayBuffer[Throwable]()
+
+ def check() {
+ if (throwables.nonEmpty) println(buffer.toString)
+ }
+
+ implicit def stringops(s: String) = new {
+
+ def should[U](snippets: =>U): U = {
+ bufferPrintln(s + " should:")
+ snippets
+ }
+
+ def in[U](snippet: =>U): Unit = {
+ try {
+ bufferPrintln("- " + s)
+ snippet
+ bufferPrintln("[OK] Test passed.")
+ } catch {
+ case e: Throwable =>
+ bufferPrintln("[FAILED] " + e)
+ bufferPrintln(e.getStackTrace().mkString("\n"))
+ throwables += e
+ }
+ }
+
+ }
+
+ implicit def objectops(obj: Any) = new {
+
+ def mustBe(other: Any) = assert(obj == other, obj + " is not " + other)
+ def mustEqual(other: Any) = mustBe(other)
+
+ }
+
+ def intercept[T <: Throwable: ClassTag](body: =>Any): T = {
+ try {
+ body
+ throw new Exception("Exception of type %s was not thrown".format(classTag[T]))
+ } catch {
+ case t: Throwable =>
+ if (classTag[T].runtimeClass != t.getClass) throw t
+ else t.asInstanceOf[T]
+ }
+ }
+
+ def checkType[T: ClassTag, S](in: Future[T], refclasstag: ClassTag[S]): Boolean = classTag[T] == refclasstag
+}
+
+
+object TestLatch {
+ val DefaultTimeout = Duration(5, TimeUnit.SECONDS)
+
+ def apply(count: Int = 1) = new TestLatch(count)
+}
+
+
+class TestLatch(count: Int = 1) extends Awaitable[Unit] {
+ private var latch = new CountDownLatch(count)
+
+ def countDown() = latch.countDown()
+ def isOpen: Boolean = latch.getCount == 0
+ def open() = while (!isOpen) countDown()
+ def reset() = latch = new CountDownLatch(count)
+
+ @throws(classOf[TimeoutException])
+ def ready(atMost: Duration)(implicit permit: CanAwait) = {
+ val opened = latch.await(atMost.toNanos, TimeUnit.NANOSECONDS)
+ if (!opened) throw new TimeoutException("Timeout of %s." format (atMost.toString))
+ this
+ }
+
+ @throws(classOf[Exception])
+ def result(atMost: Duration)(implicit permit: CanAwait): Unit = {
+ ready(atMost)
+ }
+
+}
diff --git a/test/files/run/if-else3/if-else3.scala b/test/files/run/if-else3/if-else3.scala
new file mode 100644
index 0000000..ad95cea
--- /dev/null
+++ b/test/files/run/if-else3/if-else3.scala
@@ -0,0 +1,52 @@
+/**
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+import language.{ reflectiveCalls, postfixOps }
+import scala.concurrent.{ Future, ExecutionContext, future, Await }
+import scala.concurrent.duration._
+import scala.async.Async.{ async, await }
+
+object Test extends App {
+
+ IfElse3Spec.check()
+
+}
+
+class TestIfElse3Class {
+ import ExecutionContext.Implicits.global
+
+ def base(x: Int): Future[Int] = future {
+ Thread.sleep(1000)
+ x + 2
+ }
+
+ def m(y: Int): Future[Int] = async {
+ val f = base(y)
+ var z = 0
+ if (y > 0) {
+ val x1 = await(f)
+ var w = x1 + 2
+ z = w + 2
+ } else {
+ val x2 = await(f)
+ var w = x2 + 2
+ z = w - 2
+ }
+ z
+ }
+}
+
+
+object IfElse3Spec extends MinimalScalaTest {
+
+ "An async method" should {
+ "allow variables of the same name in different blocks" in {
+ val o = new TestIfElse3Class
+ val fut = o.m(10)
+ val res = Await.result(fut, 2 seconds)
+ res mustBe(16)
+ }
+ }
+
+}