diff options
Diffstat (limited to 'src/continuations')
8 files changed, 1682 insertions, 0 deletions
diff --git a/src/continuations/library/scala/util/continuations/ControlContext.scala b/src/continuations/library/scala/util/continuations/ControlContext.scala new file mode 100644 index 0000000000..87a9bf1fc5 --- /dev/null +++ b/src/continuations/library/scala/util/continuations/ControlContext.scala @@ -0,0 +1,161 @@ +// $Id$ + +package scala.util.continuations + + +class cpsParam[-B,+C] extends StaticAnnotation with TypeConstraint + +private class cpsSym[B] extends Annotation // implementation detail + +private class cpsSynth extends Annotation // implementation detail + +private class cpsPlus extends StaticAnnotation with TypeConstraint // implementation detail +private class cpsMinus extends Annotation // implementation detail + + + +@serializable final class ControlContext[+A,-B,+C](val fun: (A => B, Exception => B) => C, val x: A) { + + /* + final def map[A1](f: A => A1): ControlContext[A1,B,C] = { + new ControlContext((k:(A1 => B)) => fun((x:A) => k(f(x))), null.asInstanceOf[A1]) + } + + final def flatMap[A1,B1<:B](f: (A => ControlContext[A1,B1,B])): ControlContext[A1,B1,C] = { + new ControlContext((k:(A1 => B1)) => fun((x:A) => f(x).fun(k))) + } + */ + + + @noinline final def map[A1](f: A => A1): ControlContext[A1,B,C] = { + if (fun eq null) + try { + new ControlContext(null, f(x)) // TODO: only alloc if f(x) != x + } catch { + case ex: Exception => + new ControlContext((k: A1 => B, thr: Exception => B) => thr(ex).asInstanceOf[C], null.asInstanceOf[A1]) + } + else + new ControlContext({ (k: A1 => B, thr: Exception => B) => + fun( { (x:A) => + var done = false + try { + val res = f(x) + done = true + k(res) + } catch { + case ex: Exception if !done => + thr(ex) + } + }, thr) + }, null.asInstanceOf[A1]) + } + + + // it would be nice if @inline would turn the trivial path into a tail call. + // unfortunately it doesn't, so we do it ourselves in SelectiveCPSTransform + + @noinline final def flatMap[A1,B1,C1<:B](f: (A => ControlContext[A1,B1,C1])): ControlContext[A1,B1,C] = { + if (fun eq null) + try { + f(x).asInstanceOf[ControlContext[A1,B1,C]] + } catch { + case ex: Exception => + new ControlContext((k: A1 => B1, thr: Exception => B1) => thr(ex).asInstanceOf[C], null.asInstanceOf[A1]) + } + else + new ControlContext({ (k: A1 => B1, thr: Exception => B1) => + fun( { (x:A) => + var done = false + try { + val ctxR = f(x) + done = true + val res: C1 = ctxR.foreachFull(k, thr) // => B1 + res + } catch { + case ex: Exception if !done => + thr(ex).asInstanceOf[B] // => B NOTE: in general this is unsafe! + } // However, the plugin will not generate offending code + }, thr.asInstanceOf[Exception=>B]) // => B + }, null.asInstanceOf[A1]) + } + + final def foreach(f: A => B) = foreachFull(f, throw _) + + def foreachFull(f: A => B, g: Exception => B): C = { + if (fun eq null) + f(x).asInstanceOf[C] + else + fun(f, g) + } + + + final def isTrivial = fun eq null + final def getTrivialValue = x.asInstanceOf[A] + + // need filter or other functions? + + final def flatMapCatch[A1>:A,B1<:B,C1>:C<:B1](pf: PartialFunction[Exception, ControlContext[A1,B1,C1]]): ControlContext[A1,B1,C1] = { + if (fun eq null) + this + else { + val fun1 = (ret1: A1 => B1, thr1: Exception => B1) => { + val thr: Exception => B1 = { t: Exception => + var captureExceptions = true + try { + if (pf.isDefinedAt(t)) { + val cc1 = pf(t) + captureExceptions = false + cc1.foreachFull(ret1, thr1) // Throw => B + } else { + captureExceptions = false + thr1(t) // Throw => B1 + } + } catch { + case t1: Exception if captureExceptions => thr1(t1) // => E2 + } + } + fun(ret1, thr)// fun(ret1, thr) // => B + } + new ControlContext(fun1, null.asInstanceOf[A1]) + } + } + + final def mapFinally(f: () => Unit): ControlContext[A,B,C] = { + if (fun eq null) { + try { + f() + this + } catch { + case ex: Exception => + new ControlContext((k: A => B, thr: Exception => B) => thr(ex).asInstanceOf[C], null.asInstanceOf[A]) + } + } else { + val fun1 = (ret1: A => B, thr1: Exception => B) => { + val ret: A => B = { x: A => + var captureExceptions = true + try { + f() + captureExceptions = false + ret1(x) + } catch { + case t1: Exception if captureExceptions => thr1(t1) + } + } + val thr: Exception => B = { t: Exception => + var captureExceptions = true + try { + f() + captureExceptions = false + thr1(t) + } catch { + case t1: Exception if captureExceptions => thr1(t1) + } + } + fun(ret, thr1) + } + new ControlContext(fun1, null.asInstanceOf[A]) + } + } + +} diff --git a/src/continuations/library/scala/util/continuations/package.scala b/src/continuations/library/scala/util/continuations/package.scala new file mode 100644 index 0000000000..aa4681a0cc --- /dev/null +++ b/src/continuations/library/scala/util/continuations/package.scala @@ -0,0 +1,65 @@ +// $Id$ + + +// TODO: scaladoc + +package scala.util + +package object continuations { + + type cps[A] = cpsParam[A,A] + + type suspendable = cps[Unit] + + + def shift[A,B,C](fun: (A => B) => C): A @cpsParam[B,C] = { + throw new NoSuchMethodException("this code has to be compiled with the Scala continuations plugin enabled") + } + + def reset[A,C](ctx: =>(A @cpsParam[A,C])): C = { + val ctxR = reify[A,A,C](ctx) + if (ctxR.isTrivial) + ctxR.getTrivialValue.asInstanceOf[C] + else + ctxR.foreach((x:A) => x) + } + + def reset0[A](ctx: =>(A @cpsParam[A,A])): A = reset(ctx) + + def run[A](ctx: =>(Any @cpsParam[Unit,A])): A = { + val ctxR = reify[Any,Unit,A](ctx) + if (ctxR.isTrivial) + ctxR.getTrivialValue.asInstanceOf[A] + else + ctxR.foreach((x:Any) => ()) + } + + + // methods below are primarily implementation details and are not + // needed frequently in client code + + def shiftUnit0[A,B](x: A): A @cpsParam[B,B] = { + shiftUnit[A,B,B](x) + } + + def shiftUnit[A,B,C>:B](x: A): A @cpsParam[B,C] = { + throw new NoSuchMethodException("this code has to be compiled with the Scala continuations plugin enabled") + } + + def reify[A,B,C](ctx: =>(A @cpsParam[B,C])): ControlContext[A,B,C] = { + throw new NoSuchMethodException("this code has to be compiled with the Scala continuations plugin enabled") + } + + def shiftUnitR[A,B](x: A): ControlContext[A,B,B] = { + new ControlContext(null, x) + } + + def shiftR[A,B,C](fun: (A => B) => C): ControlContext[A,B,C] = { + new ControlContext((f:A=>B,g:Exception=>B) => fun(f), null.asInstanceOf[A]) + } + + def reifyR[A,B,C](ctx: => ControlContext[A,B,C]): ControlContext[A,B,C] = { + ctx + } + +} diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala new file mode 100644 index 0000000000..0c124c9c19 --- /dev/null +++ b/src/continuations/plugin/scala/tools/selectivecps/CPSAnnotationChecker.scala @@ -0,0 +1,462 @@ +// $Id$ + +package scala.tools.selectivecps + +import scala.tools.nsc.Global + +import scala.collection.mutable.{Map, HashMap} + +import java.io.{StringWriter, PrintWriter} + +abstract class CPSAnnotationChecker extends CPSUtils { + val global: Global + import global._ + import definitions._ + + //override val verbose = true + + /** + * Checks whether @cps annotations conform + */ + object checker extends AnnotationChecker { + + /** Check annotations to decide whether tpe1 <:< tpe2 */ + def annotationsConform(tpe1: Type, tpe2: Type): Boolean = { + if (!cpsEnabled) return true + + vprintln("check annotations: " + tpe1 + " <:< " + tpe2) + + // Nothing is least element, but Any is not the greatest + if (tpe1.typeSymbol eq NothingClass) + return true + + val annots1 = filterAttribs(tpe1,MarkerCPSTypes) + val annots2 = filterAttribs(tpe2,MarkerCPSTypes) + + // @plus and @minus should only occur at the left, and never together + // TODO: insert check + val adaptPlusAnnots1 = filterAttribs(tpe1,MarkerCPSAdaptPlus) + val adaptMinusAnnots1 = filterAttribs(tpe1,MarkerCPSAdaptMinus) + + // @minus @cps is the same as no annotations + if (!adaptMinusAnnots1.isEmpty) + return annots2.isEmpty + + // to handle answer type modification, we must make @plus <:< @cps + if (!adaptPlusAnnots1.isEmpty && annots1.isEmpty) + return true + + // @plus @cps will fall through and compare the @cps type args + + // @cps parameters must match exactly + if ((annots1 corresponds annots2) { _.atp <:< _.atp }) + return true + + false + } + + + /** Refine the computed least upper bound of a list of types. + * All this should do is add annotations. */ + override def annotationsLub(tpe: Type, ts: List[Type]): Type = { + if (!cpsEnabled) return tpe + + val annots1 = filterAttribs(tpe, MarkerCPSTypes) + val annots2 = ts flatMap (filterAttribs(_, MarkerCPSTypes)) + + if (annots2.nonEmpty) { + val cpsLub = AnnotationInfo(global.lub(annots1:::annots2 map (_.atp)), Nil, Nil) + val tpe1 = if (annots1.nonEmpty) removeAttribs(tpe, MarkerCPSTypes) else tpe + tpe1.withAnnotation(cpsLub) + } else tpe + } + + /** Refine the bounds on type parameters to the given type arguments. */ + override def adaptBoundsToAnnotations(bounds: List[TypeBounds], tparams: List[Symbol], targs: List[Type]): List[TypeBounds] = { + if (!cpsEnabled) return bounds + + val anyAtCPS = AnnotationInfo(appliedType(MarkerCPSTypes.tpe, List(NothingClass.tpe, AnyClass.tpe)), Nil, Nil) + if (isFunctionType(tparams.head.owner.tpe) || tparams.head.owner == PartialFunctionClass) { + vprintln("function bound: " + tparams.head.owner.tpe + "/"+bounds+"/"+targs) + if (targs.last.hasAnnotation(MarkerCPSTypes)) + bounds.reverse match { + case res::b if !res.hi.hasAnnotation(MarkerCPSTypes) => + (TypeBounds(res.lo, res.hi.withAnnotation(anyAtCPS))::b).reverse + case _ => bounds + } + else + bounds + } else if (tparams.head.owner == ByNameParamClass) { + vprintln("byname bound: " + tparams.head.owner.tpe + "/"+bounds+"/"+targs) + if (targs.head.hasAnnotation(MarkerCPSTypes) && !bounds.head.hi.hasAnnotation(MarkerCPSTypes)) + TypeBounds(bounds.head.lo, bounds.head.hi.withAnnotation(anyAtCPS))::Nil + else bounds + } else + bounds + } + + + override def canAdaptAnnotations(tree: Tree, mode: Int, pt: Type): Boolean = { + if (!cpsEnabled) return false + vprintln("can adapt annotations? " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt) + + val annots1 = filterAttribs(tree.tpe,MarkerCPSTypes) + val annots2 = filterAttribs(pt,MarkerCPSTypes) + + if ((mode & global.analyzer.PATTERNmode) != 0) { + //println("can adapt pattern annotations? " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt) + if (!annots1.isEmpty) { + return true + } + } + +/* + // not precise enough -- still relying on addAnnotations to remove things from ValDef symbols + if ((mode & global.analyzer.TYPEmode) != 0 && (mode & global.analyzer.BYVALmode) != 0) { + if (!annots1.isEmpty) { + return true + } + } +*/ + +/* + this interferes with overloading resolution + if ((mode & global.analyzer.BYVALmode) != 0 && tree.tpe <:< pt) { + vprintln("already compatible, can't adapt further") + return false + } +*/ + if ((mode & global.analyzer.EXPRmode) != 0) { + if ((annots1 corresponds annots2) { case (a1,a2) => a1.atp <:< a2.atp }) { + vprintln("already same, can't adapt further") + return false + } + + if (annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.BYVALmode) == 0)) { + //println("can adapt annotations? " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt) + val adapt = AnnotationInfo(MarkerCPSAdaptPlus.tpe, Nil, Nil) + if (!tree.tpe.annotations.contains(adapt)) { + // val base = tree.tpe <:< removeAllCPSAnnotations(pt) + // val known = global.analyzer.isFullyDefined(pt) + // println(same + "/" + base + "/" + known) + //val same = annots2 forall { case AnnotationInfo(atp: TypeRef, _, _) => atp.typeArgs(0) =:= atp.typeArgs(1) } + // TBD: use same or not? + //if (same) { + vprintln("yes we can!! (unit)") + return true + //} + } + } else if (!annots1.isEmpty && ((mode & global.analyzer.BYVALmode) != 0)) { + if (!tree.tpe.hasAnnotation(MarkerCPSAdaptMinus)) { + vprintln("yes we can!! (byval)") + return true + } + } + } + false + } + + + override def adaptAnnotations(tree: Tree, mode: Int, pt: Type): Tree = { + if (!cpsEnabled) return tree + + vprintln("adapt annotations " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt) + + val annots1 = filterAttribs(tree.tpe,MarkerCPSTypes) + val annots2 = filterAttribs(pt,MarkerCPSTypes) + + if ((mode & global.analyzer.PATTERNmode) != 0) { + if (!annots1.isEmpty) { + return tree.setType(removeAllCPSAnnotations(tree.tpe)) + } + } + +/* + // doesn't work correctly -- still relying on addAnnotations to remove things from ValDef symbols + if ((mode & global.analyzer.TYPEmode) != 0 && (mode & global.analyzer.BYVALmode) != 0) { + if (!annots1.isEmpty) { + println("removing annotation from " + tree + "/" + tree.tpe) + val s = tree.setType(removeAllCPSAnnotations(tree.tpe)) + println(s) + s + } + } +*/ + + if ((mode & global.analyzer.EXPRmode) != 0) { + if (annots1.isEmpty && !annots2.isEmpty && ((mode & global.analyzer.BYVALmode) == 0)) { // shiftUnit + // add a marker annotation that will make tree.tpe behave as pt, subtyping wise + // tree will look like having any possible annotation + //println("adapt annotations " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt) + + val adapt = AnnotationInfo(MarkerCPSAdaptPlus.tpe, Nil, Nil) + //val same = annots2 forall { case AnnotationInfo(atp: TypeRef, _, _) => atp.typeArgs(0) =:= atp.typeArgs(1) } + // TBD: use same or not? see infer0.scala/infer1.scala + + // CAVEAT: + // for monomorphic answer types we want to have @plus @cps (for better checking) + // for answer type modification we want to have only @plus (because actual answer type may differ from pt) + + //val known = global.analyzer.isFullyDefined(pt) + + if (/*same &&*/ !tree.tpe.annotations.contains(adapt)) { + //if (known) + return tree.setType(tree.tpe.withAnnotations(adapt::annots2)) // needed for #1807 + //else + // return tree.setType(tree.tpe.withAnnotations(adapt::Nil)) + } + tree + } else if (!annots1.isEmpty && ((mode & global.analyzer.BYVALmode) != 0)) { // dropping annotation + // add a marker annotation that will make tree.tpe behave as pt, subtyping wise + // tree will look like having no annotation + if (!tree.tpe.hasAnnotation(MarkerCPSAdaptMinus)) { + val adapt = AnnotationInfo(MarkerCPSAdaptMinus.tpe, Nil, Nil) + return tree.setType(tree.tpe.withAnnotations(adapt::Nil)) + } + } + } + tree + } + + + def updateAttributesFromChildren(tpe: Type, childAnnots: List[AnnotationInfo], byName: List[Tree]): Type = { + tpe match { + // Would need to push annots into each alternative of overloaded type + // But we can't, since alternatives aren't types but symbols, which we + // can't change (we'd be affecting symbols globally) + /* + case OverloadedType(pre, alts) => + OverloadedType(pre, alts.map((sym: Symbol) => updateAttributes(pre.memberType(sym), annots))) + */ + case _ => + assert(childAnnots forall (_.atp.typeSymbol == MarkerCPSTypes), childAnnots) + /* + [] + [] = [] + plus + [] = plus + cps + [] = cps + plus cps + [] = plus cps + minus cps + [] = minus cp + synth cps + [] = synth cps // <- synth on left - does it happen? + + [] + cps = cps + plus + cps = synth cps + cps + cps = cps! <- lin + plus cps + cps = synth cps! <- unify + minus cps + cps = minus cps! <- lin + synth cps + cps = synth cps! <- unify + */ + + val plus = tpe.hasAnnotation(MarkerCPSAdaptPlus) || (tpe.hasAnnotation(MarkerCPSTypes) && + byName.nonEmpty && byName.forall(_.tpe.hasAnnotation(MarkerCPSAdaptPlus))) + + // move @plus annotations outward from by-name children + if (childAnnots.isEmpty) { + if (plus) { // @plus or @plus @cps + for (t <- byName) { + //println("removeAnnotation " + t + " / " + t.tpe) + t.setType(removeAttribs(t.tpe, MarkerCPSAdaptPlus, MarkerCPSTypes)) + } + return tpe.withAnnotation(AnnotationInfo(MarkerCPSAdaptPlus.tpe, Nil, Nil)) + } else + return tpe + } + + val annots1 = filterAttribs(tpe, MarkerCPSTypes) + + if (annots1.isEmpty) { // nothing or @plus + val synth = MarkerCPSSynth.tpe + val annots2 = List(linearize(childAnnots)) + removeAttribs(tpe,MarkerCPSAdaptPlus).withAnnotations(AnnotationInfo(synth, Nil, Nil)::annots2) + } else { + val annot1 = single(annots1) + if (plus) { // @plus @cps + val synth = AnnotationInfo(MarkerCPSSynth.tpe, Nil, Nil) + val annot2 = linearize(childAnnots) + if (!(annot2.atp <:< annot1.atp)) + throw new TypeError(annot2 + " is not a subtype of " + annot1) + val res = removeAttribs(tpe, MarkerCPSAdaptPlus, MarkerCPSTypes).withAnnotations(List(synth, annot2)) + for (t <- byName) { + //println("removeAnnotation " + t + " / " + t.tpe) + t.setType(removeAttribs(t.tpe, MarkerCPSAdaptPlus, MarkerCPSTypes)) + } + res + } else if (tpe.hasAnnotation(MarkerCPSSynth)) { // @synth @cps + val annot2 = linearize(childAnnots) + if (!(annot2.atp <:< annot1.atp)) + throw new TypeError(annot2 + " is not a subtype of " + annot1) + removeAttribs(tpe, MarkerCPSTypes).withAnnotation(annot2) + } else { // @cps + removeAttribs(tpe, MarkerCPSTypes).withAnnotation(linearize(childAnnots:::annots1)) + } + } + } + } + + + + + + def transArgList(fun: Tree, args: List[Tree]): List[List[Tree]] = { + val formals = fun.tpe.paramTypes + val overshoot = args.length - formals.length + + for ((a,tp) <- args.zip(formals ::: List.fill(overshoot)(NoType))) yield { + tp match { + case TypeRef(_, sym, List(elemtp)) if sym == ByNameParamClass => + Nil // TODO: check conformance?? + case _ => + List(a) + } + } + } + + + def transStms(stms: List[Tree]): List[Tree] = stms match { + case ValDef(mods, name, tpt, rhs)::xs => + rhs::transStms(xs) + case Assign(lhs, rhs)::xs => + rhs::transStms(xs) + case x::xs => + x::transStms(xs) + case Nil => + Nil + } + + def single(xs: List[AnnotationInfo]) = xs match { + case List(x) => x + case _ => + global.error("not a single cps annotation: " + xs)// FIXME: error message + xs(0) + } + + def transChildrenInOrder(tree: Tree, tpe: Type, childTrees: List[Tree], byName: List[Tree]) = { + val children = childTrees.flatMap { t => + if (t.tpe eq null) Nil else { + val types = filterAttribs(t.tpe, MarkerCPSTypes) + // TODO: check that it has been adapted and if so correctly + if (types.isEmpty) Nil else List(single(types)) + } + } + + val newtpe = updateAttributesFromChildren(tpe, children, byName) + + if (!newtpe.annotations.isEmpty) + vprintln("[checker] inferred " + tree + " / " + tpe + " ===> "+ newtpe) + + newtpe + } + + /** Modify the type that has thus far been inferred + * for a tree. All this should do is add annotations. */ + + override def addAnnotations(tree: Tree, tpe: Type): Type = { + if (!cpsEnabled) { + if (tpe.annotations.nonEmpty && tpe.hasAnnotation(MarkerCPSTypes)) + global.reporter.error(tree.pos, "this code must be compiled with the Scala continuations plugin enabled") + return tpe + } + +// if (tree.tpe.hasAnnotation(MarkerCPSAdaptPlus)) +// println("addAnnotation " + tree + "/" + tpe) + + tree match { + + case Apply(fun @ Select(qual, name), args) if (fun.tpe ne null) && !fun.tpe.isErroneous => + + // HACK: With overloaded methods, fun will never get annotated. This is because + // the 'overloaded' type gets annotated, but not the alternatives (among which + // fun's type is chosen) + + vprintln("[checker] checking select apply " + tree + "/" + tpe) + + transChildrenInOrder(tree, tpe, qual::(transArgList(fun, args).flatten), Nil) + + case TypeApply(fun @ Select(qual, name), args) if (fun.tpe ne null) && !fun.tpe.isErroneous => + vprintln("[checker] checking select apply " + tree + "/" + tpe) + + transChildrenInOrder(tree, tpe, List(qual, fun), Nil) + + case Apply(fun, args) if (fun.tpe ne null) && !fun.tpe.isErroneous => + + vprintln("[checker] checking unknown apply " + tree + "/" + tpe) + + transChildrenInOrder(tree, tpe, fun::(transArgList(fun, args).flatten), Nil) + + case TypeApply(fun, args) => + + vprintln("[checker] checking type apply " + tree + "/" + tpe) + + transChildrenInOrder(tree, tpe, List(fun), Nil) + + case Select(qual, name) => + + vprintln("[checker] checking select " + tree + "/" + tpe) + + // straightforward way is problematic (see select.scala and Test2.scala) + // transChildrenInOrder(tree, tpe, List(qual), Nil) + + // the problem is that qual may be of type OverloadedType (or MethodType) and + // we cannot safely annotate these. so we just ignore these cases and + // clean up later in the Apply/TypeApply trees. + + if (qual.tpe.hasAnnotation(MarkerCPSTypes)) { + // however there is one special case: + // if it's a method without parameters, just apply it. normally done in adapt, but + // we have to do it here so we don't lose the cps information (wouldn't trigger our + // adapt and there is no Apply/TypeApply created) + tpe match { + case PolyType(List(), restpe) => + //println("yep: " + restpe + "," + restpe.getClass) + transChildrenInOrder(tree, restpe, List(qual), Nil) + case _ : PolyType => tpe + case _ : MethodType => tpe + case _ : OverloadedType => tpe + case _ => + transChildrenInOrder(tree, tpe, List(qual), Nil) + } + } else + tpe + + case If(cond, thenp, elsep) => + transChildrenInOrder(tree, tpe, List(cond), List(thenp, elsep)) + + case Match(select, cases) => + // TODO: can there be cases that are not CaseDefs?? check collect vs map! + transChildrenInOrder(tree, tpe, List(select), cases:::(cases collect { case CaseDef(_, _, body) => body })) + + case Try(block, catches, finalizer) => + val tpe1 = transChildrenInOrder(tree, tpe, Nil, block::catches:::(catches collect { case CaseDef(_, _, body) => body })) + + val annots = filterAttribs(tpe1, MarkerCPSTypes) + if (annots.nonEmpty) { + val ann = single(annots) + val atp0::atp1::Nil = ann.atp.normalize.typeArgs + if (!(atp0 =:= atp1)) + throw new TypeError("only simple cps types allowed in try/catch blocks (found: " + tpe1 + ")") + if (!finalizer.isEmpty) // no finalizers allowed. see explanation in SelectiveCPSTransform + reporter.error(tree.pos, "try/catch blocks that use continuations cannot have finalizers") + } + tpe1 + + case Block(stms, expr) => + // if any stm has annotation, so does block + transChildrenInOrder(tree, tpe, transStms(stms), List(expr)) + + case ValDef(mods, name, tpt, rhs) => + vprintln("[checker] checking valdef " + name + "/"+tpe+"/"+tpt+"/"+tree.symbol.tpe) + // ValDef symbols must *not* have annotations! + if (hasAnswerTypeAnn(tree.symbol.info)) { // is it okay to modify sym here? + vprintln("removing annotation from sym " + tree.symbol + "/" + tree.symbol.tpe + "/" + tpt) + tpt.setType(removeAllCPSAnnotations(tpt.tpe)) + tree.symbol.setInfo(removeAllCPSAnnotations(tree.symbol.info)) + } + tpe + + case _ => + tpe + } + + + } + } +} diff --git a/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala new file mode 100644 index 0000000000..57cba6e829 --- /dev/null +++ b/src/continuations/plugin/scala/tools/selectivecps/CPSUtils.scala @@ -0,0 +1,131 @@ +// $Id$ + +package scala.tools.selectivecps + +import scala.tools.nsc.Global + +trait CPSUtils { + val global: Global + import global._ + import definitions._ + + var cpsEnabled = false + val verbose: Boolean = System.getProperty("cpsVerbose", "false") == "true" + @inline final def vprintln(x: =>Any): Unit = if (verbose) println(x) + + + lazy val MarkerCPSSym = definitions.getClass("scala.util.continuations.cpsSym") + lazy val MarkerCPSTypes = definitions.getClass("scala.util.continuations.cpsParam") + lazy val MarkerCPSSynth = definitions.getClass("scala.util.continuations.cpsSynth") + + lazy val MarkerCPSAdaptPlus = definitions.getClass("scala.util.continuations.cpsPlus") + lazy val MarkerCPSAdaptMinus = definitions.getClass("scala.util.continuations.cpsMinus") + + + lazy val Context = definitions.getClass("scala.util.continuations.ControlContext") + + lazy val ModCPS = definitions.getModule("scala.util.continuations") + lazy val MethShiftUnit = definitions.getMember(ModCPS, "shiftUnit") + lazy val MethShiftUnitR = definitions.getMember(ModCPS, "shiftUnitR") + lazy val MethShift = definitions.getMember(ModCPS, "shift") + lazy val MethShiftR = definitions.getMember(ModCPS, "shiftR") + lazy val MethReify = definitions.getMember(ModCPS, "reify") + lazy val MethReifyR = definitions.getMember(ModCPS, "reifyR") + + + lazy val allCPSAnnotations = List(MarkerCPSSym, MarkerCPSTypes, MarkerCPSSynth, + MarkerCPSAdaptPlus, MarkerCPSAdaptMinus) + + // annotation checker + + def filterAttribs(tpe:Type, cls:Symbol) = + tpe.annotations.filter(_.atp.typeSymbol == cls) + + def removeAttribs(tpe:Type, cls:Symbol*) = + tpe.withoutAnnotations.withAnnotations(tpe.annotations.filterNot(cls contains _.atp.typeSymbol)) + + def removeAllCPSAnnotations(tpe: Type) = removeAttribs(tpe, allCPSAnnotations:_*) + + def linearize(ann: List[AnnotationInfo]): AnnotationInfo = { + ann.reduceLeft { (a, b) => + val atp0::atp1::Nil = a.atp.normalize.typeArgs + val btp0::btp1::Nil = b.atp.normalize.typeArgs + val (u0,v0) = (atp0, atp1) + val (u1,v1) = (btp0, btp1) +/* + val (u0,v0) = (a.atp.typeArgs(0), a.atp.typeArgs(1)) + val (u1,v1) = (b.atp.typeArgs(0), b.atp.typeArgs(1)) + vprintln("check lin " + a + " andThen " + b) +*/ + vprintln("check lin " + a + " andThen " + b) + if (!(v1 <:< u0)) + throw new TypeError("illegal answer type modification: " + a + " andThen " + b) + // TODO: improve error message (but it is not very common) + AnnotationInfo(appliedType(MarkerCPSTypes.tpe, List(u1,v0)),Nil,Nil) + } + } + + // anf transform + + def getExternalAnswerTypeAnn(tp: Type) = { + tp.annotations.find(a => a.atp.typeSymbol == MarkerCPSTypes) match { + case Some(AnnotationInfo(atp, _, _)) => + val atp0::atp1::Nil = atp.normalize.typeArgs + Some((atp0, atp1)) + case None => + if (tp.hasAnnotation(MarkerCPSAdaptPlus)) + global.warning("trying to instantiate type " + tp + " to unknown cps type") + None + } + } + + def getAnswerTypeAnn(tp: Type) = { + tp.annotations.find(a => a.atp.typeSymbol == MarkerCPSTypes) match { + case Some(AnnotationInfo(atp, _, _)) => + if (!tp.hasAnnotation(MarkerCPSAdaptPlus)) {//&& !tp.hasAnnotation(MarkerCPSAdaptMinus)) + val atp0::atp1::Nil = atp.normalize.typeArgs + Some((atp0, atp1)) + } else + None + case None => None + } + } + + def hasAnswerTypeAnn(tp: Type) = { + tp.hasAnnotation(MarkerCPSTypes) && !tp.hasAnnotation(MarkerCPSAdaptPlus) /*&& + !tp.hasAnnotation(MarkerCPSAdaptMinus)*/ + } + + def hasSynthAnn(tp: Type) = { + tp.annotations.exists(a => a.atp.typeSymbol == MarkerCPSSynth) + } + + def updateSynthFlag(tree: Tree) = { // remove annotations if *we* added them (@synth present) + if (hasSynthAnn(tree.tpe)) { + log("removing annotation from " + tree) + tree.setType(removeAllCPSAnnotations(tree.tpe)) + } else + tree + } + + type CPSInfo = Option[(Type,Type)] + + def linearize(a: CPSInfo, b: CPSInfo)(implicit unit: CompilationUnit, pos: Position): CPSInfo = { + (a,b) match { + case (Some((u0,v0)), Some((u1,v1))) => + vprintln("check lin " + a + " andThen " + b) + if (!(v1 <:< u0)) { + unit.error(pos,"cannot change answer type in composition of cps expressions " + + "from " + u1 + " to " + v0 + " because " + v1 + " is not a subtype of " + u0 + ".") + throw new Exception("check lin " + a + " andThen " + b) + } + Some((u1,v0)) + case (Some(_), _) => a + case (_, Some(_)) => b + case _ => None + } + } + + // cps transform + +}
\ No newline at end of file diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala new file mode 100644 index 0000000000..0525e6fdbc --- /dev/null +++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveANFTransform.scala @@ -0,0 +1,414 @@ +// $Id$ + +package scala.tools.selectivecps + +import scala.tools.nsc._ +import scala.tools.nsc.transform._ +import scala.tools.nsc.symtab._ +import scala.tools.nsc.plugins._ + +import scala.tools.nsc.ast._ + +/** + * In methods marked @cps, explicitly name results of calls to other @cps methods + */ +abstract class SelectiveANFTransform extends PluginComponent with Transform with + TypingTransformers with CPSUtils { + // inherits abstract value `global' and class `Phase' from Transform + + import global._ // the global environment + import definitions._ // standard classes and methods + import typer.atOwner // methods to type trees + + /** the following two members override abstract members in Transform */ + val phaseName: String = "selectiveanf" + + protected def newTransformer(unit: CompilationUnit): Transformer = + new ANFTransformer(unit) + + + class ANFTransformer(unit: CompilationUnit) extends TypingTransformer(unit) { + + implicit val _unit = unit // allow code in CPSUtils.scala to report errors + var cpsAllowed: Boolean = false // detect cps code in places we do not handle (yet) + + override def transform(tree: Tree): Tree = { + if (!cpsEnabled) return tree + + tree match { + + // Maybe we should further generalize the transform and move it over + // to the regular Transformer facility. But then, actual and required cps + // state would need more complicated (stateful!) tracking. + + // Making the default case use transExpr(tree, None, None) instead of + // calling super.transform() would be a start, but at the moment, + // this would cause infinite recursion. But we could remove the + // ValDef case here. + + case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs) => + log("transforming " + dd.symbol) + + atOwner(dd.symbol) { + val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe)) + + log("result "+rhs1) + log("result is of type "+rhs1.tpe) + + treeCopy.DefDef(dd, mods, name, transformTypeDefs(tparams), transformValDefss(vparamss), + transform(tpt), rhs1) + } + + case ff @ Function(vparams, body) => + log("transforming anon function " + ff.symbol) + + atOwner(ff.symbol) { + + //val body1 = transExpr(body, None, getExternalAnswerTypeAnn(body.tpe)) + + // need to special case partial functions: if expected type is @cps + // but all cases are pure, then we would transform + // { x => x match { case A => ... }} to + // { x => shiftUnit(x match { case A => ... })} + // which Uncurry cannot handle (see function6.scala) + + val ext = getExternalAnswerTypeAnn(body.tpe) + + val body1 = body match { + case Match(selector, cases) if (ext.isDefined && getAnswerTypeAnn(body.tpe).isEmpty) => + val cases1 = for { + cd @ CaseDef(pat, guard, caseBody) <- cases + val caseBody1 = transExpr(body, None, ext) + } yield { + treeCopy.CaseDef(cd, transform(pat), transform(guard), caseBody1) + } + treeCopy.Match(tree, transform(selector), cases1) + + case _ => + transExpr(body, None, ext) + } + + log("result "+body1) + log("result is of type "+body1.tpe) + + treeCopy.Function(ff, transformValDefs(vparams), body1) + } + + case vd @ ValDef(mods, name, tpt, rhs) => // object-level valdefs + log("transforming valdef " + vd.symbol) + + atOwner(vd.symbol) { + + assert(getExternalAnswerTypeAnn(tpt.tpe) == None) + + val rhs1 = transExpr(rhs, None, None) + + treeCopy.ValDef(vd, mods, name, transform(tpt), rhs1) + } + + case TypeTree() => + // circumvent cpsAllowed here + super.transform(tree) + + case Apply(_,_) => + // this allows reset { ... } in object constructors + // it's kind of a hack to put it here (see note above) + transExpr(tree, None, None) + + case _ => + + if (hasAnswerTypeAnn(tree.tpe)) { + if (!cpsAllowed) + unit.error(tree.pos, "cps code not allowed here / " + tree.getClass + " / " + tree) + + log(tree) + } + + cpsAllowed = false + super.transform(tree) + } + } + + + def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): Tree = { + transTailValue(tree, cpsA, cpsR) match { + case (Nil, b) => b + case (a, b) => + treeCopy.Block(tree, a,b) + } + } + + + def transArgList(fun: Tree, args: List[Tree], cpsA: CPSInfo): (List[List[Tree]], List[Tree], CPSInfo) = { + val formals = fun.tpe.paramTypes + val overshoot = args.length - formals.length + + var spc: CPSInfo = cpsA + + val (stm,expr) = (for ((a,tp) <- args.zip(formals ::: List.fill(overshoot)(NoType))) yield { + tp match { + case TypeRef(_, sym, List(elemtp)) if sym == ByNameParamClass => + (Nil, transExpr(a, None, getAnswerTypeAnn(elemtp))) + case _ => + val (valStm, valExpr, valSpc) = transInlineValue(a, spc) + spc = valSpc + (valStm, valExpr) + } + }).unzip + + (stm,expr,spc) + } + + + def transValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): (List[Tree], Tree, CPSInfo) = { + // return value: (stms, expr, spc), where spc is CPSInfo after stms but *before* expr + implicit val pos = tree.pos + tree match { + case Block(stms, expr) => + val (cpsA2, cpsR2) = (cpsA, linearize(cpsA, getAnswerTypeAnn(tree.tpe))) // tbd +// val (cpsA2, cpsR2) = (None, getAnswerTypeAnn(tree.tpe)) + val (a, b) = transBlock(stms, expr, cpsA2, cpsR2) + + val tree1 = (treeCopy.Block(tree, a, b)) // no updateSynthFlag here!!! + + (Nil, tree1, cpsA) + + case If(cond, thenp, elsep) => + + val (condStats, condVal, spc) = transInlineValue(cond, cpsA) + + val (cpsA2, cpsR2) = (spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) +// val (cpsA2, cpsR2) = (None, getAnswerTypeAnn(tree.tpe)) + val thenVal = transExpr(thenp, cpsA2, cpsR2) + val elseVal = transExpr(elsep, cpsA2, cpsR2) + + // check that then and else parts agree (not necessary any more, but left as sanity check) + if (cpsR.isDefined) { + if (elsep == EmptyTree) + unit.error(tree.pos, "always need else part in cps code") + } + if (hasAnswerTypeAnn(thenVal.tpe) != hasAnswerTypeAnn(elseVal.tpe)) { + unit.error(tree.pos, "then and else parts must both be cps code or neither of them") + } + + (condStats, updateSynthFlag(treeCopy.If(tree, condVal, thenVal, elseVal)), spc) + + case Match(selector, cases) => + + val (selStats, selVal, spc) = transInlineValue(selector, cpsA) + val (cpsA2, cpsR2) = (spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) +// val (cpsA2, cpsR2) = (None, getAnswerTypeAnn(tree.tpe)) + + val caseVals = for { + cd @ CaseDef(pat, guard, body) <- cases + val bodyVal = transExpr(body, cpsA2, cpsR2) + } yield { + treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal) + } + + (selStats, updateSynthFlag(treeCopy.Match(tree, selVal, caseVals)), spc) + + + case ldef @ LabelDef(name, params, rhs) => + if (hasAnswerTypeAnn(tree.tpe)) { + val sym = currentOwner.newMethod(tree.pos, name)//unit.fresh.newName(tree.pos, "myloopvar")) + .setInfo(ldef.symbol.info) + .setFlag(Flags.SYNTHETIC) + + new TreeSymSubstituter(List(ldef.symbol), List(sym)).traverse(rhs) + val rhsVal = transExpr(rhs, None, getAnswerTypeAnn(tree.tpe)) + + val stm1 = localTyper.typed(DefDef(sym, rhsVal)) + val expr = localTyper.typed(Apply(Ident(sym), List())) + + (List(stm1), expr, cpsA) + } else { + val rhsVal = transExpr(rhs, None, None) + (Nil, updateSynthFlag(treeCopy.LabelDef(tree, name, params, rhsVal)), cpsA) + } + + + case Try(block, catches, finalizer) => + val blockVal = transExpr(block, cpsA, cpsR) + + val catchVals = for { + cd @ CaseDef(pat, guard, body) <- catches + val bodyVal = transExpr(body, cpsA, cpsR) + } yield { + treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal) + } + + val finallyVal = transExpr(finalizer, None, None) // for now, no cps in finally + + (Nil, updateSynthFlag(treeCopy.Try(tree, blockVal, catchVals, finallyVal)), cpsA) + + case Assign(lhs, rhs) => + // allow cps code in rhs only + val (stms, expr, spc) = transInlineValue(rhs, cpsA) + (stms, updateSynthFlag(treeCopy.Assign(tree, transform(lhs), expr)), spc) + + case Return(expr0) => + val (stms, expr, spc) = transInlineValue(expr0, cpsA) + (stms, updateSynthFlag(treeCopy.Return(tree, expr)), spc) + + case Throw(expr0) => + val (stms, expr, spc) = transInlineValue(expr0, cpsA) + (stms, updateSynthFlag(treeCopy.Throw(tree, expr)), spc) + + case Typed(expr0, tpt) => + // TODO: should x: A @cps[B,C] have a special meaning? + // type casts used in different ways (see match2.scala, #3199) + val (stms, expr, spc) = transInlineValue(expr0, cpsA) + val tpt1 = if (treeInfo.isWildcardStarArg(tree)) tpt else + treeCopy.TypeTree(tpt).setType(removeAllCPSAnnotations(tpt.tpe)) +// (stms, updateSynthFlag(treeCopy.Typed(tree, expr, tpt1)), spc) + (stms, treeCopy.Typed(tree, expr, tpt1).setType(removeAllCPSAnnotations(tree.tpe)), spc) + + case TypeApply(fun, args) => + val (stms, expr, spc) = transInlineValue(fun, cpsA) + (stms, updateSynthFlag(treeCopy.TypeApply(tree, expr, args)), spc) + + case Select(qual, name) => + val (stms, expr, spc) = transInlineValue(qual, cpsA) + (stms, updateSynthFlag(treeCopy.Select(tree, expr, name)), spc) + + case Apply(fun, args) => + val (funStm, funExpr, funSpc) = transInlineValue(fun, cpsA) + val (argStm, argExpr, argSpc) = transArgList(fun, args, funSpc) + + (funStm ::: (argStm.flatten), updateSynthFlag(treeCopy.Apply(tree, funExpr, argExpr)), + argSpc) + + case _ => + cpsAllowed = true + (Nil, transform(tree), cpsA) + } + } + + def transTailValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo): (List[Tree], Tree) = { + + val (stms, expr, spc) = transValue(tree, cpsA, cpsR) + + val bot = linearize(spc, getAnswerTypeAnn(expr.tpe))(unit, tree.pos) + + val plainTpe = removeAllCPSAnnotations(expr.tpe) + + if (cpsR.isDefined && !bot.isDefined) { + + if (!expr.isEmpty && (expr.tpe.typeSymbol ne NothingClass)) { + // must convert! + log("cps type conversion (has: " + cpsA + "/" + spc + "/" + expr.tpe + ")") + log("cps type conversion (expected: " + cpsR.get + "): " + expr) + + if (!expr.tpe.hasAnnotation(MarkerCPSAdaptPlus)) + unit.warning(tree.pos, "expression " + tree + " is cps-transformed unexpectedly") + + try { + val Some((a, b)) = cpsR + + val res = localTyper.typed(atPos(tree.pos) { + Apply(TypeApply(gen.mkAttributedRef(MethShiftUnit), + List(TypeTree(plainTpe), TypeTree(a), TypeTree(b))), + List(expr)) + }) + return (stms, res) + + } catch { + case ex:TypeError => + unit.error(ex.pos, "cannot cps-transform expression " + tree + ": " + ex.msg) + } + } + + } else if (!cpsR.isDefined && bot.isDefined) { + // error! + log("cps type error: " + expr) + //println("cps type error: " + expr + "/" + expr.tpe + "/" + getAnswerTypeAnn(expr.tpe)) + + println(cpsR + "/" + spc + "/" + bot) + + unit.error(tree.pos, "found cps expression in non-cps position") + } else { + // all is well + + if (expr.tpe.hasAnnotation(MarkerCPSAdaptPlus)) { + unit.warning(tree.pos, "expression " + expr + " of type " + expr.tpe + " is not expected to have a cps type") + expr.setType(removeAllCPSAnnotations(expr.tpe)) + } + + // TODO: sanity check that types agree + } + + (stms, expr) + } + + def transInlineValue(tree: Tree, cpsA: CPSInfo): (List[Tree], Tree, CPSInfo) = { + + val (stms, expr, spc) = transValue(tree, cpsA, None) // never required to be cps + + getAnswerTypeAnn(expr.tpe) match { + case spcVal @ Some(_) => + + val valueTpe = removeAllCPSAnnotations(expr.tpe) + + val sym = currentOwner.newValue(tree.pos, unit.fresh.newName(tree.pos, "tmp")) + .setInfo(valueTpe) + .setFlag(Flags.SYNTHETIC) + .setAnnotations(List(AnnotationInfo(MarkerCPSSym.tpe, Nil, Nil))) + + (stms ::: List(ValDef(sym, expr) setType(NoType)), + Ident(sym) setType(valueTpe) setPos(tree.pos), linearize(spc, spcVal)(unit, tree.pos)) + + case _ => + (stms, expr, spc) + } + + } + + + + def transInlineStm(stm: Tree, cpsA: CPSInfo): (List[Tree], CPSInfo) = { + stm match { + + // TODO: what about DefDefs? + // TODO: relation to top-level val def? + // TODO: what about lazy vals? + + case tree @ ValDef(mods, name, tpt, rhs) => + val (stms, anfRhs, spc) = atOwner(tree.symbol) { transValue(rhs, cpsA, None) } + + val tv = new ChangeOwnerTraverser(tree.symbol, currentOwner) + stms.foreach(tv.traverse(_)) + + // TODO: symbol might already have annotation. Should check conformance + // TODO: better yet: do without annotations on symbols + + val spcVal = getAnswerTypeAnn(anfRhs.tpe) + if (spcVal.isDefined) { + tree.symbol.setAnnotations(List(AnnotationInfo(MarkerCPSSym.tpe, Nil, Nil))) + } + + (stms:::List(treeCopy.ValDef(tree, mods, name, tpt, anfRhs)), linearize(spc, spcVal)(unit, tree.pos)) + + case _ => + val (headStms, headExpr, headSpc) = transInlineValue(stm, cpsA) + val valSpc = getAnswerTypeAnn(headExpr.tpe) + (headStms:::List(headExpr), linearize(headSpc, valSpc)(unit, stm.pos)) + } + } + + def transBlock(stms: List[Tree], expr: Tree, cpsA: CPSInfo, cpsR: CPSInfo): (List[Tree], Tree) = { + stms match { + case Nil => + transTailValue(expr, cpsA, cpsR) + + case stm::rest => + var (rest2, expr2) = (rest, expr) + val (headStms, headSpc) = transInlineStm(stm, cpsA) + val (restStms, restExpr) = transBlock(rest2, expr2, headSpc, cpsR) + (headStms:::restStms, restExpr) + } + } + + + } +} diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSPlugin.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSPlugin.scala new file mode 100644 index 0000000000..a16e9b9a4c --- /dev/null +++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSPlugin.scala @@ -0,0 +1,60 @@ +// $Id$ + +package scala.tools.selectivecps + +import scala.tools.nsc +import scala.tools.nsc.typechecker._ +import nsc.Global +import nsc.Phase +import nsc.plugins.Plugin +import nsc.plugins.PluginComponent + +class SelectiveCPSPlugin(val global: Global) extends Plugin { + import global._ + + val name = "continuations" + val description = "applies selective cps conversion" + + val anfPhase = new SelectiveANFTransform() { + val global = SelectiveCPSPlugin.this.global + val runsAfter = List("pickler") + } + + val cpsPhase = new SelectiveCPSTransform() { + val global = SelectiveCPSPlugin.this.global + val runsAfter = List("selectiveanf") + } + + + val components = List[PluginComponent](anfPhase, cpsPhase) + + val checker = new CPSAnnotationChecker { + val global: SelectiveCPSPlugin.this.global.type = SelectiveCPSPlugin.this.global + } + global.addAnnotationChecker(checker.checker) + + global.log("instantiated cps plugin: " + this) + + def setEnabled(flag: Boolean) = { + checker.cpsEnabled = flag + anfPhase.cpsEnabled = flag + cpsPhase.cpsEnabled = flag + } + + // TODO: require -enabled command-line flag + + override def processOptions(options: List[String], error: String => Unit) = { + var enabled = false + for (option <- options) { + if (option == "enable") { + enabled = true + } else { + error("Option not understood: "+option) + } + } + setEnabled(enabled) + } + + override val optionsHelp: Option[String] = + Some(" -P:continuations:enable Enable continuations") +} diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSTransform.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSTransform.scala new file mode 100644 index 0000000000..6da56f93d4 --- /dev/null +++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSTransform.scala @@ -0,0 +1,384 @@ +// $Id$ + +package scala.tools.selectivecps + +import scala.collection._ + +import scala.tools.nsc._ +import scala.tools.nsc.transform._ +import scala.tools.nsc.plugins._ + +import scala.tools.nsc.ast.TreeBrowsers +import scala.tools.nsc.ast._ + +/** + * In methods marked @cps, CPS-transform assignments introduced by ANF-transform phase. + */ +abstract class SelectiveCPSTransform extends PluginComponent with + InfoTransform with TypingTransformers with CPSUtils { + // inherits abstract value `global' and class `Phase' from Transform + + import global._ // the global environment + import definitions._ // standard classes and methods + import typer.atOwner // methods to type trees + + /** the following two members override abstract members in Transform */ + val phaseName: String = "selectivecps" + + protected def newTransformer(unit: CompilationUnit): Transformer = + new CPSTransformer(unit) + + /** This class does not change linearization */ + override def changesBaseClasses = false + + /** - return symbol's transformed type, + */ + def transformInfo(sym: Symbol, tp: Type): Type = { + if (!cpsEnabled) return tp + + val newtp = transformCPSType(tp) + + if (newtp != tp) + log("transformInfo changed type for " + sym + " to " + newtp); + + if (sym == MethReifyR) + log("transformInfo (not)changed type for " + sym + " to " + newtp); + + newtp + } + + def transformCPSType(tp: Type): Type = { // TODO: use a TypeMap? need to handle more cases? + tp match { + case PolyType(params,res) => PolyType(params, transformCPSType(res)) + case MethodType(params,res) => + MethodType(params, transformCPSType(res)) + case TypeRef(pre, sym, args) => TypeRef(pre, sym, args.map(transformCPSType(_))) + case _ => + getExternalAnswerTypeAnn(tp) match { + case Some((res, outer)) => + appliedType(Context.tpe, List(removeAllCPSAnnotations(tp), res, outer)) + case _ => + removeAllCPSAnnotations(tp) + } + } + } + + + class CPSTransformer(unit: CompilationUnit) extends TypingTransformer(unit) { + + override def transform(tree: Tree): Tree = { + if (!cpsEnabled) return tree + postTransform(mainTransform(tree)) + } + + def postTransform(tree: Tree): Tree = { + tree.setType(transformCPSType(tree.tpe)) + } + + + def mainTransform(tree: Tree): Tree = { + tree match { + + // TODO: can we generalize this? + + case Apply(TypeApply(fun, targs), args) + if (fun.symbol == MethShift) => + log("found shift: " + tree) + atPos(tree.pos) { + val funR = gen.mkAttributedRef(MethShiftR) // TODO: correct? + //gen.mkAttributedSelect(gen.mkAttributedSelect(gen.mkAttributedSelect(gen.mkAttributedIdent(ScalaPackage), + //ScalaPackage.tpe.member("util")), ScalaPackage.tpe.member("util").tpe.member("continuations")), MethShiftR) + //gen.mkAttributedRef(ModCPS.tpe, MethShiftR) // TODO: correct? + log(funR.tpe) + Apply( + TypeApply(funR, targs).setType(appliedType(funR.tpe, targs.map((t:Tree) => t.tpe))), + args.map(transform(_)) + ).setType(transformCPSType(tree.tpe)) + } + + case Apply(TypeApply(fun, targs), args) + if (fun.symbol == MethShiftUnit) => + log("found shiftUnit: " + tree) + atPos(tree.pos) { + val funR = gen.mkAttributedRef(MethShiftUnitR) // TODO: correct? + log(funR.tpe) + Apply( + TypeApply(funR, List(targs(0), targs(1))).setType(appliedType(funR.tpe, + List(targs(0).tpe, targs(1).tpe))), + args.map(transform(_)) + ).setType(appliedType(Context.tpe, List(targs(0).tpe,targs(1).tpe,targs(1).tpe))) + } + + case Apply(TypeApply(fun, targs), args) + if (fun.symbol == MethReify) => + log("found reify: " + tree) + atPos(tree.pos) { + val funR = gen.mkAttributedRef(MethReifyR) // TODO: correct? + log(funR.tpe) + Apply( + TypeApply(funR, targs).setType(appliedType(funR.tpe, targs.map((t:Tree) => t.tpe))), + args.map(transform(_)) + ).setType(transformCPSType(tree.tpe)) + } + + case Try(block, catches, finalizer) => + // currently duplicates the catch block into a partial function. + // this is kinda risky, but we don't expect there will be lots + // of try/catches inside catch blocks (exp. blowup unlikely). + + // CAVEAT: finalizers are surprisingly tricky! + // the problem is that they cannot easily be removed + // from the regular control path and hence will + // also be invoked after creating the Context object. + + /* + object Test { + def foo1 = { + throw new Exception("in sub") + shift((k:Int=>Int) => k(1)) + 10 + } + def foo2 = { + shift((k:Int=>Int) => k(2)) + 20 + } + def foo3 = { + shift((k:Int=>Int) => k(3)) + throw new Exception("in sub") + 30 + } + def foo4 = { + shift((k:Int=>Int) => 4) + throw new Exception("in sub") + 40 + } + def bar(x: Int) = try { + if (x == 1) + foo1 + else if (x == 2) + foo2 + else if (x == 3) + foo3 + else //if (x == 4) + foo4 + } catch { + case _ => + println("exception") + 0 + } finally { + println("done") + } + } + + reset(Test.bar(1)) // should print: exception,done,0 + reset(Test.bar(2)) // should print: done,20 <-- but prints: done,done,20 + reset(Test.bar(3)) // should print: exception,done,0 <-- but prints: done,exception,done,0 + reset(Test.bar(4)) // should print: 4 <-- but prints: done,4 + */ + + val block1 = transform(block) + val catches1 = transformCaseDefs(catches) + val finalizer1 = transform(finalizer) + + if (hasAnswerTypeAnn(tree.tpe)) { + //vprintln("CPS Transform: " + tree + "/" + tree.tpe + "/" + block1.tpe) + + val (stms, expr1) = block1 match { + case Block(stms, expr) => (stms, expr) + case expr => (Nil, expr) + } + + val targettp = transformCPSType(tree.tpe) + +// val expr2 = if (catches.nonEmpty) { + val pos = catches.head.pos + val argSym = currentOwner.newValueParameter(pos, "$ex").setInfo(ThrowableClass.tpe) + val rhs = Match(Ident(argSym), catches1) + val fun = Function(List(ValDef(argSym)), rhs) + val funSym = currentOwner.newValueParameter(pos, "$catches").setInfo(appliedType(PartialFunctionClass.tpe, List(ThrowableClass.tpe, targettp))) + val funDef = localTyper.typed(atPos(pos) { ValDef(funSym, fun) }) + val expr2 = localTyper.typed(atPos(pos) { Apply(Select(expr1, expr1.tpe.member("flatMapCatch")), List(Ident(funSym))) }) + + argSym.owner = fun.symbol + val chown = new ChangeOwnerTraverser(currentOwner, fun.symbol) + chown.traverse(rhs) + + val exSym = currentOwner.newValueParameter(pos, "$ex").setInfo(ThrowableClass.tpe) + val catch2 = { localTyper.typedCases(tree, List( + CaseDef(Bind(exSym, Typed(Ident("_"), TypeTree(ThrowableClass.tpe))), + Apply(Select(Ident(funSym), "isDefinedAt"), List(Ident(exSym))), + Apply(Ident(funSym), List(Ident(exSym)))) + ), ThrowableClass.tpe, targettp) } + + //typedCases(tree, catches, ThrowableClass.tpe, pt) + + localTyper.typed(Block(List(funDef), treeCopy.Try(tree, treeCopy.Block(block1, stms, expr2), catch2, finalizer1))) + + +/* + disabled for now - see notes above + + val expr3 = if (!finalizer.isEmpty) { + val pos = finalizer.pos + val finalizer2 = duplicateTree(finalizer1) + val fun = Function(List(), finalizer2) + val expr3 = localTyper.typed(atPos(pos) { Apply(Select(expr2, expr2.tpe.member("mapFinally")), List(fun)) }) + + val chown = new ChangeOwnerTraverser(currentOwner, fun.symbol) + chown.traverse(finalizer2) + + expr3 + } else + expr2 +*/ + } else { + treeCopy.Try(tree, block1, catches1, finalizer1) + } + + case Block(stms, expr) => + + val (stms1, expr1) = transBlock(stms, expr) + treeCopy.Block(tree, stms1, expr1) + + case _ => + super.transform(tree) + } + } + + + + def transBlock(stms: List[Tree], expr: Tree): (List[Tree], Tree) = { + + stms match { + case Nil => + (Nil, transform(expr)) + + case stm::rest => + + stm match { + case vd @ ValDef(mods, name, tpt, rhs) + if (vd.symbol.hasAnnotation(MarkerCPSSym)) => + + log("found marked ValDef "+name+" of type " + vd.symbol.tpe) + + val tpe = vd.symbol.tpe + val rhs1 = atOwner(vd.symbol) { transform(rhs) } + + new ChangeOwnerTraverser(vd.symbol, currentOwner).traverse(rhs1) // TODO: don't traverse twice + + log("valdef symbol " + vd.symbol + " has type " + tpe) + log("right hand side " + rhs1 + " has type " + rhs1.tpe) + + log("currentOwner: " + currentOwner) + log("currentMethod: " + currentMethod) + + val (bodyStms, bodyExpr) = transBlock(rest, expr) + // FIXME: result will later be traversed again by TreeSymSubstituter and + // ChangeOwnerTraverser => exp. running time. + // Should be changed to fuse traversals into one. + + val specialCaseTrivial = bodyExpr match { + case Apply(fun, args) => + // for now, look for explicit tail calls only. + // are there other cases that could profit from specializing on + // trivial contexts as well? + (bodyExpr.tpe.typeSymbol == Context) && (currentMethod == fun.symbol) + case _ => false + } + + def applyTrivial(ctxValSym: Symbol, body: Tree) = { + + new TreeSymSubstituter(List(vd.symbol), List(ctxValSym)).traverse(body) + + val body2 = localTyper.typed(atPos(vd.symbol.pos) { body }) + + // in theory it would be nicer to look for an @cps annotation instead + // of testing for Context + if ((body2.tpe == null) || !(body2.tpe.typeSymbol == Context)) { + //println(body2 + "/" + body2.tpe) + unit.error(rhs.pos, "cannot compute type for CPS-transformed function result") + } + body2 + } + + def applyCombinatorFun(ctxR: Tree, body: Tree) = { + val arg = currentOwner.newValueParameter(ctxR.pos, name).setInfo(tpe) + new TreeSymSubstituter(List(vd.symbol), List(arg)).traverse(body) + val fun = localTyper.typed(atPos(vd.symbol.pos) { Function(List(ValDef(arg)), body) }) // types body as well + arg.owner = fun.symbol + new ChangeOwnerTraverser(currentOwner, fun.symbol).traverse(body) + + // see note about multiple traversals above + + log("fun.symbol: "+fun.symbol) + log("fun.symbol.owner: "+fun.symbol.owner) + log("arg.owner: "+arg.owner) + + log("fun.tpe:"+fun.tpe) + log("return type of fun:"+body.tpe) + + var methodName = "map" + + if (body.tpe != null) { + if (body.tpe.typeSymbol == Context) + methodName = "flatMap" + } + else + unit.error(rhs.pos, "cannot compute type for CPS-transformed function result") + + log("will use method:"+methodName) + + localTyper.typed(atPos(vd.symbol.pos) { + Apply(Select(ctxR, ctxR.tpe.member(methodName)), List(fun)) + }) + } + + def mkBlock(stms: List[Tree], expr: Tree) = if (stms.nonEmpty) Block(stms, expr) else expr + + try { + if (specialCaseTrivial) { + log("will optimize possible tail call: " + bodyExpr) + + // FIXME: flatMap impl has become more complicated due to + // exceptions. do we need to put a try/catch in the then part?? + + // val ctx = <rhs> + // if (ctx.isTrivial) + // val <lhs> = ctx.getTrivialValue; ... <--- TODO: try/catch ??? don't bother for the moment... + // else + // ctx.flatMap { <lhs> => ... } + val ctxSym = currentOwner.newValue(vd.symbol.name + "$shift").setInfo(rhs1.tpe) + val ctxDef = localTyper.typed(ValDef(ctxSym, rhs1)) + def ctxRef = localTyper.typed(Ident(ctxSym)) + val argSym = currentOwner.newValue(vd.symbol.name).setInfo(tpe) + val argDef = localTyper.typed(ValDef(argSym, Select(ctxRef, ctxRef.tpe.member("getTrivialValue")))) + val switchExpr = localTyper.typed(atPos(vd.symbol.pos) { + val body2 = duplicateTree(mkBlock(bodyStms, bodyExpr)) // dup before typing! + If(Select(ctxRef, ctxSym.tpe.member("isTrivial")), + applyTrivial(argSym, mkBlock(argDef::bodyStms, bodyExpr)), + applyCombinatorFun(ctxRef, body2)) + }) + (List(ctxDef), switchExpr) + } else { + // ctx.flatMap { <lhs> => ... } + // or + // ctx.map { <lhs> => ... } + (Nil, applyCombinatorFun(rhs1, mkBlock(bodyStms, bodyExpr))) + } + } catch { + case ex:TypeError => + unit.error(ex.pos, ex.msg) + (bodyStms, bodyExpr) + } + + case _ => + val stm1 = transform(stm) + val (a, b) = transBlock(rest, expr) + (stm1::a, b) + } + } + } + + + } +} diff --git a/src/continuations/plugin/scalac-plugin.xml b/src/continuations/plugin/scalac-plugin.xml new file mode 100644 index 0000000000..04d42655c5 --- /dev/null +++ b/src/continuations/plugin/scalac-plugin.xml @@ -0,0 +1,5 @@ +<!-- $Id$ --> +<plugin> + <name>continuations</name> + <classname>scala.tools.selectivecps.SelectiveCPSPlugin</classname> +</plugin> |