summaryrefslogblamecommitdiff
path: root/sources/scalac/transformer/TailCallPhase.java
blob: 4798e22c1a5ace4f2133f1322ef0eaf4516b9d17 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13












                                                                          
                              

                                 


                            
 











                                                                      

                                                                      











                                                                      


                                                                    



                                                                    

                                          
































                                                                                         










                                                                              


                                                            

     





                                                                               
 

                                            
 















                                                                           





                                                 











                                                                                     

                                                                 
                                                         

                                      
                                                                             
                                         


                                                                                     
                     






                                                                     
                 
                             

                            
                                                 


                                                                    
                                                                    
 
 



                                                       
                                                            
 





                                                                               





                                                                                                    
                                                                        










                                                                         
                                               



















                                                                                         
 




                                             
 









                                    



                                          








                                                                   

                                           

                                      










                                                                                                













                                                                               







                                                                              
 
/*     ____ ____  ____ ____  ______                                     *\
**    / __// __ \/ __// __ \/ ____/    SOcos COmpiles Scala             **
**  __\_ \/ /_/ / /__/ /_/ /\_ \       (c) 2002, LAMP/EPFL              **
** /_____/\____/\___/\____/____/                                        **
\*                                                                      */

// $Id$

package scalac.transformer;

import scalac.Global;
import scalac.Phase;
import scalac.PhaseDescriptor;
import scalac.CompilationUnit;
import scalac.ast.Tree;
import scalac.ast.GenTransformer;
import scalac.symtab.Symbol;
import scalac.symtab.Type;
import scalac.util.Debug;

/**
 * A Tail Call Transformer
 *
 * @author     Erik Stenman
 * @version    1.0
 *
 * What it does:
 *
 * Finds method calls in tail-position and replaces them with jumps.
 * A call is in a tail-position if it is the last instruction to be
 * executed in the body of a method.  This is done by recursing over
 * the trees that may contain calls in tail-position (trees that can't
 * contain such calls are not transformed). However, they are not that
 * many.
 *
 * Self-recursive calls in tail-position are replaced by jumps to a
 * label at the beginning of the method. As the JVM provides no way to
 * jump from a method to another one, non-recursive calls in
 * tail-position are not optimized.
 *
 * A method call is self-recursive if it calls the current method on
 * the current instance and the method is final (otherwise, it could
 * be a call to an overridden method in a subclass). Furthermore, If
 * the method has type parameters, the call must contain these
 * parameters as type arguments.
 *
 * Nested functions can be tail recursive (if this phase runs before
 * lambda lift) and they are transformed as well.
 *
 * If a method contains self-recursive calls, a label is added to at
 * the beginning of its body and the calls are replaced by jumps to
 * that label.
 */
public class TailCallPhase extends Phase {

    private static class Context {
        /** The current method */
        public Symbol method = Symbol.NONE;

        /** The current tail-call label */
        public Symbol label = Symbol.NONE;

        /** The expected type arguments of self-recursive calls */
        public Type[] types;

        /** Tells whether we are in a tail position. */
        public boolean tailPosition;

        public Context() {
            this.tailPosition = false;
        }

        /**
         * The Context of this transformation. It contains the current enclosing method,
         * the label and the type parameters of the enclosing method, together with a
         * flag which says whether our position in the tree allows tail calls. This is
         * modified only in <code>Block</code>, where stats cannot possibly contain tail
         * calls, but the value can.
         */
        public Context(Symbol method, Symbol label, Type[] types, boolean tailPosition) {
            this.method = method;
            this.label  = label;
            this.types  = types;
            this.tailPosition = tailPosition;
        }
    }


    //########################################################################
    // Public Constructors

    /** Initializes this instance. */
    public TailCallPhase(Global global, PhaseDescriptor descriptor) {
        super(global, descriptor);
    }

   //########################################################################
    // Public Methods

    /** Applies this phase to the given compilation unit. */
    public void apply(CompilationUnit unit) {
        treeTransformer.apply(unit);
    }

   //########################################################################
    // Private Classes

    /** The tree transformer */
    private final GenTransformer treeTransformer = new GenTransformer(global) {


        /** The context of this call */
        private Context ctx = new Context();

        /** Transform the given tree, which is (or not) in tail position */
        public Tree transform(Tree tree, boolean tailPos) {
            boolean oldTP = ctx.tailPosition;
            ctx.tailPosition = tailPos;
            tree = transform(tree);
            ctx.tailPosition = oldTP;
            return tree;
        }

        public Tree[] transform(Tree[] trees, boolean tailPos) {
            boolean oldTP = ctx.tailPosition;
            ctx.tailPosition = tailPos;
            trees = transform(trees);
            ctx.tailPosition = oldTP;
            return trees;
        }

        /** Transforms the given tree. */
        public Tree transform(Tree tree) {
            switch (tree) {

            case DefDef(_, _, _, _, _, Tree rhs):
                Context oldCtx = ctx;

                ctx = new Context();

                //assert method == null: Debug.show(method) + " -- " + tree;
                ctx.method = tree.symbol();
                ctx.tailPosition = true;

                if (ctx.method.isMethodFinal() || ctx.method.owner().isMethod()) {
                    ctx.label = ctx.method.newLabel(ctx.method.pos, ctx.method.name);
                    ctx.types = Type.EMPTY_ARRAY;
                    Type type = ctx.method.type();
                    switch (type) {
                    case PolyType(Symbol[] tparams, Type result):
                        ctx.types = Symbol.type(tparams);
                        type = result;
                    }
                    ctx.label.setInfo(type.cloneType(ctx.method, ctx.label));
                    rhs = transform(rhs);
                    if (ctx.label.isAccessed()) {
                        global.log("Rewriting def " + ctx.method.simpleName());
                        rhs = gen.LabelDef(ctx.label, ctx.method.valueParams(), rhs);
                    }
                    tree = gen.DefDef(ctx.method, rhs);
                } else {
                    assert !ctx.method.isMethodFinal()
                        : "Final method: " + ctx.method.simpleName();
                    // non-final method
                    ctx.tailPosition = false;
                    tree = gen.DefDef(tree.symbol(), transform(rhs));
                }
                ctx = oldCtx;
                return tree;

            case Block(Tree[] stats, Tree value):
                boolean oldPosition = ctx.tailPosition;
                ctx.tailPosition = false;  stats = transform(stats);
                ctx.tailPosition = oldPosition;
                return gen.Block(tree.pos, stats, transform(value));


            case If(Tree cond, Tree thenp, Tree elsep):
                Type type = tree.type();
                thenp = transform(thenp);
                elsep = transform(elsep);
                return gen.If(tree.pos, cond, thenp, elsep);

            case Switch(Tree test, int[] tags, Tree[] bodies, Tree otherwise):
                Type type = tree.type();
                bodies = transform(bodies);
                otherwise = transform(otherwise);
                return gen.Switch(tree.pos, test, tags, bodies,otherwise,type);

            // handle pattern matches explicitly
	    case Apply(Select(_, scalac.util.Names._match), Tree[] args):
		Tree newTree = global.make.Apply(tree.pos, ((Tree.Apply)tree).fun, transform(args));
		newTree.setType(tree.getType());
		return newTree;

            case Apply(TypeApply(Tree fun, Tree[] targs), Tree[] vargs):
		if (ctx.method != Symbol.NONE && ctx.tailPosition) {
		    assert targs != null : "Null type arguments " + tree;
		    assert ctx.types != null : "Null types " + tree;

		    if (!Type.isSameAs(Tree.typeOf(targs), ctx.types) |
                        !ctx.tailPosition)
                        return tree;
		    return transform(tree, fun, transform(vargs, false));
		} else
		    return tree;

            case Apply(Tree fun, Tree[] vargs):
                if (ctx.tailPosition)
                    return transform(tree, fun, transform(vargs, false));
                else
                    return gen.mkApply_V(fun, transform(vargs, false));

	    case Visitor(Tree.CaseDef[] cases):
		Tree newTree = global.make.Visitor(tree.pos, super.transform(cases));
		newTree.setType(tree.getType());
		return newTree;

	    case CaseDef(Tree pattern, Tree guard, Tree body):
		return gen.CaseDef(pattern, guard, transform(body));

	    case Typed(Tree expr, Tree type):
		return gen.Typed(transform(expr), type);

            case ClassDef(_, _, _, _, _, Tree.Template impl):
                Symbol impl_symbol = getSymbolFor(impl);
                Tree[] body = transform(impl.body);
                return gen.ClassDef(getSymbolFor(tree), impl.parents, impl_symbol, body);

            case PackageDef(_, _):
            case LabelDef(_, _, _):
            case Return(_):
                return super.transform(tree);


            case Empty:
            case ValDef(_, _, _, _):
            case Assign(_, _):
            case New(_):
            case Super(_, _):
            case This(_):
            case Select(_, _):
            case Ident(_):
            case Literal(_):
            case TypeTerm():
	    case AbsTypeDef(_, _, _, _):
	    case AliasTypeDef(_, _, _, _):
	    case Import(_, _):
	    case Function(_, _):
                return tree;

            default:
                throw Debug.abort("illegal case", tree);
            }
        }

        /** Transforms the given function call. */
        private Tree transform(Tree tree, Tree fun, Tree[] vargs) {
            if (fun.symbol() != ctx.method)
		return tree;
            switch (fun) {
            case Select(Tree qual, _):
                if (!isReferenceToThis(qual, ctx.method.owner()))
		    return tree;
                global.log("Applying tail call recursion elimination for " +
                           ctx.method.enclClass().simpleName() + "." + ctx.method.simpleName());
                return gen.Apply(tree.pos, gen.Ident(qual.pos, ctx.label), vargs);

            case Ident(_):
                global.log("Applying tail call recursion elimination for function " +
                           ctx.method.enclClass().simpleName() + "." + ctx.method.simpleName());
                return gen.Apply(tree.pos, gen.Ident(fun.pos, ctx.label), vargs);

            default:
                throw Debug.abort("illegal case", fun);
            }
        }

        /**
         * Returns true if the tree represents the current instance of
         * given class.
         */
        private boolean isReferenceToThis(Tree tree, Symbol clasz) {
            switch (tree) {
            case This(_):
                assert tree.symbol() == clasz: tree +" -- "+ Debug.show(clasz);
                return true;
            default:
                return false;
            }
        }

    };

    //########################################################################
}