From 10aa18736a1d5161f9ad34ebcd9a6a756c904666 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Wed, 21 Nov 2012 22:48:34 +0100 Subject: Only transform if/match-s that contain an await. Accurate reporting of misplaced awaits. Attempt to collect the minimal set of vars to lift. --- src/main/scala/scala/async/Async.scala | 3 + src/main/scala/scala/async/ExprBuilder.scala | 90 +++++++++++++++++++++- src/test/scala/scala/async/TestUtils.scala | 4 +- .../scala/async/neg/AnfTransformNegSpec.scala | 4 +- src/test/scala/scala/async/neg/NakedAwait.scala | 72 +++++++++++++++++ 5 files changed, 167 insertions(+), 6 deletions(-) diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index 072aea7..30b393e 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -81,6 +81,9 @@ abstract class AsyncBase { c.typeCheck(Block(stats1, expr1)) } + val traverser = new builder.LiftableVarTraverser + traverser.traverse(btree) + AsyncUtils.vprintln(s"In file '${c.macroApplication.pos.source.path}':") AsyncUtils.vprintln(s"${c.macroApplication}") AsyncUtils.vprintln(s"ANF transform expands to:\n $btree") diff --git a/src/main/scala/scala/async/ExprBuilder.scala b/src/main/scala/scala/async/ExprBuilder.scala index 07aa1ee..1ca9e8f 100644 --- a/src/main/scala/scala/async/ExprBuilder.scala +++ b/src/main/scala/scala/async/ExprBuilder.scala @@ -310,7 +310,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val // when adding assignment need to take `toRename` into account stateBuilder.addVarDef(mods, newName, tpt, rhs, toRename) - case If(cond, thenp, elsep) => + case If(cond, thenp, elsep) if stat exists isAwait => checkForUnsupportedAwait(cond) val ifBudget: Int = remainingBudget / 2 @@ -335,7 +335,7 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val currState = currState + ifBudget stateBuilder = new builder.AsyncStateBuilder(currState, toRename) - case Match(scrutinee, cases) => + case Match(scrutinee, cases) if stat exists isAwait => checkForUnsupportedAwait(scrutinee) val matchBudget: Int = remainingBudget / 2 @@ -395,6 +395,92 @@ final class ExprBuilder[C <: Context, FS <: FutureSystem](override val c: C, val } } + private val Boolean_ShortCircuits: Set[Symbol] = { + import definitions.BooleanClass + def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(newTermName(name).encodedName) + val Boolean_&& = BooleanTermMember("&&") + val Boolean_|| = BooleanTermMember("||") + Set(Boolean_&&, Boolean_||) + } + + def isByName(fun: Tree): (Int => Boolean) = { + if (Boolean_ShortCircuits contains fun.symbol) i => true + else fun.tpe match { + case MethodType(params, _) => + val isByNameParams = params.map(_.asTerm.isByNameParam) + (i: Int) => isByNameParams.applyOrElse(i, (_: Int) => false) + case _ => Map() + } + } + + private def isAwait(fun: Tree) = { + fun.symbol == defn.Async_await + } + + private[async] class LiftableVarTraverser extends Traverser { + var blockId = 0 + var valDefBlockId = Map[Symbol, (ValDef, Int)]() + val liftable = collection.mutable.Set[ValDef]() + + + def reportUnsupportedAwait(tree: Tree, whyUnsupported: String) { + val badAwaits = tree collect { + case rt: RefTree if rt.symbol == Async_await => rt + } + badAwaits foreach { + tree => + c.error(tree.pos, s"await must not be used under a $whyUnsupported.") + } + } + + override def traverse(tree: Tree) = tree match { + case cd: ClassDef => + val kind = if (cd.symbol.asClass.isTrait) "trait" else "class" + reportUnsupportedAwait(tree, s"nested ${kind}") + case md: ModuleDef => + reportUnsupportedAwait(tree, "nested object") + case _: Function => + reportUnsupportedAwait(tree, "nested anonymous function") + case If(cond, thenp, elsep) if tree exists isAwait => + traverse(cond) + blockId += 1 + traverse(thenp) + blockId += 1 + traverse(elsep) + blockId += 1 + case Match(selector, cases) if tree exists isAwait => + traverse(selector) + blockId += 1 + cases foreach {c => traverse(c); blockId += 1} + case Apply(fun, args) if isAwait(fun) => + traverseTrees(args) + traverse(fun) + blockId += 1 + case Apply(fun, args) => + val isInByName = isByName(fun) + for ((arg, index) <- args.zipWithIndex) { + if (!isInByName(index)) traverse(arg) + else reportUnsupportedAwait(arg, "by-name argument") + } + traverse(fun) + case vd: ValDef => + super.traverse(tree) + valDefBlockId += (vd.symbol -> (vd, blockId)) + if (vd.rhs.symbol == Async_await) liftable += vd + case as: Assign => + if (as.rhs.symbol == Async_await) liftable += valDefBlockId.getOrElse(as.symbol, c.abort(as.pos, "await may only be assigned to a var/val defined in the async block."))._1 + case rt: RefTree => + valDefBlockId.get(rt.symbol) match { + case Some((vd, defBlockId)) if defBlockId != blockId => + liftable += vd + case _ => + } + super.traverse(tree) + case _ => super.traverse(tree) + } + } + + /** `termSym( (_: Foo).bar(null: A, null: B)` will return the symbol of `bar`, after overload resolution. */ private def methodSym(apply: c.Expr[Any]): Symbol = { val tree2: Tree = c.typeCheck(apply.tree) // TODO why is this needed? diff --git a/src/test/scala/scala/async/TestUtils.scala b/src/test/scala/scala/async/TestUtils.scala index bac22a3..0ae78b8 100644 --- a/src/test/scala/scala/async/TestUtils.scala +++ b/src/test/scala/scala/async/TestUtils.scala @@ -50,9 +50,9 @@ trait TestUtils { m.mkToolBox(options = compileOptions) } - def expectError(errorSnippet: String, compileOptions: String = "")(code: String) { + def expectError(errorSnippet: String, compileOptions: String = "", baseCompileOptions: String = "-cp target/scala-2.10/classes")(code: String) { intercept[ToolBoxError] { - eval(code, compileOptions) + eval(code, compileOptions + " " + baseCompileOptions) }.getMessage mustContain errorSnippet } } diff --git a/src/test/scala/scala/async/neg/AnfTransformNegSpec.scala b/src/test/scala/scala/async/neg/AnfTransformNegSpec.scala index 974a5f1..38790dd 100644 --- a/src/test/scala/scala/async/neg/AnfTransformNegSpec.scala +++ b/src/test/scala/scala/async/neg/AnfTransformNegSpec.scala @@ -13,7 +13,7 @@ class AnfTransformNegSpec { @Test def `inlining block produces duplicate definition`() { - expectError("x is already defined as value x", "-cp target/scala-2.10/classes -deprecation -Xfatal-warnings") { + expectError("x is already defined as value x", "-deprecation -Xfatal-warnings") { """ | import scala.concurrent.ExecutionContext.Implicits.global | import scala.concurrent.Future @@ -36,7 +36,7 @@ class AnfTransformNegSpec { @Test def `inlining block in tail position produces duplicate definition`() { - expectError("x is already defined as value x", "-cp target/scala-2.10/classes -deprecation -Xfatal-warnings") { + expectError("x is already defined as value x", "-deprecation -Xfatal-warnings") { """ | import scala.concurrent.ExecutionContext.Implicits.global | import scala.concurrent.Future diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala index db67f18..66bc947 100644 --- a/src/test/scala/scala/async/neg/NakedAwait.scala +++ b/src/test/scala/scala/async/neg/NakedAwait.scala @@ -16,4 +16,76 @@ class NakedAwait { """.stripMargin } } + + + @Test + def `await not allowed in by-name argument`() { + expectError("await must not be used under a by-name argument.") { + """ + | import _root_.scala.async.AsyncId._ + | def foo(a: Int)(b: => Int) = 0 + | async { foo(0)(await(0)) } + """.stripMargin + } + } + + @Test + def `await not allowed in boolean short circuit argument 1`() { + expectError("await must not be used under a by-name argument.") { + """ + | import _root_.scala.async.AsyncId._ + | async { true && await(false) } + """.stripMargin + } + } + + @Test + def `await not allowed in boolean short circuit argument 2`() { + expectError("await must not be used under a by-name argument.") { + """ + | import _root_.scala.async.AsyncId._ + | async { true || await(false) } + """.stripMargin + } + } + + @Test + def nestedObject() { + expectError("await must not be used under a nested object.") { + """ + | import _root_.scala.async.AsyncId._ + | async { object Nested { await(false) } } + """.stripMargin + } + } + + @Test + def nestedTrait() { + expectError("await must not be used under a nested trait.") { + """ + | import _root_.scala.async.AsyncId._ + | async { trait Nested { await(false) } } + """.stripMargin + } + } + + @Test + def nestedClass() { + expectError("await must not be used under a nested class.") { + """ + | import _root_.scala.async.AsyncId._ + | async { class Nested { await(false) } } + """.stripMargin + } + } + + @Test + def nestedFunction() { + expectError("await must not be used under a nested anonymous function.") { + """ + | import _root_.scala.async.AsyncId._ + | async { () => { await(false) } } + """.stripMargin + } + } } -- cgit v1.2.3