From 75e36233a0ea290cee98a35bc295feed4b18237e Mon Sep 17 00:00:00 2001 From: phaller Date: Mon, 23 Jul 2012 14:03:22 +0200 Subject: SI-5314 - CPS transform of return statement fails Enable return expressions in CPS code if they are in tail position. Note that tail returns are only removed in methods that do not call `shift` or `reset` (otherwise, an error is reported). Addresses the issues pointed out in a previous pull request: https://github.com/scala/scala/pull/720 - Addresses all issues mentioned here: https://github.com/scala/scala/pull/720#issuecomment-6429705 - Move transformation methods to SelectiveANFTransform.scala: https://github.com/scala/scala/pull/720#commitcomment-1477497 - Do not keep a list of tail returns. Tests: - continuations-neg/t5314-missing-result-type.scala - continuations-neg/t5314-type-error.scala - continuations-neg/t5314-npe.scala - continuations-neg/t5314-return-reset.scala - continuations-run/t5314.scala - continuations-run/t5314-2.scala - continuations-run/t5314-3.scala --- .../scala/tools/nsc/typechecker/Modes.scala | 6 +- .../scala/tools/nsc/typechecker/Typers.scala | 3 +- .../tools/selectivecps/CPSAnnotationChecker.scala | 45 ++++++++--- .../plugin/scala/tools/selectivecps/CPSUtils.scala | 7 ++ .../tools/selectivecps/SelectiveANFTransform.scala | 89 +++++++++++++++++++++- .../t5314-missing-result-type.check | 4 + .../t5314-missing-result-type.scala | 13 ++++ test/files/continuations-neg/t5314-npe.check | 4 + test/files/continuations-neg/t5314-npe.scala | 3 + .../continuations-neg/t5314-return-reset.check | 4 + .../continuations-neg/t5314-return-reset.scala | 21 +++++ .../files/continuations-neg/t5314-type-error.check | 6 ++ .../files/continuations-neg/t5314-type-error.scala | 17 +++++ test/files/continuations-run/t5314-2.check | 5 ++ test/files/continuations-run/t5314-2.scala | 44 +++++++++++ test/files/continuations-run/t5314-3.check | 4 + test/files/continuations-run/t5314-3.scala | 27 +++++++ test/files/continuations-run/t5314.check | 4 + test/files/continuations-run/t5314.scala | 41 ++++++++++ 19 files changed, 332 insertions(+), 15 deletions(-) create mode 100644 test/files/continuations-neg/t5314-missing-result-type.check create mode 100644 test/files/continuations-neg/t5314-missing-result-type.scala create mode 100644 test/files/continuations-neg/t5314-npe.check create mode 100644 test/files/continuations-neg/t5314-npe.scala create mode 100644 test/files/continuations-neg/t5314-return-reset.check create mode 100644 test/files/continuations-neg/t5314-return-reset.scala create mode 100644 test/files/continuations-neg/t5314-type-error.check create mode 100644 test/files/continuations-neg/t5314-type-error.scala create mode 100644 test/files/continuations-run/t5314-2.check create mode 100644 test/files/continuations-run/t5314-2.scala create mode 100644 test/files/continuations-run/t5314-3.check create mode 100644 test/files/continuations-run/t5314-3.scala create mode 100644 test/files/continuations-run/t5314.check create mode 100644 test/files/continuations-run/t5314.scala diff --git a/src/compiler/scala/tools/nsc/typechecker/Modes.scala b/src/compiler/scala/tools/nsc/typechecker/Modes.scala index e9ea99faab..d942d080cb 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Modes.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Modes.scala @@ -86,6 +86,10 @@ trait Modes { */ final val TYPEPATmode = 0x10000 + /** RETmode is set when we are typing a return expression. + */ + final val RETmode = 0x20000 + final private val StickyModes = EXPRmode | PATTERNmode | TYPEmode | ALTmode final def onlyStickyModes(mode: Int) = @@ -133,4 +137,4 @@ trait Modes { def modeString(mode: Int): String = if (mode == 0) "NOmode" else (modeNameMap filterKeys (bit => inAllModes(mode, bit))).values mkString " " -} \ No newline at end of file +} diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala index a6d7424837..21bc890a20 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala @@ -4087,7 +4087,8 @@ trait Typers extends Modes with Adaptations with Tags { ReturnWithoutTypeError(tree, enclMethod.owner) } else { context.enclMethod.returnsSeen = true - val expr1: Tree = typed(expr, EXPRmode | BYVALmode, restpt.tpe) + val expr1: Tree = typed(expr, EXPRmode | BYVALmode | RETmode, restpt.tpe) + // Warn about returning a value if no value can be returned. if (restpt.tpe.typeSymbol == UnitClass) { // The typing in expr1 says expr is Unit (it has already been coerced if diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala index a20ff1667b..dbbc428428 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala @@ -150,10 +150,8 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { if ((mode & global.analyzer.EXPRmode) != 0) { if ((annots1 corresponds annots2)(_.atp <:< _.atp)) { vprintln("already same, can't adapt further") - return false - } - - if (annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.BYVALmode) == 0)) { + false + } else if (annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.BYVALmode) == 0)) { //println("can adapt annotations? " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt) if (!hasPlusMarker(tree.tpe)) { // val base = tree.tpe <:< removeAllCPSAnnotations(pt) @@ -163,17 +161,26 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { // TBD: use same or not? //if (same) { vprintln("yes we can!! (unit)") - return true + true //} - } - } else if (!annots1.isEmpty && ((mode & global.analyzer.BYVALmode) != 0)) { - if (!hasMinusMarker(tree.tpe)) { + } else false + } else if (!hasPlusMarker(tree.tpe) && annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.RETmode) != 0)) { + vprintln("checking enclosing method's result type without annotations") + tree.tpe <:< pt.withoutAnnotations + } else if (!hasMinusMarker(tree.tpe) && !annots1.isEmpty && ((mode & global.analyzer.BYVALmode) != 0)) { + val optCpsTypes: Option[(Type, Type)] = cpsParamTypes(tree.tpe) + val optExpectedCpsTypes: Option[(Type, Type)] = cpsParamTypes(pt) + if (optCpsTypes.isEmpty || optExpectedCpsTypes.isEmpty) { vprintln("yes we can!! (byval)") - return true + true + } else { // check cps param types + val cpsTpes = optCpsTypes.get + val cpsPts = optExpectedCpsTypes.get + // class cpsParam[-B,+C], therefore: + cpsPts._1 <:< cpsTpes._1 && cpsTpes._2 <:< cpsPts._2 } - } - } - false + } else false + } else false } override def adaptAnnotations(tree: Tree, mode: Int, pt: Type): Tree = { @@ -184,6 +191,7 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { val patMode = (mode & global.analyzer.PATTERNmode) != 0 val exprMode = (mode & global.analyzer.EXPRmode) != 0 val byValMode = (mode & global.analyzer.BYVALmode) != 0 + val retMode = (mode & global.analyzer.RETmode) != 0 val annotsTree = cpsParamAnnotation(tree.tpe) val annotsExpected = cpsParamAnnotation(pt) @@ -210,6 +218,12 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { val res = tree modifyType addMinusMarker vprintln("adapted annotations (by val) of " + tree + " to " + res.tpe) res + } else if (retMode && !hasPlusMarker(tree.tpe) && annotsTree.isEmpty && annotsExpected.nonEmpty) { + // add a marker annotation that will make tree.tpe behave as pt, subtyping wise + // tree will look like having no annotation + val res = tree modifyType (_ withAnnotations List(newPlusMarker())) + vprintln("adapted annotations (return) of " + tree + " to " + res.tpe) + res } else tree } @@ -466,6 +480,13 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { } tpe + case ret @ Return(expr) => + // only change type if this return will (a) be removed (in tail position) or (b) cause + // an error (not in tail position) + if (expr.tpe != null && hasPlusMarker(expr.tpe)) + ret setType expr.tpe + ret.tpe + case _ => tpe } diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala index 46c644bcd6..db7ef4b540 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala @@ -27,6 +27,8 @@ trait CPSUtils { val shiftUnit0 = newTermName("shiftUnit0") val shiftUnit = newTermName("shiftUnit") val shiftUnitR = newTermName("shiftUnitR") + val reset = newTermName("reset") + val reset0 = newTermName("reset0") } lazy val MarkerCPSSym = rootMirror.getRequiredClass("scala.util.continuations.cpsSym") @@ -45,10 +47,15 @@ trait CPSUtils { lazy val MethShiftR = definitions.getMember(ModCPS, cpsNames.shiftR) lazy val MethReify = definitions.getMember(ModCPS, cpsNames.reify) lazy val MethReifyR = definitions.getMember(ModCPS, cpsNames.reifyR) + lazy val MethReset = definitions.getMember(ModCPS, cpsNames.reset) + lazy val MethReset0 = definitions.getMember(ModCPS, cpsNames.reset0) lazy val allCPSAnnotations = List(MarkerCPSSym, MarkerCPSTypes, MarkerCPSSynth, MarkerCPSAdaptPlus, MarkerCPSAdaptMinus) + lazy val allCPSMethods = List(MethShiftUnit, MethShiftUnit0, MethShiftUnitR, MethShift, MethShiftR, + MethReify, MethReifyR, MethReset, MethReset0) + // TODO - needed? Can these all use the same annotation info? protected def newSynthMarker() = newMarker(MarkerCPSSynth) protected def newPlusMarker() = newMarker(MarkerCPSAdaptPlus) diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala index 51760d2807..ed313342ec 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala @@ -32,6 +32,84 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with implicit val _unit = unit // allow code in CPSUtils.scala to report errors var cpsAllowed: Boolean = false // detect cps code in places we do not handle (yet) + /* Does not attempt to remove tail returns. + * Only checks whether the method contains returns as well as calls to CPS methods. + */ + class CheckCPSMethodsTraverser extends Traverser { + var cpsMethodsSeen = false + var returnsSeen: Option[Tree] = None + override def traverse(tree: Tree): Unit = tree match { + case Ident(_) | Select(_, _) => + if (tree.hasSymbol && (allCPSMethods contains tree.symbol)) + cpsMethodsSeen = true + super.traverse(tree) + case Return(_) => + returnsSeen = Some(tree) + case _ => + super.traverse(tree) + } + } + + /* Also checks whether the method calls a CPS method in which case an error is produced + */ + class RemoveTailReturnsTransformer extends Transformer { + var cpsMethodsSeen = false + var returnsSeen: Option[Tree] = None + override def transform(tree: Tree): Tree = tree match { + case Block(stms, r @ Return(expr)) => + returnsSeen = Some(r) + treeCopy.Block(tree, stms, expr) + + case Block(stms, expr) => + treeCopy.Block(tree, stms, transform(expr)) + + case If(cond, r1 @ Return(thenExpr), r2 @ Return(elseExpr)) => + returnsSeen = Some(r1) + treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) + + case If(cond, thenExpr, elseExpr) => + treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) + + case Try(block, catches, finalizer) => + treeCopy.Try(tree, + transform(block), + (catches map (t => transform(t))).asInstanceOf[List[CaseDef]], + transform(finalizer)) + + case CaseDef(pat, guard, r @ Return(expr)) => + returnsSeen = Some(r) + treeCopy.CaseDef(tree, pat, guard, expr) + + case CaseDef(pat, guard, body) => + treeCopy.CaseDef(tree, pat, guard, transform(body)) + + case Return(_) => + unit.error(tree.pos, "return expressions in CPS code must be in tail position") + tree + + case Ident(_) | Select(_, _) => + if (tree.hasSymbol && (allCPSMethods contains tree.symbol)) + cpsMethodsSeen = true + super.transform(tree) + + case _ => + super.transform(tree) + } + } + + def removeTailReturns(body: Tree): Tree = { + // support body with single return expression + body match { + case Return(expr) => expr + case _ => + val tr = new RemoveTailReturnsTransformer + val res = tr.transform(body) + if (tr.returnsSeen.nonEmpty && tr.cpsMethodsSeen) + unit.error(tr.returnsSeen.get.pos, "return expressions not allowed, since method calls CPS method") + res + } + } + override def transform(tree: Tree): Tree = { if (!cpsEnabled) return tree @@ -46,10 +124,19 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with // this would cause infinite recursion. But we could remove the // ValDef case here. - case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs) => + case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs0) => debuglog("transforming " + dd.symbol) atOwner(dd.symbol) { + val rhs = + if (cpsParamTypes(tpt.tpe).nonEmpty) removeTailReturns(rhs0) + else { + val checker = new CheckCPSMethodsTraverser + checker.traverse(rhs0) + if (checker.returnsSeen.nonEmpty && checker.cpsMethodsSeen) + unit.error(checker.returnsSeen.get.pos, "return expressions not allowed, since method calls CPS method") + rhs0 + } val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe)) debuglog("result "+rhs1) diff --git a/test/files/continuations-neg/t5314-missing-result-type.check b/test/files/continuations-neg/t5314-missing-result-type.check new file mode 100644 index 0000000000..341e580cf3 --- /dev/null +++ b/test/files/continuations-neg/t5314-missing-result-type.check @@ -0,0 +1,4 @@ +t5314-missing-result-type.scala:6: error: method bar has return statement; needs result type + def bar(x:Int) = return foo(x) + ^ +one error found diff --git a/test/files/continuations-neg/t5314-missing-result-type.scala b/test/files/continuations-neg/t5314-missing-result-type.scala new file mode 100644 index 0000000000..d7c5043a86 --- /dev/null +++ b/test/files/continuations-neg/t5314-missing-result-type.scala @@ -0,0 +1,13 @@ +import scala.util.continuations._ + +object Test extends App { + def foo(x:Int): Int @cps[Int] = x + + def bar(x:Int) = return foo(x) + + reset { + val res = bar(8) + println(res) + res + } +} diff --git a/test/files/continuations-neg/t5314-npe.check b/test/files/continuations-neg/t5314-npe.check new file mode 100644 index 0000000000..b5f024aa89 --- /dev/null +++ b/test/files/continuations-neg/t5314-npe.check @@ -0,0 +1,4 @@ +t5314-npe.scala:2: error: method bar has return statement; needs result type + def bar(x:Int) = { return x; x } // NPE + ^ +one error found diff --git a/test/files/continuations-neg/t5314-npe.scala b/test/files/continuations-neg/t5314-npe.scala new file mode 100644 index 0000000000..2b5966e07c --- /dev/null +++ b/test/files/continuations-neg/t5314-npe.scala @@ -0,0 +1,3 @@ +object Test extends App { + def bar(x:Int) = { return x; x } // NPE +} diff --git a/test/files/continuations-neg/t5314-return-reset.check b/test/files/continuations-neg/t5314-return-reset.check new file mode 100644 index 0000000000..e288556546 --- /dev/null +++ b/test/files/continuations-neg/t5314-return-reset.check @@ -0,0 +1,4 @@ +t5314-return-reset.scala:14: error: return expressions not allowed, since method calls CPS method + if (rnd.nextInt(100) > 50) return 5 // not allowed, since method is calling `reset` + ^ +one error found diff --git a/test/files/continuations-neg/t5314-return-reset.scala b/test/files/continuations-neg/t5314-return-reset.scala new file mode 100644 index 0000000000..df9d58e4cb --- /dev/null +++ b/test/files/continuations-neg/t5314-return-reset.scala @@ -0,0 +1,21 @@ +import scala.util.continuations._ +import scala.util.Random + +object Test extends App { + val rnd = new Random + + def foo(x: Int): Int @cps[Int] = shift { k => k(x) } + + def bar(x: Int): Int @cps[Int] = return foo(x) + + def caller(): Int = { + val v: Int = reset { + val res: Int = bar(8) + if (rnd.nextInt(100) > 50) return 5 // not allowed, since method is calling `reset` + 42 + } + v + } + + caller() +} diff --git a/test/files/continuations-neg/t5314-type-error.check b/test/files/continuations-neg/t5314-type-error.check new file mode 100644 index 0000000000..1f4e46a7f2 --- /dev/null +++ b/test/files/continuations-neg/t5314-type-error.check @@ -0,0 +1,6 @@ +t5314-type-error.scala:7: error: type mismatch; + found : Int @util.continuations.cps[Int] + required: Int @util.continuations.cps[String] + def bar(x:Int): Int @cps[String] = return foo(x) + ^ +one error found diff --git a/test/files/continuations-neg/t5314-type-error.scala b/test/files/continuations-neg/t5314-type-error.scala new file mode 100644 index 0000000000..e36ce6c203 --- /dev/null +++ b/test/files/continuations-neg/t5314-type-error.scala @@ -0,0 +1,17 @@ +import scala.util.continuations._ + +object Test extends App { + def foo(x:Int): Int @cps[Int] = shift { k => k(x) } + + // should be a type error + def bar(x:Int): Int @cps[String] = return foo(x) + + def caller(): Unit = { + val v: String = reset { + val res: Int = bar(8) + "hello" + } + } + + caller() +} diff --git a/test/files/continuations-run/t5314-2.check b/test/files/continuations-run/t5314-2.check new file mode 100644 index 0000000000..35b3c93780 --- /dev/null +++ b/test/files/continuations-run/t5314-2.check @@ -0,0 +1,5 @@ +8 +hi +8 +from try +8 diff --git a/test/files/continuations-run/t5314-2.scala b/test/files/continuations-run/t5314-2.scala new file mode 100644 index 0000000000..8a896dec2c --- /dev/null +++ b/test/files/continuations-run/t5314-2.scala @@ -0,0 +1,44 @@ +import scala.util.continuations._ + +class ReturnRepro { + def s1: Int @cps[Any] = shift { k => k(5) } + def caller = reset { println(p(3)) } + def caller2 = reset { println(p2(3)) } + def caller3 = reset { println(p3(3)) } + + def p(i: Int): Int @cps[Any] = { + val v= s1 + 3 + return v + } + + def p2(i: Int): Int @cps[Any] = { + val v = s1 + 3 + if (v > 0) { + println("hi") + return v + } else { + println("hi") + return 8 + } + } + + def p3(i: Int): Int @cps[Any] = { + val v = s1 + 3 + try { + println("from try") + return v + } catch { + case e: Exception => + println("from catch") + return 7 + } + } + +} + +object Test extends App { + val repro = new ReturnRepro + repro.caller + repro.caller2 + repro.caller3 +} diff --git a/test/files/continuations-run/t5314-3.check b/test/files/continuations-run/t5314-3.check new file mode 100644 index 0000000000..71489f097c --- /dev/null +++ b/test/files/continuations-run/t5314-3.check @@ -0,0 +1,4 @@ +enter return expr +8 +hi +8 diff --git a/test/files/continuations-run/t5314-3.scala b/test/files/continuations-run/t5314-3.scala new file mode 100644 index 0000000000..62c547f5a2 --- /dev/null +++ b/test/files/continuations-run/t5314-3.scala @@ -0,0 +1,27 @@ +import scala.util.continuations._ + +class ReturnRepro { + def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) } + def caller = reset { println(p(3)) } + def caller2 = reset { println(p2(3)) } + + def p(i: Int): Int @cpsParam[Unit, Any] = { + val v= s1 + 3 + return { println("enter return expr"); v } + } + + def p2(i: Int): Int @cpsParam[Unit, Any] = { + val v = s1 + 3 + if (v > 0) { + return { println("hi"); v } + } else { + return { println("hi"); 8 } + } + } +} + +object Test extends App { + val repro = new ReturnRepro + repro.caller + repro.caller2 +} diff --git a/test/files/continuations-run/t5314.check b/test/files/continuations-run/t5314.check new file mode 100644 index 0000000000..4951e7caae --- /dev/null +++ b/test/files/continuations-run/t5314.check @@ -0,0 +1,4 @@ +8 +hi +8 +8 diff --git a/test/files/continuations-run/t5314.scala b/test/files/continuations-run/t5314.scala new file mode 100644 index 0000000000..0bdea19824 --- /dev/null +++ b/test/files/continuations-run/t5314.scala @@ -0,0 +1,41 @@ +import scala.util.continuations._ + +class ReturnRepro { + def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) } + def caller = reset { println(p(3)) } + def caller2 = reset { println(p2(3)) } + + def p(i: Int): Int @cpsParam[Unit, Any] = { + val v= s1 + 3 + return v + } + + def p2(i: Int): Int @cpsParam[Unit, Any] = { + val v = s1 + 3 + if (v > 0) { + println("hi") + return v + } else { + println("hi") + return 8 + } + } +} + +object Test extends App { + def foo(x:Int): Int @cps[Int] = shift { k => k(x) } + + def bar(x:Int): Int @cps[Int] = return foo(x) + + def nocps(x: Int): Int = { return x; x } + + val repro = new ReturnRepro + repro.caller + repro.caller2 + + reset { + val res = bar(8) + println(res) + res + } +} -- cgit v1.2.3