summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Phillips <paulp@improving.org>2012-02-22 11:55:31 -0800
committerPaul Phillips <paulp@improving.org>2012-02-22 11:55:31 -0800
commit67ce08e882c7e8c7eeaf117d0cbbddd64455fc68 (patch)
treebce1d1113285e28374edc166f3b4a8b4d8b68186
parent2eeb8d3d40ea243baf7ebeeba981c8cb903a7ab0 (diff)
parent9c486dcc733a9db4298a33b291dcb60680dbcdc2 (diff)
downloadscala-67ce08e882c7e8c7eeaf117d0cbbddd64455fc68.tar.gz
scala-67ce08e882c7e8c7eeaf117d0cbbddd64455fc68.tar.bz2
scala-67ce08e882c7e8c7eeaf117d0cbbddd64455fc68.zip
Merge remote-tracking branch 'szeiger/backport/cps-if-then-else-91dbfb2' into 2.9.x
-rw-r--r--src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala60
-rw-r--r--test/files/continuations-run/ifelse4.check4
-rw-r--r--test/files/continuations-run/ifelse4.scala31
-rw-r--r--test/files/continuations-run/patvirt.check2
-rw-r--r--test/files/continuations-run/patvirt.scala32
5 files changed, 101 insertions, 28 deletions
diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala
index 7927efc6d5..e2e31a5d37 100644
--- a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala
+++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala
@@ -173,40 +173,44 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
(Nil, tree1, cpsA)
- case If(cond, thenp, elsep) =>
-
- val (condStats, condVal, spc) = transInlineValue(cond, cpsA)
-
- val (cpsA2, cpsR2) = (spc, linearize(spc, getAnswerTypeAnn(tree.tpe)))
-// val (cpsA2, cpsR2) = (None, getAnswerTypeAnn(tree.tpe))
- val thenVal = transExpr(thenp, cpsA2, cpsR2)
- val elseVal = transExpr(elsep, cpsA2, cpsR2)
-
- // check that then and else parts agree (not necessary any more, but left as sanity check)
- if (cpsR.isDefined) {
- if (elsep == EmptyTree)
- unit.error(tree.pos, "always need else part in cps code")
- }
- if (hasAnswerTypeAnn(thenVal.tpe) != hasAnswerTypeAnn(elseVal.tpe)) {
- unit.error(tree.pos, "then and else parts must both be cps code or neither of them")
- }
+ case If(cond, thenp, elsep) =>
+ /* possible situations:
+ cps before (cpsA)
+ cps in condition (spc) <-- synth flag set if *only* here!
+ cps in (one or both) branches */
+ val (condStats, condVal, spc) = transInlineValue(cond, cpsA)
+ val (cpsA2, cpsR2) = if (tree.tpe hasAnnotation MarkerCPSSynth)
+ (spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) else
+ (None, getAnswerTypeAnn(tree.tpe)) // if no cps in condition, branches must conform to tree.tpe directly
+ val thenVal = transExpr(thenp, cpsA2, cpsR2)
+ val elseVal = transExpr(elsep, cpsA2, cpsR2)
+
+ // check that then and else parts agree (not necessary any more, but left as sanity check)
+ if (cpsR.isDefined) {
+ if (elsep == EmptyTree)
+ unit.error(tree.pos, "always need else part in cps code")
+ }
+ if (hasAnswerTypeAnn(thenVal.tpe) != hasAnswerTypeAnn(elseVal.tpe)) {
+ unit.error(tree.pos, "then and else parts must both be cps code or neither of them")
+ }
- (condStats, updateSynthFlag(treeCopy.If(tree, condVal, thenVal, elseVal)), spc)
+ (condStats, updateSynthFlag(treeCopy.If(tree, condVal, thenVal, elseVal)), spc)
- case Match(selector, cases) =>
+ case Match(selector, cases) =>
- val (selStats, selVal, spc) = transInlineValue(selector, cpsA)
- val (cpsA2, cpsR2) = (spc, linearize(spc, getAnswerTypeAnn(tree.tpe)))
-// val (cpsA2, cpsR2) = (None, getAnswerTypeAnn(tree.tpe))
+ val (selStats, selVal, spc) = transInlineValue(selector, cpsA)
+ val (cpsA2, cpsR2) = if (tree.tpe hasAnnotation MarkerCPSSynth)
+ (spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) else
+ (None, getAnswerTypeAnn(tree.tpe))
- val caseVals = for {
- cd @ CaseDef(pat, guard, body) <- cases
+ val caseVals = for {
+ cd @ CaseDef(pat, guard, body) <- cases
val bodyVal = transExpr(body, cpsA2, cpsR2)
- } yield {
- treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal)
- }
+ } yield {
+ treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal)
+ }
- (selStats, updateSynthFlag(treeCopy.Match(tree, selVal, caseVals)), spc)
+ (selStats, updateSynthFlag(treeCopy.Match(tree, selVal, caseVals)), spc)
case ldef @ LabelDef(name, params, rhs) =>
diff --git a/test/files/continuations-run/ifelse4.check b/test/files/continuations-run/ifelse4.check
new file mode 100644
index 0000000000..2545dd49a0
--- /dev/null
+++ b/test/files/continuations-run/ifelse4.check
@@ -0,0 +1,4 @@
+10
+10
+10
+10
diff --git a/test/files/continuations-run/ifelse4.scala b/test/files/continuations-run/ifelse4.scala
new file mode 100644
index 0000000000..8360375283
--- /dev/null
+++ b/test/files/continuations-run/ifelse4.scala
@@ -0,0 +1,31 @@
+import scala.util.continuations._
+
+object Test {
+ def sh(x1:Int) = shift( (k: Int => Int) => k(k(k(x1))))
+
+ def testA(x1: Int): Int @cps[Int] = {
+ sh(x1)
+ if (x1==42) x1 else sh(x1)
+ }
+
+ def testB(x1: Int): Int @cps[Int] = {
+ if (sh(x1)==43) x1 else x1
+ }
+
+ def testC(x1: Int): Int @cps[Int] = {
+ sh(x1)
+ if (sh(x1)==44) x1 else x1
+ }
+
+ def testD(x1: Int): Int @cps[Int] = {
+ sh(x1)
+ if (sh(x1)==45) x1 else sh(x1)
+ }
+
+ def main(args: Array[String]): Any = {
+ println(reset(1 + testA(7)))
+ println(reset(1 + testB(9)))
+ println(reset(1 + testC(9)))
+ println(reset(1 + testD(7)))
+ }
+} \ No newline at end of file
diff --git a/test/files/continuations-run/patvirt.check b/test/files/continuations-run/patvirt.check
new file mode 100644
index 0000000000..b5fa014ad3
--- /dev/null
+++ b/test/files/continuations-run/patvirt.check
@@ -0,0 +1,2 @@
+10
+11
diff --git a/test/files/continuations-run/patvirt.scala b/test/files/continuations-run/patvirt.scala
new file mode 100644
index 0000000000..5b4d312f20
--- /dev/null
+++ b/test/files/continuations-run/patvirt.scala
@@ -0,0 +1,32 @@
+import scala.util.continuations._
+
+object Test {
+ def sh(x1:Int) = shift( (k: Int => Int) => k(k(k(x1))))
+
+ def test(x1: Int) = {
+ val o7 = {
+ val o6 = {
+ val o3 =
+ if (7 == x1) Some(x1)
+ else None
+
+ if (o3.isEmpty) None
+ else Some(sh(x1))
+ }
+ if (o6.isEmpty) {
+ val o5 =
+ if (8 == x1) Some(x1)
+ else None
+
+ if (o5.isEmpty) None
+ else Some(sh(x1))
+ } else o6
+ }
+ o7.get
+ }
+
+ def main(args: Array[String]): Any = {
+ println(reset(1 + test(7)))
+ println(reset(1 + test(8)))
+ }
+}