aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2012-11-21 22:48:34 +0100
committerJason Zaugg <jzaugg@gmail.com>2012-11-21 22:48:34 +0100
commit10aa18736a1d5161f9ad34ebcd9a6a756c904666 (patch)
treee9a23df0dd26e9b588d49be9772d901032629bf1
parenteeb0f5e676e8d9cc44ab886a6225da62dfb5d561 (diff)
downloadscala-async-10aa18736a1d5161f9ad34ebcd9a6a756c904666.tar.gz
scala-async-10aa18736a1d5161f9ad34ebcd9a6a756c904666.tar.bz2
scala-async-10aa18736a1d5161f9ad34ebcd9a6a756c904666.zip
Only transform if/match-s that contain an await.
Accurate reporting of misplaced awaits. Attempt to collect the minimal set of vars to lift.
-rw-r--r--src/main/scala/scala/async/Async.scala3
-rw-r--r--src/main/scala/scala/async/ExprBuilder.scala90
-rw-r--r--src/test/scala/scala/async/TestUtils.scala4
-rw-r--r--src/test/scala/scala/async/neg/AnfTransformNegSpec.scala4
-rw-r--r--src/test/scala/scala/async/neg/NakedAwait.scala72
5 files changed, 167 insertions, 6 deletions
diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala
index 072aea7..30b393e 100644
--- a/src/main/scala/scala/async/Async.scala
+++ b/src/main/scala/scala/async/Async.scala
@@ -81,6 +81,9 @@ abstract class AsyncBase {
c.typeCheck(Block(stats1, expr1))
}
+ val traverser = new builder.LiftableVarTraverser
+ traverser.traverse(btree)
+
AsyncUtils.vprintln(s"In file '${c.macroApplication.pos.source.path}':")
AsyncUtils.vprintln(s"${c.macroApplication}")
AsyncUtils.vprintln(s"ANF transform expands to:\n $btree")
diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala
index 07aa1ee..1ca9e8f 100644
--- a/src/main/scala/scala/async/ExprBuilder.scala
+++ b/src/main/scala/scala/async/ExprBuilder.scala
@@ -310,7 +310,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
// when adding assignment need to take `toRename` into account
stateBuilder.addVarDef(mods, newName, tpt, rhs, toRename)
- case If(cond, thenp, elsep) =>
+ case If(cond, thenp, elsep) if stat exists isAwait =>
checkForUnsupportedAwait(cond)
val ifBudget: Int = remainingBudget / 2
@@ -335,7 +335,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
currState = currState + ifBudget
stateBuilder = new builder.AsyncStateBuilder(currState, toRename)
- case Match(scrutinee, cases) =>
+ case Match(scrutinee, cases) if stat exists isAwait =>
checkForUnsupportedAwait(scrutinee)
val matchBudget: Int = remainingBudget / 2
@@ -395,6 +395,92 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val
}
}
+ private val Boolean_ShortCircuits: Set[Symbol] = {
+ import definitions.BooleanClass
+ def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName)
+ val Boolean_&& = BooleanTermMember("&&")
+ val Boolean_|| = BooleanTermMember("||")
+ Set(Boolean_&&, Boolean_||)
+ }
+
+ def isByName(fun: Tree): (Int => Boolean) = {
+ if (Boolean_ShortCircuits contains fun.symbol) i => true
+ else fun.tpe match {
+ case MethodType(params, _) =>
+ val isByNameParams = params.map(_.asTerm.isByNameParam)
+ (i: Int) => isByNameParams.applyOrElse(i, (_: Int) => false)
+ case _ => Map()
+ }
+ }
+
+ private def isAwait(fun: Tree) = {
+ fun.symbol == defn.Async_await
+ }
+
+ private[async] class LiftableVarTraverser extends Traverser {
+ var blockId = 0
+ var valDefBlockId = Map[Symbol, (ValDef, Int)]()
+ val liftable = collection.mutable.Set[ValDef]()
+
+
+ def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) {
+ val badAwaits = tree collect {
+ case rt: RefTree if rt.symbol == Async_await => rt
+ }
+ badAwaits foreach {
+ tree =>
+ c.error(tree.pos, s"await must not be used under a $whyUnsupported.")
+ }
+ }
+
+ override def traverse(tree: Tree) = tree match {
+ case cd: ClassDef =>
+ val kind = if (cd.symbol.asClass.isTrait) "trait" else "class"
+ reportUnsupportedAwait(tree, s"nested ${kind}")
+ case md: ModuleDef =>
+ reportUnsupportedAwait(tree, "nested object")
+ case _: Function =>
+ reportUnsupportedAwait(tree, "nested anonymous function")
+ case If(cond, thenp, elsep) if tree exists isAwait =>
+ traverse(cond)
+ blockId += 1
+ traverse(thenp)
+ blockId += 1
+ traverse(elsep)
+ blockId += 1
+ case Match(selector, cases) if tree exists isAwait =>
+ traverse(selector)
+ blockId += 1
+ cases foreach {c => traverse(c); blockId += 1}
+ case Apply(fun, args) if isAwait(fun) =>
+ traverseTrees(args)
+ traverse(fun)
+ blockId += 1
+ case Apply(fun, args) =>
+ val isInByName = isByName(fun)
+ for ((arg, index) <- args.zipWithIndex) {
+ if (!isInByName(index)) traverse(arg)
+ else reportUnsupportedAwait(arg, "by-name argument")
+ }
+ traverse(fun)
+ case vd: ValDef =>
+ super.traverse(tree)
+ valDefBlockId += (vd.symbol -> (vd, blockId))
+ if (vd.rhs.symbol == Async_await) liftable += vd
+ case as: Assign =>
+ if (as.rhs.symbol == Async_await) liftable += valDefBlockId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block."))._1
+ case rt: RefTree =>
+ valDefBlockId.get(rt.symbol) match {
+ case Some((vd, defBlockId)) if defBlockId != blockId =>
+ liftable += vd
+ case _ =>
+ }
+ super.traverse(tree)
+ case _ => super.traverse(tree)
+ }
+ }
+
+
/** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */
private def methodSym(apply: c.Expr[Any]): Symbol = {
val tree2: Tree = c.typeCheck(apply.tree) // TODO why is this needed?
diff --git a/src/test/scala/scala/async/TestUtils.scala b/src/test/scala/scala/async/TestUtils.scala
index bac22a3..0ae78b8 100644
--- a/src/test/scala/scala/async/TestUtils.scala
+++ b/src/test/scala/scala/async/TestUtils.scala
@@ -50,9 +50,9 @@ trait TestUtils {
m.mkToolBox(options = compileOptions)
}
- def expectError(errorSnippet: String, compileOptions: String = "")(code: String) {
+ def expectError(errorSnippet: String, compileOptions: String = "", baseCompileOptions: String = "-cp target/scala-2.10/classes")(code: String) {
intercept[ToolBoxError] {
- eval(code, compileOptions)
+ eval(code, compileOptions + " " + baseCompileOptions)
}.getMessage mustContain errorSnippet
}
}
diff --git a/src/test/scala/scala/async/neg/AnfTransformNegSpec.scala b/src/test/scala/scala/async/neg/AnfTransformNegSpec.scala
index 974a5f1..38790dd 100644
--- a/src/test/scala/scala/async/neg/AnfTransformNegSpec.scala
+++ b/src/test/scala/scala/async/neg/AnfTransformNegSpec.scala
@@ -13,7 +13,7 @@ class AnfTransformNegSpec {
@Test
def `inlining block produces duplicate definition`() {
- expectError("x is already defined as value x", "-cp target/scala-2.10/classes -deprecation -Xfatal-warnings") {
+ expectError("x is already defined as value x", "-deprecation -Xfatal-warnings") {
"""
| import scala.concurrent.ExecutionContext.Implicits.global
| import scala.concurrent.Future
@@ -36,7 +36,7 @@ class AnfTransformNegSpec {
@Test
def `inlining block in tail position produces duplicate definition`() {
- expectError("x is already defined as value x", "-cp target/scala-2.10/classes -deprecation -Xfatal-warnings") {
+ expectError("x is already defined as value x", "-deprecation -Xfatal-warnings") {
"""
| import scala.concurrent.ExecutionContext.Implicits.global
| import scala.concurrent.Future
diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala
index db67f18..66bc947 100644
--- a/src/test/scala/scala/async/neg/NakedAwait.scala
+++ b/src/test/scala/scala/async/neg/NakedAwait.scala
@@ -16,4 +16,76 @@ class NakedAwait {
""".stripMargin
}
}
+
+
+ @Test
+ def `await not allowed in by-name argument`() {
+ expectError("await must not be used under a by-name argument.") {
+ """
+ | import _root_.scala.async.AsyncId._
+ | def foo(a: Int)(b: => Int) = 0
+ | async { foo(0)(await(0)) }
+ """.stripMargin
+ }
+ }
+
+ @Test
+ def `await not allowed in boolean short circuit argument 1`() {
+ expectError("await must not be used under a by-name argument.") {
+ """
+ | import _root_.scala.async.AsyncId._
+ | async { true && await(false) }
+ """.stripMargin
+ }
+ }
+
+ @Test
+ def `await not allowed in boolean short circuit argument 2`() {
+ expectError("await must not be used under a by-name argument.") {
+ """
+ | import _root_.scala.async.AsyncId._
+ | async { true || await(false) }
+ """.stripMargin
+ }
+ }
+
+ @Test
+ def nestedObject() {
+ expectError("await must not be used under a nested object.") {
+ """
+ | import _root_.scala.async.AsyncId._
+ | async { object Nested { await(false) } }
+ """.stripMargin
+ }
+ }
+
+ @Test
+ def nestedTrait() {
+ expectError("await must not be used under a nested trait.") {
+ """
+ | import _root_.scala.async.AsyncId._
+ | async { trait Nested { await(false) } }
+ """.stripMargin
+ }
+ }
+
+ @Test
+ def nestedClass() {
+ expectError("await must not be used under a nested class.") {
+ """
+ | import _root_.scala.async.AsyncId._
+ | async { class Nested { await(false) } }
+ """.stripMargin
+ }
+ }
+
+ @Test
+ def nestedFunction() {
+ expectError("await must not be used under a nested anonymous function.") {
+ """
+ | import _root_.scala.async.AsyncId._
+ | async { () => { await(false) } }
+ """.stripMargin
+ }
+ }
}