aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJason Zaugg <jzaugg@gmail.com>2012-11-25 11:51:34 +0100
committerJason Zaugg <jzaugg@gmail.com>2012-11-26 16:13:42 +0100
commit26038aebf1555b582dba35e8bfc3698f126705c5 (patch)
tree89517670afcf4ff41648bddecf41603661a111dc
parenta5cab2959067bc7f9d3884064fbf7bf7ec0b7285 (diff)
downloadscala-async-26038aebf1555b582dba35e8bfc3698f126705c5.tar.gz
scala-async-26038aebf1555b582dba35e8bfc3698f126705c5.tar.bz2
scala-async-26038aebf1555b582dba35e8bfc3698f126705c5.zip
Fix await in if condition / match scrutinee.
The type-checking performed in ANF transform is precarious, and needed to use the original condition/ scrutinee in a throwaway tree to get things to work.
-rw-r--r--src/main/scala/scala/async/AnfTransform.scala18
-rw-r--r--src/main/scala/scala/async/AsyncAnalysis.scala3
-rw-r--r--src/main/scala/scala/async/AsyncUtils.scala4
-rw-r--r--src/test/scala/scala/async/TreeInterrogation.scala44
-rw-r--r--src/test/scala/scala/async/neg/NakedAwait.scala11
-rw-r--r--src/test/scala/scala/async/run/ifelse0/IfElse0.scala8
-rw-r--r--src/test/scala/scala/async/run/match0/Match0.scala12
7 files changed, 59 insertions, 41 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala
index 5080ecf..055676d 100644
--- a/src/main/scala/scala/async/AnfTransform.scala
+++ b/src/main/scala/scala/async/AnfTransform.scala
@@ -195,11 +195,18 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
stats :+ attachCopy(tree)(Assign(lhs, expr))
case If(cond, thenp, elsep) if containsAwait =>
- val stats :+ expr = inline.transformToList(cond)
+ val condStats :+ condExpr = inline.transformToList(cond)
val thenBlock = inline.transformToBlock(thenp)
val elseBlock = inline.transformToBlock(elsep)
- stats :+
- c.typeCheck(attachCopy(tree)(If(expr, thenBlock, elseBlock)))
+ // Typechecking with `condExpr` as the condition fails if the condition
+ // contains an await. `ifTree.setType(tree.tpe)` also fails; it seems
+ // we rely on this call to `typeCheck` descending into the branches.
+ // But, we can get away with typechecking a throwaway `If` tree with the
+ // original scrutinee and the new branches, and setting that type on
+ // the real `If` tree.
+ val ifType = c.typeCheck(If(cond, thenBlock, elseBlock)).tpe
+ condStats :+
+ attachCopy(tree)(If(condExpr, thenBlock, elseBlock)).setType(ifType)
case Match(scrut, cases) if containsAwait =>
val scrutStats :+ scrutExpr = inline.transformToList(scrut)
@@ -218,7 +225,10 @@ private[async] final case class AnfTransform[C <: Context](c: C) {
val Block(stats1, expr1) = utils.substituteNames(block, mappings.toMap).asInstanceOf[Block]
attachCopy(tree)(CaseDef(pat, guard, Block(valDefs ++ stats1, expr1)))
}
- scrutStats :+ c.typeCheck(attachCopy(tree)(Match(scrutExpr, caseDefs)))
+ // Refer to comments the translation of `If` above.
+ val matchType = c.typeCheck(Match(scrut, caseDefs)).tpe
+ val typedMatch = attachCopy(tree)(Match(scrutExpr, caseDefs)).setType(tree.tpe)
+ scrutStats :+ typedMatch
case LabelDef(name, params, rhs) if containsAwait =>
List(LabelDef(name, params, Block(inline.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
diff --git a/src/main/scala/scala/async/AsyncAnalysis.scala b/src/main/scala/scala/async/AsyncAnalysis.scala
index f0d4511..ecd5054 100644
--- a/src/main/scala/scala/async/AsyncAnalysis.scala
+++ b/src/main/scala/scala/async/AsyncAnalysis.scala
@@ -71,9 +71,6 @@ private[async] final case class AsyncAnalysis[C <: Context](c: C) {
case Try(_, _, _) if containsAwait =>
reportUnsupportedAwait(tree, "try/catch")
super.traverse(tree)
- case If(cond, _, _) if containsAwait =>
- reportUnsupportedAwait(cond, "condition")
- super.traverse(tree)
case Return(_) =>
c.abort(tree.pos, "return is illegal within a async block")
case _ =>
diff --git a/src/main/scala/scala/async/AsyncUtils.scala b/src/main/scala/scala/async/AsyncUtils.scala
index 87a63d7..999cb95 100644
--- a/src/main/scala/scala/async/AsyncUtils.scala
+++ b/src/main/scala/scala/async/AsyncUtils.scala
@@ -10,8 +10,8 @@ object AsyncUtils {
private def enabled(level: String) = sys.props.getOrElse(s"scala.async.$level", "false").equalsIgnoreCase("true")
- private val verbose = enabled("debug")
- private val trace = enabled("trace")
+ var verbose = enabled("debug")
+ var trace = enabled("trace")
private[async] def vprintln(s: => Any): Unit = if (verbose) println(s"[async] $s")
diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala
index f005b8a..e3012c7 100644
--- a/src/test/scala/scala/async/TreeInterrogation.scala
+++ b/src/test/scala/scala/async/TreeInterrogation.scala
@@ -41,26 +41,28 @@ class TreeInterrogation {
}
object TreeInterrogation extends App {
- sys.props("scala.async.debug") = true.toString
- sys.props("scala.async.trace") = true.toString
+ def withDebug[T](t: => T) {
+ AsyncUtils.trace = true
+ AsyncUtils.verbose = true
+ try t
+ finally {
+ AsyncUtils.trace = false
+ AsyncUtils.verbose = false
+ }
+ }
- val cm = reflect.runtime.currentMirror
- val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:all")
- val tree = tb.parse(
- """ import _root_.scala.async.AsyncId._
- | async {
- | val x = 1
- | Option(x) match {
- | case op @ Some(x) =>
- | assert(op != null)
- | println((op, x))
- | x + await(x)
- | case None => await(0)
- | }
- | }
- | """.stripMargin)
- println(tree)
- val tree1 = tb.typeCheck(tree.duplicate)
- println(cm.universe.show(tree1))
- println(tb.eval(tree))
+ withDebug {
+ val cm = reflect.runtime.currentMirror
+ val tb = mkToolbox("-cp target/scala-2.10/classes -Xprint:all")
+ val tree = tb.parse(
+ """ import _root_.scala.async.AsyncId._
+ | async {
+ | await(0) match { case _ => 0 }
+ | }
+ | """.stripMargin)
+ println(tree)
+ val tree1 = tb.typeCheck(tree.duplicate)
+ println(cm.universe.show(tree1))
+ println(tb.eval(tree))
+ }
} \ No newline at end of file
diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala
index f4cfca2..ecc84f9 100644
--- a/src/test/scala/scala/async/neg/NakedAwait.scala
+++ b/src/test/scala/scala/async/neg/NakedAwait.scala
@@ -143,15 +143,4 @@ class NakedAwait {
|""".stripMargin
}
}
-
- // TODO Anf transform if to have a simple condition.
- @Test
- def ifCondition() {
- expectError("await must not be used under a condition.") {
- """
- | import _root_.scala.async.AsyncId._
- | async { if (await(true)) () }
- |""".stripMargin
- }
- }
}
diff --git a/src/test/scala/scala/async/run/ifelse0/IfElse0.scala b/src/test/scala/scala/async/run/ifelse0/IfElse0.scala
index 0a72f1e..e2b1ca6 100644
--- a/src/test/scala/scala/async/run/ifelse0/IfElse0.scala
+++ b/src/test/scala/scala/async/run/ifelse0/IfElse0.scala
@@ -47,4 +47,12 @@ class IfElseSpec {
val res = Await.result(fut, 2 seconds)
res mustBe (14)
}
+
+ @Test def `await in condition`() {
+ import AsyncId.{async, await}
+ val result = async {
+ if ({await(true); await(true)}) await(1) else ???
+ }
+ result mustBe (1)
+ }
}
diff --git a/src/test/scala/scala/async/run/match0/Match0.scala b/src/test/scala/scala/async/run/match0/Match0.scala
index 5237629..8263e72 100644
--- a/src/test/scala/scala/async/run/match0/Match0.scala
+++ b/src/test/scala/scala/async/run/match0/Match0.scala
@@ -100,4 +100,16 @@ class MatchSpec {
}
result mustBe ((Some(""), true))
}
+
+ @Test def `await in scrutinee`() {
+ import AsyncId.{async, await}
+ val result = async {
+ await(if ("".isEmpty) await(1) else ???) match {
+ case x if x < 0 => ???
+ case y: Int => y * await(3)
+ case _ => ???
+ }
+ }
+ result mustBe (3)
+ }
}