From b9c3a3b083833d2166b3d2ca2d1ccbca36b83c71 Mon Sep 17 00:00:00 2001 From: phaller Date: Wed, 8 Aug 2012 16:20:07 +0200 Subject: Replace CheckCPSMethodTraverser with additional parameter on transformer methods Other fixes: - remove CPSUtils.allCPSMethods - add clarifying comment about adding a plus marker to a return expression --- .../tools/selectivecps/CPSAnnotationChecker.scala | 3 + .../plugin/scala/tools/selectivecps/CPSUtils.scala | 7 -- .../tools/selectivecps/SelectiveANFTransform.scala | 84 +++++++--------------- 3 files changed, 28 insertions(+), 66 deletions(-) (limited to 'src/continuations/plugin') diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala index dbbc428428..122fe32ed3 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala @@ -221,6 +221,9 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { } else if (retMode && !hasPlusMarker(tree.tpe) && annotsTree.isEmpty && annotsExpected.nonEmpty) { // add a marker annotation that will make tree.tpe behave as pt, subtyping wise // tree will look like having no annotation + + // note that we are only adding a plus marker if the method's result type is a CPS type + // (annotsExpected.nonEmpty == cpsParamAnnotation(pt).nonEmpty) val res = tree modifyType (_ withAnnotations List(newPlusMarker())) vprintln("adapted annotations (return) of " + tree + " to " + res.tpe) res diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala index db7ef4b540..46c644bcd6 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala @@ -27,8 +27,6 @@ trait CPSUtils { val shiftUnit0 = newTermName("shiftUnit0") val shiftUnit = newTermName("shiftUnit") val shiftUnitR = newTermName("shiftUnitR") - val reset = newTermName("reset") - val reset0 = newTermName("reset0") } lazy val MarkerCPSSym = rootMirror.getRequiredClass("scala.util.continuations.cpsSym") @@ -47,15 +45,10 @@ trait CPSUtils { lazy val MethShiftR = definitions.getMember(ModCPS, cpsNames.shiftR) lazy val MethReify = definitions.getMember(ModCPS, cpsNames.reify) lazy val MethReifyR = definitions.getMember(ModCPS, cpsNames.reifyR) - lazy val MethReset = definitions.getMember(ModCPS, cpsNames.reset) - lazy val MethReset0 = definitions.getMember(ModCPS, cpsNames.reset0) lazy val allCPSAnnotations = List(MarkerCPSSym, MarkerCPSTypes, MarkerCPSSynth, MarkerCPSAdaptPlus, MarkerCPSAdaptMinus) - lazy val allCPSMethods = List(MethShiftUnit, MethShiftUnit0, MethShiftUnitR, MethShift, MethShiftR, - MethReify, MethReifyR, MethReset, MethReset0) - // TODO - needed? Can these all use the same annotation info? protected def newSynthMarker() = newMarker(MarkerCPSSynth) protected def newPlusMarker() = newMarker(MarkerCPSAdaptPlus) diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala index ed313342ec..1783264e5c 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala @@ -32,39 +32,15 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with implicit val _unit = unit // allow code in CPSUtils.scala to report errors var cpsAllowed: Boolean = false // detect cps code in places we do not handle (yet) - /* Does not attempt to remove tail returns. - * Only checks whether the method contains returns as well as calls to CPS methods. - */ - class CheckCPSMethodsTraverser extends Traverser { - var cpsMethodsSeen = false - var returnsSeen: Option[Tree] = None - override def traverse(tree: Tree): Unit = tree match { - case Ident(_) | Select(_, _) => - if (tree.hasSymbol && (allCPSMethods contains tree.symbol)) - cpsMethodsSeen = true - super.traverse(tree) - case Return(_) => - returnsSeen = Some(tree) - case _ => - super.traverse(tree) - } - } - - /* Also checks whether the method calls a CPS method in which case an error is produced - */ - class RemoveTailReturnsTransformer extends Transformer { - var cpsMethodsSeen = false - var returnsSeen: Option[Tree] = None + object RemoveTailReturnsTransformer extends Transformer { override def transform(tree: Tree): Tree = tree match { case Block(stms, r @ Return(expr)) => - returnsSeen = Some(r) treeCopy.Block(tree, stms, expr) case Block(stms, expr) => treeCopy.Block(tree, stms, transform(expr)) case If(cond, r1 @ Return(thenExpr), r2 @ Return(elseExpr)) => - returnsSeen = Some(r1) treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) case If(cond, thenExpr, elseExpr) => @@ -77,7 +53,6 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with transform(finalizer)) case CaseDef(pat, guard, r @ Return(expr)) => - returnsSeen = Some(r) treeCopy.CaseDef(tree, pat, guard, expr) case CaseDef(pat, guard, body) => @@ -87,11 +62,6 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with unit.error(tree.pos, "return expressions in CPS code must be in tail position") tree - case Ident(_) | Select(_, _) => - if (tree.hasSymbol && (allCPSMethods contains tree.symbol)) - cpsMethodsSeen = true - super.transform(tree) - case _ => super.transform(tree) } @@ -101,12 +71,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with // support body with single return expression body match { case Return(expr) => expr - case _ => - val tr = new RemoveTailReturnsTransformer - val res = tr.transform(body) - if (tr.returnsSeen.nonEmpty && tr.cpsMethodsSeen) - unit.error(tr.returnsSeen.get.pos, "return expressions not allowed, since method calls CPS method") - res + case _ => RemoveTailReturnsTransformer.transform(body) } } @@ -130,14 +95,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with atOwner(dd.symbol) { val rhs = if (cpsParamTypes(tpt.tpe).nonEmpty) removeTailReturns(rhs0) - else { - val checker = new CheckCPSMethodsTraverser - checker.traverse(rhs0) - if (checker.returnsSeen.nonEmpty && checker.cpsMethodsSeen) - unit.error(checker.returnsSeen.get.pos, "return expressions not allowed, since method calls CPS method") - rhs0 - } - val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe)) + else rhs0 + val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe))(getExternalAnswerTypeAnn(tpt.tpe).isDefined) debuglog("result "+rhs1) debuglog("result is of type "+rhs1.tpe) @@ -162,6 +121,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with val ext = getExternalAnswerTypeAnn(body.tpe) val pureBody = getAnswerTypeAnn(body.tpe).isEmpty + implicit val isParentImpure = ext.isDefined def transformPureMatch(tree: Tree, selector: Tree, cases: List[CaseDef]) = { val caseVals = cases map { case cd @ CaseDef(pat, guard, body) => @@ -241,8 +201,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with } - def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): Tree = { - transTailValue(tree, cpsA, cpsR) match { + def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean = false): Tree = { + transTailValue(tree, cpsA, cpsR)(cpsR.isDefined || isAnyParentImpure) match { case (Nil, b) => b case (a, b) => treeCopy.Block(tree, a,b) @@ -250,7 +210,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with } - def transArgList(fun: Tree, args: List[Tree], cpsA: CPSInfo): (List[List[Tree]], List[Tree], CPSInfo) = { + def transArgList(fun: Tree, args: List[Tree], cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[List[Tree]], List[Tree], CPSInfo) = { val formals = fun.tpe.paramTypes val overshoot = args.length - formals.length @@ -259,7 +219,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with val (stm,expr) = (for ((a,tp) <- args.zip(formals ::: List.fill(overshoot)(NoType))) yield { tp match { case TypeRef(_, ByNameParamClass, List(elemtp)) => - (Nil, transExpr(a, None, getAnswerTypeAnn(elemtp))) + // note that we're not passing just isAnyParentImpure + (Nil, transExpr(a, None, getAnswerTypeAnn(elemtp))(getAnswerTypeAnn(elemtp).isDefined || isAnyParentImpure)) case _ => val (valStm, valExpr, valSpc) = transInlineValue(a, spc) spc = valSpc @@ -271,7 +232,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with } - def transValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): (List[Tree], Tree, CPSInfo) = { + // precondition: cpsR.isDefined "implies" isAnyParentImpure + def transValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree, CPSInfo) = { // return value: (stms, expr, spc), where spc is CPSInfo after stms but *before* expr implicit val pos = tree.pos tree match { @@ -279,7 +241,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with val (cpsA2, cpsR2) = (cpsA, linearize(cpsA, getAnswerTypeAnn(tree.tpe))) // tbd // val (cpsA2, cpsR2) = (None, getAnswerTypeAnn(tree.tpe)) - val (a, b) = transBlock(stms, expr, cpsA2, cpsR2) + val (a, b) = transBlock(stms, expr, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure) val tree1 = (treeCopy.Block(tree, a, b)) // no updateSynthFlag here!!! (Nil, tree1, cpsA) @@ -293,8 +255,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with val (cpsA2, cpsR2) = if (hasSynthMarker(tree.tpe)) (spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) else (None, getAnswerTypeAnn(tree.tpe)) // if no cps in condition, branches must conform to tree.tpe directly - val thenVal = transExpr(thenp, cpsA2, cpsR2) - val elseVal = transExpr(elsep, cpsA2, cpsR2) + val thenVal = transExpr(thenp, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure) + val elseVal = transExpr(elsep, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure) // check that then and else parts agree (not necessary any more, but left as sanity check) if (cpsR.isDefined) { @@ -314,7 +276,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with else (None, getAnswerTypeAnn(tree.tpe)) val caseVals = cases map { case cd @ CaseDef(pat, guard, body) => - val bodyVal = transExpr(body, cpsA2, cpsR2) + val bodyVal = transExpr(body, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure) treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal) } @@ -332,7 +294,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with // currentOwner.newMethod(name, tree.pos, Flags.SYNTHETIC) setInfo ldef.symbol.info val sym = ldef.symbol resetFlag Flags.LABEL val rhs1 = rhs //new TreeSymSubstituter(List(ldef.symbol), List(sym)).transform(rhs) - val rhsVal = transExpr(rhs1, None, getAnswerTypeAnn(tree.tpe)) changeOwner (currentOwner -> sym) + val rhsVal = transExpr(rhs1, None, getAnswerTypeAnn(tree.tpe))(getAnswerTypeAnn(tree.tpe).isDefined || isAnyParentImpure) changeOwner (currentOwner -> sym) val stm1 = localTyper.typed(DefDef(sym, rhsVal)) // since virtpatmat does not rely on fall-through, don't call the labels it emits @@ -371,6 +333,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with (stms, updateSynthFlag(treeCopy.Assign(tree, transform(lhs), expr)), spc) case Return(expr0) => + if (isAnyParentImpure) + unit.error(tree.pos, "return expression not allowed, since method calls CPS method") val (stms, expr, spc) = transInlineValue(expr0, cpsA) (stms, updateSynthFlag(treeCopy.Return(tree, expr)), spc) @@ -408,7 +372,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with } } - def transTailValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): (List[Tree], Tree) = { + // precondition: cpsR.isDefined "implies" isAnyParentImpure + def transTailValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree) = { val (stms, expr, spc) = transValue(tree, cpsA, cpsR) @@ -485,7 +450,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with (stms, expr) } - def transInlineValue(tree: Tree, cpsA: CPSInfo): (List[Tree], Tree, CPSInfo) = { + def transInlineValue(tree: Tree, cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree, CPSInfo) = { val (stms, expr, spc) = transValue(tree, cpsA, None) // never required to be cps @@ -512,7 +477,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with - def transInlineStm(stm: Tree, cpsA: CPSInfo): (List[Tree], CPSInfo) = { + def transInlineStm(stm: Tree, cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], CPSInfo) = { stm match { // TODO: what about DefDefs? @@ -542,7 +507,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with } } - def transBlock(stms: List[Tree], expr: Tree, cpsA: CPSInfo, cpsR: CPSInfo): (List[Tree], Tree) = { + // precondition: cpsR.isDefined "implies" isAnyParentImpure + def transBlock(stms: List[Tree], expr: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree) = { def rec(currStats: List[Tree], currAns: CPSInfo, accum: List[Tree]): (List[Tree], Tree) = currStats match { case Nil => -- cgit v1.2.3