From 2b15e8ce934c9629e18b5fa4a4bb39bba750575c Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Fri, 29 Jul 2011 18:26:00 +0000 Subject: First steps towards lifting --- .../scala/reflect/internal/Definitions.scala | 20 +- src/compiler/scala/reflect/internal/TreeGen.scala | 4 +- src/compiler/scala/reflect/runtime/Mirror.scala | 4 + .../scala/tools/nsc/transform/LambdaLift.scala | 6 +- .../scala/tools/nsc/transform/LiftCode.scala | 204 ++++++++++++++++----- .../scala/tools/nsc/transform/UnCurry.scala | 6 +- src/compiler/scala/tools/nsc/util/trace.scala | 13 +- src/library/scala/reflect/package.scala | 4 +- src/library/scala/runtime/ObjectRef.java | 6 +- src/library/scala/runtime/VolatileObjectRef.java | 6 +- 10 files changed, 202 insertions(+), 71 deletions(-) diff --git a/src/compiler/scala/reflect/internal/Definitions.scala b/src/compiler/scala/reflect/internal/Definitions.scala index 857d7470d4..d076e80836 100644 --- a/src/compiler/scala/reflect/internal/Definitions.scala +++ b/src/compiler/scala/reflect/internal/Definitions.scala @@ -218,14 +218,14 @@ trait Definitions extends reflect.api.StandardDefinitions { lazy val BridgeClass = getClass("scala.annotation.bridge") // fundamental reference classes - lazy val ScalaObjectClass = getClass("scala.ScalaObject") - lazy val PartialFunctionClass = getClass("scala.PartialFunction") - lazy val SymbolClass = getClass("scala.Symbol") - lazy val StringClass = getClass(sn.String) - lazy val StringModule = StringClass.linkedClassOfClass - lazy val ClassClass = getClass(sn.Class) - def Class_getMethod = getMember(ClassClass, nme.getMethod_) - lazy val DynamicClass = getClass("scala.Dynamic") + lazy val ScalaObjectClass = getClass("scala.ScalaObject") + lazy val PartialFunctionClass = getClass("scala.PartialFunction") + lazy val SymbolClass = getClass("scala.Symbol") + lazy val StringClass = getClass(sn.String) + lazy val StringModule = StringClass.linkedClassOfClass + lazy val ClassClass = getClass(sn.Class) + def Class_getMethod = getMember(ClassClass, nme.getMethod_) + lazy val DynamicClass = getClass("scala.Dynamic") // fundamental modules lazy val SysPackage = getPackageObject("scala.sys") @@ -248,6 +248,10 @@ trait Definitions extends reflect.api.StandardDefinitions { def arrayCloneMethod = getMember(ScalaRunTimeModule, "array_clone") def ensureAccessibleMethod = getMember(ScalaRunTimeModule, "ensureAccessible") def scalaRuntimeSameElements = getMember(ScalaRunTimeModule, nme.sameElements) + lazy val ReflectRuntimeMirror = getModule("scala.reflect.runtime.Mirror") + def freeValueMethod = getMember(ReflectRuntimeMirror, "freeValue") + lazy val ReflectPackage = getPackageObject("scala.reflect") + def Reflect_mirror = getMember(ReflectPackage, "mirror") // classes with special meanings lazy val StringAddClass = getClass("scala.runtime.StringAdd") diff --git a/src/compiler/scala/reflect/internal/TreeGen.scala b/src/compiler/scala/reflect/internal/TreeGen.scala index f184bcee51..a5149898b2 100644 --- a/src/compiler/scala/reflect/internal/TreeGen.scala +++ b/src/compiler/scala/reflect/internal/TreeGen.scala @@ -255,9 +255,7 @@ abstract class TreeGen { case IntClass => Literal(0) case LongClass => Literal(0L) case CharClass => Literal(0.toChar) - case _ => - if (NullClass.tpe <:< tp) Literal(null: Any) - else abort("Cannot determine zero for " + tp) + case _ => Literal(Constant(null)) } tree setType tp } diff --git a/src/compiler/scala/reflect/runtime/Mirror.scala b/src/compiler/scala/reflect/runtime/Mirror.scala index 3d6d330cfe..d5d3762fab 100644 --- a/src/compiler/scala/reflect/runtime/Mirror.scala +++ b/src/compiler/scala/reflect/runtime/Mirror.scala @@ -33,6 +33,10 @@ class Mirror extends Universe with api.Mirror { methodToJava(meth).invoke(receiver, args.asInstanceOf[Seq[AnyRef]]: _*) } + def freeValue(x: Any) = FreeValue(x) + + case class FreeValue(any: Any) + } object Mirror extends Mirror diff --git a/src/compiler/scala/tools/nsc/transform/LambdaLift.scala b/src/compiler/scala/tools/nsc/transform/LambdaLift.scala index e3cd9eb351..d5d7bdd1e9 100644 --- a/src/compiler/scala/tools/nsc/transform/LambdaLift.scala +++ b/src/compiler/scala/tools/nsc/transform/LambdaLift.scala @@ -116,7 +116,11 @@ abstract class LambdaLift extends InfoTransform { } changedFreeVars = true debuglog("" + sym + " is free in " + enclosure); - if ((sym.isVariable || (sym.isValue && sym.isLazy)) && !sym.hasFlag(CAPTURED)) { + if (sym.isVariable && !sym.hasFlag(CAPTURED)) { + // todo: We should merge this with the lifting done in liftCode. + // We do have to lift twice: in liftCode, because Code[T] needs to see the lifted version + // and here again because lazy bitmaps are introduced later and get lifted here. + // But we should factor out the code and run it twice. sym setFlag CAPTURED val symClass = sym.tpe.typeSymbol atPhase(phase.next) { diff --git a/src/compiler/scala/tools/nsc/transform/LiftCode.scala b/src/compiler/scala/tools/nsc/transform/LiftCode.scala index b52419f7ca..0ab7a0cf4c 100644 --- a/src/compiler/scala/tools/nsc/transform/LiftCode.scala +++ b/src/compiler/scala/tools/nsc/transform/LiftCode.scala @@ -14,11 +14,13 @@ import scala.tools.nsc.util.FreshNameCreator /** Translate expressions of the form reflect.Code.lift(exp) * to the lifted "reflect trees" representation of exp. + * Also: mutable variables that are accessed from a local function are wrapped in refs. * * @author Gilles Dubochet - * @version 1.0 + * @author Martin Odersky + * @version 2.10 */ -abstract class LiftCode extends Transform with Reifiers { +abstract class LiftCode extends Transform with TypingTransformers { import global._ // the global environment import definitions._ // standard classes and methods @@ -30,22 +32,54 @@ abstract class LiftCode extends Transform with Reifiers { val phaseName: String = "liftcode" def newTransformer(unit: CompilationUnit): Transformer = - new AddRefFields(unit) + new Lifter(unit) - class AddRefFields(unit: CompilationUnit) extends Transformer { - override def transform(tree: Tree): Tree = tree match { - case Apply(lift, List(tree)) - if lift.symbol == Code_lift => - typed(atPos(tree.pos)(codify(tree))) - case _ => - super.transform(tree) + class Lifter(unit: CompilationUnit) extends TypingTransformer(unit) { + override def transformUnit(unit: CompilationUnit) { + freeMutableVars.clear() + freeLocalsTraverser(unit.body) + atPhase(phase.next) { + super.transformUnit(unit) + } + } + + override def transform(tree: Tree): Tree = { + val sym = tree.symbol + tree match { + case Apply(lift, List(tree)) if sym == Code_lift => + transform(localTyper.typedPos(tree.pos)(codify(tree))) + case ValDef(mods, name, tpt, rhs) if (freeMutableVars(sym)) => + val tpt1 = TypeTree(sym.tpe) setPos tpt.pos + /* Creating a constructor argument if one isn't present. */ + val constructorArg = rhs match { + case EmptyTree => gen.mkZero(atPhase(phase.prev)(sym.tpe)) + case _ => transform(rhs) + } + val rhs1 = typer.typedPos(rhs.pos) { + util.errtrace("lifted rhs for "+tree+" in "+unit) ( + Apply(Select(New(TypeTree(sym.tpe)), nme.CONSTRUCTOR), List(constructorArg))) + } + sym resetFlag MUTABLE + sym removeAnnotation VolatileAttr + treeCopy.ValDef(tree, mods &~ MUTABLE, name, tpt1, rhs1) + case Ident(name) if freeMutableVars(sym) => + localTyper.typedPos(tree.pos) { + util.errtrace("lifting ")(Select(tree setType sym.tpe, nme.elem)) + } + case _ => + super.transform(tree) + } } } + case class FreeValue(tree: Tree) extends Tree - type InjectEnvironment = immutable.ListMap[reflect.Symbol, Name] + class Reifier(owner: Symbol) {#:: + import reflect.runtime.{Mirror => rm} - class Injector(env: InjectEnvironment, fresh: FreshNameCreator) { + private val boundVars: mutable.Set[Symbol] = mutable.Set() + private val freeTrees: mutable.Set[Tree] = mutable.Set() + private val mirrorPrefix = gen.mkAttributedRef(ReflectRuntimeMirror) // todo replace className by caseName in CaseClass once we have switched to nsc. def className(value: AnyRef): String = value match { @@ -60,58 +94,140 @@ abstract class LiftCode extends Transform with Reifiers { def objectName(value: Any): String = value match { case Nil => "scala.collection.immutable.Nil" - case reflect.NoSymbol => "scala.reflect.NoSymbol" - case reflect.RootSymbol => "scala.reflect.RootSymbol" - case reflect.NoPrefix => "scala.reflect.NoPrefix" - case reflect.NoType => "scala.reflect.NoType" + case reflect.NoSymbol => "scala.reflect.runtime.Mirror.NoSymbol" + case reflect.RootSymbol => "scala.reflect.runtime.Mirror.definitions.RootSymbol" + case reflect.NoPrefix => "scala.reflect.runtime.Mirror.NoPrefix" + case reflect.NoType => "scala.reflect.runtime.Mirror.NoType" case _ => "" } - def inject(value: Any): Tree = { - def treatProduct(c: Product) = { + def reify(value: Any): rm.Tree = { + def treatProduct(c: Product): rm.Tree = { val name = objectName(c) - if (name.length() != 0) - gen.mkAttributedRef(definitions.getModule(name)) + if (name.length != 0) + rm.gen.mkAttributedRef(rm.definitions.getModule(name)) else { val name = className(c) - if (name.length() == 0) abort("don't know how to inject " + value) - val injectedArgs = new ListBuffer[Tree] + if (name.length == 0) abort("don't know how to inject " + value) + val injectedArgs = new ListBuffer[rm.Tree] for (i <- 0 until c.productArity) - injectedArgs += inject(c.productElement(i)) - New(Ident(definitions.getClass(name)), List(injectedArgs.toList)) + injectedArgs += reify(c.productElement(i)) + rm.New(rm.gen.mkAttributedRef(rm.definitions.getClass(name)), List(injectedArgs.toList)) } } + + def makeFree(tree: Tree): rm.Tree = { + freeTrees += tree + reify(Apply(gen.mkAttributedRef(definitions.freeValueMethod), List(tree))) + } + value match { - case FreeValue(tree) => - New(Ident(definitions.getClass("scala.reflect.Literal")), List(List(tree))) - case () => Literal(Constant(())) - case x: String => Literal(Constant(x)) - case x: Boolean => Literal(Constant(x)) - case x: Byte => Literal(Constant(x)) - case x: Short => Literal(Constant(x)) - case x: Char => Literal(Constant(x)) - case x: Int => Literal(Constant(x)) - case x: Long => Literal(Constant(x)) - case x: Float => Literal(Constant(x)) - case x: Double => Literal(Constant(x)) - case c: Product => treatProduct(c) - case null => - gen.mkAttributedRef(definitions.getModule("scala.reflect.NoType")) + case tree: Tree if freeTrees contains tree => + tree + case tree: DefTree => + boundVars += tree.symbol + reify1(tree) + case tree @ This(_) if !(boundVars contains tree.symbol) => + makeFree(tree) + case tree @ Ident(_) if !(boundVars contains tree.symbol) => + makeFree(tree) case _ => - abort("don't know how to inject " + value) + reify1(value) } } - } // Injector + def reify1(value: Any): rm.Tree = value match { + case () => rm.Literal(rm.Constant(())) + case x: String => rm.Literal(rm.Constant(x)) + case x: Boolean => rm.Literal(rm.Constant(x)) + case x: Byte => rm.Literal(rm.Constant(x)) + case x: Short => rm.Literal(rm.Constant(x)) + case x: Char => rm.Literal(rm.Constant(x)) + case x: Int => rm.Literal(rm.Constant(x)) + case x: Long => rm.Literal(rm.Constant(x)) + case x: Float => rm.Literal(rm.Constant(x)) + case x: Double => rm.Literal(rm.Constant(x)) + case c: Product => treatProduct(c) + case _ => + abort("don't know how to inject " + value) + } + } // Injector - def inject(code: reflect.Tree): Tree = - new Injector(immutable.ListMap.empty, new FreshNameCreator.Default).inject(code) + def reify(tree: Tree): Tree = + new Reifier().reify(tree) def codify (tree: Tree): Tree = + + Block( + ValDef( New(TypeTree(appliedType(definitions.CodeClass.typeConstructor, List(tree.tpe))), - List(List(inject(reify(tree))))) + List(List(reify(tree)))) + + /** Set of mutable local variables that are free in some inner method. */ + private val freeMutableVars: mutable.Set[Symbol] = new mutable.HashSet + + /** PP: There is apparently some degree of overlap between the CAPTURED + * flag and the role being filled here. I think this is how this was able + * to go for so long looking only at DefDef and Ident nodes, as bugs + * would only emerge under more complicated conditions such as #3855. + * I'll try to figure it all out, but if someone who already knows the + * whole story wants to fill it in, that too would be great. + */ + private val freeLocalsTraverser = new Traverser { + var currentMethod: Symbol = NoSymbol + var maybeEscaping = false + + def withEscaping(body: => Unit) { + val saved = maybeEscaping + maybeEscaping = true + try body + finally maybeEscaping = saved + } + + override def traverse(tree: Tree) = tree match { + case DefDef(_, _, _, _, _, _) => + val lastMethod = currentMethod + currentMethod = tree.symbol + try super.traverse(tree) + finally currentMethod = lastMethod + /** A method call with a by-name parameter represents escape. */ + case Apply(fn, args) if fn.symbol.paramss.nonEmpty => + traverse(fn) + (fn.symbol.paramss.head, args).zipped foreach { (param, arg) => + if (param.tpe != null && isByNameParamType(param.tpe)) + withEscaping(traverse(arg)) + else + traverse(arg) + } + /** The rhs of a closure represents escape. */ + case Function(vparams, body) => + vparams foreach traverse + withEscaping(traverse(body)) + + /** The appearance of an ident outside the method where it was defined or + * anytime maybeEscaping is true implies escape. + */ + case Ident(_) => + val sym = tree.symbol + if (sym.isVariable && sym.owner.isMethod && (maybeEscaping || sym.owner != currentMethod)) { + freeMutableVars += sym + val symTpe = sym.tpe + val symClass = symTpe.typeSymbol + atPhase(phase.next) { + def refType(valueRef: Map[Symbol, Symbol], objectRefClass: Symbol) = + if (isValueClass(symClass)) valueRef(symClass).tpe + else appliedType(objectRefClass.typeConstructor, List(symTpe)) + sym updateInfo ( + if (sym.hasAnnotation(VolatileAttr)) refType(volatileRefClass, VolatileObjectRefClass) + else refType(refClass, ObjectRefClass)) + } + } + case _ => + super.traverse(tree) + } + } } // case EmptyTree => diff --git a/src/compiler/scala/tools/nsc/transform/UnCurry.scala b/src/compiler/scala/tools/nsc/transform/UnCurry.scala index 8ac69f15e1..a636c75281 100644 --- a/src/compiler/scala/tools/nsc/transform/UnCurry.scala +++ b/src/compiler/scala/tools/nsc/transform/UnCurry.scala @@ -770,8 +770,8 @@ abstract class UnCurry extends InfoTransform case DefDef(_, _, _, _, _, _) => val lastMethod = currentMethod currentMethod = tree.symbol - super.traverse(tree) - currentMethod = lastMethod + try super.traverse(tree) + finally currentMethod = lastMethod /** A method call with a by-name parameter represents escape. */ case Apply(fn, args) if fn.symbol.paramss.nonEmpty => traverse(fn) @@ -792,7 +792,7 @@ abstract class UnCurry extends InfoTransform case Ident(_) => val sym = tree.symbol if (sym.isVariable && sym.owner.isMethod && (maybeEscaping || sym.owner != currentMethod)) - freeMutableVars += sym + assert(false, "Failure to lift "+sym+sym.locationString); freeMutableVars += sym case _ => super.traverse(tree) } diff --git a/src/compiler/scala/tools/nsc/util/trace.scala b/src/compiler/scala/tools/nsc/util/trace.scala index 97b3123372..6a4d20e7f1 100644 --- a/src/compiler/scala/tools/nsc/util/trace.scala +++ b/src/compiler/scala/tools/nsc/util/trace.scala @@ -1,13 +1,18 @@ package scala.tools.nsc package util -object trace { - def apply[T](msg: String)(value: T): T = { - println(msg+value) +import java.io.PrintStream + +object trace extends SimpleTracer(System.out) +object errtrace extends SimpleTracer(System.err) + +class SimpleTracer(out: PrintStream) { + def apply[T](msg: String)(value: T): T = { + out.println(msg+value) value } def withFun[T, U](msg: String)(value: T)(fun: T => U): T = { - println(msg+fun(value)) + out.println(msg+fun(value)) value } } diff --git a/src/library/scala/reflect/package.scala b/src/library/scala/reflect/package.scala index 3fe25316ca..eac4adcffe 100644 --- a/src/library/scala/reflect/package.scala +++ b/src/library/scala/reflect/package.scala @@ -4,7 +4,7 @@ package object reflect { val mirror: api.Mirror = try { // we use (Java) reflection here so that we can keep reflect.runtime and reflect.internals in a seperate jar - (java.lang.Class forName "scala.reflect.runtime.Mirror$" getField "$MODULE" get null).asInstanceOf[api.Mirror] + (java.lang.Class forName "scala.reflect.runtime.Mirror$" getField "MODULE$" get null).asInstanceOf[api.Mirror] } catch { case ex: NoClassDefFoundError => throw new UnsupportedOperationException("Scala reflection not available on this platform") @@ -15,4 +15,4 @@ package object reflect { type Tree = mirror.Tree */ -} \ No newline at end of file +} diff --git a/src/library/scala/runtime/ObjectRef.java b/src/library/scala/runtime/ObjectRef.java index a1dd3d78d9..15f2f493c7 100644 --- a/src/library/scala/runtime/ObjectRef.java +++ b/src/library/scala/runtime/ObjectRef.java @@ -11,10 +11,10 @@ package scala.runtime; -public class ObjectRef implements java.io.Serializable { +public class ObjectRef implements java.io.Serializable { private static final long serialVersionUID = -9055728157600312291L; - public Object elem; - public ObjectRef(Object elem) { this.elem = elem; } + public T elem; + public ObjectRef(T elem) { this.elem = elem; } public String toString() { return String.valueOf(elem); } } diff --git a/src/library/scala/runtime/VolatileObjectRef.java b/src/library/scala/runtime/VolatileObjectRef.java index 73facba978..7c393b405a 100755 --- a/src/library/scala/runtime/VolatileObjectRef.java +++ b/src/library/scala/runtime/VolatileObjectRef.java @@ -11,10 +11,10 @@ package scala.runtime; -public class VolatileObjectRef implements java.io.Serializable { +public class VolatileObjectRef implements java.io.Serializable { private static final long serialVersionUID = -9055728157600312291L; - volatile public Object elem; - public VolatileObjectRef(Object elem) { this.elem = elem; } + volatile public T elem; + public VolatileObjectRef(T elem) { this.elem = elem; } public String toString() { return String.valueOf(elem); } } -- cgit v1.2.3