aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--build.sbt2
-rw-r--r--src/main/scala/scala/async/internal/AnfTransform.scala12
-rw-r--r--src/main/scala/scala/async/internal/TransformUtils.scala2
-rw-r--r--src/test/scala/scala/async/run/anf/AnfTransformSpec.scala42
4 files changed, 54 insertions, 4 deletions
diff --git a/build.sbt b/build.sbt
index b5f7fad..ff5e956 100644
--- a/build.sbt
+++ b/build.sbt
@@ -8,7 +8,7 @@ organization := "org.scala-lang.modules"
name := "scala-async"
-version := "0.9.5-SNAPSHOT"
+version := "0.9.6-SNAPSHOT"
libraryDependencies <++= (scalaVersion) {
sv => Seq(
diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala
index 3c8c837..e41c64e 100644
--- a/src/main/scala/scala/async/internal/AnfTransform.scala
+++ b/src/main/scala/scala/async/internal/AnfTransform.scala
@@ -70,6 +70,9 @@ private[async] trait AnfTransform {
val stats :+ expr = anf.transformToList(tree)
def statsExprUnit =
stats :+ expr :+ localTyper.typedPos(expr.pos)(Literal(Constant(())))
+ def statsExprThrow =
+ stats :+ expr :+ localTyper.typedPos(expr.pos)(Throw(Apply(Select(New(gen.mkAttributedRef(defn.IllegalStateExceptionClass)), nme.CONSTRUCTOR), Nil)))
+
expr match {
case Apply(fun, args) if isAwait(fun) =>
val valDef = defineVal(name.await, expr, tree.pos)
@@ -81,7 +84,7 @@ private[async] trait AnfTransform {
// TODO avoid creating a ValDef for the result of this await to avoid this tree shape altogether.
// This will require some deeper changes to the later parts of the macro which currently assume regular
// tree structure around `await` calls.
- gen.mkCast(ref, definitions.UnitTpe)
+ localTyper.typedPos(tree.pos)(gen.mkCast(ref, definitions.UnitTpe))
else ref
stats :+ valDef :+ atPos(tree.pos)(ref1)
@@ -90,6 +93,8 @@ private[async] trait AnfTransform {
// but add Unit value to bring it into form expected by async transform
if (expr.tpe =:= definitions.UnitTpe) {
statsExprUnit
+ } else if (expr.tpe =:= definitions.NothingTpe) {
+ statsExprThrow
} else {
val varDef = defineVar(name.ifRes, expr.tpe, tree.pos)
def branchWithAssign(orig: Tree) = localTyper.typedPos(orig.pos) {
@@ -110,8 +115,9 @@ private[async] trait AnfTransform {
// but add Unit value to bring it into form expected by async transform
if (expr.tpe =:= definitions.UnitTpe) {
statsExprUnit
- }
- else {
+ } else if (expr.tpe =:= definitions.NothingTpe) {
+ statsExprThrow
+ } else {
val varDef = defineVar(name.matchRes, expr.tpe, tree.pos)
def typedAssign(lhs: Tree) =
localTyper.typedPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, varDef.symbol.tpe)))
diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala
index da76c18..5b5c4d2 100644
--- a/src/main/scala/scala/async/internal/TransformUtils.scala
+++ b/src/main/scala/scala/async/internal/TransformUtils.scala
@@ -83,6 +83,8 @@ private[async] trait TransformUtils {
val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal")
val Async_await = asyncBase.awaitMethod(global)(macroApplication.symbol).ensuring(_ != NoSymbol)
+ val IllegalStateExceptionClass = rootMirror.staticClass("java.lang.IllegalStateException")
+
}
def isSafeToInline(tree: Tree) = {
diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
index 757ae0b..bbc1b2b 100644
--- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
+++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala
@@ -405,4 +405,46 @@ class AnfTransformSpec {
val applyImplicitView = tree.collect { case x if x.getClass.getName.endsWith("ApplyImplicitView") => x }
applyImplicitView.map(_.toString) mustBe List("view(a$1)")
}
+
+ @Test
+ def nothingTypedIf(): Unit = {
+ import scala.async.internal.AsyncId.{async, await}
+ val result = util.Try(async {
+ if (true) {
+ val n = await(1)
+ if (n < 2) {
+ throw new RuntimeException("case a")
+ }
+ else {
+ throw new RuntimeException("case b")
+ }
+ }
+ else {
+ "case c"
+ }
+ })
+
+ assert(result.asInstanceOf[util.Failure[_]].exception.getMessage == "case a")
+ }
+
+ @Test
+ def nothingTypedMatch(): Unit = {
+ import scala.async.internal.AsyncId.{async, await}
+ val result = util.Try(async {
+ 0 match {
+ case _ if "".isEmpty =>
+ val n = await(1)
+ n match {
+ case _ if n < 2 =>
+ throw new RuntimeException("case a")
+ case _ =>
+ throw new RuntimeException("case b")
+ }
+ case _ =>
+ "case c"
+ }
+ })
+
+ assert(result.asInstanceOf[util.Failure[_]].exception.getMessage == "case a")
+ }
}