diff options
author | Josh Suereth <Joshua.Suereth@gmail.com> | 2012-06-18 09:45:37 -0700 |
---|---|---|
committer | Josh Suereth <Joshua.Suereth@gmail.com> | 2012-06-18 09:45:37 -0700 |
commit | bf4b982389dcd573e81155bfe233f70f1471b017 (patch) | |
tree | 97239b357d7eb6549ca5bb2867e21956c2c43e40 /src | |
parent | 33901013cdd8aef50c28006a5cf52aa3137f747f (diff) | |
parent | 0ada0706746c9c603bf5bc8a0e6780e5783297cf (diff) | |
download | scala-bf4b982389dcd573e81155bfe233f70f1471b017.tar.gz scala-bf4b982389dcd573e81155bfe233f70f1471b017.tar.bz2 scala-bf4b982389dcd573e81155bfe233f70f1471b017.zip |
Merge pull request #720 from phaller/cps-ticket-1681
CPS: enable return expressions in CPS code if they are in tail position
Diffstat (limited to 'src')
5 files changed, 74 insertions, 3 deletions
diff --git a/src/compiler/scala/tools/nsc/typechecker/Modes.scala b/src/compiler/scala/tools/nsc/typechecker/Modes.scala index 3eff5ef024..bde3ad98c9 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Modes.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Modes.scala @@ -86,6 +86,10 @@ trait Modes { */ final val TYPEPATmode = 0x10000 + /** RETmode is set when we are typing a return expression. + */ + final val RETmode = 0x20000 + final private val StickyModes = EXPRmode | PATTERNmode | TYPEmode | ALTmode final def onlyStickyModes(mode: Int) = diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala index 2bdae4164a..1193d3013a 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala @@ -3987,7 +3987,8 @@ trait Typers extends Modes with Adaptations with Tags { ReturnWithoutTypeError(tree, enclMethod.owner) } else { context.enclMethod.returnsSeen = true - val expr1: Tree = typed(expr, EXPRmode | BYVALmode, restpt.tpe) + val expr1: Tree = typed(expr, EXPRmode | BYVALmode | RETmode, restpt.tpe) + // Warn about returning a value if no value can be returned. if (restpt.tpe.typeSymbol == UnitClass) { // The typing in expr1 says expr is Unit (it has already been coerced if diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala index 862b19d0a4..5b8e9baa21 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala @@ -170,6 +170,9 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { vprintln("yes we can!! (byval)") return true } + } else if ((mode & global.analyzer.RETmode) != 0) { + vprintln("yes we can!! (return)") + return true } } false @@ -183,6 +186,7 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { val patMode = (mode & global.analyzer.PATTERNmode) != 0 val exprMode = (mode & global.analyzer.EXPRmode) != 0 val byValMode = (mode & global.analyzer.BYVALmode) != 0 + val retMode = (mode & global.analyzer.RETmode) != 0 val annotsTree = cpsParamAnnotation(tree.tpe) val annotsExpected = cpsParamAnnotation(pt) @@ -209,6 +213,12 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { val res = tree modifyType addMinusMarker vprintln("adapted annotations (by val) of " + tree + " to " + res.tpe) res + } else if (retMode && !hasPlusMarker(tree.tpe) && annotsTree.isEmpty && annotsExpected.nonEmpty) { + // add a marker annotation that will make tree.tpe behave as pt, subtyping wise + // tree will look like having no annotation + val res = tree modifyType (_ withAnnotations List(newPlusMarker())) + vprintln("adapted annotations (return) of " + tree + " to " + res.tpe) + res } else tree } @@ -464,6 +474,11 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { } tpe + case ret @ Return(expr) => + if (hasPlusMarker(expr.tpe)) + ret setType expr.tpe + ret.tpe + case _ => tpe } diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala index 3a1dc87a6a..765cde5a81 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala @@ -3,6 +3,7 @@ package scala.tools.selectivecps import scala.tools.nsc.Global +import scala.collection.mutable.ListBuffer trait CPSUtils { val global: Global @@ -135,4 +136,43 @@ trait CPSUtils { case _ => None } } + + def isTailReturn(retExpr: Tree, body: Tree): Boolean = { + val removed = ListBuffer[Tree]() + removeTailReturn(body, removed) + removed contains retExpr + } + + def removeTailReturn(tree: Tree, removed: ListBuffer[Tree]): Tree = tree match { + case Block(stms, r @ Return(expr)) => + removed += r + treeCopy.Block(tree, stms, expr) + + case Block(stms, expr) => + treeCopy.Block(tree, stms, removeTailReturn(expr, removed)) + + case If(cond, r1 @ Return(thenExpr), r2 @ Return(elseExpr)) => + removed ++= Seq(r1, r2) + treeCopy.If(tree, cond, removeTailReturn(thenExpr, removed), removeTailReturn(elseExpr, removed)) + + case If(cond, thenExpr, elseExpr) => + treeCopy.If(tree, cond, removeTailReturn(thenExpr, removed), removeTailReturn(elseExpr, removed)) + + case Try(block, catches, finalizer) => + treeCopy.Try(tree, + removeTailReturn(block, removed), + (catches map (t => removeTailReturn(t, removed))).asInstanceOf[List[CaseDef]], + removeTailReturn(finalizer, removed)) + + case CaseDef(pat, guard, r @ Return(expr)) => + removed += r + treeCopy.CaseDef(tree, pat, guard, expr) + + case CaseDef(pat, guard, body) => + treeCopy.CaseDef(tree, pat, guard, removeTailReturn(body, removed)) + + case _ => + tree + } + } diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala index 017c8d24fd..fe465aad0d 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala @@ -9,6 +9,8 @@ import scala.tools.nsc.plugins._ import scala.tools.nsc.ast._ +import scala.collection.mutable.ListBuffer + /** * In methods marked @cps, explicitly name results of calls to other @cps methods */ @@ -46,10 +48,20 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with // this would cause infinite recursion. But we could remove the // ValDef case here. - case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs) => + case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs0) => debuglog("transforming " + dd.symbol) atOwner(dd.symbol) { + val tailReturns = ListBuffer[Tree]() + val rhs = removeTailReturn(rhs0, tailReturns) + // throw an error if there is a Return tree which is not in tail position + rhs0 foreach { + case r @ Return(_) => + if (!tailReturns.contains(r)) + unit.error(r.pos, "return expressions in CPS code must be in tail position") + case _ => /* do nothing */ + } + val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe)) debuglog("result "+rhs1) @@ -153,7 +165,6 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with } } - def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): Tree = { transTailValue(tree, cpsA, cpsR) match { case (Nil, b) => b |