diff options
-rw-r--r-- | src/main/scala/scala/async/AnfTransform.scala | 32 | ||||
-rw-r--r-- | src/test/scala/scala/async/run/anf/AnfTransformSpec.scala | 20 |
2 files changed, 39 insertions, 13 deletions
diff --git a/src/main/scala/scala/async/AnfTransform.scala b/src/main/scala/scala/async/AnfTransform.scala index 2dd6f8c..f29d6a1 100644 --- a/src/main/scala/scala/async/AnfTransform.scala +++ b/src/main/scala/scala/async/AnfTransform.scala @@ -16,21 +16,27 @@ class AnfTransform[C <: Context](override val c: C) extends TransformUtils(c) { stats :+ ValDef(NoMods, liftedName, TypeTree(), expr) :+ Ident(liftedName) case If(cond, thenp, elsep) => - val liftedName = c.fresh("ifres$") - val varDef = - ValDef(Modifiers(Flag.MUTABLE), liftedName, TypeTree(expr.tpe), defaultValue(expr.tpe)) - val thenWithAssign = thenp match { - case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(liftedName), thenExpr)) - case _ => Assign(Ident(liftedName), thenp) + // if type of if-else is Unit don't introduce assignment, + // but add Unit value to bring it into form expected by async transform + if (expr.tpe =:= definitions.UnitTpe) { + stats :+ expr :+ Literal(Constant(())) } - val elseWithAssign = elsep match { - case Block(elseStats, elseExpr) => Block(elseStats, Assign(Ident(liftedName), elseExpr)) - case _ => Assign(Ident(liftedName), elsep) + else { + val liftedName = c.fresh("ifres$") + val varDef = + ValDef(Modifiers(Flag.MUTABLE), liftedName, TypeTree(expr.tpe), defaultValue(expr.tpe)) + val thenWithAssign = thenp match { + case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(liftedName), thenExpr)) + case _ => Assign(Ident(liftedName), thenp) + } + val elseWithAssign = elsep match { + case Block(elseStats, elseExpr) => Block(elseStats, Assign(Ident(liftedName), elseExpr)) + case _ => Assign(Ident(liftedName), elsep) + } + val ifWithAssign = + If(cond, thenWithAssign, elseWithAssign) + stats :+ varDef :+ ifWithAssign :+ Ident(liftedName) } - val ifWithAssign = - If(cond, thenWithAssign, elseWithAssign) - stats :+ varDef :+ ifWithAssign :+ Ident(liftedName) - case _ => stats :+ expr } diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala index f38efa9..f2fc2d7 100644 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala @@ -54,8 +54,20 @@ class AnfTestClass { } z + 1 } + + def futureUnitIfElse(y: Int): Future[Unit] = async { + val f = base(y) + if (y > 0) { + State.result = await(f) + 2 + } else { + State.result = await(f) - 2 + } + } } +object State { + @volatile var result: Int = 0 +} @RunWith(classOf[JUnit4]) class AnfTransformSpec { @@ -91,4 +103,12 @@ class AnfTransformSpec { val res = Await.result(fut, 2 seconds) res mustBe (15) } + + @Test + def `Unit-typed if-else in tail position`() { + val o = new AnfTestClass + val fut = o.futureUnitIfElse(10) + Await.result(fut, 2 seconds) + State.result mustBe (14) + } } |