diff options
Diffstat (limited to 'src/main/scala/scala/async/internal/TransformUtils.scala')
-rw-r--r-- | src/main/scala/scala/async/internal/TransformUtils.scala | 22 |
1 files changed, 16 insertions, 6 deletions
diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala index 90419d3..e15ef1b 100644 --- a/src/main/scala/scala/async/internal/TransformUtils.scala +++ b/src/main/scala/scala/async/internal/TransformUtils.scala @@ -38,6 +38,9 @@ private[async] trait TransformUtils { def fresh(name: String): String = c.freshName(name) } + def isAsync(fun: Tree) = + fun.symbol == defn.Async_async + def isAwait(fun: Tree) = fun.symbol == defn.Async_await @@ -164,6 +167,7 @@ private[async] trait TransformUtils { val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal") val ThrowableClass = rootMirror.staticClass("java.lang.Throwable") + val Async_async = asyncBase.asyncMethod(c.universe)(c.macroApplication.symbol).ensuring(_ != NoSymbol) val Async_await = asyncBase.awaitMethod(c.universe)(c.macroApplication.symbol).ensuring(_ != NoSymbol) val IllegalStateExceptionClass = rootMirror.staticClass("java.lang.IllegalStateException") } @@ -281,6 +285,8 @@ private[async] trait TransformUtils { override def traverse(tree: Tree) { tree match { + case _ if isAsync(tree) => + // Under -Ymacro-expand:discard, used in the IDE, nested async blocks will be visible to the outer blocks case cd: ClassDef => nestedClass(cd) case md: ModuleDef => nestedModule(md) case dd: DefDef => nestedMethod(dd) @@ -398,7 +404,7 @@ private[async] trait TransformUtils { final def containsAwaitCached(t: Tree): Tree => Boolean = { def treeCannotContainAwait(t: Tree) = t match { case _: Ident | _: TypeTree | _: Literal => true - case _ => false + case _ => isAsync(t) } def shouldAttach(t: Tree) = !treeCannotContainAwait(t) val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] @@ -417,11 +423,15 @@ private[async] trait TransformUtils { override def traverse(tree: Tree): Unit = { stack ::= tree try { - if (isAwait(tree)) - stack.foreach(attachContainsAwait) - else - attachNoAwait(tree) - super.traverse(tree) + if (isAsync(tree)) { + ; + } else { + if (isAwait(tree)) + stack.foreach(attachContainsAwait) + else + attachNoAwait(tree) + super.traverse(tree) + } } finally stack = stack.tail } } |