diff options
author | Paul Phillips <paulp@improving.org> | 2010-09-19 06:41:47 +0000 |
---|---|---|
committer | Paul Phillips <paulp@improving.org> | 2010-09-19 06:41:47 +0000 |
commit | fd2bfa28b0aaa2d9f7fa3a0dd702fcb0c93b80a4 (patch) | |
tree | e97d46684fb511e8ab68e05fa8fec26610a8c70b | |
parent | 28c1aa3c204b0c1081d040a66a628d8f21306b10 (diff) | |
download | scala-fd2bfa28b0aaa2d9f7fa3a0dd702fcb0c93b80a4.tar.gz scala-fd2bfa28b0aaa2d9f7fa3a0dd702fcb0c93b80a4.tar.bz2 scala-fd2bfa28b0aaa2d9f7fa3a0dd702fcb0c93b80a4.zip |
Mostly a setup commit.
impressively tedious it is to work directly with the AST, so I picked up
TreeDSL again and fleshed it out some more. And then I did a once over
on SyntheticMethods beating out bits of duplication. No review.
6 files changed, 155 insertions, 72 deletions
diff --git a/src/compiler/scala/tools/nsc/ast/TreeDSL.scala b/src/compiler/scala/tools/nsc/ast/TreeDSL.scala index a24c8c01d3..547476a6db 100644 --- a/src/compiler/scala/tools/nsc/ast/TreeDSL.scala +++ b/src/compiler/scala/tools/nsc/ast/TreeDSL.scala @@ -8,6 +8,7 @@ package scala.tools.nsc package ast import PartialFunction._ +import symtab.Flags /** A DSL for generating scala code. The goal is that the * code generating code should look a lot like the code it @@ -147,15 +148,98 @@ trait TreeDSL { def ==>(body: Tree): CaseDef = CaseDef(pat, guard, body) } - abstract class ValOrDefStart(sym: Symbol) { - def ===(body: Tree): ValOrDefDef + /** VODD, if it's not obvious, means ValOrDefDef. This is the + * common code between a tree based on a pre-existing symbol and + * one being built from scratch. + */ + trait VODDStart { + def name: Name + def defaultMods: Modifiers + def defaultTpt: Tree + def defaultPos: Position + + type ResultTreeType <: ValOrDefDef + def mkTree(rhs: Tree): ResultTreeType + def ===(rhs: Tree): ResultTreeType + + private var _mods: Modifiers = null + private var _tpt: Tree = null + private var _pos: Position = null + + def withType(tp: Type): this.type = { + _tpt = TypeTree(tp) + this + } + def withFlags(flags: Long*): this.type = { + if (_mods == null) + _mods = defaultMods + + _mods = flags.foldLeft(_mods)(_ | _) + this + } + def withPos(pos: Position): this.type = { + _pos = pos + this + } + + final def mods = if (_mods == null) defaultMods else _mods + final def tpt = if (_tpt == null) defaultTpt else _tpt + final def pos = if (_pos == null) defaultPos else _pos + } + trait SymVODDStart extends VODDStart { + def sym: Symbol + def symType: Type + + def name = sym.name + def defaultMods = Modifiers(sym.flags) + def defaultTpt = TypeTree(symType) setPos sym.pos.focus + def defaultPos = sym.pos + + final def ===(rhs: Tree): ResultTreeType = + atPos(sym.pos)(mkTree(rhs) setSymbol sym) + } + trait ValCreator { + self: VODDStart => + + type ResultTreeType = ValDef + def mkTree(rhs: Tree): ValDef = ValDef(mods, name, tpt, rhs) + } + trait DefCreator { + self: VODDStart => + + def tparams: List[TypeDef] + def vparamss: List[List[ValDef]] + + type ResultTreeType = DefDef + def mkTree(rhs: Tree): DefDef = DefDef(mods, name, tparams, vparamss, tpt, rhs) } - class DefStart(sym: Symbol) extends ValOrDefStart(sym) { - def ===(body: Tree) = DefDef(sym, body) + + class DefSymStart(val sym: Symbol) extends SymVODDStart with DefCreator { + def symType = sym.tpe.finalResultType + def tparams = sym.typeParams map TypeDef + def vparamss = sym.paramss map (xs => xs map ValDef) + } + class ValSymStart(val sym: Symbol) extends SymVODDStart with ValCreator { + def symType = sym.tpe + } + + trait TreeVODDStart extends VODDStart { + def defaultMods = NoMods + def defaultTpt = TypeTree() + def defaultPos = NoPosition + + final def ===(rhs: Tree): ResultTreeType = + if (pos == NoPosition) mkTree(rhs) + else atPos(pos)(mkTree(rhs)) + } + + class ValTreeStart(val name: Name) extends TreeVODDStart with ValCreator { } - class ValStart(sym: Symbol) extends ValOrDefStart(sym) { - def ===(body: Tree) = ValDef(sym, body) + class DefTreeStart(val name: Name) extends TreeVODDStart with DefCreator { + def tparams: List[TypeDef] = Nil + def vparamss: List[List[ValDef]] = List(Nil) } + class IfStart(cond: Tree, thenp: Tree) { def THEN(x: Tree) = new IfStart(cond, x) def ELSE(elsep: Tree) = If(cond, thenp, elsep) @@ -170,10 +254,6 @@ trait TreeDSL { def CASE(pat: Tree): CaseStart = new CaseStart(pat, EmptyTree) def DEFAULT: CaseStart = new CaseStart(WILD(), EmptyTree) - class NameMethods(target: Name) { - def BIND(body: Tree) = Bind(target, body) - } - class SymbolMethods(target: Symbol) { def BIND(body: Tree) = Bind(target, body) @@ -204,8 +284,13 @@ trait TreeDSL { if (args.isEmpty) New(TypeTree(sym.tpe)) else New(TypeTree(sym.tpe), List(args.toList)) - def VAL(sym: Symbol) = new ValStart(sym) - def DEF(sym: Symbol) = new DefStart(sym) + def VAR(name: Name, tp: Type): ValTreeStart = VAR(name) withType tp + def VAR(name: Name): ValTreeStart = new ValTreeStart(name) withFlags Flags.MUTABLE + def VAL(name: Name, tp: Type): ValTreeStart = VAL(name) withType tp + def VAL(name: Name): ValTreeStart = new ValTreeStart(name) + def VAL(sym: Symbol): ValSymStart = new ValSymStart(sym) + + def DEF(sym: Symbol): DefSymStart = new DefSymStart(sym) def AND(guards: Tree*) = if (guards.isEmpty) EmptyTree else guards reduceLeft gen.mkAnd @@ -248,12 +333,6 @@ trait TreeDSL { /** Implicits - some of these should probably disappear **/ implicit def mkTreeMethods(target: Tree): TreeMethods = new TreeMethods(target) implicit def mkTreeMethodsFromSymbol(target: Symbol): TreeMethods = new TreeMethods(Ident(target)) - implicit def mkTreeMethodsFromName(target: Name): TreeMethods = new TreeMethods(Ident(target)) - implicit def mkTreeMethodsFromString(target: String): TreeMethods = new TreeMethods(Ident(target)) - - implicit def mkNameMethodsFromName(target: Name): NameMethods = new NameMethods(target) - implicit def mkNameMethodsFromString(target: String): NameMethods = new NameMethods(target) - implicit def mkSymbolMethodsFromSymbol(target: Symbol): SymbolMethods = new SymbolMethods(target) /** (foo DOT bar) might be simply a Select, but more likely it is to be immediately diff --git a/src/compiler/scala/tools/nsc/backend/jvm/GenJVM.scala b/src/compiler/scala/tools/nsc/backend/jvm/GenJVM.scala index c29b532cf0..2d2146d113 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/GenJVM.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/GenJVM.scala @@ -89,8 +89,7 @@ abstract class GenJVM extends SubComponent { val MethodHandleType = new JObjectType("java.dyn.MethodHandle") // Scala attributes - val SerializableAttr = definitions.SerializableAttr - val SerialVersionUID = definitions.getClass("scala.SerialVersionUID") + import definitions.{ SerializableAttr, SerialVersionUIDAttr } val CloneableAttr = definitions.getClass("scala.cloneable") val TransientAtt = definitions.getClass("scala.transient") val VolatileAttr = definitions.getClass("scala.volatile") @@ -213,7 +212,7 @@ abstract class GenJVM extends SubComponent { parents = parents ::: List(definitions.SerializableClass.tpe) case AnnotationInfo(tp, _, _) if tp.typeSymbol == CloneableAttr => parents = parents ::: List(CloneableClass.tpe) - case AnnotationInfo(tp, Literal(const) :: _, _) if tp.typeSymbol == SerialVersionUID => + case AnnotationInfo(tp, Literal(const) :: _, _) if tp.typeSymbol == SerialVersionUIDAttr => serialVUID = Some(const.longValue) case AnnotationInfo(tp, _, _) if tp.typeSymbol == RemoteAttr => parents = parents ::: List(RemoteInterface.tpe) diff --git a/src/compiler/scala/tools/nsc/symtab/Definitions.scala b/src/compiler/scala/tools/nsc/symtab/Definitions.scala index 16563bb38c..1d11e31a19 100644 --- a/src/compiler/scala/tools/nsc/symtab/Definitions.scala +++ b/src/compiler/scala/tools/nsc/symtab/Definitions.scala @@ -449,6 +449,7 @@ trait Definitions extends reflect.generic.StandardDefinitions { // special attributes lazy val SerializableAttr: Symbol = getClass("scala.serializable") + lazy val SerialVersionUIDAttr: Symbol = getClass("scala.SerialVersionUID") lazy val DeprecatedAttr: Symbol = getClass("scala.deprecated") lazy val DeprecatedNameAttr: Symbol = getClass("scala.deprecatedName") lazy val MigrationAnnotationClass: Symbol = getClass("scala.annotation.migration") diff --git a/src/compiler/scala/tools/nsc/symtab/Symbols.scala b/src/compiler/scala/tools/nsc/symtab/Symbols.scala index ab79f735b4..44bfbc5d27 100644 --- a/src/compiler/scala/tools/nsc/symtab/Symbols.scala +++ b/src/compiler/scala/tools/nsc/symtab/Symbols.scala @@ -119,6 +119,9 @@ trait Symbols extends reflect.generic.Symbols { self: SymbolTable => annots1 } + def setSerializable(): Unit = + addAnnotation(AnnotationInfo(SerializableAttr.tpe, Nil, Nil)) + def setAnnotations(annots: List[AnnotationInfoBase]): this.type = { this.rawannots = annots this @@ -455,6 +458,7 @@ trait Symbols extends reflect.generic.Symbols { self: SymbolTable => } } + def isSerializable = hasAnnotation(SerializableAttr) def isDeprecated = hasAnnotation(DeprecatedAttr) def deprecationMessage = getAnnotation(DeprecatedAttr) flatMap { _.stringArg(0) } // !!! when annotation arguments are not literal strings, but any sort of diff --git a/src/compiler/scala/tools/nsc/transform/CleanUp.scala b/src/compiler/scala/tools/nsc/transform/CleanUp.scala index d613124013..387e7bc6db 100644 --- a/src/compiler/scala/tools/nsc/transform/CleanUp.scala +++ b/src/compiler/scala/tools/nsc/transform/CleanUp.scala @@ -31,12 +31,8 @@ abstract class CleanUp extends Transform with ast.TreeDSL { private var localTyper: analyzer.Typer = null - private lazy val serializableAnnotation = - AnnotationInfo(SerializableAttr.tpe, Nil, Nil) - private lazy val serialVersionUIDAnnotation = { - val attr = definitions.getClass("scala.SerialVersionUID") - AnnotationInfo(attr.tpe, List(Literal(Constant(0))), List()) - } + private lazy val serialVersionUIDAnnotation = + AnnotationInfo(SerialVersionUIDAttr.tpe, List(Literal(Constant(0))), List()) private object MethodDispatchType extends scala.Enumeration { val NO_CACHE, MONO_CACHE, POLY_CACHE = Value @@ -574,8 +570,8 @@ abstract class CleanUp extends Transform with ast.TreeDSL { if (settings.target.value == "jvm-1.5") { val sym = cdef.symbol // is this an anonymous function class? - if (sym.isAnonymousFunction && !sym.hasAnnotation(SerializableAttr)) { - sym addAnnotation serializableAnnotation + if (sym.isAnonymousFunction && !sym.isSerializable) { + sym.setSerializable() sym addAnnotation serialVersionUIDAnnotation } } diff --git a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala index 60374b3a55..52d79542c2 100644 --- a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala +++ b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala @@ -48,11 +48,6 @@ trait SyntheticMethods extends ast.TreeDSL { val localContext = if (reporter.hasErrors) context makeSilent false else context val localTyper = newTyper(localContext) - def hasImplementation(name: Name): Boolean = { - val sym = clazz.info member name // member and not nonPrivateMember: bug #1385 - sym.isTerm && !(sym hasFlag DEFERRED) - } - def hasOverridingImplementation(meth: Symbol): Boolean = { val sym = clazz.info nonPrivateMember meth.name sym.alternatives exists { sym => @@ -64,16 +59,30 @@ trait SyntheticMethods extends ast.TreeDSL { def syntheticMethod(name: Name, flags: Int, tpeCons: Symbol => Type) = newSyntheticMethod(name, flags | OVERRIDE, tpeCons) + /** Note: at this writing this is the only place the SYNTHETICMETH is ever created. + * The flag is commented "synthetic method, but without SYNTHETIC flag", an explanation + * for which I now attempt to reverse engineer the motivation. + * + * In general case classes/objects are not given synthetic equals methods if some + * non-AnyRef implementation is inherited. However if you let a case object inherit + * an implementation from a case class, it creates an asymmetric equals with all the + * associated badness: see ticket #883. So if it sees such a thing (which is marked + * SYNTHETICMETH) it re-overrides it with reference equality. + * + * In other words it only exists to support (deprecated) case class inheritance. + */ def newSyntheticMethod(name: Name, flags: Int, tpeCons: Symbol => Type) = { val method = clazz.newMethod(clazz.pos.focus, name) setFlag (flags | SYNTHETICMETH) method setInfo tpeCons(method) - clazz.info.decls.enter(method).asInstanceOf[TermSymbol] + clazz.info.decls.enter(method) } def makeNoArgConstructor(res: Type) = (sym: Symbol) => MethodType(Nil, res) def makeTypeConstructor(args: List[Type], res: Type) = (sym: Symbol) => MethodType(sym newSyntheticValueParams args, res) + def makeEqualityMethod(name: Name) = + syntheticMethod(name, 0, makeTypeConstructor(List(AnyClass.tpe), BooleanClass.tpe)) import CODE._ @@ -118,9 +127,7 @@ trait SyntheticMethods extends ast.TreeDSL { def forwardingMethod(name: Name, targetName: Name): Tree = { val target = getMember(ScalaRunTimeModule, targetName) val paramtypes = target.tpe.paramTypes drop 1 - val method = syntheticMethod( - name, 0, makeTypeConstructor(paramtypes, target.tpe.resultType) - ) + val method = syntheticMethod(name, 0, makeTypeConstructor(paramtypes, target.tpe.resultType)) typer typed { DEF(method) === { @@ -132,16 +139,12 @@ trait SyntheticMethods extends ast.TreeDSL { def hashCodeTarget: Name = nme.hashCode_ // if (settings.Yjenkins.value) "hashCodeJenkins" else nme.hashCode_ - def equalsSym = syntheticMethod( - nme.equals_, 0, makeTypeConstructor(List(AnyClass.tpe), BooleanClass.tpe) - ) - /** The equality method for case modules: * def equals(that: Any) = this eq that */ def equalsModuleMethod: Tree = localTyper typed { - val method = equalsSym - val that = method ARG 0 + val method = makeEqualityMethod(nme.equals_) + val that = method ARG 0 localTyper typed { DEF(method) === { @@ -155,7 +158,7 @@ trait SyntheticMethods extends ast.TreeDSL { * so as not to interfere. */ def canEqualMethod: Tree = { - val method = syntheticMethod(nme.canEqual_, 0, makeTypeConstructor(List(AnyClass.tpe), BooleanClass.tpe)) + val method = makeEqualityMethod(nme.canEqual_) val that = method ARG 0 typer typed (DEF(method) === (that IS_OBJ clazz.tpe)) @@ -173,8 +176,8 @@ trait SyntheticMethods extends ast.TreeDSL { * }) */ def equalsClassMethod: Tree = { - val method = equalsSym - val that = method ARG 0 + val method = makeEqualityMethod(nme.equals_) + val that = method ARG 0 val constrParamTypes = clazz.primaryConstructor.tpe.paramTypes // returns (Apply, Bind) @@ -184,9 +187,9 @@ trait SyntheticMethods extends ast.TreeDSL { val binding = if (isRepeated) Star(WILD()) else WILD() val eqMethod: Tree = if (isRepeated) gen.mkRuntimeCall(nme.sameElements, List(Ident(varName), Ident(acc))) - else (varName DOT nme.EQ)(Ident(acc)) + else (Ident(varName) DOT nme.EQ)(Ident(acc)) - (eqMethod, varName BIND binding) + (eqMethod, Bind(varName, binding)) } // Creates list of parameters and a guard for each @@ -201,7 +204,7 @@ trait SyntheticMethods extends ast.TreeDSL { } // Pattern is classname applied to parameters, and guards are all logical and-ed - val (guard, pat) = (AND(guards: _*), clazz.name.toTermName APPLY params) + val (guard, pat) = (AND(guards: _*), Ident(clazz.name.toTermName) APPLY params) localTyper typed { DEF(method) === { @@ -213,15 +216,6 @@ trait SyntheticMethods extends ast.TreeDSL { } } - def hasSerializableAnnotation(clazz: Symbol) = - clazz hasAnnotation SerializableAttr - - def readResolveMethod: Tree = { - // !!! the synthetic method "readResolve" should be private, but then it is renamed !!! - val method = newSyntheticMethod(nme.readResolve, PROTECTED, makeNoArgConstructor(ObjectClass.tpe)) - typer typed (DEF(method) === REF(clazz.sourceModule)) - } - def newAccessorMethod(tree: Tree): Tree = tree match { case DefDef(_, _, _, _, _, rhs) => var newAcc = tree.symbol.cloneSymbol @@ -234,20 +228,21 @@ trait SyntheticMethods extends ast.TreeDSL { result } + // A buffer collecting additional methods for the template body val ts = new ListBuffer[Tree] if (!phase.erasedTypes) try { - if (clazz hasFlag Flags.CASE) { - val isTop = !(clazz.ancestors exists (_ hasFlag Flags.CASE)) - // case classes are implicitly declared serializable - clazz addAnnotation AnnotationInfo(SerializableAttr.tpe, Nil, Nil) + if (clazz.isCase) { + val isTop = clazz.ancestors forall (x => !x.isCase) + // case classes are automatically marked serializable + clazz.setSerializable() if (isTop) { // If this case class has fields with less than public visibility, their getter at this // point also has those permissions. In that case we create a new, public accessor method // with a new name and remove the CASEACCESSOR flag from the existing getter. This complicates // the retrieval of the case field accessors (see def caseFieldAccessors in Symbols.) - def needsService(s: Symbol) = s.isMethod && (s hasFlag CASEACCESSOR) && !s.isPublic + def needsService(s: Symbol) = s.isMethod && s.isCaseAccessor && !s.isPublic for (stat <- templ.body ; if stat.isDef && needsService(stat.symbol)) { ts += newAccessorMethod(stat) stat.symbol resetFlag CASEACCESSOR @@ -291,10 +286,15 @@ trait SyntheticMethods extends ast.TreeDSL { } if (clazz.isModuleClass) { - if (!hasSerializableAnnotation(clazz)) { + if (!clazz.isSerializable) { val comp = companionClassOf(clazz, context) - if (comp.hasFlag(Flags.CASE) || hasSerializableAnnotation(comp)) - clazz addAnnotation AnnotationInfo(SerializableAttr.tpe, Nil, Nil) + if (comp.isCase || comp.isSerializable) + clazz.setSerializable() + } + + def hasReadResolve = { + val sym = clazz.info member nme.readResolve // any member, including private + sym.isTerm && !sym.isDeferred } /** If you serialize a singleton and then deserialize it twice, @@ -302,16 +302,20 @@ trait SyntheticMethods extends ast.TreeDSL { * the readResolve() method (see http://www.javaworld.com/javaworld/ * jw-04-2003/jw-0425-designpatterns_p.html) */ - if (hasSerializableAnnotation(clazz) && !hasImplementation(nme.readResolve)) - ts += readResolveMethod + if (clazz.isSerializable && !hasReadResolve) { + // PP: To this day I really can't figure out what this next comment is getting at: + // the !!! normally means there is something broken, but if so, what is it? + // + // !!! the synthetic method "readResolve" should be private, but then it is renamed !!! + val method = newSyntheticMethod(nme.readResolve, PROTECTED, makeNoArgConstructor(ObjectClass.tpe)) + ts += typer typed (DEF(method) === REF(clazz.sourceModule)) + } } } catch { case ex: TypeError => if (!reporter.hasErrors) throw ex } - val synthetics = ts.toList - treeCopy.Template( - templ, templ.parents, templ.self, if (synthetics.isEmpty) templ.body else templ.body ::: synthetics) + treeCopy.Template(templ, templ.parents, templ.self, templ.body ++ ts.toList) } } |