diff options
author | phaller <hallerp@gmail.com> | 2012-11-03 00:06:53 +0100 |
---|---|---|
committer | phaller <hallerp@gmail.com> | 2012-11-03 00:06:53 +0100 |
commit | f2e7ed60de323ec6d274ee8ce08141a26d0ff0ee (patch) | |
tree | babef0dacdfa3aebb411bde5121b31c0a16e93ac | |
parent | f22998d343c01951254fc2020731986fb3219ff0 (diff) | |
download | scala-async-f2e7ed60de323ec6d274ee8ce08141a26d0ff0ee.tar.gz scala-async-f2e7ed60de323ec6d274ee8ce08141a26d0ff0ee.tar.bz2 scala-async-f2e7ed60de323ec6d274ee8ce08141a26d0ff0ee.zip |
Name-mangle lifted local vars
-rw-r--r-- | src/async/library/scala/async/Async.scala | 2 | ||||
-rw-r--r-- | src/async/library/scala/async/AsyncUtils.scala | 4 | ||||
-rw-r--r-- | src/async/library/scala/async/ExprBuilder.scala | 52 | ||||
-rw-r--r-- | test/files/run/if-else2/MinimalScalaTest.scala | 102 | ||||
-rw-r--r-- | test/files/run/if-else2/if-else2.scala | 50 | ||||
-rw-r--r-- | test/files/run/if-else3/MinimalScalaTest.scala | 102 | ||||
-rw-r--r-- | test/files/run/if-else3/if-else3.scala | 52 |
7 files changed, 346 insertions, 18 deletions
diff --git a/src/async/library/scala/async/Async.scala b/src/async/library/scala/async/Async.scala index 1d9b2d9..d7f4101 100644 --- a/src/async/library/scala/async/Async.scala +++ b/src/async/library/scala/async/Async.scala @@ -44,7 +44,7 @@ object Async extends AsyncUtils { try { body.tree match { case Block(stats, expr) => - val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000) + val asyncBlockBuilder = new builder.AsyncBlockBuilder(stats, expr, 0, 1000, 1000, Map()) vprintln(s"states of current method (${asyncBlockBuilder.asyncStates}):") asyncBlockBuilder.asyncStates foreach vprintln diff --git a/src/async/library/scala/async/AsyncUtils.scala b/src/async/library/scala/async/AsyncUtils.scala index adc8c87..19e9d92 100644 --- a/src/async/library/scala/async/AsyncUtils.scala +++ b/src/async/library/scala/async/AsyncUtils.scala @@ -28,5 +28,9 @@ trait AsyncUtils { val tpe = asyncMod.moduleClass.asType.toType tpe.member(c.universe.newTermName("awaitCps")) } + + private var cnt = 0 + protected[async] def freshString(prefix: String): String = + prefix + "$async$" + { cnt += 1; cnt } } diff --git a/src/async/library/scala/async/ExprBuilder.scala b/src/async/library/scala/async/ExprBuilder.scala index 4d068b5..a47cd1a 100644 --- a/src/async/library/scala/async/ExprBuilder.scala +++ b/src/async/library/scala/async/ExprBuilder.scala @@ -287,7 +287,7 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { /* * Builder for a single state of an async method. */ - class AsyncStateBuilder(state: Int) extends Builder[c.Tree, AsyncState] { + class AsyncStateBuilder(state: Int, private var nameMap: Map[c.Symbol, c.Name]) extends Builder[c.Tree, AsyncState] { self => /* Statements preceding an await call. */ @@ -306,14 +306,25 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { private val varDefs = ListBuffer[(c.universe.TermName, c.universe.Type)]() + private val renamer = new Transformer { + override def transform(tree: Tree) = tree match { + case Ident(_) if nameMap.keySet contains tree.symbol => + Ident(nameMap(tree.symbol)) + case _ => + super.transform(tree) + } + } + def += (stat: c.Tree): this.type = { - stats += c.resetAllAttrs(stat.duplicate) + stats += c.resetAllAttrs(renamer.transform(stat).duplicate) this } //TODO do not ignore `mods` - def addVarDef(mods: Any, name: c.universe.TermName, tpt: c.Tree): this.type = { + def addVarDef(mods: Any, name: c.universe.TermName, tpt: c.Tree, rhs: c.Tree, extNameMap: Map[c.Symbol, c.Name]): this.type = { varDefs += (name -> tpt.tpe) + nameMap ++= extNameMap // update name map + this += Assign(Ident(name), c.resetAllAttrs(renamer.transform(rhs).duplicate)) this } @@ -344,8 +355,9 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { * @param awaitResultName the name of the variable that the result of await is assigned to * @param awaitResultType the type of the result of await */ - def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: Tree, nextState: Int = state + 1): this.type = { - awaitable = c.resetAllAttrs(awaitArg.duplicate) + def complete(awaitArg: c.Tree, awaitResultName: c.universe.TermName, awaitResultType: Tree, extNameMap: Map[c.Symbol, c.Name], nextState: Int = state + 1): this.type = { + nameMap ++= extNameMap + awaitable = c.resetAllAttrs(renamer.transform(awaitArg).duplicate) resultName = awaitResultName resultType = awaitResultType.tpe this.nextState = nextState @@ -375,10 +387,10 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { } } - class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, budget: Int) { + class AsyncBlockBuilder(stats: List[c.Tree], expr: c.Tree, startState: Int, endState: Int, budget: Int, private var toRename: Map[c.Symbol, c.Name]) { val asyncStates = ListBuffer[builder.AsyncState]() - private var stateBuilder = new builder.AsyncStateBuilder(startState) // current state builder + private var stateBuilder = new builder.AsyncStateBuilder(startState, toRename) // current state builder private var currState = startState private var remainingBudget = budget @@ -393,20 +405,24 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { for (stat <- stats) stat match { // the val name = await(..) pattern case ValDef(mods, name, tpt, Apply(fun, args)) if fun.symbol == awaitMethod => - asyncStates += stateBuilder.complete(args(0), name, tpt).result // complete with await + val newName = newTermName(Async.freshString(name.toString())) + toRename += (stat.symbol -> newName) + + asyncStates += stateBuilder.complete(args(0), newName, tpt, toRename).result // complete with await if (remainingBudget > 0) remainingBudget -= 1 else assert(false, "too many invocations of `await` in current method") currState += 1 - stateBuilder = new builder.AsyncStateBuilder(currState) + stateBuilder = new builder.AsyncStateBuilder(currState, toRename) case ValDef(mods, name, tpt, rhs) => checkForUnsupportedAwait(rhs) - stateBuilder.addVarDef(mods, name, tpt) - stateBuilder += // instead of adding `stat` we add a simple assignment - Assign(Ident(name), c.resetAllAttrs(rhs.duplicate)) + val newName = newTermName(Async.freshString(name.toString())) + toRename += (stat.symbol -> newName) + // when adding assignment need to take `toRename` into account + stateBuilder.addVarDef(mods, newName, tpt, rhs, toRename) case If(cond, thenp, elsep) => checkForUnsupportedAwait(cond) @@ -425,9 +441,9 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { val thenBuilder = thenp match { case Block(thenStats, thenExpr) => - new AsyncBlockBuilder(thenStats, thenExpr, currState + 1, currState + ifBudget, thenBudget) + new AsyncBlockBuilder(thenStats, thenExpr, currState + 1, currState + ifBudget, thenBudget, toRename) case _ => - new AsyncBlockBuilder(List(thenp), Literal(Constant(())), currState + 1, currState + ifBudget, thenBudget) + new AsyncBlockBuilder(List(thenp), Literal(Constant(())), currState + 1, currState + ifBudget, thenBudget, toRename) } vprintln("ASYNC IF: thenBuilder: "+thenBuilder) vprintln("ASYNC IF: states of thenp:") @@ -436,12 +452,13 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { // insert states of thenBuilder into asyncStates asyncStates ++= thenBuilder.asyncStates + toRename ++= thenBuilder.toRename val elseBuilder = elsep match { case Block(elseStats, elseExpr) => - new AsyncBlockBuilder(elseStats, elseExpr, currState + thenBudget, currState + ifBudget, elseBudget) + new AsyncBlockBuilder(elseStats, elseExpr, currState + thenBudget, currState + ifBudget, elseBudget, toRename) case _ => - new AsyncBlockBuilder(List(elsep), Literal(Constant(())), currState + thenBudget, currState + ifBudget, elseBudget) + new AsyncBlockBuilder(List(elsep), Literal(Constant(())), currState + thenBudget, currState + ifBudget, elseBudget, toRename) } vprintln("ASYNC IF: elseBuilder: "+elseBuilder) vprintln("ASYNC IF: states of elsep:") @@ -450,10 +467,11 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils { // insert states of elseBuilder into asyncStates asyncStates ++= elseBuilder.asyncStates + toRename ++= elseBuilder.toRename // create new state builder for state `currState + ifBudget` currState = currState + ifBudget - stateBuilder = new builder.AsyncStateBuilder(currState) + stateBuilder = new builder.AsyncStateBuilder(currState, toRename) case _ => checkForUnsupportedAwait(stat) diff --git a/test/files/run/if-else2/MinimalScalaTest.scala b/test/files/run/if-else2/MinimalScalaTest.scala new file mode 100644 index 0000000..91de1fc --- /dev/null +++ b/test/files/run/if-else2/MinimalScalaTest.scala @@ -0,0 +1,102 @@ +import language.reflectiveCalls +import language.postfixOps +import language.implicitConversions + +import scala.reflect.{ ClassTag, classTag } + +import scala.collection.mutable +import scala.concurrent.{ Future, Awaitable, CanAwait } +import java.util.concurrent.{ TimeoutException, CountDownLatch, TimeUnit } +import scala.concurrent.duration.Duration + + + +trait Output { + val buffer = new StringBuilder + + def bufferPrintln(a: Any): Unit = buffer.synchronized { + buffer.append(a.toString + "\n") + } +} + + +trait MinimalScalaTest extends Output { + + val throwables = mutable.ArrayBuffer[Throwable]() + + def check() { + if (throwables.nonEmpty) println(buffer.toString) + } + + implicit def stringops(s: String) = new { + + def should[U](snippets: =>U): U = { + bufferPrintln(s + " should:") + snippets + } + + def in[U](snippet: =>U): Unit = { + try { + bufferPrintln("- " + s) + snippet + bufferPrintln("[OK] Test passed.") + } catch { + case e: Throwable => + bufferPrintln("[FAILED] " + e) + bufferPrintln(e.getStackTrace().mkString("\n")) + throwables += e + } + } + + } + + implicit def objectops(obj: Any) = new { + + def mustBe(other: Any) = assert(obj == other, obj + " is not " + other) + def mustEqual(other: Any) = mustBe(other) + + } + + def intercept[T <: Throwable: ClassTag](body: =>Any): T = { + try { + body + throw new Exception("Exception of type %s was not thrown".format(classTag[T])) + } catch { + case t: Throwable => + if (classTag[T].runtimeClass != t.getClass) throw t + else t.asInstanceOf[T] + } + } + + def checkType[T: ClassTag, S](in: Future[T], refclasstag: ClassTag[S]): Boolean = classTag[T] == refclasstag +} + + +object TestLatch { + val DefaultTimeout = Duration(5, TimeUnit.SECONDS) + + def apply(count: Int = 1) = new TestLatch(count) +} + + +class TestLatch(count: Int = 1) extends Awaitable[Unit] { + private var latch = new CountDownLatch(count) + + def countDown() = latch.countDown() + def isOpen: Boolean = latch.getCount == 0 + def open() = while (!isOpen) countDown() + def reset() = latch = new CountDownLatch(count) + + @throws(classOf[TimeoutException]) + def ready(atMost: Duration)(implicit permit: CanAwait) = { + val opened = latch.await(atMost.toNanos, TimeUnit.NANOSECONDS) + if (!opened) throw new TimeoutException("Timeout of %s." format (atMost.toString)) + this + } + + @throws(classOf[Exception]) + def result(atMost: Duration)(implicit permit: CanAwait): Unit = { + ready(atMost) + } + +} diff --git a/test/files/run/if-else2/if-else2.scala b/test/files/run/if-else2/if-else2.scala new file mode 100644 index 0000000..262308c --- /dev/null +++ b/test/files/run/if-else2/if-else2.scala @@ -0,0 +1,50 @@ +/** + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +import language.{ reflectiveCalls, postfixOps } +import scala.concurrent.{ Future, ExecutionContext, future, Await } +import scala.concurrent.duration._ +import scala.async.Async.{ async, await } + +object Test extends App { + + IfElse2Spec.check() + +} + +class TestIfElse2Class { + import ExecutionContext.Implicits.global + + def base(x: Int): Future[Int] = future { + Thread.sleep(1000) + x + 2 + } + + def m(y: Int): Future[Int] = async { + val f = base(y) + var z = 0 + if (y > 0) { + val x = await(f) + z = x + 2 + } else { + val x = await(f) + z = x - 2 + } + z + } +} + + +object IfElse2Spec extends MinimalScalaTest { + + "An async method" should { + "allow variables of the same name in different blocks" in { + val o = new TestIfElse2Class + val fut = o.m(10) + val res = Await.result(fut, 2 seconds) + res mustBe(14) + } + } + +} diff --git a/test/files/run/if-else3/MinimalScalaTest.scala b/test/files/run/if-else3/MinimalScalaTest.scala new file mode 100644 index 0000000..91de1fc --- /dev/null +++ b/test/files/run/if-else3/MinimalScalaTest.scala @@ -0,0 +1,102 @@ +import language.reflectiveCalls +import language.postfixOps +import language.implicitConversions + +import scala.reflect.{ ClassTag, classTag } + +import scala.collection.mutable +import scala.concurrent.{ Future, Awaitable, CanAwait } +import java.util.concurrent.{ TimeoutException, CountDownLatch, TimeUnit } +import scala.concurrent.duration.Duration + + + +trait Output { + val buffer = new StringBuilder + + def bufferPrintln(a: Any): Unit = buffer.synchronized { + buffer.append(a.toString + "\n") + } +} + + +trait MinimalScalaTest extends Output { + + val throwables = mutable.ArrayBuffer[Throwable]() + + def check() { + if (throwables.nonEmpty) println(buffer.toString) + } + + implicit def stringops(s: String) = new { + + def should[U](snippets: =>U): U = { + bufferPrintln(s + " should:") + snippets + } + + def in[U](snippet: =>U): Unit = { + try { + bufferPrintln("- " + s) + snippet + bufferPrintln("[OK] Test passed.") + } catch { + case e: Throwable => + bufferPrintln("[FAILED] " + e) + bufferPrintln(e.getStackTrace().mkString("\n")) + throwables += e + } + } + + } + + implicit def objectops(obj: Any) = new { + + def mustBe(other: Any) = assert(obj == other, obj + " is not " + other) + def mustEqual(other: Any) = mustBe(other) + + } + + def intercept[T <: Throwable: ClassTag](body: =>Any): T = { + try { + body + throw new Exception("Exception of type %s was not thrown".format(classTag[T])) + } catch { + case t: Throwable => + if (classTag[T].runtimeClass != t.getClass) throw t + else t.asInstanceOf[T] + } + } + + def checkType[T: ClassTag, S](in: Future[T], refclasstag: ClassTag[S]): Boolean = classTag[T] == refclasstag +} + + +object TestLatch { + val DefaultTimeout = Duration(5, TimeUnit.SECONDS) + + def apply(count: Int = 1) = new TestLatch(count) +} + + +class TestLatch(count: Int = 1) extends Awaitable[Unit] { + private var latch = new CountDownLatch(count) + + def countDown() = latch.countDown() + def isOpen: Boolean = latch.getCount == 0 + def open() = while (!isOpen) countDown() + def reset() = latch = new CountDownLatch(count) + + @throws(classOf[TimeoutException]) + def ready(atMost: Duration)(implicit permit: CanAwait) = { + val opened = latch.await(atMost.toNanos, TimeUnit.NANOSECONDS) + if (!opened) throw new TimeoutException("Timeout of %s." format (atMost.toString)) + this + } + + @throws(classOf[Exception]) + def result(atMost: Duration)(implicit permit: CanAwait): Unit = { + ready(atMost) + } + +} diff --git a/test/files/run/if-else3/if-else3.scala b/test/files/run/if-else3/if-else3.scala new file mode 100644 index 0000000..ad95cea --- /dev/null +++ b/test/files/run/if-else3/if-else3.scala @@ -0,0 +1,52 @@ +/** + * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com> + */ + +import language.{ reflectiveCalls, postfixOps } +import scala.concurrent.{ Future, ExecutionContext, future, Await } +import scala.concurrent.duration._ +import scala.async.Async.{ async, await } + +object Test extends App { + + IfElse3Spec.check() + +} + +class TestIfElse3Class { + import ExecutionContext.Implicits.global + + def base(x: Int): Future[Int] = future { + Thread.sleep(1000) + x + 2 + } + + def m(y: Int): Future[Int] = async { + val f = base(y) + var z = 0 + if (y > 0) { + val x1 = await(f) + var w = x1 + 2 + z = w + 2 + } else { + val x2 = await(f) + var w = x2 + 2 + z = w - 2 + } + z + } +} + + +object IfElse3Spec extends MinimalScalaTest { + + "An async method" should { + "allow variables of the same name in different blocks" in { + val o = new TestIfElse3Class + val fut = o.m(10) + val res = Await.result(fut, 2 seconds) + res mustBe(16) + } + } + +} |