/* NSC -- new scala compiler
* Copyright 2005 LAMP/EPFL
* @author
*/
// $Id$
package scala.tools.nsc.transform;
import symtab._;
import Flags._;
import util.{ListBuffer, TreeSet};
import collection.mutable.HashMap;
abstract class LambdaLift extends InfoTransform {
import global._;
import definitions._;
import typer.{typed, typedOperator};
import posAssigner.atPos;
/** the following two members override abstract members in Transform */
val phaseName: String = "lambdalift";
private val lifted = new TypeMap {
def apply(tp: Type): Type = tp match {
case TypeRef(pre, sym, args) =>
if (pre == NoPrefix && sym.isClass && !sym.isPackageClass) {
assert(args.isEmpty);
typeRef(apply(sym.owner.enclClass.thisType), sym, args)
} else mapOver(tp)
case ClassInfoType(parents, decls, clazz) =>
val parents1 = List.mapConserve(parents)(this);
if (parents1 eq parents) tp
else ClassInfoType(parents1, decls, clazz)
case _ =>
mapOver(tp)
}
}
def transformInfo(sym: Symbol, tp: Type): Type = lifted(tp);
protected def newTransformer(unit: CompilationUnit): Transformer =
new LambdaLifter(unit);
class LambdaLifter(unit: CompilationUnit) extends explicitOuter.OuterPathTransformer {
/** A map storing free variables of functions and classes */
private val free = new HashMap[Symbol,SymSet];
/** A map storing the free variable proxies of functions and classes */
private val proxies = new HashMap[Symbol, List[Symbol]];
/** A hashtable storing calls between functions */
private val called = new HashMap[Symbol, SymSet];
/** The set of symbols that need to be renamed. */
private val renamable = newSymSet;
/** A flag to indicate whether new free variables have been found */
private var changedFreeVars: boolean = _;
/** Buffers for lifted out classes and methods */
private val liftedDefs = new HashMap[Symbol, ListBuffer[Tree]];
private type SymSet = TreeSet[Symbol];
private def newSymSet = new TreeSet[Symbol]((x, y) => x.isLess(y));
private def symSet(f: HashMap[Symbol, SymSet], sym: Symbol): SymSet = f.get(sym) match {
case Some(ss) => ss
case None => val ss = newSymSet; f(sym) = ss; ss
}
private def outer(sym: Symbol): Symbol =
if (sym.isConstructor) sym.owner.owner else sym.owner;
private def enclMethOrClass(sym: Symbol): Symbol = {
def localToConstr(sym: Symbol) =
if (sym.isLocalDummy) sym.owner.primaryConstructor else sym;
var encl = localToConstr(sym);
while (!encl.isMethod && !encl.isClass)
encl = localToConstr(outer(encl));
encl
}
/** Mark symbol `sym' as being free in `owner', unless `sym'
* is defined in `owner' or there is a class between `owner's owner
* and the owner of `sym'.
* Return `true' if there is no class between `owner' and
* the owner of sym.
* pre: sym.isLocal, (owner.isMethod || owner.isClass)
*/
private def markFree(sym: Symbol, owner: Symbol): boolean = {
if (settings.debug.value) log("mark " + sym + " of " + sym.owner + " free in " + owner);
if (owner == enclMethOrClass(sym.owner)) true
else if (owner.isPackageClass || !markFree(sym, enclMethOrClass(outer(owner)))) false
else {
val ss = symSet(free, owner);
if (!(ss contains sym)) {
ss addEntry sym;
renamable addEntry sym;
changedFreeVars = true;
if (settings.debug.value) log("" + sym + " is free in " + owner);
if (sym.isVariable && !(sym hasFlag CAPTURED)) {
sym setFlag CAPTURED;
val symClass = sym.tpe.symbol;
atPhase(phase.next) {
sym updateInfo (
if (isValueClass(symClass)) refClass(symClass).tpe else ObjectRefClass.tpe)
}
}
}
!owner.isClass
}
}
def freeVars(sym: Symbol): Iterator[Symbol] = free.get(sym) match {
case Some(ss) => ss.elements
case None => Iterator.empty
}
/** The traverse function */
private val freeVarTraverser = new Traverser {
override def traverse(tree: Tree): unit = {
try { //debug
val sym = tree.symbol;
tree match {
case ClassDef(_, _, _, _, _) =>
liftedDefs(tree.symbol) = new ListBuffer;
if (sym.isLocal) renamable addEntry sym;
case DefDef(_, _, _, _, _, _) =>
if (sym.isLocal) {
renamable addEntry sym;
sym setFlag (PRIVATE | LOCAL | FINAL)
} else if (sym.isPrimaryConstructor) {
symSet(called, sym) addEntry sym.owner
}
case Ident(name) =>
if (sym == NoSymbol) {
assert(name == nme.WILDCARD)
} else if (sym.isLocal) {
val owner = enclMethOrClass(currentOwner);
if (sym.isTerm && !sym.isMethod) markFree(sym, owner)
else if (owner.isMethod && sym.isMethod) symSet(called, owner) addEntry sym;
}
case Select(_, _) =>
if (sym.isConstructor && sym.owner.isLocal) {
val owner = enclMethOrClass(currentOwner);
if (owner.isMethod) symSet(called, owner) addEntry sym;
}
case _ =>
}
super.traverse(tree)
} catch {//debug
case ex: Throwable =>
System.out.println("exception when traversing " + tree);
throw ex
}
}
}
/** Compute free variables map `fvs'.
* Also assign unique names to all
* value/variable/let that are free in some function or class, and to
* all class/function symbols that are owned by some function.
*/
private def computeFreeVars: unit = {
freeVarTraverser.traverse(unit.body);
do {
changedFreeVars = false;
for (val caller <- called.keys;
val callee <- called(caller).elements;
val fv <- freeVars(callee))
markFree(fv, caller);
} while (changedFreeVars);
for (val sym <- renamable.elements) {
sym.name = unit.fresh.newName(sym.name.toString());
if (settings.debug.value) log("renamed: " + sym.name);
}
atPhase(phase.next) {
for (val owner <- free.keys) {
if (settings.debug.value)
log("free(" + owner + owner.locationString + ") = " + free(owner).elements.toList);
proxies(owner) =
for (val fv <- free(owner).elements.toList) yield {
val proxy = owner.newValue(owner.pos, fv.name)
setFlag (if (owner.isClass) PARAMACCESSOR | PRIVATE | LOCAL else PARAM)
setFlag SYNTHETIC
setInfo fv.info;
if (owner.isClass) owner.info.decls enter proxy;
proxy
}
}
}
}
private def proxy(sym: Symbol) = {
def searchIn(owner: Symbol): Symbol = {
if (settings.debug.value) log("searching for " + sym + "(" + sym.owner + ") in " + owner + " " + enclMethOrClass(owner));//debug
proxies.get(enclMethOrClass(owner)) match {
case Some(ps) =>
ps filter (p => p.name == sym.name) match {
case List(p) => p
case List() => searchIn(outer(owner))
}
case None => searchIn(outer(owner))
}
}
if (settings.debug.value) log("proxy " + sym + " in " + sym.owner + " from " + currentOwner + " " + enclMethOrClass(sym.owner));//debug
if (enclMethOrClass(sym.owner) == enclMethOrClass(currentOwner)) sym
else searchIn(currentOwner)
}
private def memberRef(sym: Symbol) = {
val clazz = sym.owner.enclClass;
val qual = if (clazz == currentOwner.enclClass) gen.This(clazz)
else {
sym resetFlag(LOCAL | PRIVATE);
if (clazz.isStaticOwner) gen.mkQualifier(clazz.thisType)
else outerPath(outerValue, clazz)
}
Select(qual, sym) setType sym.tpe
}
private def proxyRef(sym: Symbol) = {
val psym = proxy(sym);
if (psym.isLocal) gen.Ident(psym) else memberRef(psym)
}
private def addFreeArgs(pos: int, sym: Symbol, args: List[Tree]) = {
def freeArg(fv: Symbol) = atPos(pos)(proxyRef(fv));
val fvs = freeVars(sym).toList;
if (fvs.isEmpty) args else args ::: (fvs map freeArg)
}
private def addFreeParams(tree: Tree, sym: Symbol): Tree = proxies.get(sym) match {
case Some(ps) =>
val freeParams = ps map (p => ValDef(p) setPos tree.pos setType NoType);
tree match {
case DefDef(mods, name, tparams, List(vparams), tpt, rhs) =>
sym.updateInfo(
lifted(MethodType(sym.info.paramTypes ::: (ps map (.tpe)), sym.info.resultType)));
copy.DefDef(tree, mods, name, tparams, List(vparams ::: freeParams), tpt, rhs)
case ClassDef(mods, name, tparams, tpt, impl @ Template(parents, body)) =>
copy.ClassDef(tree, mods, name, tparams, tpt,
copy.Template(impl, parents, body ::: freeParams))
}
case None =>
tree
}
private def liftDef(tree: Tree): Tree = {
val sym = tree.symbol;
sym.owner = sym.owner.enclClass;
if (sym.isClass) sym.owner = sym.owner.toInterface;
if (sym.isMethod) sym setFlag LIFTED;
liftedDefs(sym.owner) += tree;
sym.owner.info.decls enterUnique sym;
if (settings.debug.value) log("lifted: " + sym + sym.locationString);
EmptyTree
}
private def postTransform(tree: Tree): Tree = {
val sym = tree.symbol;
tree match {
case ClassDef(_, _, _, _, _) =>
if (sym.isLocal) liftDef(tree) else tree
case DefDef(_, _, _, _, _, _) =>
val tree1 = addFreeParams(tree, sym);
if (sym.isLocal) liftDef(tree1) else tree1
case ValDef(mods, name, tpt, rhs) =>
if (sym hasFlag CAPTURED) {
val tpt1 = TypeTree(sym.tpe) setPos tpt.pos;
val rhs1 =
atPos(rhs.pos) {
typed {
Apply(Select(New(TypeTree(sym.tpe)), nme.CONSTRUCTOR), List(rhs))
}
}
copy.ValDef(tree, mods, name, tpt1, rhs1)
} else tree
case Return(Block(stats, value)) =>
Block(stats, copy.Return(tree, value)) setType tree.tpe setPos tree.pos;
case Return(expr) =>
if (sym != currentOwner.enclMethod) {
System.out.println(sym);//debug
System.out.println(currentOwner.enclMethod);//debug
unit.error(tree.pos, "non-local return not yet implemented");
}
tree
case Apply(fn, args) =>
copy.Apply(tree, fn, addFreeArgs(tree.pos, sym, args));
case Assign(Apply(TypeApply(sel @ Select(qual, _), _), List()), rhs) =>
// eliminate casts introduced by selecting a captured variable field
// on the lhs of an assignment.
assert(sel.symbol == Object_asInstanceOf);
copy.Assign(tree, qual, rhs)
case Ident(name) =>
val tree1 =
if (sym != NoSymbol && sym.isTerm && !sym.isLabel)
if (sym.isMethod)
atPos(tree.pos)(memberRef(sym))
else if (sym.isLocal && enclMethOrClass(sym.owner) != enclMethOrClass(currentOwner))
atPos(tree.pos)(proxyRef(sym))
else tree
else tree;
if (sym hasFlag CAPTURED)
atPos(tree.pos) {
val tp = tree.tpe;
val elemTree = typed { Select(tree1 setType sym.tpe, nme.elem) }
if (elemTree.tpe.symbol != tp.symbol) gen.cast(elemTree, tp) else elemTree
}
else tree1
case _ =>
tree
}
}
override def transform(tree: Tree): Tree =
postTransform(super.transform(tree) setType lifted(tree.tpe));
/** Transform statements and add lifted definitions to them. */
override def transformStats(stats: List[Tree], exprOwner: Symbol): List[Tree] = {
val stats1 = super.transformStats(stats, exprOwner);
if (currentOwner.isClass && !currentOwner.isPackageClass && liftedDefs(currentOwner).hasNext)
stats1 ::: liftedDefs(currentOwner).toList
else
stats1
}
override def transformUnit(unit: CompilationUnit): unit = {
computeFreeVars;
atPhase(phase.next)(super.transformUnit(unit))
}
}
}