summaryrefslogblamecommitdiff
path: root/sources/scala/tools/nsc/transform/TailCalls.scala
blob: eb4296631251a3414245edc1a18152851fcde27f (plain) (tree)


















































































































































































































































                                                                                                                    
/* NSC -- new scala compiler
 * Copyright 2005 LAMP/EPFL
 * @author
 */
// $Id$
package scala.tools.nsc.transform;

/** Perform tail recursive call elimination.
 */
abstract class TailCalls extends Transform {
  // inherits abstract value `global' and class `Phase' from Transform

  import global._;                  // the global environment
  import definitions._;             // standard classes and methods
  import typer.{typed, atOwner};               // methods to type trees
  import posAssigner.atPos;         // for filling in tree positions

  protected val phaseName: String = "tailcalls";
  protected def newTransformer(unit: CompilationUnit): Transformer = new TailCallElimination(unit);

  /**
   * A Tail Call Transformer
   *
   * @author     Erik Stenman, Iulian Dragos
   * @version    1.1
   *
   * What it does:
   *
   * Finds method calls in tail-position and replaces them with jumps.
   * A call is in a tail-position if it is the last instruction to be
   * executed in the body of a method.  This is done by recursing over
   * the trees that may contain calls in tail-position (trees that can't
   * contain such calls are not transformed). However, they are not that
   * many.
   *
   * Self-recursive calls in tail-position are replaced by jumps to a
   * label at the beginning of the method. As the JVM provides no way to
   * jump from a method to another one, non-recursive calls in
   * tail-position are not optimized.
   *
   * A method call is self-recursive if it calls the current method on
   * the current instance and the method is final (otherwise, it could
   * be a call to an overridden method in a subclass). Furthermore, If
   * the method has type parameters, the call must contain these
   * parameters as type arguments.
   *
   * This phase has been moved before pattern matching to catch more
   * of the common cases of tail recursive functions. This means that
   * more cases should be taken into account (like nested function, and
   * pattern cases).
   *
   * If a method contains self-recursive calls, a label is added to at
   * the beginning of its body and the calls are replaced by jumps to
   * that label.
   *
   * Assumes: Uncurry has been run already, and no multiple parameter
   *          lists exit.
   */
  class TailCallElimination(unit: CompilationUnit) extends Transformer {

    class Context {
      /** The current method */
      var currentMethod: Symbol = NoSymbol;

      /** The current tail-call label */
      var label: Symbol = NoSymbol;

      /** The expected type arguments of self-recursive calls */
      var types: List[Type] = Nil;

      /** Tells whether we are in a (possible) tail position */
      var tailPos = false;

      /** Is the label accessed? */
      var accessed = false;

      def this(that: Context) = {
        this();
        this.currentMethod = that.currentMethod;
        this.label         = that.label;
        this.types         = that.types;
        this.tailPos       = that.tailPos;
        this.accessed    = that.accessed;
      }

      /** Create a new method symbol for the current method and store it in
        * the label field.
        */
      def makeLabel(): Unit = {
        label = currentMethod.newLabel(currentMethod.pos, "_" + currentMethod.name);
      }

      override def toString(): String = {
        "" + currentMethod.name + " types: " + types + " tailPos: " + tailPos +
        " accessed: " + accessed;
      }
    }

    private def mkContext(that: Context) = new Context(that);
    private def mkContext(that: Context, tp: Boolean): Context = {
      val t = mkContext(that);
      t.tailPos = tp;
      t
    }

    private var ctx: Context = new Context();

    /** Rewrite this tree to contain no tail recursive calls */
    def transform(tree: Tree, nctx: Context): Tree = {
      val oldCtx = ctx;
      ctx = nctx;
      val t = transform(tree);
      this.ctx = oldCtx;
      t
    }

    override def transform(tree: Tree): Tree = {
      tree match {

        case DefDef(mods, name, tparams, vparams, tpt, rhs) =>
          log("Entering DefDef: " + name);
          val newCtx = mkContext(ctx);
          newCtx.currentMethod = tree.symbol;
          newCtx.makeLabel();
          newCtx.label.setInfo(tree.symbol.info);
          newCtx.tailPos = true;
          log("Label type is: " + newCtx.label.info);

          if (newCtx.currentMethod.isFinal) {
            newCtx.types = Nil;

            newCtx.currentMethod.tpe match {
              case PolyType(tparams, result) =>
                newCtx.types = tparams map (s => s.tpe);
              case _ => ;
            }

            var newRHS = transform(rhs, newCtx);
            if (newCtx.accessed) {
              log("Rewrote def " + newCtx.currentMethod);
              vparams.head map ((v) => log(v.symbol));
              newRHS =
                  typed(atPos(tree.pos)(
                    LabelDef(newCtx.label, vparams.head map (.symbol), newRHS)));
              copy.DefDef(tree, mods, name, tparams, vparams, tpt, newRHS);
            } else
              copy.DefDef(tree, mods, name, tparams, vparams, tpt, newRHS);
          } else {
            log("Non-final method: " + name);
            DefDef(newCtx.currentMethod, (x) => transform(rhs, newCtx));
          }

        case EmptyTree => tree;

        case PackageDef(name, stats) => super.transform(tree);
        case ClassDef(mods, name, tparams, tpt, impl) => super.transform(tree);

        case ValDef(mods, name, tpt, rhs) => tree;
        case AbsTypeDef(mods, name, lo, hi) => tree; //  (eliminated by erasure)
        case AliasTypeDef(mods, name, tparams, rhs) => tree; // (eliminated by erasure)
        case LabelDef(name, params, rhs) => super.transform(tree);

        case Template(parents, body) => super.transform(tree);

        case Block(stats, expr) =>
          copy.Block(tree,
                     transformTrees(stats, mkContext(ctx, false)),
                     transform(expr));

        case CaseDef(pat, guard, body) =>
          copy.CaseDef(tree, pat, guard, transform(body));

        case Sequence(_) | Alternative(_) |
             Star(_)     | Bind(_, _) =>
               throw new RuntimeException("We should've never gotten inside a pattern");

        case Function(vparams, body) =>
          throw new RuntimeException("Anonymous function should not exist at this point");

        case Assign(lhs, rhs) => super.transform(tree);
        case If(cond, thenp, elsep) =>
          log("entering If: " + ctx);
          copy.If(tree, cond, transform(thenp), transform(elsep));

        case Match(selector, cases) => super.transform(tree);
        case Return(expr) =>           super.transform(tree);
        case Try(block, catches, finalizer) => super.transform(tree);

        case Throw(expr) => super.transform(tree);
        case New(tpt) => super.transform(tree);
        case Typed(expr, tpt) => super.transform(tree);

        case TypeApply(fun, args) =>
          throw new RuntimeException("Lonely TypeApply found -- we can only handle them inside Apply(TypeApply())");

        case Apply(fun, args) =>
          log("entering Apply: " + ctx);
          if ((fun.symbol eq ctx.currentMethod) && ctx.tailPos)
            rewriteTailCall(tree, transformTrees(args, mkContext(ctx, false)));
          else
            copy.Apply(tree, fun, transformTrees(args, mkContext(ctx, false)));

        case Super(qual, mixin) =>
          tree;
        case This(qual) =>
          tree;
        case Select(qualifier, selector) =>
          tree;
        case Ident(name) =>
          tree;
        case Literal(value) =>
          tree;
        case TypeTree() =>  tree;

//         case Block(List(), expr) =>           // a simple optimization
//           expr
//         case Block(defs, sup @ Super(qual, mix)) => // A hypthothetic transformation, which replaces
//                                                     // {super} by {super.sample}
//           copy.Block(                           // `copy' is the usual lazy tree copier
//             tree, defs,
//             typed(                              // `typed' assigns types to its tree argument
//               atPos(tree.pos)(                 // `atPos' fills in position of its tree argument
//                 Select(                         // The `Select' factory method is defined in class `Trees'
//                   sup,
//                   currentOwner.newValue(        // creates a new term symbol owned by `currentowner'
//                     tree.pos,
//                     newTermName("sample"))))))  // The standard term name creator
        case _ =>
          tree
      }
    }

    def transformTrees(trees: List[Tree], nctx: Context): List[Tree] =
      trees map ((tree) => transform(tree, nctx));

    private def rewriteTailCall(tree: Tree, args: List[Tree]): Tree = {
      log("Rewriting..");
      ctx.accessed = true;
        typed(atPos(tree.pos)(
          Apply(Ident(ctx.label), args)));
    }
  }
}