diff options
Diffstat (limited to 'src')
5 files changed, 158 insertions, 31 deletions
diff --git a/src/compiler/scala/tools/nsc/symtab/AnnotationCheckers.scala b/src/compiler/scala/tools/nsc/symtab/AnnotationCheckers.scala index d26c5d65a0..be2e520554 100644 --- a/src/compiler/scala/tools/nsc/symtab/AnnotationCheckers.scala +++ b/src/compiler/scala/tools/nsc/symtab/AnnotationCheckers.scala @@ -47,6 +47,13 @@ trait AnnotationCheckers { * before. If the implementing class cannot do the adaptiong, it * should return the tree unchanged.*/ def adaptAnnotations(tree: Tree, mode: Int, pt: Type): Tree = tree + + /** Adapt the type of a return expression. The decision of an annotation checker + * whether the type should be adapted is based on the type of the expression + * which is returned, as well as the result type of the method (pt). + * By default, this method simply returns the passed `default` type. + */ + def adaptTypeOfReturn(tree: Tree, pt: Type, default: => Type): Type = default } /** The list of annotation checkers that have been registered */ @@ -117,4 +124,23 @@ trait AnnotationCheckers { annotationCheckers.foldLeft(tree)((tree, checker) => checker.adaptAnnotations(tree, mode, pt)) } + + /** Let a registered annotation checker adapt the type of a return expression. + * Annotation checkers that cannot do the adaptation should simply return + * the `default` argument. + * + * Note that the result is undefined if more than one annotation checker + * returns an adapted type which is not a subtype of `default`. + */ + def adaptTypeOfReturn(tree: Tree, pt: Type, default: => Type): Type = { + val adaptedTypes = annotationCheckers flatMap { checker => + val adapted = checker.adaptTypeOfReturn(tree, pt, default) + if (!(adapted <:< default)) List(adapted) + else List() + } + adaptedTypes match { + case fst :: _ => fst + case List() => default + } + } } diff --git a/src/compiler/scala/tools/nsc/typechecker/Modes.scala b/src/compiler/scala/tools/nsc/typechecker/Modes.scala index ad4be4662c..b8308ebabf 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Modes.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Modes.scala @@ -86,6 +86,10 @@ trait Modes { */ final val TYPEPATmode = 0x10000 + /** RETmode is set when we are typing a return expression. + */ + final val RETmode = 0x20000 + final private val StickyModes = EXPRmode | PATTERNmode | TYPEmode | ALTmode final def onlyStickyModes(mode: Int) = @@ -130,4 +134,4 @@ trait Modes { def modeString(mode: Int): String = if (mode == 0) "NOmode" else (modeNameMap filterKeys (bit => inAllModes(mode, bit))).values mkString " " -}
\ No newline at end of file +} diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala index 1a77ee6a9b..05b36d128e 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala @@ -3181,7 +3181,7 @@ trait Typers extends Modes { errorTree(tree, enclMethod.owner + " has return statement; needs result type") else { context.enclMethod.returnsSeen = true - val expr1: Tree = typed(expr, EXPRmode | BYVALmode, restpt.tpe) + val expr1: Tree = typed(expr, EXPRmode | BYVALmode | RETmode, restpt.tpe) // Warn about returning a value if no value can be returned. if (restpt.tpe.typeSymbol == UnitClass) { // The typing in expr1 says expr is Unit (it has already been coerced if @@ -3190,7 +3190,8 @@ trait Typers extends Modes { if (typed(expr).tpe.typeSymbol != UnitClass) unit.warning(tree.pos, "enclosing method " + name + " has result type Unit: return value discarded") } - treeCopy.Return(tree, checkDead(expr1)) setSymbol enclMethod.owner setType NothingClass.tpe + treeCopy.Return(tree, checkDead(expr1)).setSymbol(enclMethod.owner) + .setType(adaptTypeOfReturn(expr1, restpt.tpe, NothingClass.tpe)) } } } diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala index c4599bebbc..5f3cb22c44 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala @@ -154,10 +154,8 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { if ((mode & global.analyzer.EXPRmode) != 0) { if ((annots1 corresponds annots2)(_.atp <:< _.atp)) { vprintln("already same, can't adapt further") - return false - } - - if (annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.BYVALmode) == 0)) { + false + } else if (annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.BYVALmode) == 0)) { //println("can adapt annotations? " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt) if (!hasPlusMarker(tree.tpe)) { // val base = tree.tpe <:< removeAllCPSAnnotations(pt) @@ -167,17 +165,26 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { // TBD: use same or not? //if (same) { vprintln("yes we can!! (unit)") - return true + true //} - } - } else if (!annots1.isEmpty && ((mode & global.analyzer.BYVALmode) != 0)) { - if (!hasMinusMarker(tree.tpe)) { + } else false + } else if (!hasPlusMarker(tree.tpe) && annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.RETmode) != 0)) { + vprintln("checking enclosing method's result type without annotations") + tree.tpe <:< pt.withoutAnnotations + } else if (!hasMinusMarker(tree.tpe) && !annots1.isEmpty && ((mode & global.analyzer.BYVALmode) != 0)) { + val optCpsTypes: Option[(Type, Type)] = cpsParamTypes(tree.tpe) + val optExpectedCpsTypes: Option[(Type, Type)] = cpsParamTypes(pt) + if (optCpsTypes.isEmpty || optExpectedCpsTypes.isEmpty) { vprintln("yes we can!! (byval)") - return true + true + } else { // check cps param types + val cpsTpes = optCpsTypes.get + val cpsPts = optExpectedCpsTypes.get + // class cpsParam[-B,+C], therefore: + cpsPts._1 <:< cpsTpes._1 && cpsTpes._2 <:< cpsPts._2 } - } - } - false + } else false + } else false } override def adaptAnnotations(tree: Tree, mode: Int, pt: Type): Tree = { @@ -188,6 +195,7 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { val patMode = (mode & global.analyzer.PATTERNmode) != 0 val exprMode = (mode & global.analyzer.EXPRmode) != 0 val byValMode = (mode & global.analyzer.BYVALmode) != 0 + val retMode = (mode & global.analyzer.RETmode) != 0 val annotsTree = cpsParamAnnotation(tree.tpe) val annotsExpected = cpsParamAnnotation(pt) @@ -214,9 +222,38 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { val res = tree modifyType addMinusMarker vprintln("adapted annotations (by val) of " + tree + " to " + res.tpe) res + } else if (retMode && !hasPlusMarker(tree.tpe) && annotsTree.isEmpty && annotsExpected.nonEmpty) { + // add a marker annotation that will make tree.tpe behave as pt, subtyping wise + // tree will look like having any possible annotation + + // note 1: we are only adding a plus marker if the method's result type is a cps type + // (annotsExpected.nonEmpty == cpsParamAnnotation(pt).nonEmpty) + // note 2: we are not adding the expected cps annotations, since they will be added + // by adaptTypeOfReturn (see below). + val res = tree modifyType (_ withAnnotations List(newPlusMarker())) + vprintln("adapted annotations (return) of " + tree + " to " + res.tpe) + res } else tree } + /** Returns an adapted type for a return expression if the method's result type (pt) is a CPS type. + * Otherwise, it returns the `default` type (`typedReturn` passes `NothingClass.tpe`). + * + * A return expression in a method that has a CPS result type is an error unless the return + * is in tail position. Therefore, we are making sure that only the types of return expressions + * are adapted which will either be removed, or lead to an error. + */ + override def adaptTypeOfReturn(tree: Tree, pt: Type, default: => Type): Type = { + // only adapt if method's result type (pt) is cps type + val annots = cpsParamAnnotation(pt) + if (annots.nonEmpty) { + // return type of `tree` without plus marker, but only if it doesn't have other cps annots + if (hasPlusMarker(tree.tpe) && !hasCpsParamTypes(tree.tpe)) + tree.setType(removeAttribs(tree.tpe, MarkerCPSAdaptPlus)) + tree.tpe + } else default + } + def updateAttributesFromChildren(tpe: Type, childAnnots: List[AnnotationInfo], byName: List[Tree]): Type = { tpe match { // Would need to push annots into each alternative of overloaded type diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala index 1f71eb285b..501396b31c 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala @@ -32,6 +32,55 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with implicit val _unit = unit // allow code in CPSUtils.scala to report errors var cpsAllowed: Boolean = false // detect cps code in places we do not handle (yet) + object RemoveTailReturnsTransformer extends Transformer { + override def transform(tree: Tree): Tree = tree match { + case Block(stms, r @ Return(expr)) => + treeCopy.Block(tree, stms, expr) + + case Block(stms, expr) => + treeCopy.Block(tree, stms, transform(expr)) + + case If(cond, r1 @ Return(thenExpr), r2 @ Return(elseExpr)) => + treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) + + case If(cond, r1 @ Return(thenExpr), elseExpr) => + treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) + + case If(cond, thenExpr, r2 @ Return(elseExpr)) => + treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) + + case If(cond, thenExpr, elseExpr) => + treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) + + case Try(block, catches, finalizer) => + treeCopy.Try(tree, + transform(block), + (catches map (t => transform(t))).asInstanceOf[List[CaseDef]], + transform(finalizer)) + + case CaseDef(pat, guard, r @ Return(expr)) => + treeCopy.CaseDef(tree, pat, guard, expr) + + case CaseDef(pat, guard, body) => + treeCopy.CaseDef(tree, pat, guard, transform(body)) + + case Return(_) => + unit.error(tree.pos, "return expressions in CPS code must be in tail position") + tree + + case _ => + super.transform(tree) + } + } + + def removeTailReturns(body: Tree): Tree = { + // support body with single return expression + body match { + case Return(expr) => expr + case _ => RemoveTailReturnsTransformer.transform(body) + } + } + override def transform(tree: Tree): Tree = { if (!cpsEnabled) return tree @@ -46,11 +95,14 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with // this would cause infinite recursion. But we could remove the // ValDef case here. - case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs) => + case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs0) => debuglog("transforming " + dd.symbol) atOwner(dd.symbol) { - val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe)) + val rhs = + if (cpsParamTypes(tpt.tpe).nonEmpty) removeTailReturns(rhs0) + else rhs0 + val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe))(getExternalAnswerTypeAnn(tpt.tpe).isDefined) debuglog("result "+rhs1) debuglog("result is of type "+rhs1.tpe) @@ -75,6 +127,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with val ext = getExternalAnswerTypeAnn(body.tpe) val pureBody = getAnswerTypeAnn(body.tpe).isEmpty + implicit val isParentImpure = ext.isDefined def transformPureMatch(tree: Tree, selector: Tree, cases: List[CaseDef]) = { val caseVals = cases map { case cd @ CaseDef(pat, guard, body) => @@ -154,8 +207,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with } - def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): Tree = { - transTailValue(tree, cpsA, cpsR) match { + def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean = false): Tree = { + transTailValue(tree, cpsA, cpsR)(cpsR.isDefined || isAnyParentImpure) match { case (Nil, b) => b case (a, b) => treeCopy.Block(tree, a,b) @@ -163,7 +216,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with } - def transArgList(fun: Tree, args: List[Tree], cpsA: CPSInfo): (List[List[Tree]], List[Tree], CPSInfo) = { + def transArgList(fun: Tree, args: List[Tree], cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[List[Tree]], List[Tree], CPSInfo) = { val formals = fun.tpe.paramTypes val overshoot = args.length - formals.length @@ -172,7 +225,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with val (stm,expr) = (for ((a,tp) <- args.zip(formals ::: List.fill(overshoot)(NoType))) yield { tp match { case TypeRef(_, ByNameParamClass, List(elemtp)) => - (Nil, transExpr(a, None, getAnswerTypeAnn(elemtp))) + // note that we're not passing just isAnyParentImpure + (Nil, transExpr(a, None, getAnswerTypeAnn(elemtp))(getAnswerTypeAnn(elemtp).isDefined || isAnyParentImpure)) case _ => val (valStm, valExpr, valSpc) = transInlineValue(a, spc) spc = valSpc @@ -184,7 +238,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with } - def transValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): (List[Tree], Tree, CPSInfo) = { + // precondition: cpsR.isDefined "implies" isAnyParentImpure + def transValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree, CPSInfo) = { // return value: (stms, expr, spc), where spc is CPSInfo after stms but *before* expr implicit val pos = tree.pos tree match { @@ -192,7 +247,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with val (cpsA2, cpsR2) = (cpsA, linearize(cpsA, getAnswerTypeAnn(tree.tpe))) // tbd // val (cpsA2, cpsR2) = (None, getAnswerTypeAnn(tree.tpe)) - val (a, b) = transBlock(stms, expr, cpsA2, cpsR2) + val (a, b) = transBlock(stms, expr, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure) val tree1 = (treeCopy.Block(tree, a, b)) // no updateSynthFlag here!!! (Nil, tree1, cpsA) @@ -206,8 +261,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with val (cpsA2, cpsR2) = if (hasSynthMarker(tree.tpe)) (spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) else (None, getAnswerTypeAnn(tree.tpe)) // if no cps in condition, branches must conform to tree.tpe directly - val thenVal = transExpr(thenp, cpsA2, cpsR2) - val elseVal = transExpr(elsep, cpsA2, cpsR2) + val thenVal = transExpr(thenp, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure) + val elseVal = transExpr(elsep, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure) // check that then and else parts agree (not necessary any more, but left as sanity check) if (cpsR.isDefined) { @@ -227,7 +282,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with else (None, getAnswerTypeAnn(tree.tpe)) val caseVals = cases map { case cd @ CaseDef(pat, guard, body) => - val bodyVal = transExpr(body, cpsA2, cpsR2) + val bodyVal = transExpr(body, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure) treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal) } @@ -245,7 +300,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with // currentOwner.newMethod(name, tree.pos, Flags.SYNTHETIC) setInfo ldef.symbol.info val sym = ldef.symbol resetFlag Flags.LABEL val rhs1 = rhs //new TreeSymSubstituter(List(ldef.symbol), List(sym)).transform(rhs) - val rhsVal = transExpr(rhs1, None, getAnswerTypeAnn(tree.tpe)) changeOwner (currentOwner -> sym) + val rhsVal = transExpr(rhs1, None, getAnswerTypeAnn(tree.tpe))(getAnswerTypeAnn(tree.tpe).isDefined || isAnyParentImpure) changeOwner (currentOwner -> sym) val stm1 = localTyper.typed(DefDef(sym, rhsVal)) // since virtpatmat does not rely on fall-through, don't call the labels it emits @@ -284,6 +339,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with (stms, updateSynthFlag(treeCopy.Assign(tree, transform(lhs), expr)), spc) case Return(expr0) => + if (isAnyParentImpure) + unit.error(tree.pos, "return expression not allowed, since method calls CPS method") val (stms, expr, spc) = transInlineValue(expr0, cpsA) (stms, updateSynthFlag(treeCopy.Return(tree, expr)), spc) @@ -321,7 +378,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with } } - def transTailValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): (List[Tree], Tree) = { + // precondition: cpsR.isDefined "implies" isAnyParentImpure + def transTailValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree) = { val (stms, expr, spc) = transValue(tree, cpsA, cpsR) @@ -398,7 +456,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with (stms, expr) } - def transInlineValue(tree: Tree, cpsA: CPSInfo): (List[Tree], Tree, CPSInfo) = { + def transInlineValue(tree: Tree, cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree, CPSInfo) = { val (stms, expr, spc) = transValue(tree, cpsA, None) // never required to be cps @@ -426,7 +484,7 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with - def transInlineStm(stm: Tree, cpsA: CPSInfo): (List[Tree], CPSInfo) = { + def transInlineStm(stm: Tree, cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], CPSInfo) = { stm match { // TODO: what about DefDefs? @@ -456,7 +514,8 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with } } - def transBlock(stms: List[Tree], expr: Tree, cpsA: CPSInfo, cpsR: CPSInfo): (List[Tree], Tree) = { + // precondition: cpsR.isDefined "implies" isAnyParentImpure + def transBlock(stms: List[Tree], expr: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree) = { def rec(currStats: List[Tree], currAns: CPSInfo, accum: List[Tree]): (List[Tree], Tree) = currStats match { case Nil => |