summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Suereth <Joshua.Suereth@gmail.com>2012-06-18 09:45:37 -0700
committerJosh Suereth <Joshua.Suereth@gmail.com>2012-06-18 09:45:37 -0700
commitbf4b982389dcd573e81155bfe233f70f1471b017 (patch)
tree97239b357d7eb6549ca5bb2867e21956c2c43e40
parent33901013cdd8aef50c28006a5cf52aa3137f747f (diff)
parent0ada0706746c9c603bf5bc8a0e6780e5783297cf (diff)
downloadscala-bf4b982389dcd573e81155bfe233f70f1471b017.tar.gz
scala-bf4b982389dcd573e81155bfe233f70f1471b017.tar.bz2
scala-bf4b982389dcd573e81155bfe233f70f1471b017.zip
Merge pull request #720 from phaller/cps-ticket-1681
CPS: enable return expressions in CPS code if they are in tail position
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/Modes.scala4
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/Typers.scala3
-rw-r--r--src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala15
-rw-r--r--src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala40
-rw-r--r--src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala15
-rw-r--r--test/files/continuations-neg/ts-1681-nontail-return.check4
-rw-r--r--test/files/continuations-neg/ts-1681-nontail-return.scala18
-rw-r--r--test/files/continuations-run/ts-1681-2.check5
-rw-r--r--test/files/continuations-run/ts-1681-2.scala44
-rw-r--r--test/files/continuations-run/ts-1681-3.check4
-rw-r--r--test/files/continuations-run/ts-1681-3.scala27
-rw-r--r--test/files/continuations-run/ts-1681.check3
-rw-r--r--test/files/continuations-run/ts-1681.scala29
13 files changed, 208 insertions, 3 deletions
diff --git a/src/compiler/scala/tools/nsc/typechecker/Modes.scala b/src/compiler/scala/tools/nsc/typechecker/Modes.scala
index 3eff5ef024..bde3ad98c9 100644
--- a/src/compiler/scala/tools/nsc/typechecker/Modes.scala
+++ b/src/compiler/scala/tools/nsc/typechecker/Modes.scala
@@ -86,6 +86,10 @@ trait Modes {
*/
final val TYPEPATmode = 0x10000
+ /** RETmode is set when we are typing a return expression.
+ */
+ final val RETmode = 0x20000
+
final private val StickyModes = EXPRmode | PATTERNmode | TYPEmode | ALTmode
final def onlyStickyModes(mode: Int) =
diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala
index 2bdae4164a..1193d3013a 100644
--- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala
+++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala
@@ -3987,7 +3987,8 @@ trait Typers extends Modes with Adaptations with Tags {
ReturnWithoutTypeError(tree, enclMethod.owner)
} else {
context.enclMethod.returnsSeen = true
- val expr1: Tree = typed(expr, EXPRmode | BYVALmode, restpt.tpe)
+ val expr1: Tree = typed(expr, EXPRmode | BYVALmode | RETmode, restpt.tpe)
+
// Warn about returning a value if no value can be returned.
if (restpt.tpe.typeSymbol == UnitClass) {
// The typing in expr1 says expr is Unit (it has already been coerced if
diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala
index 862b19d0a4..5b8e9baa21 100644
--- a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala
+++ b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala
@@ -170,6 +170,9 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
vprintln("yes we can!! (byval)")
return true
}
+ } else if ((mode & global.analyzer.RETmode) != 0) {
+ vprintln("yes we can!! (return)")
+ return true
}
}
false
@@ -183,6 +186,7 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
val patMode = (mode & global.analyzer.PATTERNmode) != 0
val exprMode = (mode & global.analyzer.EXPRmode) != 0
val byValMode = (mode & global.analyzer.BYVALmode) != 0
+ val retMode = (mode & global.analyzer.RETmode) != 0
val annotsTree = cpsParamAnnotation(tree.tpe)
val annotsExpected = cpsParamAnnotation(pt)
@@ -209,6 +213,12 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
val res = tree modifyType addMinusMarker
vprintln("adapted annotations (by val) of " + tree + " to " + res.tpe)
res
+ } else if (retMode && !hasPlusMarker(tree.tpe) && annotsTree.isEmpty && annotsExpected.nonEmpty) {
+ // add a marker annotation that will make tree.tpe behave as pt, subtyping wise
+ // tree will look like having no annotation
+ val res = tree modifyType (_ withAnnotations List(newPlusMarker()))
+ vprintln("adapted annotations (return) of " + tree + " to " + res.tpe)
+ res
} else tree
}
@@ -464,6 +474,11 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
}
tpe
+ case ret @ Return(expr) =>
+ if (hasPlusMarker(expr.tpe))
+ ret setType expr.tpe
+ ret.tpe
+
case _ =>
tpe
}
diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala
index 3a1dc87a6a..765cde5a81 100644
--- a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala
+++ b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala
@@ -3,6 +3,7 @@
package scala.tools.selectivecps
import scala.tools.nsc.Global
+import scala.collection.mutable.ListBuffer
trait CPSUtils {
val global: Global
@@ -135,4 +136,43 @@ trait CPSUtils {
case _ => None
}
}
+
+ def isTailReturn(retExpr: Tree, body: Tree): Boolean = {
+ val removed = ListBuffer[Tree]()
+ removeTailReturn(body, removed)
+ removed contains retExpr
+ }
+
+ def removeTailReturn(tree: Tree, removed: ListBuffer[Tree]): Tree = tree match {
+ case Block(stms, r @ Return(expr)) =>
+ removed += r
+ treeCopy.Block(tree, stms, expr)
+
+ case Block(stms, expr) =>
+ treeCopy.Block(tree, stms, removeTailReturn(expr, removed))
+
+ case If(cond, r1 @ Return(thenExpr), r2 @ Return(elseExpr)) =>
+ removed ++= Seq(r1, r2)
+ treeCopy.If(tree, cond, removeTailReturn(thenExpr, removed), removeTailReturn(elseExpr, removed))
+
+ case If(cond, thenExpr, elseExpr) =>
+ treeCopy.If(tree, cond, removeTailReturn(thenExpr, removed), removeTailReturn(elseExpr, removed))
+
+ case Try(block, catches, finalizer) =>
+ treeCopy.Try(tree,
+ removeTailReturn(block, removed),
+ (catches map (t => removeTailReturn(t, removed))).asInstanceOf[List[CaseDef]],
+ removeTailReturn(finalizer, removed))
+
+ case CaseDef(pat, guard, r @ Return(expr)) =>
+ removed += r
+ treeCopy.CaseDef(tree, pat, guard, expr)
+
+ case CaseDef(pat, guard, body) =>
+ treeCopy.CaseDef(tree, pat, guard, removeTailReturn(body, removed))
+
+ case _ =>
+ tree
+ }
+
}
diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala
index 017c8d24fd..fe465aad0d 100644
--- a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala
+++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala
@@ -9,6 +9,8 @@ import scala.tools.nsc.plugins._
import scala.tools.nsc.ast._
+import scala.collection.mutable.ListBuffer
+
/**
* In methods marked @cps, explicitly name results of calls to other @cps methods
*/
@@ -46,10 +48,20 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
// this would cause infinite recursion. But we could remove the
// ValDef case here.
- case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
+ case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs0) =>
debuglog("transforming " + dd.symbol)
atOwner(dd.symbol) {
+ val tailReturns = ListBuffer[Tree]()
+ val rhs = removeTailReturn(rhs0, tailReturns)
+ // throw an error if there is a Return tree which is not in tail position
+ rhs0 foreach {
+ case r @ Return(_) =>
+ if (!tailReturns.contains(r))
+ unit.error(r.pos, "return expressions in CPS code must be in tail position")
+ case _ => /* do nothing */
+ }
+
val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe))
debuglog("result "+rhs1)
@@ -153,7 +165,6 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
}
}
-
def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): Tree = {
transTailValue(tree, cpsA, cpsR) match {
case (Nil, b) => b
diff --git a/test/files/continuations-neg/ts-1681-nontail-return.check b/test/files/continuations-neg/ts-1681-nontail-return.check
new file mode 100644
index 0000000000..8fe15f154b
--- /dev/null
+++ b/test/files/continuations-neg/ts-1681-nontail-return.check
@@ -0,0 +1,4 @@
+ts-1681-nontail-return.scala:10: error: return expressions in CPS code must be in tail position
+ return v
+ ^
+one error found
diff --git a/test/files/continuations-neg/ts-1681-nontail-return.scala b/test/files/continuations-neg/ts-1681-nontail-return.scala
new file mode 100644
index 0000000000..af86ad304f
--- /dev/null
+++ b/test/files/continuations-neg/ts-1681-nontail-return.scala
@@ -0,0 +1,18 @@
+import scala.util.continuations._
+
+class ReturnRepro {
+ def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) }
+ def caller = reset { println(p(3)) }
+
+ def p(i: Int): Int @cpsParam[Unit, Any] = {
+ val v= s1 + 3
+ if (v == 8)
+ return v
+ v + 1
+ }
+}
+
+object Test extends App {
+ val repro = new ReturnRepro
+ repro.caller
+}
diff --git a/test/files/continuations-run/ts-1681-2.check b/test/files/continuations-run/ts-1681-2.check
new file mode 100644
index 0000000000..35b3c93780
--- /dev/null
+++ b/test/files/continuations-run/ts-1681-2.check
@@ -0,0 +1,5 @@
+8
+hi
+8
+from try
+8
diff --git a/test/files/continuations-run/ts-1681-2.scala b/test/files/continuations-run/ts-1681-2.scala
new file mode 100644
index 0000000000..8a896dec2c
--- /dev/null
+++ b/test/files/continuations-run/ts-1681-2.scala
@@ -0,0 +1,44 @@
+import scala.util.continuations._
+
+class ReturnRepro {
+ def s1: Int @cps[Any] = shift { k => k(5) }
+ def caller = reset { println(p(3)) }
+ def caller2 = reset { println(p2(3)) }
+ def caller3 = reset { println(p3(3)) }
+
+ def p(i: Int): Int @cps[Any] = {
+ val v= s1 + 3
+ return v
+ }
+
+ def p2(i: Int): Int @cps[Any] = {
+ val v = s1 + 3
+ if (v > 0) {
+ println("hi")
+ return v
+ } else {
+ println("hi")
+ return 8
+ }
+ }
+
+ def p3(i: Int): Int @cps[Any] = {
+ val v = s1 + 3
+ try {
+ println("from try")
+ return v
+ } catch {
+ case e: Exception =>
+ println("from catch")
+ return 7
+ }
+ }
+
+}
+
+object Test extends App {
+ val repro = new ReturnRepro
+ repro.caller
+ repro.caller2
+ repro.caller3
+}
diff --git a/test/files/continuations-run/ts-1681-3.check b/test/files/continuations-run/ts-1681-3.check
new file mode 100644
index 0000000000..71489f097c
--- /dev/null
+++ b/test/files/continuations-run/ts-1681-3.check
@@ -0,0 +1,4 @@
+enter return expr
+8
+hi
+8
diff --git a/test/files/continuations-run/ts-1681-3.scala b/test/files/continuations-run/ts-1681-3.scala
new file mode 100644
index 0000000000..62c547f5a2
--- /dev/null
+++ b/test/files/continuations-run/ts-1681-3.scala
@@ -0,0 +1,27 @@
+import scala.util.continuations._
+
+class ReturnRepro {
+ def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) }
+ def caller = reset { println(p(3)) }
+ def caller2 = reset { println(p2(3)) }
+
+ def p(i: Int): Int @cpsParam[Unit, Any] = {
+ val v= s1 + 3
+ return { println("enter return expr"); v }
+ }
+
+ def p2(i: Int): Int @cpsParam[Unit, Any] = {
+ val v = s1 + 3
+ if (v > 0) {
+ return { println("hi"); v }
+ } else {
+ return { println("hi"); 8 }
+ }
+ }
+}
+
+object Test extends App {
+ val repro = new ReturnRepro
+ repro.caller
+ repro.caller2
+}
diff --git a/test/files/continuations-run/ts-1681.check b/test/files/continuations-run/ts-1681.check
new file mode 100644
index 0000000000..85176d8e66
--- /dev/null
+++ b/test/files/continuations-run/ts-1681.check
@@ -0,0 +1,3 @@
+8
+hi
+8
diff --git a/test/files/continuations-run/ts-1681.scala b/test/files/continuations-run/ts-1681.scala
new file mode 100644
index 0000000000..efb1abae15
--- /dev/null
+++ b/test/files/continuations-run/ts-1681.scala
@@ -0,0 +1,29 @@
+import scala.util.continuations._
+
+class ReturnRepro {
+ def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) }
+ def caller = reset { println(p(3)) }
+ def caller2 = reset { println(p2(3)) }
+
+ def p(i: Int): Int @cpsParam[Unit, Any] = {
+ val v= s1 + 3
+ return v
+ }
+
+ def p2(i: Int): Int @cpsParam[Unit, Any] = {
+ val v = s1 + 3
+ if (v > 0) {
+ println("hi")
+ return v
+ } else {
+ println("hi")
+ return 8
+ }
+ }
+}
+
+object Test extends App {
+ val repro = new ReturnRepro
+ repro.caller
+ repro.caller2
+}