diff options
Diffstat (limited to 'src/continuations')
3 files changed, 128 insertions, 13 deletions
diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala index a20ff1667b..dbbc428428 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala @@ -150,10 +150,8 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { if ((mode & global.analyzer.EXPRmode) != 0) { if ((annots1 corresponds annots2)(_.atp <:< _.atp)) { vprintln("already same, can't adapt further") - return false - } - - if (annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.BYVALmode) == 0)) { + false + } else if (annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.BYVALmode) == 0)) { //println("can adapt annotations? " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt) if (!hasPlusMarker(tree.tpe)) { // val base = tree.tpe <:< removeAllCPSAnnotations(pt) @@ -163,17 +161,26 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { // TBD: use same or not? //if (same) { vprintln("yes we can!! (unit)") - return true + true //} - } - } else if (!annots1.isEmpty && ((mode & global.analyzer.BYVALmode) != 0)) { - if (!hasMinusMarker(tree.tpe)) { + } else false + } else if (!hasPlusMarker(tree.tpe) && annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.RETmode) != 0)) { + vprintln("checking enclosing method's result type without annotations") + tree.tpe <:< pt.withoutAnnotations + } else if (!hasMinusMarker(tree.tpe) && !annots1.isEmpty && ((mode & global.analyzer.BYVALmode) != 0)) { + val optCpsTypes: Option[(Type, Type)] = cpsParamTypes(tree.tpe) + val optExpectedCpsTypes: Option[(Type, Type)] = cpsParamTypes(pt) + if (optCpsTypes.isEmpty || optExpectedCpsTypes.isEmpty) { vprintln("yes we can!! (byval)") - return true + true + } else { // check cps param types + val cpsTpes = optCpsTypes.get + val cpsPts = optExpectedCpsTypes.get + // class cpsParam[-B,+C], therefore: + cpsPts._1 <:< cpsTpes._1 && cpsTpes._2 <:< cpsPts._2 } - } - } - false + } else false + } else false } override def adaptAnnotations(tree: Tree, mode: Int, pt: Type): Tree = { @@ -184,6 +191,7 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { val patMode = (mode & global.analyzer.PATTERNmode) != 0 val exprMode = (mode & global.analyzer.EXPRmode) != 0 val byValMode = (mode & global.analyzer.BYVALmode) != 0 + val retMode = (mode & global.analyzer.RETmode) != 0 val annotsTree = cpsParamAnnotation(tree.tpe) val annotsExpected = cpsParamAnnotation(pt) @@ -210,6 +218,12 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { val res = tree modifyType addMinusMarker vprintln("adapted annotations (by val) of " + tree + " to " + res.tpe) res + } else if (retMode && !hasPlusMarker(tree.tpe) && annotsTree.isEmpty && annotsExpected.nonEmpty) { + // add a marker annotation that will make tree.tpe behave as pt, subtyping wise + // tree will look like having no annotation + val res = tree modifyType (_ withAnnotations List(newPlusMarker())) + vprintln("adapted annotations (return) of " + tree + " to " + res.tpe) + res } else tree } @@ -466,6 +480,13 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { } tpe + case ret @ Return(expr) => + // only change type if this return will (a) be removed (in tail position) or (b) cause + // an error (not in tail position) + if (expr.tpe != null && hasPlusMarker(expr.tpe)) + ret setType expr.tpe + ret.tpe + case _ => tpe } diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala index 46c644bcd6..db7ef4b540 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala @@ -27,6 +27,8 @@ trait CPSUtils { val shiftUnit0 = newTermName("shiftUnit0") val shiftUnit = newTermName("shiftUnit") val shiftUnitR = newTermName("shiftUnitR") + val reset = newTermName("reset") + val reset0 = newTermName("reset0") } lazy val MarkerCPSSym = rootMirror.getRequiredClass("scala.util.continuations.cpsSym") @@ -45,10 +47,15 @@ trait CPSUtils { lazy val MethShiftR = definitions.getMember(ModCPS, cpsNames.shiftR) lazy val MethReify = definitions.getMember(ModCPS, cpsNames.reify) lazy val MethReifyR = definitions.getMember(ModCPS, cpsNames.reifyR) + lazy val MethReset = definitions.getMember(ModCPS, cpsNames.reset) + lazy val MethReset0 = definitions.getMember(ModCPS, cpsNames.reset0) lazy val allCPSAnnotations = List(MarkerCPSSym, MarkerCPSTypes, MarkerCPSSynth, MarkerCPSAdaptPlus, MarkerCPSAdaptMinus) + lazy val allCPSMethods = List(MethShiftUnit, MethShiftUnit0, MethShiftUnitR, MethShift, MethShiftR, + MethReify, MethReifyR, MethReset, MethReset0) + // TODO - needed? Can these all use the same annotation info? protected def newSynthMarker() = newMarker(MarkerCPSSynth) protected def newPlusMarker() = newMarker(MarkerCPSAdaptPlus) diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala index 51760d2807..ed313342ec 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala @@ -32,6 +32,84 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with implicit val _unit = unit // allow code in CPSUtils.scala to report errors var cpsAllowed: Boolean = false // detect cps code in places we do not handle (yet) + /* Does not attempt to remove tail returns. + * Only checks whether the method contains returns as well as calls to CPS methods. + */ + class CheckCPSMethodsTraverser extends Traverser { + var cpsMethodsSeen = false + var returnsSeen: Option[Tree] = None + override def traverse(tree: Tree): Unit = tree match { + case Ident(_) | Select(_, _) => + if (tree.hasSymbol && (allCPSMethods contains tree.symbol)) + cpsMethodsSeen = true + super.traverse(tree) + case Return(_) => + returnsSeen = Some(tree) + case _ => + super.traverse(tree) + } + } + + /* Also checks whether the method calls a CPS method in which case an error is produced + */ + class RemoveTailReturnsTransformer extends Transformer { + var cpsMethodsSeen = false + var returnsSeen: Option[Tree] = None + override def transform(tree: Tree): Tree = tree match { + case Block(stms, r @ Return(expr)) => + returnsSeen = Some(r) + treeCopy.Block(tree, stms, expr) + + case Block(stms, expr) => + treeCopy.Block(tree, stms, transform(expr)) + + case If(cond, r1 @ Return(thenExpr), r2 @ Return(elseExpr)) => + returnsSeen = Some(r1) + treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) + + case If(cond, thenExpr, elseExpr) => + treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) + + case Try(block, catches, finalizer) => + treeCopy.Try(tree, + transform(block), + (catches map (t => transform(t))).asInstanceOf[List[CaseDef]], + transform(finalizer)) + + case CaseDef(pat, guard, r @ Return(expr)) => + returnsSeen = Some(r) + treeCopy.CaseDef(tree, pat, guard, expr) + + case CaseDef(pat, guard, body) => + treeCopy.CaseDef(tree, pat, guard, transform(body)) + + case Return(_) => + unit.error(tree.pos, "return expressions in CPS code must be in tail position") + tree + + case Ident(_) | Select(_, _) => + if (tree.hasSymbol && (allCPSMethods contains tree.symbol)) + cpsMethodsSeen = true + super.transform(tree) + + case _ => + super.transform(tree) + } + } + + def removeTailReturns(body: Tree): Tree = { + // support body with single return expression + body match { + case Return(expr) => expr + case _ => + val tr = new RemoveTailReturnsTransformer + val res = tr.transform(body) + if (tr.returnsSeen.nonEmpty && tr.cpsMethodsSeen) + unit.error(tr.returnsSeen.get.pos, "return expressions not allowed, since method calls CPS method") + res + } + } + override def transform(tree: Tree): Tree = { if (!cpsEnabled) return tree @@ -46,10 +124,19 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with // this would cause infinite recursion. But we could remove the // ValDef case here. - case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs) => + case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs0) => debuglog("transforming " + dd.symbol) atOwner(dd.symbol) { + val rhs = + if (cpsParamTypes(tpt.tpe).nonEmpty) removeTailReturns(rhs0) + else { + val checker = new CheckCPSMethodsTraverser + checker.traverse(rhs0) + if (checker.returnsSeen.nonEmpty && checker.cpsMethodsSeen) + unit.error(checker.returnsSeen.get.pos, "return expressions not allowed, since method calls CPS method") + rhs0 + } val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe)) debuglog("result "+rhs1) |