/* NSC -- new scala compiler
* Copyright 2005-2007 LAMP/EPFL
* @author Iulian Dragos
*/
// $Id$
package scala.tools.nsc.transform
import scala.tools.nsc.symtab.Flags
/** Perform tail recursive call elimination.
*
* @author Iulian Dragos
* @version 1.0
*/
abstract class TailCalls extends Transform
/* with JavaLogging() */ {
// inherits abstract value `global' and class `Phase' from Transform
import global._ // the global environment
import definitions._ // standard classes and methods
import typer.{typed, atOwner} // methods to type trees
import posAssigner.atPos // for filling in tree positions
val phaseName: String = "tailcalls"
def newTransformer(unit: CompilationUnit): Transformer =
new TailCallElimination(unit)
/** Create a new phase which applies transformer */
override def newPhase(prev: scala.tools.nsc.Phase): StdPhase = new Phase(prev)
/** The phase defined by this transform */
class Phase(prev: scala.tools.nsc.Phase) extends StdPhase(prev) {
def apply(unit: global.CompilationUnit) {
if (!(settings.debuginfo.value == "notailcalls")) {
newTransformer(unit).transformUnit(unit);
}
}
}
/**
* A Tail Call Transformer
*
* @author Erik Stenman, Iulian Dragos
* @version 1.1
*
* What it does:
* <p>
* Finds method calls in tail-position and replaces them with jumps.
* A call is in a tail-position if it is the last instruction to be
* executed in the body of a method. This is done by recursing over
* the trees that may contain calls in tail-position (trees that can't
* contain such calls are not transformed). However, they are not that
* many.
* </p>
* <p>
* Self-recursive calls in tail-position are replaced by jumps to a
* label at the beginning of the method. As the JVM provides no way to
* jump from a method to another one, non-recursive calls in
* tail-position are not optimized.
* </p>
* <p>
* A method call is self-recursive if it calls the current method and
* the method is final (otherwise, it could
* be a call to an overridden method in a subclass). Furthermore, If
* the method has type parameters, the call must contain these
* parameters as type arguments. Recursive calls on a different instance
* are optimized. Since 'this' is not a local variable, a dummy local val
* is added and used as a label parameter. The backend knows to load
* the corresponding argument in the 'this' (local at index 0). This dummy local
* is never used and should be cleand up by dead code elmination (when enabled).
* </p>
* <p>
* This phase has been moved before pattern matching to catch more
* of the common cases of tail recursive functions. This means that
* more cases should be taken into account (like nested function, and
* pattern cases).
* </p>
* <p>
* If a method contains self-recursive calls, a label is added to at
* the beginning of its body and the calls are replaced by jumps to
* that label.
* </p>
* <p>
* Assumes: <code>Uncurry</code> has been run already, and no multiple
* parameter lists exit.
* </p>
*/
class TailCallElimination(unit: CompilationUnit) extends Transformer {
class Context {
/** The current method */
var currentMethod: Symbol = NoSymbol
/** The current tail-call label */
var label: Symbol = NoSymbol
/** The expected type arguments of self-recursive calls */
var tparams: List[Symbol] = Nil
/** Tells whether we are in a (possible) tail position */
var tailPos = false
/** Is the label accessed? */
var accessed = false
def this(that: Context) = {
this()
this.currentMethod = that.currentMethod
this.label = that.label
this.tparams = that.tparams
this.tailPos = that.tailPos
this.accessed = that.accessed
}
/** Create a new method symbol for the current method and store it in
* the label field.
*/
def makeLabel(): Unit = {
label = currentMethod.newLabel(currentMethod.pos, "_" + currentMethod.name)
accessed = false
}
override def toString(): String = (
"" + currentMethod.name + " tparams: " + tparams + " tailPos: " + tailPos +
" accessed: " + accessed + "\nLabel: " + label + "\nLabel type: " + label.info
)
}
private def mkContext(that: Context) = new Context(that)
private def mkContext(that: Context, tp: Boolean): Context = {
val t = mkContext(that)
t.tailPos = tp
t
}
private var ctx: Context = new Context()
/** Rewrite this tree to contain no tail recursive calls */
def transform(tree: Tree, nctx: Context): Tree = {
val oldCtx = ctx
ctx = nctx
val t = transform(tree)
this.ctx = oldCtx
t
}
override def transform(tree: Tree): Tree = {
tree match {
case DefDef(mods, name, tparams, vparams, tpt, rhs) =>
log("Entering DefDef: " + name)
val newCtx = mkContext(ctx)
newCtx.currentMethod = tree.symbol
newCtx.makeLabel()
newCtx.label.setInfo(MethodType(currentClass.tpe :: tree.symbol.tpe.paramTypes, tree.symbol.tpe.finalResultType))
newCtx.tailPos = true
val t1 = if (newCtx.currentMethod.isFinal ||
newCtx.currentMethod.enclClass.hasFlag(Flags.MODULE)) {
newCtx.tparams = Nil
log(" Considering " + name + " for tailcalls")
tree.symbol.tpe match {
case PolyType(tpes, restpe) =>
newCtx.tparams = tparams map (_.symbol)
newCtx.label.setInfo(
newCtx.label.tpe.substSym(tpes, tparams map (_.symbol)))
case _ => ()
}
//println("label.tpe: " + newCtx.label.tpe)
var newRHS = transform(rhs, newCtx);
if (newCtx.accessed) {
log("Rewrote def " + newCtx.currentMethod)
val newThis = newCtx.currentMethod.newValue(tree.pos, nme.THIS)
.setInfo(currentClass.tpe)
.setFlag(Flags.SYNTHETIC)
newRHS =
typed(atPos(tree.pos)(Block(List(
ValDef(newThis, This(currentClass))),
LabelDef(newCtx.label,
newThis :: (List.flatten(vparams) map (_.symbol)),
newRHS))));
copy.DefDef(tree, mods, name, tparams, vparams, tpt, newRHS);
} else
copy.DefDef(tree, mods, name, tparams, vparams, tpt, newRHS);
} else {
copy.DefDef(tree, mods, name, tparams, vparams, tpt, transform(rhs, newCtx))
}
log("Leaving DefDef: " + name)
t1
case EmptyTree => tree
case PackageDef(name, stats) =>
super.transform(tree)
case ClassDef(_, name, _, _) =>
log("Entering class " + name)
val res = super.transform(tree)
log("Leaving class " + name)
res
case ValDef(mods, name, tpt, rhs) => super.transform(tree)
case LabelDef(name, params, rhs) => super.transform(tree)
case Template(parents, self, body) =>
super.transform(tree)
case Block(stats, expr) =>
copy.Block(tree,
transformTrees(stats, mkContext(ctx, false)),
transform(expr))
case CaseDef(pat, guard, body) =>
copy.CaseDef(tree, pat, guard, transform(body))
case Sequence(_) | Alternative(_) |
Star(_) | Bind(_, _) =>
throw new RuntimeException("We should've never gotten inside a pattern")
case Function(vparams, body) =>
tree
//throw new RuntimeException("Anonymous function should not exist at this point. at: " + unit.position(tree.pos));
case Assign(lhs, rhs) =>
super.transform(tree)
case If(cond, thenp, elsep) =>
copy.If(tree, cond, transform(thenp), transform(elsep))
case Match(selector, cases) => //super.transform(tree);
copy.Match(tree, transform(selector, mkContext(ctx, false)), transformTrees(cases).asInstanceOf[List[CaseDef]])
case Return(expr) => super.transform(tree)
case Try(block, catches, finalizer) =>
// no calls inside try are in tail position, but keep recursing for more nested functions
copy.Try(tree, transform(block, mkContext(ctx, false)),
transformTrees(catches, mkContext(ctx, false)).asInstanceOf[List[CaseDef]],
transform(finalizer, mkContext(ctx, false)))
case Throw(expr) => super.transform(tree)
case New(tpt) => super.transform(tree)
case Typed(expr, tpt) => super.transform(tree)
case Apply(tapply @ TypeApply(fun, targs), vargs) =>
lazy val defaultTree = copy.Apply(tree, tapply, transformTrees(vargs, mkContext(ctx, false)))
if ( ctx.currentMethod.isFinal &&
ctx.tailPos &&
isSameTypes(ctx.tparams, targs map (_.tpe.typeSymbol)) &&
isRecursiveCall(fun)) {
fun match {
case Select(receiver, _) => if (!forMSIL) rewriteTailCall(fun, receiver :: transformTrees(vargs, mkContext(ctx, false))) else defaultTree
case _ => rewriteTailCall(fun, This(currentClass) :: transformTrees(vargs, mkContext(ctx, false)))
}
} else
defaultTree
case TypeApply(fun, args) =>
super.transform(tree)
case Apply(fun, args) if (fun.symbol == definitions.Boolean_or ||
fun.symbol == definitions.Boolean_and) =>
copy.Apply(tree, fun, transformTrees(args))
case Apply(fun, args) =>
lazy val defaultTree = copy.Apply(tree, fun, transformTrees(args, mkContext(ctx, false)))
if (ctx.currentMethod.isFinal &&
ctx.tailPos &&
isRecursiveCall(fun)) {
fun match {
case Select(receiver, _) => if (!forMSIL) rewriteTailCall(fun, receiver :: transformTrees(args, mkContext(ctx, false))) else defaultTree
case _ => rewriteTailCall(fun, This(currentClass) :: transformTrees(args, mkContext(ctx, false)))
}
} else
defaultTree
case Super(qual, mix) =>
tree
case This(qual) =>
tree
case Select(qualifier, selector) =>
tree
case Ident(name) =>
tree
case Literal(value) =>
tree
case TypeTree() =>
tree
case _ =>
tree
}
}
def transformTrees(trees: List[Tree], nctx: Context): List[Tree] =
trees map ((tree) => transform(tree, nctx))
private def rewriteTailCall(fun: Tree, args: List[Tree]): Tree = {
log("Rewriting tail recursive method call at: " +
(fun.pos))
ctx.accessed = true
//println("fun: " + fun + " args: " + args)
typed(atPos(fun.pos)(
Apply(Ident(ctx.label), args)))
}
private def isSameTypes(ts1: List[Symbol], ts2: List[Symbol]): Boolean = {
def isSameType(t1: Symbol, t2: Symbol) = {
t1 == t2
}
List.forall2(ts1, ts2)(isSameType)
}
/** Returns <code>true</code> if the fun tree refers to the same method as
* the one saved in <code>ctx</code>.
*
* @param fun the expression that is applied
* @return <code>true</code> if the tree symbol refers to the innermost
* enclosing method
*/
private def isRecursiveCall(fun: Tree): Boolean =
(fun.symbol eq ctx.currentMethod)
}
}