diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/main/scala/scala/async/AnfTransform.scala | 35 | ||||
-rw-r--r-- | src/main/scala/scala/async/Async.scala | 7 | ||||
-rw-r--r-- | src/test/scala/scala/async/TreeInterrogation.scala | 2 | ||||
-rw-r--r-- | src/test/scala/scala/async/run/anf/AnfTransformSpec.scala | 33 |
4 files changed, 67 insertions, 10 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 6bf6e9a..a69fd4e 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -3,11 +3,21 @@ package scala.async import scala.reflect.macros.Context +object AnfTranform { + def apply[C <: Context](c: C): AnfTransform[c.type] = new AnfTransform[c.type](c) +} + class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { import c.universe._ - def uniqueNames(tree: Tree): Tree = { + def apply(tree: Tree): List[Tree] = { + val unique = uniqueNames(tree) + // Must prepend the () for issue #31. + anf.transformToList(Block(List(c.literalUnit.tree), unique)) + } + + private def uniqueNames(tree: Tree): Tree = { new UniqueNames(tree).transform(tree) } @@ -65,8 +75,25 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { } } + final class Indent { + private var indent = -1 + def indentString = " " * indent + def apply[T](prefix: String, args: Any)(t: => T): T = { + indent += 1 + try { + //println(s"${indentString}$prefix($args)") + val result = t + //println(s"${indentString}= $result") + result + } finally { + indent -= 1 + } + } + } + private val indent = new Indent + object inline { - def transformToList(tree: Tree): List[Tree] = { + def transformToList(tree: Tree): List[Tree] = indent("inline", tree) { val stats :+ expr = anf.transformToList(tree) expr match { case Apply(fun, args) if isAwait(fun) => @@ -133,9 +160,9 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { vd } } - object anf { - def transformToList(tree: Tree): List[Tree] = { + + private[AnfTransform] def transformToList(tree: Tree): List[Tree] = indent("anf", tree) { def containsAwait = tree exists isAwait tree match { case Select(qual, sel) if containsAwait => diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index 0b77c78..ed79d0b 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -78,12 +78,9 @@ abstract class AsyncBase { // - no await calls in qualifiers or arguments, // - if/match only used in statement position. val anfTree: Block = { - val transform = new AnfTransform[c.type](c) - val unique = transform.uniqueNames(body.tree) - val stats1 :+ expr1 = transform.anf.transformToList(unique) - + val anf = AnfTranform(c) + val stats1 :+ expr1 = anf(body.tree) val block = Block(stats1, expr1) - c.typeCheck(block).asInstanceOf[Block] } diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index 9e68005..16cbcf7 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -11,7 +11,7 @@ class TreeInterrogation { def `a minimal set of vals are lifted to vars`() { val cm = reflect.runtime.currentMirror val tb = mkToolbox("-cp target/scala-2.10/classes") - val tree = mkToolbox().parse( + val tree = tb.parse( """| import _root_.scala.async.AsyncId._ | async { | val x = await(1) diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala index 872e44d..8bdb80f 100644 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -172,4 +172,37 @@ class AnfTransformSpec { } result mustBe (103) } + + @Test + def nestedAwaitAsBareExpression() { + import ExecutionContext.Implicits.global + import _root_.scala.async.AsyncId.{async, await} + val result = async { + await(await("").isEmpty) + } + result mustBe (true) + } + + @Test + def nestedAwaitInBlock() { + import ExecutionContext.Implicits.global + import _root_.scala.async.AsyncId.{async, await} + val result = async { + () + await(await("").isEmpty) + } + result mustBe (true) + } + + @Test + def nestedAwaitInIf() { + import ExecutionContext.Implicits.global + import _root_.scala.async.AsyncId.{async, await} + val result = async { + if ("".isEmpty) + await(await("").isEmpty) + else 0 + } + result mustBe (true) + } } |