diff options
19 files changed, 332 insertions, 15 deletions
diff --git a/src/compiler/scala/tools/nsc/typechecker/Modes.scala b/src/compiler/scala/tools/nsc/typechecker/Modes.scala index e9ea99faab..d942d080cb 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Modes.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Modes.scala @@ -86,6 +86,10 @@ trait Modes { */ final val TYPEPATmode = 0x10000 + /** RETmode is set when we are typing a return expression. + */ + final val RETmode = 0x20000 + final private val StickyModes = EXPRmode | PATTERNmode | TYPEmode | ALTmode final def onlyStickyModes(mode: Int) = @@ -133,4 +137,4 @@ trait Modes { def modeString(mode: Int): String = if (mode == 0) "NOmode" else (modeNameMap filterKeys (bit => inAllModes(mode, bit))).values mkString " " -}
\ No newline at end of file +} diff --git a/src/compiler/scala/tools/nsc/typechecker/Typers.scala b/src/compiler/scala/tools/nsc/typechecker/Typers.scala index a6d7424837..21bc890a20 100644 --- a/src/compiler/scala/tools/nsc/typechecker/Typers.scala +++ b/src/compiler/scala/tools/nsc/typechecker/Typers.scala @@ -4087,7 +4087,8 @@ trait Typers extends Modes with Adaptations with Tags { ReturnWithoutTypeError(tree, enclMethod.owner) } else { context.enclMethod.returnsSeen = true - val expr1: Tree = typed(expr, EXPRmode | BYVALmode, restpt.tpe) + val expr1: Tree = typed(expr, EXPRmode | BYVALmode | RETmode, restpt.tpe) + // Warn about returning a value if no value can be returned. if (restpt.tpe.typeSymbol == UnitClass) { // The typing in expr1 says expr is Unit (it has already been coerced if diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala index a20ff1667b..dbbc428428 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala @@ -150,10 +150,8 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { if ((mode & global.analyzer.EXPRmode) != 0) { if ((annots1 corresponds annots2)(_.atp <:< _.atp)) { vprintln("already same, can't adapt further") - return false - } - - if (annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.BYVALmode) == 0)) { + false + } else if (annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.BYVALmode) == 0)) { //println("can adapt annotations? " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt) if (!hasPlusMarker(tree.tpe)) { // val base = tree.tpe <:< removeAllCPSAnnotations(pt) @@ -163,17 +161,26 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { // TBD: use same or not? //if (same) { vprintln("yes we can!! (unit)") - return true + true //} - } - } else if (!annots1.isEmpty && ((mode & global.analyzer.BYVALmode) != 0)) { - if (!hasMinusMarker(tree.tpe)) { + } else false + } else if (!hasPlusMarker(tree.tpe) && annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.RETmode) != 0)) { + vprintln("checking enclosing method's result type without annotations") + tree.tpe <:< pt.withoutAnnotations + } else if (!hasMinusMarker(tree.tpe) && !annots1.isEmpty && ((mode & global.analyzer.BYVALmode) != 0)) { + val optCpsTypes: Option[(Type, Type)] = cpsParamTypes(tree.tpe) + val optExpectedCpsTypes: Option[(Type, Type)] = cpsParamTypes(pt) + if (optCpsTypes.isEmpty || optExpectedCpsTypes.isEmpty) { vprintln("yes we can!! (byval)") - return true + true + } else { // check cps param types + val cpsTpes = optCpsTypes.get + val cpsPts = optExpectedCpsTypes.get + // class cpsParam[-B,+C], therefore: + cpsPts._1 <:< cpsTpes._1 && cpsTpes._2 <:< cpsPts._2 } - } - } - false + } else false + } else false } override def adaptAnnotations(tree: Tree, mode: Int, pt: Type): Tree = { @@ -184,6 +191,7 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { val patMode = (mode & global.analyzer.PATTERNmode) != 0 val exprMode = (mode & global.analyzer.EXPRmode) != 0 val byValMode = (mode & global.analyzer.BYVALmode) != 0 + val retMode = (mode & global.analyzer.RETmode) != 0 val annotsTree = cpsParamAnnotation(tree.tpe) val annotsExpected = cpsParamAnnotation(pt) @@ -210,6 +218,12 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { val res = tree modifyType addMinusMarker vprintln("adapted annotations (by val) of " + tree + " to " + res.tpe) res + } else if (retMode && !hasPlusMarker(tree.tpe) && annotsTree.isEmpty && annotsExpected.nonEmpty) { + // add a marker annotation that will make tree.tpe behave as pt, subtyping wise + // tree will look like having no annotation + val res = tree modifyType (_ withAnnotations List(newPlusMarker())) + vprintln("adapted annotations (return) of " + tree + " to " + res.tpe) + res } else tree } @@ -466,6 +480,13 @@ abstract class CPSAnnotationChecker extends CPSUtils with Modes { } tpe + case ret @ Return(expr) => + // only change type if this return will (a) be removed (in tail position) or (b) cause + // an error (not in tail position) + if (expr.tpe != null && hasPlusMarker(expr.tpe)) + ret setType expr.tpe + ret.tpe + case _ => tpe } diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala index 46c644bcd6..db7ef4b540 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala @@ -27,6 +27,8 @@ trait CPSUtils { val shiftUnit0 = newTermName("shiftUnit0") val shiftUnit = newTermName("shiftUnit") val shiftUnitR = newTermName("shiftUnitR") + val reset = newTermName("reset") + val reset0 = newTermName("reset0") } lazy val MarkerCPSSym = rootMirror.getRequiredClass("scala.util.continuations.cpsSym") @@ -45,10 +47,15 @@ trait CPSUtils { lazy val MethShiftR = definitions.getMember(ModCPS, cpsNames.shiftR) lazy val MethReify = definitions.getMember(ModCPS, cpsNames.reify) lazy val MethReifyR = definitions.getMember(ModCPS, cpsNames.reifyR) + lazy val MethReset = definitions.getMember(ModCPS, cpsNames.reset) + lazy val MethReset0 = definitions.getMember(ModCPS, cpsNames.reset0) lazy val allCPSAnnotations = List(MarkerCPSSym, MarkerCPSTypes, MarkerCPSSynth, MarkerCPSAdaptPlus, MarkerCPSAdaptMinus) + lazy val allCPSMethods = List(MethShiftUnit, MethShiftUnit0, MethShiftUnitR, MethShift, MethShiftR, + MethReify, MethReifyR, MethReset, MethReset0) + // TODO - needed? Can these all use the same annotation info? protected def newSynthMarker() = newMarker(MarkerCPSSynth) protected def newPlusMarker() = newMarker(MarkerCPSAdaptPlus) diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala index 51760d2807..ed313342ec 100644 --- a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala +++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala @@ -32,6 +32,84 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with implicit val _unit = unit // allow code in CPSUtils.scala to report errors var cpsAllowed: Boolean = false // detect cps code in places we do not handle (yet) + /* Does not attempt to remove tail returns. + * Only checks whether the method contains returns as well as calls to CPS methods. + */ + class CheckCPSMethodsTraverser extends Traverser { + var cpsMethodsSeen = false + var returnsSeen: Option[Tree] = None + override def traverse(tree: Tree): Unit = tree match { + case Ident(_) | Select(_, _) => + if (tree.hasSymbol && (allCPSMethods contains tree.symbol)) + cpsMethodsSeen = true + super.traverse(tree) + case Return(_) => + returnsSeen = Some(tree) + case _ => + super.traverse(tree) + } + } + + /* Also checks whether the method calls a CPS method in which case an error is produced + */ + class RemoveTailReturnsTransformer extends Transformer { + var cpsMethodsSeen = false + var returnsSeen: Option[Tree] = None + override def transform(tree: Tree): Tree = tree match { + case Block(stms, r @ Return(expr)) => + returnsSeen = Some(r) + treeCopy.Block(tree, stms, expr) + + case Block(stms, expr) => + treeCopy.Block(tree, stms, transform(expr)) + + case If(cond, r1 @ Return(thenExpr), r2 @ Return(elseExpr)) => + returnsSeen = Some(r1) + treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) + + case If(cond, thenExpr, elseExpr) => + treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) + + case Try(block, catches, finalizer) => + treeCopy.Try(tree, + transform(block), + (catches map (t => transform(t))).asInstanceOf[List[CaseDef]], + transform(finalizer)) + + case CaseDef(pat, guard, r @ Return(expr)) => + returnsSeen = Some(r) + treeCopy.CaseDef(tree, pat, guard, expr) + + case CaseDef(pat, guard, body) => + treeCopy.CaseDef(tree, pat, guard, transform(body)) + + case Return(_) => + unit.error(tree.pos, "return expressions in CPS code must be in tail position") + tree + + case Ident(_) | Select(_, _) => + if (tree.hasSymbol && (allCPSMethods contains tree.symbol)) + cpsMethodsSeen = true + super.transform(tree) + + case _ => + super.transform(tree) + } + } + + def removeTailReturns(body: Tree): Tree = { + // support body with single return expression + body match { + case Return(expr) => expr + case _ => + val tr = new RemoveTailReturnsTransformer + val res = tr.transform(body) + if (tr.returnsSeen.nonEmpty && tr.cpsMethodsSeen) + unit.error(tr.returnsSeen.get.pos, "return expressions not allowed, since method calls CPS method") + res + } + } + override def transform(tree: Tree): Tree = { if (!cpsEnabled) return tree @@ -46,10 +124,19 @@ abstract class SelectiveANFTransform extends PluginComponent with Transform with // this would cause infinite recursion. But we could remove the // ValDef case here. - case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs) => + case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs0) => debuglog("transforming " + dd.symbol) atOwner(dd.symbol) { + val rhs = + if (cpsParamTypes(tpt.tpe).nonEmpty) removeTailReturns(rhs0) + else { + val checker = new CheckCPSMethodsTraverser + checker.traverse(rhs0) + if (checker.returnsSeen.nonEmpty && checker.cpsMethodsSeen) + unit.error(checker.returnsSeen.get.pos, "return expressions not allowed, since method calls CPS method") + rhs0 + } val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe)) debuglog("result "+rhs1) diff --git a/test/files/continuations-neg/t5314-missing-result-type.check b/test/files/continuations-neg/t5314-missing-result-type.check new file mode 100644 index 0000000000..341e580cf3 --- /dev/null +++ b/test/files/continuations-neg/t5314-missing-result-type.check @@ -0,0 +1,4 @@ +t5314-missing-result-type.scala:6: error: method bar has return statement; needs result type + def bar(x:Int) = return foo(x) + ^ +one error found diff --git a/test/files/continuations-neg/t5314-missing-result-type.scala b/test/files/continuations-neg/t5314-missing-result-type.scala new file mode 100644 index 0000000000..d7c5043a86 --- /dev/null +++ b/test/files/continuations-neg/t5314-missing-result-type.scala @@ -0,0 +1,13 @@ +import scala.util.continuations._ + +object Test extends App { + def foo(x:Int): Int @cps[Int] = x + + def bar(x:Int) = return foo(x) + + reset { + val res = bar(8) + println(res) + res + } +} diff --git a/test/files/continuations-neg/t5314-npe.check b/test/files/continuations-neg/t5314-npe.check new file mode 100644 index 0000000000..b5f024aa89 --- /dev/null +++ b/test/files/continuations-neg/t5314-npe.check @@ -0,0 +1,4 @@ +t5314-npe.scala:2: error: method bar has return statement; needs result type + def bar(x:Int) = { return x; x } // NPE + ^ +one error found diff --git a/test/files/continuations-neg/t5314-npe.scala b/test/files/continuations-neg/t5314-npe.scala new file mode 100644 index 0000000000..2b5966e07c --- /dev/null +++ b/test/files/continuations-neg/t5314-npe.scala @@ -0,0 +1,3 @@ +object Test extends App { + def bar(x:Int) = { return x; x } // NPE +} diff --git a/test/files/continuations-neg/t5314-return-reset.check b/test/files/continuations-neg/t5314-return-reset.check new file mode 100644 index 0000000000..e288556546 --- /dev/null +++ b/test/files/continuations-neg/t5314-return-reset.check @@ -0,0 +1,4 @@ +t5314-return-reset.scala:14: error: return expressions not allowed, since method calls CPS method + if (rnd.nextInt(100) > 50) return 5 // not allowed, since method is calling `reset` + ^ +one error found diff --git a/test/files/continuations-neg/t5314-return-reset.scala b/test/files/continuations-neg/t5314-return-reset.scala new file mode 100644 index 0000000000..df9d58e4cb --- /dev/null +++ b/test/files/continuations-neg/t5314-return-reset.scala @@ -0,0 +1,21 @@ +import scala.util.continuations._ +import scala.util.Random + +object Test extends App { + val rnd = new Random + + def foo(x: Int): Int @cps[Int] = shift { k => k(x) } + + def bar(x: Int): Int @cps[Int] = return foo(x) + + def caller(): Int = { + val v: Int = reset { + val res: Int = bar(8) + if (rnd.nextInt(100) > 50) return 5 // not allowed, since method is calling `reset` + 42 + } + v + } + + caller() +} diff --git a/test/files/continuations-neg/t5314-type-error.check b/test/files/continuations-neg/t5314-type-error.check new file mode 100644 index 0000000000..1f4e46a7f2 --- /dev/null +++ b/test/files/continuations-neg/t5314-type-error.check @@ -0,0 +1,6 @@ +t5314-type-error.scala:7: error: type mismatch; + found : Int @util.continuations.cps[Int] + required: Int @util.continuations.cps[String] + def bar(x:Int): Int @cps[String] = return foo(x) + ^ +one error found diff --git a/test/files/continuations-neg/t5314-type-error.scala b/test/files/continuations-neg/t5314-type-error.scala new file mode 100644 index 0000000000..e36ce6c203 --- /dev/null +++ b/test/files/continuations-neg/t5314-type-error.scala @@ -0,0 +1,17 @@ +import scala.util.continuations._ + +object Test extends App { + def foo(x:Int): Int @cps[Int] = shift { k => k(x) } + + // should be a type error + def bar(x:Int): Int @cps[String] = return foo(x) + + def caller(): Unit = { + val v: String = reset { + val res: Int = bar(8) + "hello" + } + } + + caller() +} diff --git a/test/files/continuations-run/t5314-2.check b/test/files/continuations-run/t5314-2.check new file mode 100644 index 0000000000..35b3c93780 --- /dev/null +++ b/test/files/continuations-run/t5314-2.check @@ -0,0 +1,5 @@ +8 +hi +8 +from try +8 diff --git a/test/files/continuations-run/t5314-2.scala b/test/files/continuations-run/t5314-2.scala new file mode 100644 index 0000000000..8a896dec2c --- /dev/null +++ b/test/files/continuations-run/t5314-2.scala @@ -0,0 +1,44 @@ +import scala.util.continuations._ + +class ReturnRepro { + def s1: Int @cps[Any] = shift { k => k(5) } + def caller = reset { println(p(3)) } + def caller2 = reset { println(p2(3)) } + def caller3 = reset { println(p3(3)) } + + def p(i: Int): Int @cps[Any] = { + val v= s1 + 3 + return v + } + + def p2(i: Int): Int @cps[Any] = { + val v = s1 + 3 + if (v > 0) { + println("hi") + return v + } else { + println("hi") + return 8 + } + } + + def p3(i: Int): Int @cps[Any] = { + val v = s1 + 3 + try { + println("from try") + return v + } catch { + case e: Exception => + println("from catch") + return 7 + } + } + +} + +object Test extends App { + val repro = new ReturnRepro + repro.caller + repro.caller2 + repro.caller3 +} diff --git a/test/files/continuations-run/t5314-3.check b/test/files/continuations-run/t5314-3.check new file mode 100644 index 0000000000..71489f097c --- /dev/null +++ b/test/files/continuations-run/t5314-3.check @@ -0,0 +1,4 @@ +enter return expr +8 +hi +8 diff --git a/test/files/continuations-run/t5314-3.scala b/test/files/continuations-run/t5314-3.scala new file mode 100644 index 0000000000..62c547f5a2 --- /dev/null +++ b/test/files/continuations-run/t5314-3.scala @@ -0,0 +1,27 @@ +import scala.util.continuations._ + +class ReturnRepro { + def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) } + def caller = reset { println(p(3)) } + def caller2 = reset { println(p2(3)) } + + def p(i: Int): Int @cpsParam[Unit, Any] = { + val v= s1 + 3 + return { println("enter return expr"); v } + } + + def p2(i: Int): Int @cpsParam[Unit, Any] = { + val v = s1 + 3 + if (v > 0) { + return { println("hi"); v } + } else { + return { println("hi"); 8 } + } + } +} + +object Test extends App { + val repro = new ReturnRepro + repro.caller + repro.caller2 +} diff --git a/test/files/continuations-run/t5314.check b/test/files/continuations-run/t5314.check new file mode 100644 index 0000000000..4951e7caae --- /dev/null +++ b/test/files/continuations-run/t5314.check @@ -0,0 +1,4 @@ +8 +hi +8 +8 diff --git a/test/files/continuations-run/t5314.scala b/test/files/continuations-run/t5314.scala new file mode 100644 index 0000000000..0bdea19824 --- /dev/null +++ b/test/files/continuations-run/t5314.scala @@ -0,0 +1,41 @@ +import scala.util.continuations._ + +class ReturnRepro { + def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) } + def caller = reset { println(p(3)) } + def caller2 = reset { println(p2(3)) } + + def p(i: Int): Int @cpsParam[Unit, Any] = { + val v= s1 + 3 + return v + } + + def p2(i: Int): Int @cpsParam[Unit, Any] = { + val v = s1 + 3 + if (v > 0) { + println("hi") + return v + } else { + println("hi") + return 8 + } + } +} + +object Test extends App { + def foo(x:Int): Int @cps[Int] = shift { k => k(x) } + + def bar(x:Int): Int @cps[Int] = return foo(x) + + def nocps(x: Int): Int = { return x; x } + + val repro = new ReturnRepro + repro.caller + repro.caller2 + + reset { + val res = bar(8) + println(res) + res + } +} |