aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/scala/async/internal/AsyncAnalysis.scala20
-rw-r--r--src/main/scala/scala/async/internal/AsyncBase.scala5
-rw-r--r--src/main/scala/scala/async/internal/AsyncMacro.scala4
-rw-r--r--src/main/scala/scala/async/internal/AsyncTransform.scala3
-rw-r--r--src/main/scala/scala/async/internal/ExprBuilder.scala6
-rw-r--r--src/main/scala/scala/async/internal/TransformUtils.scala22
-rw-r--r--src/test/scala/scala/async/run/WarningsSpec.scala20
7 files changed, 63 insertions, 17 deletions
diff --git a/src/main/scala/scala/async/internal/AsyncAnalysis.scala b/src/main/scala/scala/async/internal/AsyncAnalysis.scala
index 6540bdb..6b75493 100644
--- a/src/main/scala/scala/async/internal/AsyncAnalysis.scala
+++ b/src/main/scala/scala/async/internal/AsyncAnalysis.scala
@@ -4,6 +4,7 @@
package scala.async.internal
+import scala.collection.mutable.ListBuffer
import scala.reflect.macros.Context
import scala.collection.mutable
@@ -53,14 +54,13 @@ trait AsyncAnalysis {
}
override def traverse(tree: Tree) {
- def containsAwait = tree exists isAwait
tree match {
- case Try(_, _, _) if containsAwait =>
+ case Try(_, _, _) if containsAwait(tree) =>
reportUnsupportedAwait(tree, "try/catch")
super.traverse(tree)
case Return(_) =>
c.abort(tree.pos, "return is illegal within a async block")
- case DefDef(mods, _, _, _, _, _) if mods.hasFlag(Flag.LAZY) && containsAwait =>
+ case DefDef(mods, _, _, _, _, _) if mods.hasFlag(Flag.LAZY) && containsAwait(tree) =>
reportUnsupportedAwait(tree, "lazy val initializer")
case CaseDef(_, guard, _) if guard exists isAwait =>
// TODO lift this restriction
@@ -74,9 +74,19 @@ trait AsyncAnalysis {
* @return true, if the tree contained an unsupported await.
*/
private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = {
- val badAwaits: List[RefTree] = tree collect {
- case rt: RefTree if isAwait(rt) => rt
+ val badAwaits = ListBuffer[Tree]()
+ object traverser extends Traverser {
+ override def traverse(tree: Tree): Unit = {
+ if (!isAsync(tree))
+ super.traverse(tree)
+ tree match {
+ case rt: RefTree if isAwait(rt) =>
+ badAwaits += rt
+ case _ =>
+ }
+ }
}
+ traverser(tree)
badAwaits foreach {
tree =>
reportError(tree.pos, s"await must not be used under a $whyUnsupported.")
diff --git a/src/main/scala/scala/async/internal/AsyncBase.scala b/src/main/scala/scala/async/internal/AsyncBase.scala
index 7a1e274..4853d2b 100644
--- a/src/main/scala/scala/async/internal/AsyncBase.scala
+++ b/src/main/scala/scala/async/internal/AsyncBase.scala
@@ -53,6 +53,11 @@ abstract class AsyncBase {
c.Expr[futureSystem.Fut[T]](code)
}
+ protected[async] def asyncMethod(u: Universe)(asyncMacroSymbol: u.Symbol): u.Symbol = {
+ import u._
+ asyncMacroSymbol.owner.typeSignature.member(newTermName("async"))
+ }
+
protected[async] def awaitMethod(u: Universe)(asyncMacroSymbol: u.Symbol): u.Symbol = {
import u._
asyncMacroSymbol.owner.typeSignature.member(newTermName("await"))
diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala
index e22407d..57c73fc 100644
--- a/src/main/scala/scala/async/internal/AsyncMacro.scala
+++ b/src/main/scala/scala/async/internal/AsyncMacro.scala
@@ -11,7 +11,7 @@ object AsyncMacro {
// These members are required by `ExprBuilder`:
val futureSystem: FutureSystem = base.futureSystem
val futureSystemOps: futureSystem.Ops {val c: self.c.type} = futureSystem.mkOps(c)
- val containsAwait: c.Tree => Boolean = containsAwaitCached(body0)
+ var containsAwait: c.Tree => Boolean = containsAwaitCached(body0)
}
}
}
@@ -22,7 +22,7 @@ private[async] trait AsyncMacro
val c: scala.reflect.macros.Context
val body: c.Tree
- val containsAwait: c.Tree => Boolean
+ var containsAwait: c.Tree => Boolean
lazy val macroPos = c.macroApplication.pos.makeTransparent
def atMacroPos(t: c.Tree) = c.universe.atPos(macroPos)(t)
diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala
index af290e4..2e8dcf9 100644
--- a/src/main/scala/scala/async/internal/AsyncTransform.scala
+++ b/src/main/scala/scala/async/internal/AsyncTransform.scala
@@ -26,6 +26,9 @@ trait AsyncTransform {
val anfTree = futureSystemOps.postAnfTransform(anfTree0)
+ cleanupContainsAwaitAttachments(anfTree)
+ containsAwait = containsAwaitCached(anfTree)
+
val applyDefDefDummyBody: DefDef = {
val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(futureSystemOps.tryType[Any]), EmptyTree)))
DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), literalUnit)
diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala
index 16b9207..ce2345d 100644
--- a/src/main/scala/scala/async/internal/ExprBuilder.scala
+++ b/src/main/scala/scala/async/internal/ExprBuilder.scala
@@ -237,10 +237,8 @@ trait ExprBuilder {
var stateBuilder = new AsyncStateBuilder(startState, symLookup)
var currState = startState
- def checkForUnsupportedAwait(tree: Tree) = if (tree exists {
- case Apply(fun, _) if isAwait(fun) => true
- case _ => false
- }) c.abort(tree.pos, "await must not be used in this position")
+ def checkForUnsupportedAwait(tree: Tree) = if (containsAwait(tree))
+ c.abort(tree.pos, "await must not be used in this position")
def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = {
val (nestedStats, nestedExpr) = statsAndExpr(nestedTree)
diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala
index 90419d3..e15ef1b 100644
--- a/src/main/scala/scala/async/internal/TransformUtils.scala
+++ b/src/main/scala/scala/async/internal/TransformUtils.scala
@@ -38,6 +38,9 @@ private[async] trait TransformUtils {
def fresh(name: String): String = c.freshName(name)
}
+ def isAsync(fun: Tree) =
+ fun.symbol == defn.Async_async
+
def isAwait(fun: Tree) =
fun.symbol == defn.Async_await
@@ -164,6 +167,7 @@ private[async] trait TransformUtils {
val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal")
val ThrowableClass = rootMirror.staticClass("java.lang.Throwable")
+ val Async_async = asyncBase.asyncMethod(c.universe)(c.macroApplication.symbol).ensuring(_ != NoSymbol)
val Async_await = asyncBase.awaitMethod(c.universe)(c.macroApplication.symbol).ensuring(_ != NoSymbol)
val IllegalStateExceptionClass = rootMirror.staticClass("java.lang.IllegalStateException")
}
@@ -281,6 +285,8 @@ private[async] trait TransformUtils {
override def traverse(tree: Tree) {
tree match {
+ case _ if isAsync(tree) =>
+ // Under -Ymacro-expand:discard, used in the IDE, nested async blocks will be visible to the outer blocks
case cd: ClassDef => nestedClass(cd)
case md: ModuleDef => nestedModule(md)
case dd: DefDef => nestedMethod(dd)
@@ -398,7 +404,7 @@ private[async] trait TransformUtils {
final def containsAwaitCached(t: Tree): Tree => Boolean = {
def treeCannotContainAwait(t: Tree) = t match {
case _: Ident | _: TypeTree | _: Literal => true
- case _ => false
+ case _ => isAsync(t)
}
def shouldAttach(t: Tree) = !treeCannotContainAwait(t)
val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable]
@@ -417,11 +423,15 @@ private[async] trait TransformUtils {
override def traverse(tree: Tree): Unit = {
stack ::= tree
try {
- if (isAwait(tree))
- stack.foreach(attachContainsAwait)
- else
- attachNoAwait(tree)
- super.traverse(tree)
+ if (isAsync(tree)) {
+ ;
+ } else {
+ if (isAwait(tree))
+ stack.foreach(attachContainsAwait)
+ else
+ attachNoAwait(tree)
+ super.traverse(tree)
+ }
} finally stack = stack.tail
}
}
diff --git a/src/test/scala/scala/async/run/WarningsSpec.scala b/src/test/scala/scala/async/run/WarningsSpec.scala
index 00c6466..c80bf9e 100644
--- a/src/test/scala/scala/async/run/WarningsSpec.scala
+++ b/src/test/scala/scala/async/run/WarningsSpec.scala
@@ -74,4 +74,24 @@ class WarningsSpec {
run.compileSources(sourceFile :: Nil)
assert(!global.reporter.hasErrors, global.reporter.asInstanceOf[StoreReporter].infos)
}
+
+ @Test
+ def ignoreNestedAwaitsInIDE_t1002561() {
+ // https://www.assembla.com/spaces/scala-ide/tickets/1002561
+ val global = mkGlobal("-cp ${toolboxClasspath} -Yrangepos -Ystop-after:typer ")
+ val source = """
+ | class Test {
+ | def test = {
+ | import scala.async.Async._, scala.concurrent._, ExecutionContext.Implicits.global
+ | async {
+ | 1 + await({def foo = (async(await(async(2)))); foo})
+ | }
+ | }
+ |}
+ """.stripMargin
+ val run = new global.Run
+ val sourceFile = global.newSourceFile(source)
+ run.compileSources(sourceFile :: Nil)
+ assert(!global.reporter.hasErrors, global.reporter.asInstanceOf[StoreReporter].infos)
+ }
}