summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorphaller <hallerp@gmail.com>2012-07-23 14:03:22 +0200
committerphaller <hallerp@gmail.com>2012-11-02 18:26:54 +0100
commit2c00346a9756190aef497cabbb6f77ecc25212c8 (patch)
tree96b369ab54d8a592734096961b084202f7056d09
parent8aeae6231600a9bae14299bc92c8a6109c67c5af (diff)
downloadscala-2c00346a9756190aef497cabbb6f77ecc25212c8.tar.gz
scala-2c00346a9756190aef497cabbb6f77ecc25212c8.tar.bz2
scala-2c00346a9756190aef497cabbb6f77ecc25212c8.zip
SI-5314 - CPS transform of return statement fails
Enable return expressions in CPS code if they are in tail position. Note that tail returns are only removed in methods that do not call `shift` or `reset` (otherwise, an error is reported). Addresses the issues pointed out in a previous pull request: https://github.com/scala/scala/pull/720 - Addresses all issues mentioned here: https://github.com/scala/scala/pull/720#issuecomment-6429705 - Move transformation methods to SelectiveANFTransform.scala: https://github.com/scala/scala/pull/720#commitcomment-1477497 - Do not keep a list of tail returns. Tests: - continuations-neg/t5314-missing-result-type.scala - continuations-neg/t5314-type-error.scala - continuations-neg/t5314-npe.scala - continuations-neg/t5314-return-reset.scala - continuations-run/t5314.scala - continuations-run/t5314-2.scala - continuations-run/t5314-3.scala
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/Modes.scala6
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/Typers.scala3
-rw-r--r--src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala45
-rw-r--r--src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala7
-rw-r--r--src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala89
-rw-r--r--test/files/continuations-neg/t5314-missing-result-type.check4
-rw-r--r--test/files/continuations-neg/t5314-missing-result-type.scala13
-rw-r--r--test/files/continuations-neg/t5314-npe.check4
-rw-r--r--test/files/continuations-neg/t5314-npe.scala3
-rw-r--r--test/files/continuations-neg/t5314-return-reset.check4
-rw-r--r--test/files/continuations-neg/t5314-return-reset.scala21
-rw-r--r--test/files/continuations-neg/t5314-type-error.check6
-rw-r--r--test/files/continuations-neg/t5314-type-error.scala17
-rw-r--r--test/files/continuations-run/t5314-2.check5
-rw-r--r--test/files/continuations-run/t5314-2.scala44
-rw-r--r--test/files/continuations-run/t5314-3.check4
-rw-r--r--test/files/continuations-run/t5314-3.scala27
-rw-r--r--test/files/continuations-run/t5314.check4
-rw-r--r--test/files/continuations-run/t5314.scala41
19 files changed, 332 insertions, 15 deletions
diff --git a/src/compiler/scala/tools/nsc/typechecker/Modes.scala b/src/compiler/scala/tools/nsc/typechecker/Modes.scala
index ad4be4662c..b8308ebabf 100644
--- a/src/compiler/scala/tools/nsc/typechecker/Modes.scala
+++ b/src/compiler/scala/tools/nsc/typechecker/Modes.scala
@@ -86,6 +86,10 @@ trait Modes {
*/
final val TYPEPATmode = 0x10000
+ /** RETmode is set when we are typing a return expression.
+ */
+ final val RETmode = 0x20000
+
final private val StickyModes = EXPRmode | PATTERNmode | TYPEmode | ALTmode
final def onlyStickyModes(mode: Int) =
@@ -130,4 +134,4 @@ trait Modes {
def modeString(mode: Int): String =
if (mode == 0) "NOmode"
else (modeNameMap filterKeys (bit => inAllModes(mode, bit))).values mkString " "
-} \ No newline at end of file
+}
diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala
index 1a77ee6a9b..2921bb1434 100644
--- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala
+++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala
@@ -3181,7 +3181,8 @@ trait Typers extends Modes {
errorTree(tree, enclMethod.owner + " has return statement; needs result type")
else {
context.enclMethod.returnsSeen = true
- val expr1: Tree = typed(expr, EXPRmode | BYVALmode, restpt.tpe)
+ val expr1: Tree = typed(expr, EXPRmode | BYVALmode | RETmode, restpt.tpe)
+
// Warn about returning a value if no value can be returned.
if (restpt.tpe.typeSymbol == UnitClass) {
// The typing in expr1 says expr is Unit (it has already been coerced if
diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala
index c4599bebbc..4a6f1df521 100644
--- a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala
+++ b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala
@@ -154,10 +154,8 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
if ((mode & global.analyzer.EXPRmode) != 0) {
if ((annots1 corresponds annots2)(_.atp <:< _.atp)) {
vprintln("already same, can't adapt further")
- return false
- }
-
- if (annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.BYVALmode) == 0)) {
+ false
+ } else if (annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.BYVALmode) == 0)) {
//println("can adapt annotations? " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt)
if (!hasPlusMarker(tree.tpe)) {
// val base = tree.tpe <:< removeAllCPSAnnotations(pt)
@@ -167,17 +165,26 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
// TBD: use same or not?
//if (same) {
vprintln("yes we can!! (unit)")
- return true
+ true
//}
- }
- } else if (!annots1.isEmpty && ((mode & global.analyzer.BYVALmode) != 0)) {
- if (!hasMinusMarker(tree.tpe)) {
+ } else false
+ } else if (!hasPlusMarker(tree.tpe) && annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.RETmode) != 0)) {
+ vprintln("checking enclosing method's result type without annotations")
+ tree.tpe <:< pt.withoutAnnotations
+ } else if (!hasMinusMarker(tree.tpe) && !annots1.isEmpty && ((mode & global.analyzer.BYVALmode) != 0)) {
+ val optCpsTypes: Option[(Type, Type)] = cpsParamTypes(tree.tpe)
+ val optExpectedCpsTypes: Option[(Type, Type)] = cpsParamTypes(pt)
+ if (optCpsTypes.isEmpty || optExpectedCpsTypes.isEmpty) {
vprintln("yes we can!! (byval)")
- return true
+ true
+ } else { // check cps param types
+ val cpsTpes = optCpsTypes.get
+ val cpsPts = optExpectedCpsTypes.get
+ // class cpsParam[-B,+C], therefore:
+ cpsPts._1 <:< cpsTpes._1 && cpsTpes._2 <:< cpsPts._2
}
- }
- }
- false
+ } else false
+ } else false
}
override def adaptAnnotations(tree: Tree, mode: Int, pt: Type): Tree = {
@@ -188,6 +195,7 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
val patMode = (mode & global.analyzer.PATTERNmode) != 0
val exprMode = (mode & global.analyzer.EXPRmode) != 0
val byValMode = (mode & global.analyzer.BYVALmode) != 0
+ val retMode = (mode & global.analyzer.RETmode) != 0
val annotsTree = cpsParamAnnotation(tree.tpe)
val annotsExpected = cpsParamAnnotation(pt)
@@ -214,6 +222,12 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
val res = tree modifyType addMinusMarker
vprintln("adapted annotations (by val) of " + tree + " to " + res.tpe)
res
+ } else if (retMode && !hasPlusMarker(tree.tpe) && annotsTree.isEmpty && annotsExpected.nonEmpty) {
+ // add a marker annotation that will make tree.tpe behave as pt, subtyping wise
+ // tree will look like having no annotation
+ val res = tree modifyType (_ withAnnotations List(newPlusMarker()))
+ vprintln("adapted annotations (return) of " + tree + " to " + res.tpe)
+ res
} else tree
}
@@ -469,6 +483,13 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
}
tpe
+ case ret @ Return(expr) =>
+ // only change type if this return will (a) be removed (in tail position) or (b) cause
+ // an error (not in tail position)
+ if (expr.tpe != null && hasPlusMarker(expr.tpe))
+ ret setType expr.tpe
+ ret.tpe
+
case _ =>
tpe
}
diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala
index 545172d407..de06f0c222 100644
--- a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala
+++ b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala
@@ -29,6 +29,8 @@ trait CPSUtils {
val shiftUnit0 = newTermName("shiftUnit0")
val shiftUnit = newTermName("shiftUnit")
val shiftUnitR = newTermName("shiftUnitR")
+ val reset = newTermName("reset")
+ val reset0 = newTermName("reset0")
}
lazy val MarkerCPSSym = definitions.getClass("scala.util.continuations.cpsSym")
@@ -47,10 +49,15 @@ trait CPSUtils {
lazy val MethShiftR = definitions.getMember(ModCPS, cpsNames.shiftR)
lazy val MethReify = definitions.getMember(ModCPS, cpsNames.reify)
lazy val MethReifyR = definitions.getMember(ModCPS, cpsNames.reifyR)
+ lazy val MethReset = definitions.getMember(ModCPS, cpsNames.reset)
+ lazy val MethReset0 = definitions.getMember(ModCPS, cpsNames.reset0)
lazy val allCPSAnnotations = List(MarkerCPSSym, MarkerCPSTypes, MarkerCPSSynth,
MarkerCPSAdaptPlus, MarkerCPSAdaptMinus)
+ lazy val allCPSMethods = List(MethShiftUnit, MethShiftUnit0, MethShiftUnitR, MethShift, MethShiftR,
+ MethReify, MethReifyR, MethReset, MethReset0)
+
def debuglog(s: => String): Unit = {
//Console.err.println(s)
}
diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala
index 1f71eb285b..b1d7f3d48c 100644
--- a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala
+++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala
@@ -32,6 +32,84 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
implicit val _unit = unit // allow code in CPSUtils.scala to report errors
var cpsAllowed: Boolean = false // detect cps code in places we do not handle (yet)
+ /* Does not attempt to remove tail returns.
+ * Only checks whether the method contains returns as well as calls to CPS methods.
+ */
+ class CheckCPSMethodsTraverser extends Traverser {
+ var cpsMethodsSeen = false
+ var returnsSeen: Option[Tree] = None
+ override def traverse(tree: Tree): Unit = tree match {
+ case Ident(_) | Select(_, _) =>
+ if (tree.hasSymbol && (allCPSMethods contains tree.symbol))
+ cpsMethodsSeen = true
+ super.traverse(tree)
+ case Return(_) =>
+ returnsSeen = Some(tree)
+ case _ =>
+ super.traverse(tree)
+ }
+ }
+
+ /* Also checks whether the method calls a CPS method in which case an error is produced
+ */
+ class RemoveTailReturnsTransformer extends Transformer {
+ var cpsMethodsSeen = false
+ var returnsSeen: Option[Tree] = None
+ override def transform(tree: Tree): Tree = tree match {
+ case Block(stms, r @ Return(expr)) =>
+ returnsSeen = Some(r)
+ treeCopy.Block(tree, stms, expr)
+
+ case Block(stms, expr) =>
+ treeCopy.Block(tree, stms, transform(expr))
+
+ case If(cond, r1 @ Return(thenExpr), r2 @ Return(elseExpr)) =>
+ returnsSeen = Some(r1)
+ treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr))
+
+ case If(cond, thenExpr, elseExpr) =>
+ treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr))
+
+ case Try(block, catches, finalizer) =>
+ treeCopy.Try(tree,
+ transform(block),
+ (catches map (t => transform(t))).asInstanceOf[List[CaseDef]],
+ transform(finalizer))
+
+ case CaseDef(pat, guard, r @ Return(expr)) =>
+ returnsSeen = Some(r)
+ treeCopy.CaseDef(tree, pat, guard, expr)
+
+ case CaseDef(pat, guard, body) =>
+ treeCopy.CaseDef(tree, pat, guard, transform(body))
+
+ case Return(_) =>
+ unit.error(tree.pos, "return expressions in CPS code must be in tail position")
+ tree
+
+ case Ident(_) | Select(_, _) =>
+ if (tree.hasSymbol && (allCPSMethods contains tree.symbol))
+ cpsMethodsSeen = true
+ super.transform(tree)
+
+ case _ =>
+ super.transform(tree)
+ }
+ }
+
+ def removeTailReturns(body: Tree): Tree = {
+ // support body with single return expression
+ body match {
+ case Return(expr) => expr
+ case _ =>
+ val tr = new RemoveTailReturnsTransformer
+ val res = tr.transform(body)
+ if (tr.returnsSeen.nonEmpty && tr.cpsMethodsSeen)
+ unit.error(tr.returnsSeen.get.pos, "return expressions not allowed, since method calls CPS method")
+ res
+ }
+ }
+
override def transform(tree: Tree): Tree = {
if (!cpsEnabled) return tree
@@ -46,10 +124,19 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
// this would cause infinite recursion. But we could remove the
// ValDef case here.
- case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
+ case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs0) =>
debuglog("transforming " + dd.symbol)
atOwner(dd.symbol) {
+ val rhs =
+ if (cpsParamTypes(tpt.tpe).nonEmpty) removeTailReturns(rhs0)
+ else {
+ val checker = new CheckCPSMethodsTraverser
+ checker.traverse(rhs0)
+ if (checker.returnsSeen.nonEmpty && checker.cpsMethodsSeen)
+ unit.error(checker.returnsSeen.get.pos, "return expressions not allowed, since method calls CPS method")
+ rhs0
+ }
val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe))
debuglog("result "+rhs1)
diff --git a/test/files/continuations-neg/t5314-missing-result-type.check b/test/files/continuations-neg/t5314-missing-result-type.check
new file mode 100644
index 0000000000..341e580cf3
--- /dev/null
+++ b/test/files/continuations-neg/t5314-missing-result-type.check
@@ -0,0 +1,4 @@
+t5314-missing-result-type.scala:6: error: method bar has return statement; needs result type
+ def bar(x:Int) = return foo(x)
+ ^
+one error found
diff --git a/test/files/continuations-neg/t5314-missing-result-type.scala b/test/files/continuations-neg/t5314-missing-result-type.scala
new file mode 100644
index 0000000000..d7c5043a86
--- /dev/null
+++ b/test/files/continuations-neg/t5314-missing-result-type.scala
@@ -0,0 +1,13 @@
+import scala.util.continuations._
+
+object Test extends App {
+ def foo(x:Int): Int @cps[Int] = x
+
+ def bar(x:Int) = return foo(x)
+
+ reset {
+ val res = bar(8)
+ println(res)
+ res
+ }
+}
diff --git a/test/files/continuations-neg/t5314-npe.check b/test/files/continuations-neg/t5314-npe.check
new file mode 100644
index 0000000000..b5f024aa89
--- /dev/null
+++ b/test/files/continuations-neg/t5314-npe.check
@@ -0,0 +1,4 @@
+t5314-npe.scala:2: error: method bar has return statement; needs result type
+ def bar(x:Int) = { return x; x } // NPE
+ ^
+one error found
diff --git a/test/files/continuations-neg/t5314-npe.scala b/test/files/continuations-neg/t5314-npe.scala
new file mode 100644
index 0000000000..2b5966e07c
--- /dev/null
+++ b/test/files/continuations-neg/t5314-npe.scala
@@ -0,0 +1,3 @@
+object Test extends App {
+ def bar(x:Int) = { return x; x } // NPE
+}
diff --git a/test/files/continuations-neg/t5314-return-reset.check b/test/files/continuations-neg/t5314-return-reset.check
new file mode 100644
index 0000000000..e288556546
--- /dev/null
+++ b/test/files/continuations-neg/t5314-return-reset.check
@@ -0,0 +1,4 @@
+t5314-return-reset.scala:14: error: return expressions not allowed, since method calls CPS method
+ if (rnd.nextInt(100) > 50) return 5 // not allowed, since method is calling `reset`
+ ^
+one error found
diff --git a/test/files/continuations-neg/t5314-return-reset.scala b/test/files/continuations-neg/t5314-return-reset.scala
new file mode 100644
index 0000000000..df9d58e4cb
--- /dev/null
+++ b/test/files/continuations-neg/t5314-return-reset.scala
@@ -0,0 +1,21 @@
+import scala.util.continuations._
+import scala.util.Random
+
+object Test extends App {
+ val rnd = new Random
+
+ def foo(x: Int): Int @cps[Int] = shift { k => k(x) }
+
+ def bar(x: Int): Int @cps[Int] = return foo(x)
+
+ def caller(): Int = {
+ val v: Int = reset {
+ val res: Int = bar(8)
+ if (rnd.nextInt(100) > 50) return 5 // not allowed, since method is calling `reset`
+ 42
+ }
+ v
+ }
+
+ caller()
+}
diff --git a/test/files/continuations-neg/t5314-type-error.check b/test/files/continuations-neg/t5314-type-error.check
new file mode 100644
index 0000000000..e540f368cd
--- /dev/null
+++ b/test/files/continuations-neg/t5314-type-error.check
@@ -0,0 +1,6 @@
+t5314-type-error.scala:7: error: type mismatch;
+ found : Int @util.continuations.package.cps[Int]
+ required: Int @util.continuations.package.cps[String]
+ def bar(x:Int): Int @cps[String] = return foo(x)
+ ^
+one error found
diff --git a/test/files/continuations-neg/t5314-type-error.scala b/test/files/continuations-neg/t5314-type-error.scala
new file mode 100644
index 0000000000..e36ce6c203
--- /dev/null
+++ b/test/files/continuations-neg/t5314-type-error.scala
@@ -0,0 +1,17 @@
+import scala.util.continuations._
+
+object Test extends App {
+ def foo(x:Int): Int @cps[Int] = shift { k => k(x) }
+
+ // should be a type error
+ def bar(x:Int): Int @cps[String] = return foo(x)
+
+ def caller(): Unit = {
+ val v: String = reset {
+ val res: Int = bar(8)
+ "hello"
+ }
+ }
+
+ caller()
+}
diff --git a/test/files/continuations-run/t5314-2.check b/test/files/continuations-run/t5314-2.check
new file mode 100644
index 0000000000..35b3c93780
--- /dev/null
+++ b/test/files/continuations-run/t5314-2.check
@@ -0,0 +1,5 @@
+8
+hi
+8
+from try
+8
diff --git a/test/files/continuations-run/t5314-2.scala b/test/files/continuations-run/t5314-2.scala
new file mode 100644
index 0000000000..8a896dec2c
--- /dev/null
+++ b/test/files/continuations-run/t5314-2.scala
@@ -0,0 +1,44 @@
+import scala.util.continuations._
+
+class ReturnRepro {
+ def s1: Int @cps[Any] = shift { k => k(5) }
+ def caller = reset { println(p(3)) }
+ def caller2 = reset { println(p2(3)) }
+ def caller3 = reset { println(p3(3)) }
+
+ def p(i: Int): Int @cps[Any] = {
+ val v= s1 + 3
+ return v
+ }
+
+ def p2(i: Int): Int @cps[Any] = {
+ val v = s1 + 3
+ if (v > 0) {
+ println("hi")
+ return v
+ } else {
+ println("hi")
+ return 8
+ }
+ }
+
+ def p3(i: Int): Int @cps[Any] = {
+ val v = s1 + 3
+ try {
+ println("from try")
+ return v
+ } catch {
+ case e: Exception =>
+ println("from catch")
+ return 7
+ }
+ }
+
+}
+
+object Test extends App {
+ val repro = new ReturnRepro
+ repro.caller
+ repro.caller2
+ repro.caller3
+}
diff --git a/test/files/continuations-run/t5314-3.check b/test/files/continuations-run/t5314-3.check
new file mode 100644
index 0000000000..71489f097c
--- /dev/null
+++ b/test/files/continuations-run/t5314-3.check
@@ -0,0 +1,4 @@
+enter return expr
+8
+hi
+8
diff --git a/test/files/continuations-run/t5314-3.scala b/test/files/continuations-run/t5314-3.scala
new file mode 100644
index 0000000000..62c547f5a2
--- /dev/null
+++ b/test/files/continuations-run/t5314-3.scala
@@ -0,0 +1,27 @@
+import scala.util.continuations._
+
+class ReturnRepro {
+ def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) }
+ def caller = reset { println(p(3)) }
+ def caller2 = reset { println(p2(3)) }
+
+ def p(i: Int): Int @cpsParam[Unit, Any] = {
+ val v= s1 + 3
+ return { println("enter return expr"); v }
+ }
+
+ def p2(i: Int): Int @cpsParam[Unit, Any] = {
+ val v = s1 + 3
+ if (v > 0) {
+ return { println("hi"); v }
+ } else {
+ return { println("hi"); 8 }
+ }
+ }
+}
+
+object Test extends App {
+ val repro = new ReturnRepro
+ repro.caller
+ repro.caller2
+}
diff --git a/test/files/continuations-run/t5314.check b/test/files/continuations-run/t5314.check
new file mode 100644
index 0000000000..4951e7caae
--- /dev/null
+++ b/test/files/continuations-run/t5314.check
@@ -0,0 +1,4 @@
+8
+hi
+8
+8
diff --git a/test/files/continuations-run/t5314.scala b/test/files/continuations-run/t5314.scala
new file mode 100644
index 0000000000..0bdea19824
--- /dev/null
+++ b/test/files/continuations-run/t5314.scala
@@ -0,0 +1,41 @@
+import scala.util.continuations._
+
+class ReturnRepro {
+ def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) }
+ def caller = reset { println(p(3)) }
+ def caller2 = reset { println(p2(3)) }
+
+ def p(i: Int): Int @cpsParam[Unit, Any] = {
+ val v= s1 + 3
+ return v
+ }
+
+ def p2(i: Int): Int @cpsParam[Unit, Any] = {
+ val v = s1 + 3
+ if (v > 0) {
+ println("hi")
+ return v
+ } else {
+ println("hi")
+ return 8
+ }
+ }
+}
+
+object Test extends App {
+ def foo(x:Int): Int @cps[Int] = shift { k => k(x) }
+
+ def bar(x:Int): Int @cps[Int] = return foo(x)
+
+ def nocps(x: Int): Int = { return x; x }
+
+ val repro = new ReturnRepro
+ repro.caller
+ repro.caller2
+
+ reset {
+ val res = bar(8)
+ println(res)
+ res
+ }
+}