summaryrefslogtreecommitdiff
path: root/src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSTransform.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSTransform.scala')
-rw-r--r--src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSTransform.scala384
1 files changed, 384 insertions, 0 deletions
diff --git a/src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSTransform.scala b/src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSTransform.scala
new file mode 100644
index 0000000000..6da56f93d4
--- /dev/null
+++ b/src/continuations/plugin/scala/tools/selectivecps/SelectiveCPSTransform.scala
@@ -0,0 +1,384 @@
+// $Id$
+
+package scala.tools.selectivecps
+
+import scala.collection._
+
+import scala.tools.nsc._
+import scala.tools.nsc.transform._
+import scala.tools.nsc.plugins._
+
+import scala.tools.nsc.ast.TreeBrowsers
+import scala.tools.nsc.ast._
+
+/**
+ * In methods marked @cps, CPS-transform assignments introduced by ANF-transform phase.
+ */
+abstract class SelectiveCPSTransform extends PluginComponent with
+ InfoTransform with TypingTransformers with CPSUtils {
+ // inherits abstract value `global' and class `Phase' from Transform
+
+ import global._ // the global environment
+ import definitions._ // standard classes and methods
+ import typer.atOwner // methods to type trees
+
+ /** the following two members override abstract members in Transform */
+ val phaseName: String = "selectivecps"
+
+ protected def newTransformer(unit: CompilationUnit): Transformer =
+ new CPSTransformer(unit)
+
+ /** This class does not change linearization */
+ override def changesBaseClasses = false
+
+ /** - return symbol's transformed type,
+ */
+ def transformInfo(sym: Symbol, tp: Type): Type = {
+ if (!cpsEnabled) return tp
+
+ val newtp = transformCPSType(tp)
+
+ if (newtp != tp)
+ log("transformInfo changed type for " + sym + " to " + newtp);
+
+ if (sym == MethReifyR)
+ log("transformInfo (not)changed type for " + sym + " to " + newtp);
+
+ newtp
+ }
+
+ def transformCPSType(tp: Type): Type = { // TODO: use a TypeMap? need to handle more cases?
+ tp match {
+ case PolyType(params,res) => PolyType(params, transformCPSType(res))
+ case MethodType(params,res) =>
+ MethodType(params, transformCPSType(res))
+ case TypeRef(pre, sym, args) => TypeRef(pre, sym, args.map(transformCPSType(_)))
+ case _ =>
+ getExternalAnswerTypeAnn(tp) match {
+ case Some((res, outer)) =>
+ appliedType(Context.tpe, List(removeAllCPSAnnotations(tp), res, outer))
+ case _ =>
+ removeAllCPSAnnotations(tp)
+ }
+ }
+ }
+
+
+ class CPSTransformer(unit: CompilationUnit) extends TypingTransformer(unit) {
+
+ override def transform(tree: Tree): Tree = {
+ if (!cpsEnabled) return tree
+ postTransform(mainTransform(tree))
+ }
+
+ def postTransform(tree: Tree): Tree = {
+ tree.setType(transformCPSType(tree.tpe))
+ }
+
+
+ def mainTransform(tree: Tree): Tree = {
+ tree match {
+
+ // TODO: can we generalize this?
+
+ case Apply(TypeApply(fun, targs), args)
+ if (fun.symbol == MethShift) =>
+ log("found shift: " + tree)
+ atPos(tree.pos) {
+ val funR = gen.mkAttributedRef(MethShiftR) // TODO: correct?
+ //gen.mkAttributedSelect(gen.mkAttributedSelect(gen.mkAttributedSelect(gen.mkAttributedIdent(ScalaPackage),
+ //ScalaPackage.tpe.member("util")), ScalaPackage.tpe.member("util").tpe.member("continuations")), MethShiftR)
+ //gen.mkAttributedRef(ModCPS.tpe, MethShiftR) // TODO: correct?
+ log(funR.tpe)
+ Apply(
+ TypeApply(funR, targs).setType(appliedType(funR.tpe, targs.map((t:Tree) => t.tpe))),
+ args.map(transform(_))
+ ).setType(transformCPSType(tree.tpe))
+ }
+
+ case Apply(TypeApply(fun, targs), args)
+ if (fun.symbol == MethShiftUnit) =>
+ log("found shiftUnit: " + tree)
+ atPos(tree.pos) {
+ val funR = gen.mkAttributedRef(MethShiftUnitR) // TODO: correct?
+ log(funR.tpe)
+ Apply(
+ TypeApply(funR, List(targs(0), targs(1))).setType(appliedType(funR.tpe,
+ List(targs(0).tpe, targs(1).tpe))),
+ args.map(transform(_))
+ ).setType(appliedType(Context.tpe, List(targs(0).tpe,targs(1).tpe,targs(1).tpe)))
+ }
+
+ case Apply(TypeApply(fun, targs), args)
+ if (fun.symbol == MethReify) =>
+ log("found reify: " + tree)
+ atPos(tree.pos) {
+ val funR = gen.mkAttributedRef(MethReifyR) // TODO: correct?
+ log(funR.tpe)
+ Apply(
+ TypeApply(funR, targs).setType(appliedType(funR.tpe, targs.map((t:Tree) => t.tpe))),
+ args.map(transform(_))
+ ).setType(transformCPSType(tree.tpe))
+ }
+
+ case Try(block, catches, finalizer) =>
+ // currently duplicates the catch block into a partial function.
+ // this is kinda risky, but we don't expect there will be lots
+ // of try/catches inside catch blocks (exp. blowup unlikely).
+
+ // CAVEAT: finalizers are surprisingly tricky!
+ // the problem is that they cannot easily be removed
+ // from the regular control path and hence will
+ // also be invoked after creating the Context object.
+
+ /*
+ object Test {
+ def foo1 = {
+ throw new Exception("in sub")
+ shift((k:Int=>Int) => k(1))
+ 10
+ }
+ def foo2 = {
+ shift((k:Int=>Int) => k(2))
+ 20
+ }
+ def foo3 = {
+ shift((k:Int=>Int) => k(3))
+ throw new Exception("in sub")
+ 30
+ }
+ def foo4 = {
+ shift((k:Int=>Int) => 4)
+ throw new Exception("in sub")
+ 40
+ }
+ def bar(x: Int) = try {
+ if (x == 1)
+ foo1
+ else if (x == 2)
+ foo2
+ else if (x == 3)
+ foo3
+ else //if (x == 4)
+ foo4
+ } catch {
+ case _ =>
+ println("exception")
+ 0
+ } finally {
+ println("done")
+ }
+ }
+
+ reset(Test.bar(1)) // should print: exception,done,0
+ reset(Test.bar(2)) // should print: done,20 <-- but prints: done,done,20
+ reset(Test.bar(3)) // should print: exception,done,0 <-- but prints: done,exception,done,0
+ reset(Test.bar(4)) // should print: 4 <-- but prints: done,4
+ */
+
+ val block1 = transform(block)
+ val catches1 = transformCaseDefs(catches)
+ val finalizer1 = transform(finalizer)
+
+ if (hasAnswerTypeAnn(tree.tpe)) {
+ //vprintln("CPS Transform: " + tree + "/" + tree.tpe + "/" + block1.tpe)
+
+ val (stms, expr1) = block1 match {
+ case Block(stms, expr) => (stms, expr)
+ case expr => (Nil, expr)
+ }
+
+ val targettp = transformCPSType(tree.tpe)
+
+// val expr2 = if (catches.nonEmpty) {
+ val pos = catches.head.pos
+ val argSym = currentOwner.newValueParameter(pos, "$ex").setInfo(ThrowableClass.tpe)
+ val rhs = Match(Ident(argSym), catches1)
+ val fun = Function(List(ValDef(argSym)), rhs)
+ val funSym = currentOwner.newValueParameter(pos, "$catches").setInfo(appliedType(PartialFunctionClass.tpe, List(ThrowableClass.tpe, targettp)))
+ val funDef = localTyper.typed(atPos(pos) { ValDef(funSym, fun) })
+ val expr2 = localTyper.typed(atPos(pos) { Apply(Select(expr1, expr1.tpe.member("flatMapCatch")), List(Ident(funSym))) })
+
+ argSym.owner = fun.symbol
+ val chown = new ChangeOwnerTraverser(currentOwner, fun.symbol)
+ chown.traverse(rhs)
+
+ val exSym = currentOwner.newValueParameter(pos, "$ex").setInfo(ThrowableClass.tpe)
+ val catch2 = { localTyper.typedCases(tree, List(
+ CaseDef(Bind(exSym, Typed(Ident("_"), TypeTree(ThrowableClass.tpe))),
+ Apply(Select(Ident(funSym), "isDefinedAt"), List(Ident(exSym))),
+ Apply(Ident(funSym), List(Ident(exSym))))
+ ), ThrowableClass.tpe, targettp) }
+
+ //typedCases(tree, catches, ThrowableClass.tpe, pt)
+
+ localTyper.typed(Block(List(funDef), treeCopy.Try(tree, treeCopy.Block(block1, stms, expr2), catch2, finalizer1)))
+
+
+/*
+ disabled for now - see notes above
+
+ val expr3 = if (!finalizer.isEmpty) {
+ val pos = finalizer.pos
+ val finalizer2 = duplicateTree(finalizer1)
+ val fun = Function(List(), finalizer2)
+ val expr3 = localTyper.typed(atPos(pos) { Apply(Select(expr2, expr2.tpe.member("mapFinally")), List(fun)) })
+
+ val chown = new ChangeOwnerTraverser(currentOwner, fun.symbol)
+ chown.traverse(finalizer2)
+
+ expr3
+ } else
+ expr2
+*/
+ } else {
+ treeCopy.Try(tree, block1, catches1, finalizer1)
+ }
+
+ case Block(stms, expr) =>
+
+ val (stms1, expr1) = transBlock(stms, expr)
+ treeCopy.Block(tree, stms1, expr1)
+
+ case _ =>
+ super.transform(tree)
+ }
+ }
+
+
+
+ def transBlock(stms: List[Tree], expr: Tree): (List[Tree], Tree) = {
+
+ stms match {
+ case Nil =>
+ (Nil, transform(expr))
+
+ case stm::rest =>
+
+ stm match {
+ case vd @ ValDef(mods, name, tpt, rhs)
+ if (vd.symbol.hasAnnotation(MarkerCPSSym)) =>
+
+ log("found marked ValDef "+name+" of type " + vd.symbol.tpe)
+
+ val tpe = vd.symbol.tpe
+ val rhs1 = atOwner(vd.symbol) { transform(rhs) }
+
+ new ChangeOwnerTraverser(vd.symbol, currentOwner).traverse(rhs1) // TODO: don't traverse twice
+
+ log("valdef symbol " + vd.symbol + " has type " + tpe)
+ log("right hand side " + rhs1 + " has type " + rhs1.tpe)
+
+ log("currentOwner: " + currentOwner)
+ log("currentMethod: " + currentMethod)
+
+ val (bodyStms, bodyExpr) = transBlock(rest, expr)
+ // FIXME: result will later be traversed again by TreeSymSubstituter and
+ // ChangeOwnerTraverser => exp. running time.
+ // Should be changed to fuse traversals into one.
+
+ val specialCaseTrivial = bodyExpr match {
+ case Apply(fun, args) =>
+ // for now, look for explicit tail calls only.
+ // are there other cases that could profit from specializing on
+ // trivial contexts as well?
+ (bodyExpr.tpe.typeSymbol == Context) && (currentMethod == fun.symbol)
+ case _ => false
+ }
+
+ def applyTrivial(ctxValSym: Symbol, body: Tree) = {
+
+ new TreeSymSubstituter(List(vd.symbol), List(ctxValSym)).traverse(body)
+
+ val body2 = localTyper.typed(atPos(vd.symbol.pos) { body })
+
+ // in theory it would be nicer to look for an @cps annotation instead
+ // of testing for Context
+ if ((body2.tpe == null) || !(body2.tpe.typeSymbol == Context)) {
+ //println(body2 + "/" + body2.tpe)
+ unit.error(rhs.pos, "cannot compute type for CPS-transformed function result")
+ }
+ body2
+ }
+
+ def applyCombinatorFun(ctxR: Tree, body: Tree) = {
+ val arg = currentOwner.newValueParameter(ctxR.pos, name).setInfo(tpe)
+ new TreeSymSubstituter(List(vd.symbol), List(arg)).traverse(body)
+ val fun = localTyper.typed(atPos(vd.symbol.pos) { Function(List(ValDef(arg)), body) }) // types body as well
+ arg.owner = fun.symbol
+ new ChangeOwnerTraverser(currentOwner, fun.symbol).traverse(body)
+
+ // see note about multiple traversals above
+
+ log("fun.symbol: "+fun.symbol)
+ log("fun.symbol.owner: "+fun.symbol.owner)
+ log("arg.owner: "+arg.owner)
+
+ log("fun.tpe:"+fun.tpe)
+ log("return type of fun:"+body.tpe)
+
+ var methodName = "map"
+
+ if (body.tpe != null) {
+ if (body.tpe.typeSymbol == Context)
+ methodName = "flatMap"
+ }
+ else
+ unit.error(rhs.pos, "cannot compute type for CPS-transformed function result")
+
+ log("will use method:"+methodName)
+
+ localTyper.typed(atPos(vd.symbol.pos) {
+ Apply(Select(ctxR, ctxR.tpe.member(methodName)), List(fun))
+ })
+ }
+
+ def mkBlock(stms: List[Tree], expr: Tree) = if (stms.nonEmpty) Block(stms, expr) else expr
+
+ try {
+ if (specialCaseTrivial) {
+ log("will optimize possible tail call: " + bodyExpr)
+
+ // FIXME: flatMap impl has become more complicated due to
+ // exceptions. do we need to put a try/catch in the then part??
+
+ // val ctx = <rhs>
+ // if (ctx.isTrivial)
+ // val <lhs> = ctx.getTrivialValue; ... <--- TODO: try/catch ??? don't bother for the moment...
+ // else
+ // ctx.flatMap { <lhs> => ... }
+ val ctxSym = currentOwner.newValue(vd.symbol.name + "$shift").setInfo(rhs1.tpe)
+ val ctxDef = localTyper.typed(ValDef(ctxSym, rhs1))
+ def ctxRef = localTyper.typed(Ident(ctxSym))
+ val argSym = currentOwner.newValue(vd.symbol.name).setInfo(tpe)
+ val argDef = localTyper.typed(ValDef(argSym, Select(ctxRef, ctxRef.tpe.member("getTrivialValue"))))
+ val switchExpr = localTyper.typed(atPos(vd.symbol.pos) {
+ val body2 = duplicateTree(mkBlock(bodyStms, bodyExpr)) // dup before typing!
+ If(Select(ctxRef, ctxSym.tpe.member("isTrivial")),
+ applyTrivial(argSym, mkBlock(argDef::bodyStms, bodyExpr)),
+ applyCombinatorFun(ctxRef, body2))
+ })
+ (List(ctxDef), switchExpr)
+ } else {
+ // ctx.flatMap { <lhs> => ... }
+ // or
+ // ctx.map { <lhs> => ... }
+ (Nil, applyCombinatorFun(rhs1, mkBlock(bodyStms, bodyExpr)))
+ }
+ } catch {
+ case ex:TypeError =>
+ unit.error(ex.pos, ex.msg)
+ (bodyStms, bodyExpr)
+ }
+
+ case _ =>
+ val stm1 = transform(stm)
+ val (a, b) = transBlock(rest, expr)
+ (stm1::a, b)
+ }
+ }
+ }
+
+
+ }
+}