diff options
author | Paul Phillips <paulp@improving.org> | 2012-02-22 11:55:31 -0800 |
---|---|---|
committer | Paul Phillips <paulp@improving.org> | 2012-02-22 11:55:31 -0800 |
commit | 67ce08e882c7e8c7eeaf117d0cbbddd64455fc68 (patch) | |
tree | bce1d1113285e28374edc166f3b4a8b4d8b68186 | |
parent | 2eeb8d3d40ea243baf7ebeeba981c8cb903a7ab0 (diff) | |
parent | 9c486dcc733a9db4298a33b291dcb60680dbcdc2 (diff) | |
download | scala-67ce08e882c7e8c7eeaf117d0cbbddd64455fc68.tar.gz scala-67ce08e882c7e8c7eeaf117d0cbbddd64455fc68.tar.bz2 scala-67ce08e882c7e8c7eeaf117d0cbbddd64455fc68.zip |
Merge remote-tracking branch 'szeiger/backport/cps-if-then-else-91dbfb2' into 2.9.x
-rw-r--r-- | src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala | 60 | ||||
-rw-r--r-- | test/files/continuations-run/ifelse4.check | 4 | ||||
-rw-r--r-- | test/files/continuations-run/ifelse4.scala | 31 | ||||
-rw-r--r-- | test/files/continuations-run/patvirt.check | 2 | ||||
-rw-r--r-- | test/files/continuations-run/patvirt.scala | 32 |
5 files changed, 101 insertions, 28 deletions
diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala index 7927efc6d5..e2e31a5d37 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala @@ -173,40 +173,44 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with (Nil, tree1, cpsA) - case If(cond, thenp, elsep) => - - val (condStats, condVal, spc) = transInlineValue(cond, cpsA) - - val (cpsA2, cpsR2) = (spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) -// val (cpsA2, cpsR2) = (None, getAnswerTypeAnn(tree.tpe)) - val thenVal = transExpr(thenp, cpsA2, cpsR2) - val elseVal = transExpr(elsep, cpsA2, cpsR2) - - // check that then and else parts agree (not necessary any more, but left as sanity check) - if (cpsR.isDefined) { - if (elsep == EmptyTree) - unit.error(tree.pos, "always need else part in cps code") - } - if (hasAnswerTypeAnn(thenVal.tpe) != hasAnswerTypeAnn(elseVal.tpe)) { - unit.error(tree.pos, "then and else parts must both be cps code or neither of them") - } + case If(cond, thenp, elsep) => + /* possible situations: + cps before (cpsA) + cps in condition (spc) <-- synth flag set if *only* here! + cps in (one or both) branches */ + val (condStats, condVal, spc) = transInlineValue(cond, cpsA) + val (cpsA2, cpsR2) = if (tree.tpe hasAnnotation MarkerCPSSynth) + (spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) else + (None, getAnswerTypeAnn(tree.tpe)) // if no cps in condition, branches must conform to tree.tpe directly + val thenVal = transExpr(thenp, cpsA2, cpsR2) + val elseVal = transExpr(elsep, cpsA2, cpsR2) + + // check that then and else parts agree (not necessary any more, but left as sanity check) + if (cpsR.isDefined) { + if (elsep == EmptyTree) + unit.error(tree.pos, "always need else part in cps code") + } + if (hasAnswerTypeAnn(thenVal.tpe) != hasAnswerTypeAnn(elseVal.tpe)) { + unit.error(tree.pos, "then and else parts must both be cps code or neither of them") + } - (condStats, updateSynthFlag(treeCopy.If(tree, condVal, thenVal, elseVal)), spc) + (condStats, updateSynthFlag(treeCopy.If(tree, condVal, thenVal, elseVal)), spc) - case Match(selector, cases) => + case Match(selector, cases) => - val (selStats, selVal, spc) = transInlineValue(selector, cpsA) - val (cpsA2, cpsR2) = (spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) -// val (cpsA2, cpsR2) = (None, getAnswerTypeAnn(tree.tpe)) + val (selStats, selVal, spc) = transInlineValue(selector, cpsA) + val (cpsA2, cpsR2) = if (tree.tpe hasAnnotation MarkerCPSSynth) + (spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) else + (None, getAnswerTypeAnn(tree.tpe)) - val caseVals = for { - cd @ CaseDef(pat, guard, body) <- cases + val caseVals = for { + cd @ CaseDef(pat, guard, body) <- cases val bodyVal = transExpr(body, cpsA2, cpsR2) - } yield { - treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal) - } + } yield { + treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal) + } - (selStats, updateSynthFlag(treeCopy.Match(tree, selVal, caseVals)), spc) + (selStats, updateSynthFlag(treeCopy.Match(tree, selVal, caseVals)), spc) case ldef @ LabelDef(name, params, rhs) => diff --git a/test/files/continuations-run/ifelse4.check b/test/files/continuations-run/ifelse4.check new file mode 100644 index 0000000000..2545dd49a0 --- /dev/null +++ b/test/files/continuations-run/ifelse4.check @@ -0,0 +1,4 @@ +10 +10 +10 +10 diff --git a/test/files/continuations-run/ifelse4.scala b/test/files/continuations-run/ifelse4.scala new file mode 100644 index 0000000000..8360375283 --- /dev/null +++ b/test/files/continuations-run/ifelse4.scala @@ -0,0 +1,31 @@ +import scala.util.continuations._ + +object Test { + def sh(x1:Int) = shift( (k: Int => Int) => k(k(k(x1)))) + + def testA(x1: Int): Int @cps[Int] = { + sh(x1) + if (x1==42) x1 else sh(x1) + } + + def testB(x1: Int): Int @cps[Int] = { + if (sh(x1)==43) x1 else x1 + } + + def testC(x1: Int): Int @cps[Int] = { + sh(x1) + if (sh(x1)==44) x1 else x1 + } + + def testD(x1: Int): Int @cps[Int] = { + sh(x1) + if (sh(x1)==45) x1 else sh(x1) + } + + def main(args: Array[String]): Any = { + println(reset(1 + testA(7))) + println(reset(1 + testB(9))) + println(reset(1 + testC(9))) + println(reset(1 + testD(7))) + } +}
\ No newline at end of file diff --git a/test/files/continuations-run/patvirt.check b/test/files/continuations-run/patvirt.check new file mode 100644 index 0000000000..b5fa014ad3 --- /dev/null +++ b/test/files/continuations-run/patvirt.check @@ -0,0 +1,2 @@ +10 +11 diff --git a/test/files/continuations-run/patvirt.scala b/test/files/continuations-run/patvirt.scala new file mode 100644 index 0000000000..5b4d312f20 --- /dev/null +++ b/test/files/continuations-run/patvirt.scala @@ -0,0 +1,32 @@ +import scala.util.continuations._ + +object Test { + def sh(x1:Int) = shift( (k: Int => Int) => k(k(k(x1)))) + + def test(x1: Int) = { + val o7 = { + val o6 = { + val o3 = + if (7 == x1) Some(x1) + else None + + if (o3.isEmpty) None + else Some(sh(x1)) + } + if (o6.isEmpty) { + val o5 = + if (8 == x1) Some(x1) + else None + + if (o5.isEmpty) None + else Some(sh(x1)) + } else o6 + } + o7.get + } + + def main(args: Array[String]): Any = { + println(reset(1 + test(7))) + println(reset(1 + test(8))) + } +} |