From 26038aebf1555b582dba35e8bfc3698f126705c5 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Sun, 25 Nov 2012 11:51:34 +0100 Subject: Fix await in if condition / match scrutinee. The type-checking performed in ANF transform is precarious, and needed to use the original condition/ scrutinee in a throwaway tree to get things to work. --- src/main/scala/scala/async/AnfTransform.scala | 18 +++++++-- src/main/scala/scala/async/AsyncAnalysis.scala | 3 -- src/main/scala/scala/async/AsyncUtils.scala | 4 +- src/test/scala/scala/async/TreeInterrogation.scala | 44 +++++++++++----------- src/test/scala/scala/async/neg/NakedAwait.scala | 11 ------ .../scala/scala/async/run/ifelse0/IfElse0.scala | 8 ++++ src/test/scala/scala/async/run/match0/Match0.scala | 12 ++++++ 7 files changed, 59 insertions(+), 41 deletions(-) diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 5080ecf..055676d 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -195,11 +195,18 @@ private[async] final case class AnfTransform[C <: Context](c: C) { stats :+ attachCopy(tree)(Assign(lhs, expr)) case If(cond, thenp, elsep) if containsAwait => - val stats :+ expr = inline.transformToList(cond) + val condStats :+ condExpr = inline.transformToList(cond) val thenBlock = inline.transformToBlock(thenp) val elseBlock = inline.transformToBlock(elsep) - stats :+ - c.typeCheck(attachCopy(tree)(If(expr, thenBlock, elseBlock))) + // Typechecking with `condExpr` as the condition fails if the condition + // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems + // we rely on this call to `typeCheck` descending into the branches. + // But, we can get away with typechecking a throwaway `If` tree with the + // original scrutinee and the new branches, and setting that type on + // the real `If` tree. + val ifType = c.typeCheck(If(cond, thenBlock, elseBlock)).tpe + condStats :+ + attachCopy(tree)(If(condExpr, thenBlock, elseBlock)).setType(ifType) case Match(scrut, cases) if containsAwait => val scrutStats :+ scrutExpr = inline.transformToList(scrut) @@ -218,7 +225,10 @@ private[async] final case class AnfTransform[C <: Context](c: C) { val Block(stats1, expr1) = utils.substituteNames(block, mappings.toMap).asInstanceOf[Block] attachCopy(tree)(CaseDef(pat, guard, Block(valDefs ++ stats1, expr1))) } - scrutStats :+ c.typeCheck(attachCopy(tree)(Match(scrutExpr, caseDefs))) + // Refer to comments the translation of `If` above. + val matchType = c.typeCheck(Match(scrut, caseDefs)).tpe + val typedMatch = attachCopy(tree)(Match(scrutExpr, caseDefs)).setType(tree.tpe) + scrutStats :+ typedMatch case LabelDef(name, params, rhs) if containsAwait => List(LabelDef(name, params, Block(inline.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol)) diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala index f0d4511..ecd5054 100644 --- a/src/main/scala/scala/async/AsyncAnalysis.scala +++ b/src/main/scala/scala/async/AsyncAnalysis.scala @@ -71,9 +71,6 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) { case Try(_, _, _) if containsAwait => reportUnsupportedAwait(tree, "try/catch") super.traverse(tree) - case If(cond, _, _) if containsAwait => - reportUnsupportedAwait(cond, "condition") - super.traverse(tree) case Return(_) => c.abort(tree.pos, "return is illegal within a async block") case _ => diff --git a/src/main/scala/scala/async/AsyncUtils.scala b/src/main/scala/scala/async/AsyncUtils.scala index 87a63d7..999cb95 100644 --- a/src/main/scala/scala/async/AsyncUtils.scala +++ b/src/main/scala/scala/async/AsyncUtils.scala @@ -10,8 +10,8 @@ object AsyncUtils { private def enabled(level: String) = sys.props.getOrElse(s"scala.async.$level", "false").equalsIgnoreCase("true") - private val verbose = enabled("debug") - private val trace = enabled("trace") + var verbose = enabled("debug") + var trace = enabled("trace") private[async] def vprintln(s: => Any): Unit = if (verbose) println(s"[async] $s") diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala index f005b8a..e3012c7 100644 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ b/src/test/scala/scala/async/TreeInterrogation.scala @@ -41,26 +41,28 @@ class TreeInterrogation { } object TreeInterrogation extends App { - sys.props("scala.async.debug") = true.toString - sys.props("scala.async.trace") = true.toString + def withDebug[T](t: => T) { + AsyncUtils.trace = true + AsyncUtils.verbose = true + try t + finally { + AsyncUtils.trace = false + AsyncUtils.verbose = false + } + } - val cm = reflect.runtime.currentMirror - val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:all") - val tree = tb.parse( - """ import _root_.scala.async.AsyncId._ - | async { - | val x = 1 - | Option(x) match { - | case op @ Some(x) => - | assert(op != null) - | println((op, x)) - | x + await(x) - | case None => await(0) - | } - | } - | """.stripMargin) - println(tree) - val tree1 = tb.typeCheck(tree.duplicate) - println(cm.universe.show(tree1)) - println(tb.eval(tree)) + withDebug { + val cm = reflect.runtime.currentMirror + val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:all") + val tree = tb.parse( + """ import _root_.scala.async.AsyncId._ + | async { + | await(0) match { case _ => 0 } + | } + | """.stripMargin) + println(tree) + val tree1 = tb.typeCheck(tree.duplicate) + println(cm.universe.show(tree1)) + println(tb.eval(tree)) + } } \ No newline at end of file diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala index f4cfca2..ecc84f9 100644 --- a/src/test/scala/scala/async/neg/NakedAwait.scala +++ b/src/test/scala/scala/async/neg/NakedAwait.scala @@ -143,15 +143,4 @@ class NakedAwait { |""".stripMargin } } - - // TODO Anf transform if to have a simple condition. - @Test - def ifCondition() { - expectError("await must not be used under a condition.") { - """ - | import _root_.scala.async.AsyncId._ - | async { if (await(true)) () } - |""".stripMargin - } - } } diff --git a/src/test/scala/scala/async/run/ifelse0/IfElse0.scala b/src/test/scala/scala/async/run/ifelse0/IfElse0.scala index 0a72f1e..e2b1ca6 100644 --- a/src/test/scala/scala/async/run/ifelse0/IfElse0.scala +++ b/src/test/scala/scala/async/run/ifelse0/IfElse0.scala @@ -47,4 +47,12 @@ class IfElseSpec { val res = Await.result(fut, 2 seconds) res mustBe (14) } + + @Test def `await in condition`() { + import AsyncId.{async, await} + val result = async { + if ({await(true); await(true)}) await(1) else ??? + } + result mustBe (1) + } } diff --git a/src/test/scala/scala/async/run/match0/Match0.scala b/src/test/scala/scala/async/run/match0/Match0.scala index 5237629..8263e72 100644 --- a/src/test/scala/scala/async/run/match0/Match0.scala +++ b/src/test/scala/scala/async/run/match0/Match0.scala @@ -100,4 +100,16 @@ class MatchSpec { } result mustBe ((Some(""), true)) } + + @Test def `await in scrutinee`() { + import AsyncId.{async, await} + val result = async { + await(if ("".isEmpty) await(1) else ???) match { + case x if x < 0 => ??? + case y: Int => y * await(3) + case _ => ??? + } + } + result mustBe (3) + } } -- cgit v1.2.3