summaryrefslogblamecommitdiff
path: root/sources/scala/tools/nsc/transform/LambdaLift.scala
blob: 2cedb71a9054af0626d8990fb0af05534a3ac58a (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15














                                                 

                                      















































































































                                                                                              





                                                    




                                                        






                                                                                          



















































                                                                                               
                                                                                                                                         













                                                                                                                                             










                                                                           

                                         
                                                            

























                                                                                                
                                      

                                                         
                                    
                                           
                                                                           







































                                                                                   





                                                                                                  






























                                                                                                   
/* NSC -- new scala compiler
 * Copyright 2005 LAMP/EPFL
 * @author
 */
// $Id$
package scala.tools.nsc.transform;

import symtab._;
import Flags._;
import util.{ListBuffer, TreeSet};
import collection.mutable.HashMap;

abstract class LambdaLift extends InfoTransform {
  import global._;
  import definitions._;
  import typer.{typed, typedOperator};
  import posAssigner.atPos;

  /** the following two members override abstract members in Transform */
  val phaseName: String = "lambdalift";

  private val lifted = new TypeMap {
    def apply(tp: Type): Type = tp match {
      case TypeRef(pre, sym, args) =>
	if (pre == NoPrefix && sym.isClass && !sym.isPackageClass) {
          assert(args.isEmpty);
          typeRef(apply(sym.owner.enclClass.thisType), sym, args)
	} else mapOver(tp)
      case ClassInfoType(parents, decls, clazz) =>
	val parents1 = List.mapConserve(parents)(this);
	if (parents1 eq parents) tp
	else ClassInfoType(parents1, decls, clazz)
      case _ =>
        mapOver(tp)
    }
  }

  def transformInfo(sym: Symbol, tp: Type): Type = lifted(tp);

  protected def newTransformer(unit: CompilationUnit): Transformer =
    new LambdaLifter(unit);

  class LambdaLifter(unit: CompilationUnit) extends explicitOuter.OuterPathTransformer {

    /** A map storing free variables of functions and classes */
    private val free = new HashMap[Symbol,SymSet];

    /** A map storing the free variable proxies of functions and classes */
    private val proxies = new HashMap[Symbol, List[Symbol]];

    /** A hashtable storing calls between functions */
    private val called = new HashMap[Symbol, SymSet];

    /** The set of symbols that need to be renamed. */
    private val renamable = newSymSet;

    /** A flag to indicate whether new free variables have been found */
    private var changedFreeVars: boolean = _;

    /** Buffers for lifted out classes and methods */
    private val liftedDefs = new HashMap[Symbol, ListBuffer[Tree]];

    private type SymSet = TreeSet[Symbol];

    private def newSymSet = new TreeSet[Symbol]((x, y) => x.isLess(y));

    private def symSet(f: HashMap[Symbol, SymSet], sym: Symbol): SymSet = f.get(sym) match {
      case Some(ss) => ss
      case None => val ss = newSymSet; f(sym) = ss; ss
    }

    private def outer(sym: Symbol): Symbol =
      if (sym.isConstructor) sym.owner.owner else sym.owner;

    private def enclMethOrClass(sym: Symbol): Symbol = {
      def localToConstr(sym: Symbol) =
	if (sym.isLocalDummy) sym.owner.primaryConstructor else sym;
      var encl = localToConstr(sym);
      while (!encl.isMethod && !encl.isClass)
        encl = localToConstr(outer(encl));
      encl
    }

    /** Mark symbol `sym' as being free in `owner', unless `sym'
     *  is defined in `owner' or there is a class between `owner's owner
     *  and the owner of `sym'.
     *  Return `true' if there is no class between `owner' and
     *  the owner of sym.
     *  pre: sym.isLocal, (owner.isMethod || owner.isClass)
     */
    private def markFree(sym: Symbol, owner: Symbol): boolean = {
      if (settings.debug.value) log("mark " + sym + " of " + sym.owner + " free in " + owner);
      if (owner == enclMethOrClass(sym.owner)) true
      else if (owner.isPackageClass || !markFree(sym, enclMethOrClass(outer(owner)))) false
      else {
	val ss = symSet(free, owner);
	if (!(ss contains sym)) {
	  ss addEntry sym;
	  renamable addEntry sym;
	  changedFreeVars = true;
	  if (settings.debug.value) log("" + sym + " is free in " + owner);
	  if (sym.isVariable && !(sym hasFlag CAPTURED)) {
            sym setFlag CAPTURED;
            val symClass = sym.tpe.symbol;
            atPhase(phase.next) {
              sym updateInfo (
                if (isValueClass(symClass)) refClass(symClass).tpe else ObjectRefClass.tpe)
            }
          }
        }
	!owner.isClass
      }
    }

    def freeVars(sym: Symbol): Iterator[Symbol] = free.get(sym) match {
      case Some(ss) => ss.elements
      case None => Iterator.empty
    }

    /** The traverse function */
    private val freeVarTraverser = new Traverser {
      override def traverse(tree: Tree): unit = {
       try { //debug
        val sym = tree.symbol;
        tree match {
	  case ClassDef(_, _, _, _, _) =>
            liftedDefs(tree.symbol) = new ListBuffer;
	    if (sym.isLocal) renamable addEntry sym;
	  case DefDef(_, _, _, _, _, _) =>
	    if (sym.isLocal) {
	      renamable addEntry sym;
	      sym setFlag (PRIVATE | LOCAL | FINAL)
	    } else if (sym.isPrimaryConstructor) {
	      symSet(called, sym) addEntry sym.owner
	    }
	  case Ident(name) =>
	    if (sym == NoSymbol) {
	      assert(name == nme.WILDCARD)
            } else if (sym.isLocal) {
	      val owner = enclMethOrClass(currentOwner);
              if (sym.isTerm && !sym.isMethod) markFree(sym, owner)
              else if (owner.isMethod && sym.isMethod) symSet(called, owner) addEntry sym;
            }
	  case Select(_, _) =>
            if (sym.isConstructor && sym.owner.isLocal) {
	      val owner = enclMethOrClass(currentOwner);
              if (owner.isMethod) symSet(called, owner) addEntry sym;
            }
          case _ =>
        }
        super.traverse(tree)
       } catch {//debug
	 case ex: Throwable =>
	   System.out.println("exception when traversing " + tree);
	 throw ex
       }
      }
    }

    /** Compute free variables map `fvs'.
     *  Also assign unique names to all
     *  value/variable/let that are free in some function or class, and to
     *  all class/function symbols that are owned by some function.
     */
    private def computeFreeVars: unit = {
      freeVarTraverser.traverse(unit.body);

      do {
	changedFreeVars = false;
	for (val caller <- called.keys;
	     val callee <- called(caller).elements;
	     val fv <- freeVars(callee))
	  markFree(fv, caller);
      } while (changedFreeVars);

      for (val sym <- renamable.elements) {
        sym.name = unit.fresh.newName(sym.name.toString());
        if (settings.debug.value) log("renamed: " + sym.name);
      }

      atPhase(phase.next) {
        for (val owner <- free.keys) {
          if (settings.debug.value)
            log("free(" + owner + owner.locationString + ") = " + free(owner).elements.toList);
          proxies(owner) =
            for (val fv <- free(owner).elements.toList) yield {
              val proxy = owner.newValue(owner.pos, fv.name)
              setFlag (if (owner.isClass) PARAMACCESSOR | PRIVATE | LOCAL else PARAM)
              setFlag SYNTHETIC
              setInfo fv.info;
              if (owner.isClass) owner.info.decls enter proxy;
              proxy
            }
        }
      }
    }

    private def proxy(sym: Symbol) = {
      def searchIn(owner: Symbol): Symbol = {
        if (settings.debug.value) log("searching for " + sym + "(" + sym.owner + ") in " + owner  + " " + enclMethOrClass(owner));//debug
        proxies.get(enclMethOrClass(owner)) match {
          case Some(ps) =>
            ps filter (p => p.name == sym.name) match {
              case List(p) => p
              case List() => searchIn(outer(owner))
            }
          case None => searchIn(outer(owner))
        }
      }
      if (settings.debug.value) log("proxy " + sym + " in " + sym.owner + " from " + currentOwner + " " + enclMethOrClass(sym.owner));//debug
      if (enclMethOrClass(sym.owner) == enclMethOrClass(currentOwner)) sym
      else searchIn(currentOwner)
    }

    private def memberRef(sym: Symbol) = {
      val clazz = sym.owner.enclClass;
      val qual = if (clazz == currentOwner.enclClass) gen.This(clazz)
		 else {
		   sym resetFlag(LOCAL | PRIVATE);
		   if (clazz.isStaticOwner) gen.mkQualifier(clazz.thisType)
		   else outerPath(outerValue, clazz)
		 }
      Select(qual, sym) setType sym.tpe
    }

    private def proxyRef(sym: Symbol) = {
      val psym = proxy(sym);
      if (psym.isLocal) gen.Ident(psym) else memberRef(psym)
    }

    private def addFreeArgs(pos: int, sym: Symbol, args: List[Tree]) = {
      def freeArg(fv: Symbol) = atPos(pos)(proxyRef(fv));
      val fvs = freeVars(sym).toList;
      if (fvs.isEmpty) args else args ::: (fvs map freeArg)
    }

    private def addFreeParams(tree: Tree, sym: Symbol): Tree = proxies.get(sym) match {
      case Some(ps) =>
        val freeParams = ps map (p => ValDef(p) setPos tree.pos setType NoType);
        tree match {
          case DefDef(mods, name, tparams, List(vparams), tpt, rhs) =>
            sym.updateInfo(
	      lifted(MethodType(sym.info.paramTypes ::: (ps map (.tpe)), sym.info.resultType)));
            copy.DefDef(tree, mods, name, tparams, List(vparams ::: freeParams), tpt, rhs)
          case ClassDef(mods, name, tparams, tpt, impl @ Template(parents, body)) =>
            copy.ClassDef(tree, mods, name, tparams, tpt,
                          copy.Template(impl, parents, body ::: freeParams))
        }
      case None =>
        tree
    }

    private def liftDef(tree: Tree): Tree = {
      val sym = tree.symbol;
      sym.owner = sym.owner.enclClass;
      if (sym.isClass) sym.owner = sym.owner.toInterface;
      if (sym.isMethod) sym setFlag LIFTED;
      liftedDefs(sym.owner) += tree;
      sym.owner.info.decls enterUnique sym;
      if (settings.debug.value) log("lifted: " + sym + sym.locationString);
      EmptyTree
    }

    private def postTransform(tree: Tree): Tree = {
      val sym = tree.symbol;
      tree match {
        case ClassDef(_, _, _, _, _) =>
          if (sym.isLocal) liftDef(tree) else tree
	case DefDef(_, _, _, _, _, _) =>
          val tree1 = addFreeParams(tree, sym);
          if (sym.isLocal) liftDef(tree1) else tree1
	case ValDef(mods, name, tpt, rhs) =>
          if (sym hasFlag CAPTURED) {
            val tpt1 = TypeTree(sym.tpe) setPos tpt.pos;
            val rhs1 =
              atPos(rhs.pos) {
                typed {
                  Apply(Select(New(TypeTree(sym.tpe)), nme.CONSTRUCTOR), List(rhs))
                }
              }
            copy.ValDef(tree, mods, name, tpt1, rhs1)
          } else tree
        case Return(Block(stats, value)) =>
          Block(stats, copy.Return(tree, value)) setType tree.tpe setPos tree.pos;
	case Return(expr) =>
	  if (sym != currentOwner.enclMethod) {
            System.out.println(sym);//debug
            System.out.println(currentOwner.enclMethod);//debug
	    unit.error(tree.pos, "non-local return not yet implemented");
          }
          tree
	case Apply(fn, args) =>
	  copy.Apply(tree, fn, addFreeArgs(tree.pos, sym, args));
	case Assign(Apply(TypeApply(sel @ Select(qual, _), _), List()), rhs) =>
	  // eliminate casts introduced by selecting a captured variable field
	  // on the lhs of an assignment.
	  assert(sel.symbol == Object_asInstanceOf);
	  copy.Assign(tree, qual, rhs)
        case Ident(name) =>
	  val tree1 =
            if (sym != NoSymbol && sym.isTerm && !sym.isLabel)
	      if (sym.isMethod)
		atPos(tree.pos)(memberRef(sym))
	      else if (sym.isLocal && enclMethOrClass(sym.owner) != enclMethOrClass(currentOwner))
		atPos(tree.pos)(proxyRef(sym))
	      else tree
	    else tree;
          if (sym hasFlag CAPTURED)
            atPos(tree.pos) {
              val tp = tree.tpe;
              val elemTree = typed { Select(tree1 setType sym.tpe, nme.elem) }
              if (elemTree.tpe.symbol != tp.symbol) gen.cast(elemTree, tp) else elemTree
            }
          else tree1
        case _ =>
          tree
      }
    }

    override def transform(tree: Tree): Tree =
      postTransform(super.transform(tree) setType lifted(tree.tpe));

    /** Transform statements and add lifted definitions to them. */
    override def transformStats(stats: List[Tree], exprOwner: Symbol): List[Tree] = {
      val stats1 = super.transformStats(stats, exprOwner);
      if (currentOwner.isClass && !currentOwner.isPackageClass && liftedDefs(currentOwner).hasNext)
        stats1 ::: liftedDefs(currentOwner).toList
      else
        stats1
    }

    override def transformUnit(unit: CompilationUnit): unit = {
      computeFreeVars;
      atPhase(phase.next)(super.transformUnit(unit))
    }
  }
}