From 796024c7429a03e974a7d8e1dc5c80b84f82467d Mon Sep 17 00:00:00 2001 From: phaller Date: Fri, 25 May 2012 17:36:19 +0200 Subject: CPS: enable return expressions in CPS code if they are in tail position Adds a stack of context trees to AnnotationChecker(s). Here, it is used to enforce that adaptAnnotations will only adapt the annotation of a return expression if the expected type is a CPS type. The remove-tail-return transform is reasonably general, covering cases such as try-catch-finally. Moreover, an error is thrown if, in a CPS method, a return is encountered which is not in a tail position such that it will be removed subsequently. --- .../tools/selectivecps/CPSAnnotationChecker.scala | 27 ++++++++++++++++ .../plugin/scala/tools/selectivecps/CPSUtils.scala | 36 ++++++++++++++++++++++ .../tools/selectivecps/SelectiveANFTransform.scala | 15 +++++++-- 3 files changed, 76 insertions(+), 2 deletions(-) (limited to 'src/continuations') diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala index 862b19d0a4..574a76484c 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala @@ -17,6 +17,8 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { * Checks whether @cps annotations conform */ object checker extends AnnotationChecker { + private var contextStack: List[Tree] = List() + private def addPlusMarker(tp: Type) = tp withAnnotation newPlusMarker() private def addMinusMarker(tp: Type) = tp withAnnotation newMinusMarker() @@ -25,6 +27,12 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { private def cleanPlusWith(tp: Type)(newAnnots: AnnotationInfo*) = cleanPlus(tp) withAnnotations newAnnots.toList + override def pushAnnotationContext(tree: Tree): Unit = + contextStack = tree :: contextStack + + override def popAnnotationContext(): Unit = + contextStack = contextStack.tail + /** Check annotations to decide whether tpe1 <:< tpe2 */ def annotationsConform(tpe1: Type, tpe2: Type): Boolean = { if (!cpsEnabled) return true @@ -116,6 +124,11 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { bounds } + private def inReturnContext(tree: Tree): Boolean = !contextStack.isEmpty && (contextStack.head match { + case Return(tree1) => tree1 == tree + case _ => false + }) + override def canAdaptAnnotations(tree: Tree, mode: Int, pt: Type): Boolean = { if (!cpsEnabled) return false vprintln("can adapt annotations? " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt) @@ -170,6 +183,9 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { vprintln("yes we can!! (byval)") return true } + } else if (inReturnContext(tree)) { + vprintln("yes we can!! (return)") + return true } } false @@ -209,6 +225,12 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { val res = tree modifyType addMinusMarker vprintln("adapted annotations (by val) of " + tree + " to " + res.tpe) res + } else if (inReturnContext(tree) && !hasPlusMarker(tree.tpe) && annotsTree.isEmpty && annotsExpected.nonEmpty) { + // add a marker annotation that will make tree.tpe behave as pt, subtyping wise + // tree will look like having no annotation + val res = tree modifyType (_ withAnnotations List(newPlusMarker())) + vprintln("adapted annotations (return) of " + tree + " to " + res.tpe) + res } else tree } @@ -464,6 +486,11 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { } tpe + case ret @ Return(expr) => + if (hasPlusMarker(expr.tpe)) + ret setType expr.tpe + ret.tpe + case _ => tpe } diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala index 3a1dc87a6a..1e4d9f21de 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala @@ -3,6 +3,7 @@ package scala.tools.selectivecps import scala.tools.nsc.Global +import scala.collection.mutable.ListBuffer trait CPSUtils { val global: Global @@ -135,4 +136,39 @@ trait CPSUtils { case _ => None } } + + def isTailReturn(retExpr: Tree, body: Tree): Boolean = { + val removedIds = ListBuffer[Int]() + removeTailReturn(body, removedIds) + removedIds contains retExpr.id + } + + def removeTailReturn(tree: Tree, ids: ListBuffer[Int]): Tree = tree match { + case Block(stms, r @ Return(expr)) => + ids += r.id + treeCopy.Block(tree, stms, expr) + + case Block(stms, expr) => + treeCopy.Block(tree, stms, removeTailReturn(expr, ids)) + + case If(cond, thenExpr, elseExpr) => + treeCopy.If(tree, cond, removeTailReturn(thenExpr, ids), removeTailReturn(elseExpr, ids)) + + case Try(block, catches, finalizer) => + treeCopy.Try(tree, + removeTailReturn(block, ids), + (catches map (t => removeTailReturn(t, ids))).asInstanceOf[List[CaseDef]], + removeTailReturn(finalizer, ids)) + + case CaseDef(pat, guard, r @ Return(expr)) => + ids += r.id + treeCopy.CaseDef(tree, pat, guard, expr) + + case CaseDef(pat, guard, body) => + treeCopy.CaseDef(tree, pat, guard, removeTailReturn(body, ids)) + + case _ => + tree + } + } diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala index 017c8d24fd..e02e02d975 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala @@ -9,6 +9,8 @@ import scala.tools.nsc.plugins._ import scala.tools.nsc.ast._ +import scala.collection.mutable.ListBuffer + /** * In methods marked @cps, explicitly name results of calls to other @cps methods */ @@ -46,10 +48,20 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with // this would cause infinite recursion. But we could remove the // ValDef case here. - case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs) => + case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs0) => debuglog("transforming " + dd.symbol) atOwner(dd.symbol) { + val tailReturns = ListBuffer[Int]() + val rhs = removeTailReturn(rhs0, tailReturns) + // throw an error if there is a Return tree which is not in tail position + rhs0 foreach { + case r @ Return(_) => + if (!tailReturns.contains(r.id)) + unit.error(r.pos, "return expressions in CPS code must be in tail position") + case _ => /* do nothing */ + } + val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe)) debuglog("result "+rhs1) @@ -153,7 +165,6 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with } } - def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): Tree = { transTailValue(tree, cpsA, cpsR) match { case (Nil, b) => b -- cgit v1.2.3