summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorphaller <hallerp@gmail.com>2012-05-25 17:36:19 +0200
committerphaller <philipp.haller@typesafe.com>2012-06-14 17:42:31 +0200
commit796024c7429a03e974a7d8e1dc5c80b84f82467d (patch)
treeb4c900328a78ec7941530d1ecc27d544284febaf /src
parent4448e7a530626105776997fde04b4af76bf13de1 (diff)
downloadscala-796024c7429a03e974a7d8e1dc5c80b84f82467d.tar.gz
scala-796024c7429a03e974a7d8e1dc5c80b84f82467d.tar.bz2
scala-796024c7429a03e974a7d8e1dc5c80b84f82467d.zip
CPS: enable return expressions in CPS code if they are in tail position
Adds a stack of context trees to AnnotationChecker(s). Here, it is used to enforce that adaptAnnotations will only adapt the annotation of a return expression if the expected type is a CPS type. The remove-tail-return transform is reasonably general, covering cases such as try-catch-finally. Moreover, an error is thrown if, in a CPS method, a return is encountered which is not in a tail position such that it will be removed subsequently.
Diffstat (limited to 'src')
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/Typers.scala4
-rw-r--r--src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala27
-rw-r--r--src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala36
-rw-r--r--src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala15
-rw-r--r--src/reflect/scala/reflect/internal/AnnotationCheckers.scala14
5 files changed, 94 insertions, 2 deletions
diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala
index 2bdae4164a..7e4f50ecd7 100644
--- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala
+++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala
@@ -3987,7 +3987,11 @@ trait Typers extends Modes with Adaptations with Tags {
ReturnWithoutTypeError(tree, enclMethod.owner)
} else {
context.enclMethod.returnsSeen = true
+ //TODO: also pass enclMethod.tree, so that adaptAnnotations can check whether return is in tail position
+ pushAnnotationContext(tree)
val expr1: Tree = typed(expr, EXPRmode | BYVALmode, restpt.tpe)
+ popAnnotationContext()
+
// Warn about returning a value if no value can be returned.
if (restpt.tpe.typeSymbol == UnitClass) {
// The typing in expr1 says expr is Unit (it has already been coerced if
diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala
index 862b19d0a4..574a76484c 100644
--- a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala
+++ b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala
@@ -17,6 +17,8 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
* Checks whether @cps annotations conform
*/
object checker extends AnnotationChecker {
+ private var contextStack: List[Tree] = List()
+
private def addPlusMarker(tp: Type) = tp withAnnotation newPlusMarker()
private def addMinusMarker(tp: Type) = tp withAnnotation newMinusMarker()
@@ -25,6 +27,12 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
private def cleanPlusWith(tp: Type)(newAnnots: AnnotationInfo*) =
cleanPlus(tp) withAnnotations newAnnots.toList
+ override def pushAnnotationContext(tree: Tree): Unit =
+ contextStack = tree :: contextStack
+
+ override def popAnnotationContext(): Unit =
+ contextStack = contextStack.tail
+
/** Check annotations to decide whether tpe1 <:< tpe2 */
def annotationsConform(tpe1: Type, tpe2: Type): Boolean = {
if (!cpsEnabled) return true
@@ -116,6 +124,11 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
bounds
}
+ private def inReturnContext(tree: Tree): Boolean = !contextStack.isEmpty && (contextStack.head match {
+ case Return(tree1) => tree1 == tree
+ case _ => false
+ })
+
override def canAdaptAnnotations(tree: Tree, mode: Int, pt: Type): Boolean = {
if (!cpsEnabled) return false
vprintln("can adapt annotations? " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt)
@@ -170,6 +183,9 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
vprintln("yes we can!! (byval)")
return true
}
+ } else if (inReturnContext(tree)) {
+ vprintln("yes we can!! (return)")
+ return true
}
}
false
@@ -209,6 +225,12 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
val res = tree modifyType addMinusMarker
vprintln("adapted annotations (by val) of " + tree + " to " + res.tpe)
res
+ } else if (inReturnContext(tree) && !hasPlusMarker(tree.tpe) && annotsTree.isEmpty && annotsExpected.nonEmpty) {
+ // add a marker annotation that will make tree.tpe behave as pt, subtyping wise
+ // tree will look like having no annotation
+ val res = tree modifyType (_ withAnnotations List(newPlusMarker()))
+ vprintln("adapted annotations (return) of " + tree + " to " + res.tpe)
+ res
} else tree
}
@@ -464,6 +486,11 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes {
}
tpe
+ case ret @ Return(expr) =>
+ if (hasPlusMarker(expr.tpe))
+ ret setType expr.tpe
+ ret.tpe
+
case _ =>
tpe
}
diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala
index 3a1dc87a6a..1e4d9f21de 100644
--- a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala
+++ b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala
@@ -3,6 +3,7 @@
package scala.tools.selectivecps
import scala.tools.nsc.Global
+import scala.collection.mutable.ListBuffer
trait CPSUtils {
val global: Global
@@ -135,4 +136,39 @@ trait CPSUtils {
case _ => None
}
}
+
+ def isTailReturn(retExpr: Tree, body: Tree): Boolean = {
+ val removedIds = ListBuffer[Int]()
+ removeTailReturn(body, removedIds)
+ removedIds contains retExpr.id
+ }
+
+ def removeTailReturn(tree: Tree, ids: ListBuffer[Int]): Tree = tree match {
+ case Block(stms, r @ Return(expr)) =>
+ ids += r.id
+ treeCopy.Block(tree, stms, expr)
+
+ case Block(stms, expr) =>
+ treeCopy.Block(tree, stms, removeTailReturn(expr, ids))
+
+ case If(cond, thenExpr, elseExpr) =>
+ treeCopy.If(tree, cond, removeTailReturn(thenExpr, ids), removeTailReturn(elseExpr, ids))
+
+ case Try(block, catches, finalizer) =>
+ treeCopy.Try(tree,
+ removeTailReturn(block, ids),
+ (catches map (t => removeTailReturn(t, ids))).asInstanceOf[List[CaseDef]],
+ removeTailReturn(finalizer, ids))
+
+ case CaseDef(pat, guard, r @ Return(expr)) =>
+ ids += r.id
+ treeCopy.CaseDef(tree, pat, guard, expr)
+
+ case CaseDef(pat, guard, body) =>
+ treeCopy.CaseDef(tree, pat, guard, removeTailReturn(body, ids))
+
+ case _ =>
+ tree
+ }
+
}
diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala
index 017c8d24fd..e02e02d975 100644
--- a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala
+++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala
@@ -9,6 +9,8 @@ import scala.tools.nsc.plugins._
import scala.tools.nsc.ast._
+import scala.collection.mutable.ListBuffer
+
/**
* In methods marked @cps, explicitly name results of calls to other @cps methods
*/
@@ -46,10 +48,20 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
// this would cause infinite recursion. But we could remove the
// ValDef case here.
- case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs) =>
+ case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs0) =>
debuglog("transforming " + dd.symbol)
atOwner(dd.symbol) {
+ val tailReturns = ListBuffer[Int]()
+ val rhs = removeTailReturn(rhs0, tailReturns)
+ // throw an error if there is a Return tree which is not in tail position
+ rhs0 foreach {
+ case r @ Return(_) =>
+ if (!tailReturns.contains(r.id))
+ unit.error(r.pos, "return expressions in CPS code must be in tail position")
+ case _ => /* do nothing */
+ }
+
val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe))
debuglog("result "+rhs1)
@@ -153,7 +165,6 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with
}
}
-
def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): Tree = {
transTailValue(tree, cpsA, cpsR) match {
case (Nil, b) => b
diff --git a/src/reflect/scala/reflect/internal/AnnotationCheckers.scala b/src/reflect/scala/reflect/internal/AnnotationCheckers.scala
index 449b0ca0bc..3848ab51b8 100644
--- a/src/reflect/scala/reflect/internal/AnnotationCheckers.scala
+++ b/src/reflect/scala/reflect/internal/AnnotationCheckers.scala
@@ -47,6 +47,10 @@ trait AnnotationCheckers {
* before. If the implementing class cannot do the adaptiong, it
* should return the tree unchanged.*/
def adaptAnnotations(tree: Tree, mode: Int, pt: Type): Tree = tree
+
+ def pushAnnotationContext(tree: Tree): Unit = {}
+
+ def popAnnotationContext(): Unit = {}
}
// Syncnote: Annotation checkers inaccessible to reflection, so no sync in var necessary.
@@ -118,4 +122,14 @@ trait AnnotationCheckers {
annotationCheckers.foldLeft(tree)((tree, checker) =>
checker.adaptAnnotations(tree, mode, pt))
}
+
+ def pushAnnotationContext(tree: Tree): Unit = {
+ annotationCheckers.foreach(checker =>
+ checker.pushAnnotationContext(tree))
+ }
+
+ def popAnnotationContext(): Unit = {
+ annotationCheckers.foreach(checker =>
+ checker.popAnnotationContext())
+ }
}