summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIulian Dragos <jaguarul@gmail.com>2005-06-28 08:52:00 +0000
committerIulian Dragos <jaguarul@gmail.com>2005-06-28 08:52:00 +0000
commitb4ba0b8045dc81357d4124b4913db0531665e54e (patch)
tree6628a595758c6548fdbf2e20eb51c6370be0f503
parentbca043774f5d799670a11c4843f6239de51fff3d (diff)
downloadscala-b4ba0b8045dc81357d4124b4913db0531665e54e.tar.gz
scala-b4ba0b8045dc81357d4124b4913db0531665e54e.tar.bz2
scala-b4ba0b8045dc81357d4124b4913db0531665e54e.zip
Added tail call elimination phase to nsc.
-rwxr-xr-xsources/scala/tools/nsc/Global.scala7
-rw-r--r--sources/scala/tools/nsc/transform/TailCalls.scala243
2 files changed, 249 insertions, 1 deletions
diff --git a/sources/scala/tools/nsc/Global.scala b/sources/scala/tools/nsc/Global.scala
index d686ca3493..1507be774f 100755
--- a/sources/scala/tools/nsc/Global.scala
+++ b/sources/scala/tools/nsc/Global.scala
@@ -186,6 +186,11 @@ class Global(val settings: Settings, val reporter: Reporter) extends SymbolTable
}
val samplePhase = new sampleTransform.Phase(transMatchPhase);
+ object tailCalls extends TailCalls {
+ val global: Global.this.type = Global.this;
+ }
+ val tailCallPhase = new tailCalls.Phase(samplePhase);
+
//val transMatchPhase = new transmatcher.TransMatchPhase(picklePhase);
/*
object icode extends ICode {
@@ -199,7 +204,7 @@ class Global(val settings: Settings, val reporter: Reporter) extends SymbolTable
}
*/
- val terminalPhase = new Phase(samplePhase) {
+ val terminalPhase = new Phase(tailCallPhase) {
def name = "terminal";
def run: unit = {}
}
diff --git a/sources/scala/tools/nsc/transform/TailCalls.scala b/sources/scala/tools/nsc/transform/TailCalls.scala
new file mode 100644
index 0000000000..eb42966312
--- /dev/null
+++ b/sources/scala/tools/nsc/transform/TailCalls.scala
@@ -0,0 +1,243 @@
+/* NSC -- new scala compiler
+ * Copyright 2005 LAMP/EPFL
+ * @author
+ */
+// $Id$
+package scala.tools.nsc.transform;
+
+/** Perform tail recursive call elimination.
+ */
+abstract class TailCalls extends Transform {
+ // inherits abstract value `global' and class `Phase' from Transform
+
+ import global._; // the global environment
+ import definitions._; // standard classes and methods
+ import typer.{typed, atOwner}; // methods to type trees
+ import posAssigner.atPos; // for filling in tree positions
+
+ protected val phaseName: String = "tailcalls";
+ protected def newTransformer(unit: CompilationUnit): Transformer = new TailCallElimination(unit);
+
+ /**
+ * A Tail Call Transformer
+ *
+ * @author Erik Stenman, Iulian Dragos
+ * @version 1.1
+ *
+ * What it does:
+ *
+ * Finds method calls in tail-position and replaces them with jumps.
+ * A call is in a tail-position if it is the last instruction to be
+ * executed in the body of a method. This is done by recursing over
+ * the trees that may contain calls in tail-position (trees that can't
+ * contain such calls are not transformed). However, they are not that
+ * many.
+ *
+ * Self-recursive calls in tail-position are replaced by jumps to a
+ * label at the beginning of the method. As the JVM provides no way to
+ * jump from a method to another one, non-recursive calls in
+ * tail-position are not optimized.
+ *
+ * A method call is self-recursive if it calls the current method on
+ * the current instance and the method is final (otherwise, it could
+ * be a call to an overridden method in a subclass). Furthermore, If
+ * the method has type parameters, the call must contain these
+ * parameters as type arguments.
+ *
+ * This phase has been moved before pattern matching to catch more
+ * of the common cases of tail recursive functions. This means that
+ * more cases should be taken into account (like nested function, and
+ * pattern cases).
+ *
+ * If a method contains self-recursive calls, a label is added to at
+ * the beginning of its body and the calls are replaced by jumps to
+ * that label.
+ *
+ * Assumes: Uncurry has been run already, and no multiple parameter
+ * lists exit.
+ */
+ class TailCallElimination(unit: CompilationUnit) extends Transformer {
+
+ class Context {
+ /** The current method */
+ var currentMethod: Symbol = NoSymbol;
+
+ /** The current tail-call label */
+ var label: Symbol = NoSymbol;
+
+ /** The expected type arguments of self-recursive calls */
+ var types: List[Type] = Nil;
+
+ /** Tells whether we are in a (possible) tail position */
+ var tailPos = false;
+
+ /** Is the label accessed? */
+ var accessed = false;
+
+ def this(that: Context) = {
+ this();
+ this.currentMethod = that.currentMethod;
+ this.label = that.label;
+ this.types = that.types;
+ this.tailPos = that.tailPos;
+ this.accessed = that.accessed;
+ }
+
+ /** Create a new method symbol for the current method and store it in
+ * the label field.
+ */
+ def makeLabel(): Unit = {
+ label = currentMethod.newLabel(currentMethod.pos, "_" + currentMethod.name);
+ }
+
+ override def toString(): String = {
+ "" + currentMethod.name + " types: " + types + " tailPos: " + tailPos +
+ " accessed: " + accessed;
+ }
+ }
+
+ private def mkContext(that: Context) = new Context(that);
+ private def mkContext(that: Context, tp: Boolean): Context = {
+ val t = mkContext(that);
+ t.tailPos = tp;
+ t
+ }
+
+ private var ctx: Context = new Context();
+
+ /** Rewrite this tree to contain no tail recursive calls */
+ def transform(tree: Tree, nctx: Context): Tree = {
+ val oldCtx = ctx;
+ ctx = nctx;
+ val t = transform(tree);
+ this.ctx = oldCtx;
+ t
+ }
+
+ override def transform(tree: Tree): Tree = {
+ tree match {
+
+ case DefDef(mods, name, tparams, vparams, tpt, rhs) =>
+ log("Entering DefDef: " + name);
+ val newCtx = mkContext(ctx);
+ newCtx.currentMethod = tree.symbol;
+ newCtx.makeLabel();
+ newCtx.label.setInfo(tree.symbol.info);
+ newCtx.tailPos = true;
+ log("Label type is: " + newCtx.label.info);
+
+ if (newCtx.currentMethod.isFinal) {
+ newCtx.types = Nil;
+
+ newCtx.currentMethod.tpe match {
+ case PolyType(tparams, result) =>
+ newCtx.types = tparams map (s => s.tpe);
+ case _ => ;
+ }
+
+ var newRHS = transform(rhs, newCtx);
+ if (newCtx.accessed) {
+ log("Rewrote def " + newCtx.currentMethod);
+ vparams.head map ((v) => log(v.symbol));
+ newRHS =
+ typed(atPos(tree.pos)(
+ LabelDef(newCtx.label, vparams.head map (.symbol), newRHS)));
+ copy.DefDef(tree, mods, name, tparams, vparams, tpt, newRHS);
+ } else
+ copy.DefDef(tree, mods, name, tparams, vparams, tpt, newRHS);
+ } else {
+ log("Non-final method: " + name);
+ DefDef(newCtx.currentMethod, (x) => transform(rhs, newCtx));
+ }
+
+ case EmptyTree => tree;
+
+ case PackageDef(name, stats) => super.transform(tree);
+ case ClassDef(mods, name, tparams, tpt, impl) => super.transform(tree);
+
+ case ValDef(mods, name, tpt, rhs) => tree;
+ case AbsTypeDef(mods, name, lo, hi) => tree; // (eliminated by erasure)
+ case AliasTypeDef(mods, name, tparams, rhs) => tree; // (eliminated by erasure)
+ case LabelDef(name, params, rhs) => super.transform(tree);
+
+ case Template(parents, body) => super.transform(tree);
+
+ case Block(stats, expr) =>
+ copy.Block(tree,
+ transformTrees(stats, mkContext(ctx, false)),
+ transform(expr));
+
+ case CaseDef(pat, guard, body) =>
+ copy.CaseDef(tree, pat, guard, transform(body));
+
+ case Sequence(_) | Alternative(_) |
+ Star(_) | Bind(_, _) =>
+ throw new RuntimeException("We should've never gotten inside a pattern");
+
+ case Function(vparams, body) =>
+ throw new RuntimeException("Anonymous function should not exist at this point");
+
+ case Assign(lhs, rhs) => super.transform(tree);
+ case If(cond, thenp, elsep) =>
+ log("entering If: " + ctx);
+ copy.If(tree, cond, transform(thenp), transform(elsep));
+
+ case Match(selector, cases) => super.transform(tree);
+ case Return(expr) => super.transform(tree);
+ case Try(block, catches, finalizer) => super.transform(tree);
+
+ case Throw(expr) => super.transform(tree);
+ case New(tpt) => super.transform(tree);
+ case Typed(expr, tpt) => super.transform(tree);
+
+ case TypeApply(fun, args) =>
+ throw new RuntimeException("Lonely TypeApply found -- we can only handle them inside Apply(TypeApply())");
+
+ case Apply(fun, args) =>
+ log("entering Apply: " + ctx);
+ if ((fun.symbol eq ctx.currentMethod) && ctx.tailPos)
+ rewriteTailCall(tree, transformTrees(args, mkContext(ctx, false)));
+ else
+ copy.Apply(tree, fun, transformTrees(args, mkContext(ctx, false)));
+
+ case Super(qual, mixin) =>
+ tree;
+ case This(qual) =>
+ tree;
+ case Select(qualifier, selector) =>
+ tree;
+ case Ident(name) =>
+ tree;
+ case Literal(value) =>
+ tree;
+ case TypeTree() => tree;
+
+// case Block(List(), expr) => // a simple optimization
+// expr
+// case Block(defs, sup @ Super(qual, mix)) => // A hypthothetic transformation, which replaces
+// // {super} by {super.sample}
+// copy.Block( // `copy' is the usual lazy tree copier
+// tree, defs,
+// typed( // `typed' assigns types to its tree argument
+// atPos(tree.pos)( // `atPos' fills in position of its tree argument
+// Select( // The `Select' factory method is defined in class `Trees'
+// sup,
+// currentOwner.newValue( // creates a new term symbol owned by `currentowner'
+// tree.pos,
+// newTermName("sample")))))) // The standard term name creator
+ case _ =>
+ tree
+ }
+ }
+
+ def transformTrees(trees: List[Tree], nctx: Context): List[Tree] =
+ trees map ((tree) => transform(tree, nctx));
+
+ private def rewriteTailCall(tree: Tree, args: List[Tree]): Tree = {
+ log("Rewriting..");
+ ctx.accessed = true;
+ typed(atPos(tree.pos)(
+ Apply(Ident(ctx.label), args)));
+ }
+ }
+}