diff options
author | phaller <hallerp@gmail.com> | 2012-05-25 17:36:19 +0200 |
---|---|---|
committer | phaller <philipp.haller@typesafe.com> | 2012-06-14 17:42:31 +0200 |
commit | 796024c7429a03e974a7d8e1dc5c80b84f82467d (patch) | |
tree | b4c900328a78ec7941530d1ecc27d544284febaf | |
parent | 4448e7a530626105776997fde04b4af76bf13de1 (diff) | |
download | scala-796024c7429a03e974a7d8e1dc5c80b84f82467d.tar.gz scala-796024c7429a03e974a7d8e1dc5c80b84f82467d.tar.bz2 scala-796024c7429a03e974a7d8e1dc5c80b84f82467d.zip |
CPS: enable return expressions in CPS code if they are in tail position
Adds a stack of context trees to AnnotationChecker(s). Here, it is used to
enforce that adaptAnnotations will only adapt the annotation of a return
expression if the expected type is a CPS type.
The remove-tail-return transform is reasonably general, covering cases such as
try-catch-finally. Moreover, an error is thrown if, in a CPS method, a return
is encountered which is not in a tail position such that it will be removed
subsequently.
11 files changed, 197 insertions, 2 deletions
diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala index 2bdae4164a..7e4f50ecd7 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala @@ -3987,7 +3987,11 @@ trait Typers extends Modes with Adaptations with Tags { ReturnWithoutTypeError(tree, enclMethod.owner) } else { context.enclMethod.returnsSeen = true + //TODO: also pass enclMethod.tree, so that adaptAnnotations can check whether return is in tail position + pushAnnotationContext(tree) val expr1: Tree = typed(expr, EXPRmode | BYVALmode, restpt.tpe) + popAnnotationContext() + // Warn about returning a value if no value can be returned. if (restpt.tpe.typeSymbol == UnitClass) { // The typing in expr1 says expr is Unit (it has already been coerced if diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala index 862b19d0a4..574a76484c 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala @@ -17,6 +17,8 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { * Checks whether @cps annotations conform */ object checker extends AnnotationChecker { + private var contextStack: List[Tree] = List() + private def addPlusMarker(tp: Type) = tp withAnnotation newPlusMarker() private def addMinusMarker(tp: Type) = tp withAnnotation newMinusMarker() @@ -25,6 +27,12 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { private def cleanPlusWith(tp: Type)(newAnnots: AnnotationInfo*) = cleanPlus(tp) withAnnotations newAnnots.toList + override def pushAnnotationContext(tree: Tree): Unit = + contextStack = tree :: contextStack + + override def popAnnotationContext(): Unit = + contextStack = contextStack.tail + /** Check annotations to decide whether tpe1 <:< tpe2 */ def annotationsConform(tpe1: Type, tpe2: Type): Boolean = { if (!cpsEnabled) return true @@ -116,6 +124,11 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { bounds } + private def inReturnContext(tree: Tree): Boolean = !contextStack.isEmpty && (contextStack.head match { + case Return(tree1) => tree1 == tree + case _ => false + }) + override def canAdaptAnnotations(tree: Tree, mode: Int, pt: Type): Boolean = { if (!cpsEnabled) return false vprintln("can adapt annotations? " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt) @@ -170,6 +183,9 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { vprintln("yes we can!! (byval)") return true } + } else if (inReturnContext(tree)) { + vprintln("yes we can!! (return)") + return true } } false @@ -209,6 +225,12 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { val res = tree modifyType addMinusMarker vprintln("adapted annotations (by val) of " + tree + " to " + res.tpe) res + } else if (inReturnContext(tree) && !hasPlusMarker(tree.tpe) && annotsTree.isEmpty && annotsExpected.nonEmpty) { + // add a marker annotation that will make tree.tpe behave as pt, subtyping wise + // tree will look like having no annotation + val res = tree modifyType (_ withAnnotations List(newPlusMarker())) + vprintln("adapted annotations (return) of " + tree + " to " + res.tpe) + res } else tree } @@ -464,6 +486,11 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { } tpe + case ret @ Return(expr) => + if (hasPlusMarker(expr.tpe)) + ret setType expr.tpe + ret.tpe + case _ => tpe } diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala index 3a1dc87a6a..1e4d9f21de 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala @@ -3,6 +3,7 @@ package scala.tools.selectivecps import scala.tools.nsc.Global +import scala.collection.mutable.ListBuffer trait CPSUtils { val global: Global @@ -135,4 +136,39 @@ trait CPSUtils { case _ => None } } + + def isTailReturn(retExpr: Tree, body: Tree): Boolean = { + val removedIds = ListBuffer[Int]() + removeTailReturn(body, removedIds) + removedIds contains retExpr.id + } + + def removeTailReturn(tree: Tree, ids: ListBuffer[Int]): Tree = tree match { + case Block(stms, r @ Return(expr)) => + ids += r.id + treeCopy.Block(tree, stms, expr) + + case Block(stms, expr) => + treeCopy.Block(tree, stms, removeTailReturn(expr, ids)) + + case If(cond, thenExpr, elseExpr) => + treeCopy.If(tree, cond, removeTailReturn(thenExpr, ids), removeTailReturn(elseExpr, ids)) + + case Try(block, catches, finalizer) => + treeCopy.Try(tree, + removeTailReturn(block, ids), + (catches map (t => removeTailReturn(t, ids))).asInstanceOf[List[CaseDef]], + removeTailReturn(finalizer, ids)) + + case CaseDef(pat, guard, r @ Return(expr)) => + ids += r.id + treeCopy.CaseDef(tree, pat, guard, expr) + + case CaseDef(pat, guard, body) => + treeCopy.CaseDef(tree, pat, guard, removeTailReturn(body, ids)) + + case _ => + tree + } + } diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala index 017c8d24fd..e02e02d975 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala @@ -9,6 +9,8 @@ import scala.tools.nsc.plugins._ import scala.tools.nsc.ast._ +import scala.collection.mutable.ListBuffer + /** * In methods marked @cps, explicitly name results of calls to other @cps methods */ @@ -46,10 +48,20 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with // this would cause infinite recursion. But we could remove the // ValDef case here. - case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs) => + case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs0) => debuglog("transforming " + dd.symbol) atOwner(dd.symbol) { + val tailReturns = ListBuffer[Int]() + val rhs = removeTailReturn(rhs0, tailReturns) + // throw an error if there is a Return tree which is not in tail position + rhs0 foreach { + case r @ Return(_) => + if (!tailReturns.contains(r.id)) + unit.error(r.pos, "return expressions in CPS code must be in tail position") + case _ => /* do nothing */ + } + val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe)) debuglog("result "+rhs1) @@ -153,7 +165,6 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with } } - def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): Tree = { transTailValue(tree, cpsA, cpsR) match { case (Nil, b) => b diff --git a/src/reflect/scala/reflect/internal/AnnotationCheckers.scala b/src/reflect/scala/reflect/internal/AnnotationCheckers.scala index 449b0ca0bc..3848ab51b8 100644 --- a/src/reflect/scala/reflect/internal/AnnotationCheckers.scala +++ b/src/reflect/scala/reflect/internal/AnnotationCheckers.scala @@ -47,6 +47,10 @@ trait AnnotationCheckers { * before. If the implementing class cannot do the adaptiong, it * should return the tree unchanged.*/ def adaptAnnotations(tree: Tree, mode: Int, pt: Type): Tree = tree + + def pushAnnotationContext(tree: Tree): Unit = {} + + def popAnnotationContext(): Unit = {} } // Syncnote: Annotation checkers inaccessible to reflection, so no sync in var necessary. @@ -118,4 +122,14 @@ trait AnnotationCheckers { annotationCheckers.foldLeft(tree)((tree, checker) => checker.adaptAnnotations(tree, mode, pt)) } + + def pushAnnotationContext(tree: Tree): Unit = { + annotationCheckers.foreach(checker => + checker.pushAnnotationContext(tree)) + } + + def popAnnotationContext(): Unit = { + annotationCheckers.foreach(checker => + checker.popAnnotationContext()) + } } diff --git a/test/files/continuations-neg/ts-1681-nontail-return.check b/test/files/continuations-neg/ts-1681-nontail-return.check new file mode 100644 index 0000000000..8fe15f154b --- /dev/null +++ b/test/files/continuations-neg/ts-1681-nontail-return.check @@ -0,0 +1,4 @@ +ts-1681-nontail-return.scala:10: error: return expressions in CPS code must be in tail position + return v + ^ +one error found diff --git a/test/files/continuations-neg/ts-1681-nontail-return.scala b/test/files/continuations-neg/ts-1681-nontail-return.scala new file mode 100644 index 0000000000..af86ad304f --- /dev/null +++ b/test/files/continuations-neg/ts-1681-nontail-return.scala @@ -0,0 +1,18 @@ +import scala.util.continuations._ + +class ReturnRepro { + def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) } + def caller = reset { println(p(3)) } + + def p(i: Int): Int @cpsParam[Unit, Any] = { + val v= s1 + 3 + if (v == 8) + return v + v + 1 + } +} + +object Test extends App { + val repro = new ReturnRepro + repro.caller +} diff --git a/test/files/continuations-run/ts-1681-2.check b/test/files/continuations-run/ts-1681-2.check new file mode 100644 index 0000000000..35b3c93780 --- /dev/null +++ b/test/files/continuations-run/ts-1681-2.check @@ -0,0 +1,5 @@ +8 +hi +8 +from try +8 diff --git a/test/files/continuations-run/ts-1681-2.scala b/test/files/continuations-run/ts-1681-2.scala new file mode 100644 index 0000000000..8a896dec2c --- /dev/null +++ b/test/files/continuations-run/ts-1681-2.scala @@ -0,0 +1,44 @@ +import scala.util.continuations._ + +class ReturnRepro { + def s1: Int @cps[Any] = shift { k => k(5) } + def caller = reset { println(p(3)) } + def caller2 = reset { println(p2(3)) } + def caller3 = reset { println(p3(3)) } + + def p(i: Int): Int @cps[Any] = { + val v= s1 + 3 + return v + } + + def p2(i: Int): Int @cps[Any] = { + val v = s1 + 3 + if (v > 0) { + println("hi") + return v + } else { + println("hi") + return 8 + } + } + + def p3(i: Int): Int @cps[Any] = { + val v = s1 + 3 + try { + println("from try") + return v + } catch { + case e: Exception => + println("from catch") + return 7 + } + } + +} + +object Test extends App { + val repro = new ReturnRepro + repro.caller + repro.caller2 + repro.caller3 +} diff --git a/test/files/continuations-run/ts-1681.check b/test/files/continuations-run/ts-1681.check new file mode 100644 index 0000000000..85176d8e66 --- /dev/null +++ b/test/files/continuations-run/ts-1681.check @@ -0,0 +1,3 @@ +8 +hi +8 diff --git a/test/files/continuations-run/ts-1681.scala b/test/files/continuations-run/ts-1681.scala new file mode 100644 index 0000000000..efb1abae15 --- /dev/null +++ b/test/files/continuations-run/ts-1681.scala @@ -0,0 +1,29 @@ +import scala.util.continuations._ + +class ReturnRepro { + def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) } + def caller = reset { println(p(3)) } + def caller2 = reset { println(p2(3)) } + + def p(i: Int): Int @cpsParam[Unit, Any] = { + val v= s1 + 3 + return v + } + + def p2(i: Int): Int @cpsParam[Unit, Any] = { + val v = s1 + 3 + if (v > 0) { + println("hi") + return v + } else { + println("hi") + return 8 + } + } +} + +object Test extends App { + val repro = new ReturnRepro + repro.caller + repro.caller2 +} |