From d0452d00c91350dfcb829ad890181d826c9ab9c8 Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Thu, 25 Jun 2009 22:41:15 +0000 Subject: Initial checkin of a code generation DSL which ... Initial checkin of a code generation DSL which I felt like I had to undertake before I could begin the pattern matcher in good conscience. I'm pretty excited about the impact this could have on the understandability of the codebase - for me anyway. // actual code in this checkin DEF(method) === { (This(clazz) EQREF that) OR (that MATCH( (CASE(pat) IF guard) ==> TRUE , DEFAULT ==> FALSE )) } --- src/compiler/scala/tools/nsc/ast/TreeDSL.scala | 101 +++++++ src/compiler/scala/tools/nsc/symtab/Types.scala | 4 + .../tools/nsc/typechecker/SyntheticMethods.scala | 297 +++++++++++---------- 3 files changed, 256 insertions(+), 146 deletions(-) create mode 100644 src/compiler/scala/tools/nsc/ast/TreeDSL.scala diff --git a/src/compiler/scala/tools/nsc/ast/TreeDSL.scala b/src/compiler/scala/tools/nsc/ast/TreeDSL.scala new file mode 100644 index 0000000000..98a6919f27 --- /dev/null +++ b/src/compiler/scala/tools/nsc/ast/TreeDSL.scala @@ -0,0 +1,101 @@ +/* NSC -- new Scala compiler + * Copyright 2005-2009 LAMP/EPFL + * + * @author Paul Phillips + */ + +package scala.tools.nsc.ast + +/** A DSL for generating scala code. The goal is that the + * code generating code should look a lot like the code it + * generates. + */ + +trait TreeDSL { + val global: Global + + import global._ + import definitions._ + + object CODE { + def LIT(x: Any) = Literal(Constant(x)) + def TRUE = LIT(true) + def FALSE = LIT(false) + def WILD = Ident(nme.WILDCARD) + + class TreeMethods(target: Tree) { + private def binop(lhs: Tree, op: Name, rhs: Tree) = Apply(Select(lhs, op), List(rhs)) + private def toAnyRef(x: Tree) = x setType AnyRefClass.tpe + + case class ExpectApply(target: Tree) { + def apply(args: Tree*) = Apply(target, args.toList) + } + + /** logical/comparison ops **/ + def OR(other: Tree) = + if (target == EmptyTree) other + else if (other == EmptyTree) target + else gen.mkOr(target, other) + + def AND(other: Tree) = + if (target == EmptyTree) other + else if (other == EmptyTree) target + else gen.mkAnd(target, other) + + def EQREF(other: Tree) = binop(target, nme.eq, toAnyRef(other)) + def EQEQ(other: Tree) = binop(target, nme.EQ, other) + + /** Apply, Select, Match **/ + def APPLY(params: List[Tree]) = Apply(target, params) + def MATCH(cases: CaseDef*) = Match(target, cases.toList) + def DOT(member: Name) = ExpectApply(Select(target, member)) + def DOT(sym: Symbol) = ExpectApply(Select(target, sym)) + + /** Casting */ + def AS(tpe: Type) = TypeApply(Select(target, Any_asInstanceOf), List(TypeTree(tpe))) + def TOSTRING() = Select(target, nme.toString_) + } + + class CaseStart(pat: Tree, guard: Tree) { + def IF(g: Tree): CaseStart = new CaseStart(pat, g) + def ==>(body: Tree): CaseDef = CaseDef(pat, guard, body) // DSL for => + } + class DefStart(sym: Symbol) { + def ===(body: Tree) = DefDef(sym, body) // DSL for = + } + + def CASE(pat: Tree): CaseStart = new CaseStart(pat, EmptyTree) + def DEFAULT: CaseStart = new CaseStart(WILD, EmptyTree) + + class NameMethods(target: Name) { + def BIND(body: Tree) = Bind(target, body) + } + + class SymbolMethods(target: Symbol) { + def BIND(body: Tree) = Bind(target, body) + + // name of nth indexed argument to a method (first parameter list), defaults to 1st + def ARG(idx: Int = 0) = Ident(target.paramss.head(idx)) + def ARGS = target.paramss.head + def ARGNAMES = ARGS map Ident + } + + /** Top level accessible. */ + def THROW(sym: Symbol, msg: Tree) = Throw(New(TypeTree(sym.tpe), List(List(msg.TOSTRING)))) + def DEF(sym: Symbol) = new DefStart(sym) + def AND(guards: Tree*) = + if (guards.isEmpty) EmptyTree + else guards reduceLeft gen.mkAnd + + /** Implicits - some of these should probably disappear **/ + implicit def mkTreeMethods(target: Tree): TreeMethods = new TreeMethods(target) + implicit def mkTreeMethodsFromSymbol(target: Symbol): TreeMethods = new TreeMethods(Ident(target)) + implicit def mkTreeMethodsFromName(target: Name): TreeMethods = new TreeMethods(Ident(target)) + implicit def mkTreeMethodsFromString(target: String): TreeMethods = new TreeMethods(Ident(target)) + + implicit def mkNameMethodsFromName(target: Name): NameMethods = new NameMethods(target) + implicit def mkNameMethodsFromString(target: String): NameMethods = new NameMethods(target) + + implicit def mkSymbolMethodsFromSymbol(target: Symbol): SymbolMethods = new SymbolMethods(target) + } +} \ No newline at end of file diff --git a/src/compiler/scala/tools/nsc/symtab/Types.scala b/src/compiler/scala/tools/nsc/symtab/Types.scala index 1c3f9692e2..1b3b93eb2c 100644 --- a/src/compiler/scala/tools/nsc/symtab/Types.scala +++ b/src/compiler/scala/tools/nsc/symtab/Types.scala @@ -637,6 +637,10 @@ trait Types { */ def isComplete: Boolean = true + /** Is this type a varargs parameter? + */ + def isVarargs: Boolean = typeSymbol == RepeatedParamClass + /** If this is a lazy type, assign a new type to `sym'. */ def complete(sym: Symbol) {} diff --git a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala index d3170cfa88..b57f4b48f9 100644 --- a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala +++ b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala @@ -6,6 +6,7 @@ package scala.tools.nsc.typechecker +import symtab.Flags import symtab.Flags._ import scala.collection.mutable.ListBuffer @@ -26,10 +27,12 @@ import scala.collection.mutable.ListBuffer * * */ -trait SyntheticMethods { self: Analyzer => +trait SyntheticMethods extends ast.TreeDSL { + self: Analyzer => + import global._ // the global environment import definitions._ // standard classes and methods - //import global.typer.{typed} // methods to type trees + // @S: type hack: by default, we are used from global.analyzer context // so this cast won't fail. If we aren't in global.analyzer, we have // to override this method anyways. @@ -41,18 +44,18 @@ trait SyntheticMethods { self: Analyzer => * @param unit ... * @return ... */ - def addSyntheticMethods(templ: Template, clazz: Symbol, context: Context): Template = try { + def addSyntheticMethods(templ: Template, clazz: Symbol, context: Context): Template = { - val localContext = if (reporter.hasErrors) context.makeSilent(false) else context - val localTyper = newTyper(localContext) + val localContext = if (reporter.hasErrors) context makeSilent false else context + val localTyper = newTyper(localContext) def hasImplementation(name: Name): Boolean = { - val sym = clazz.info.member(name) // member and not nonPrivateMember: bug #1385 + val sym = clazz.info member name // member and not nonPrivateMember: bug #1385 sym.isTerm && !(sym hasFlag DEFERRED) } def hasOverridingImplementation(meth: Symbol): Boolean = { - val sym = clazz.info.nonPrivateMember(meth.name) + val sym = clazz.info nonPrivateMember meth.name sym.alternatives exists { sym => sym != meth && !(sym hasFlag DEFERRED) && !(sym hasFlag (SYNTHETIC | SYNTHETICMETH)) && (clazz.thisType.memberType(sym) matches clazz.thisType.memberType(meth)) @@ -63,79 +66,79 @@ trait SyntheticMethods { self: Analyzer => newSyntheticMethod(name, flags | OVERRIDE, tpeCons) def newSyntheticMethod(name: Name, flags: Int, tpeCons: Symbol => Type) = { - var method = clazz.newMethod(clazz.pos, name) - .setFlag(flags | SYNTHETICMETH) - method.setInfo(tpeCons(method)) - method = clazz.info.decls.enter(method).asInstanceOf[TermSymbol] - method + val method = clazz.newMethod(clazz.pos, name) setFlag (flags | SYNTHETICMETH) + method setInfo tpeCons(method) + clazz.info.decls.enter(method).asInstanceOf[TermSymbol] } - /* - def productSelectorMethod(n: int, accessor: Symbol): Tree = { - val method = syntheticMethod(newTermName("_"+n), FINAL, accessor.tpe) - typed(DefDef(method, vparamss => gen.mkAttributedRef(accessor))) - } - */ - def productPrefixMethod: Tree = { - val method = syntheticMethod(nme.productPrefix, 0, sym => PolyType(List(), StringClass.tpe)) - typer.typed(DefDef(method, Literal(Constant(clazz.name.decode)))) + def makeNoArgConstructor(res: Type) = + (sym: Symbol) => MethodType(Nil, res) + def makeTypeConstructor(args: List[Type], res: Type) = + (sym: Symbol) => MethodType(sym newSyntheticValueParams args, res) + + import CODE._ + + def productPrefixMethod: Tree = typer.typed { + val method = syntheticMethod(nme.productPrefix, 0, sym => PolyType(Nil, StringClass.tpe)) + DEF(method) === LIT(clazz.name.decode) } - def productArityMethod(nargs:Int ): Tree = { - val method = syntheticMethod(nme.productArity, 0, sym => PolyType(List(), IntClass.tpe)) - typer.typed(DefDef(method, Literal(Constant(nargs)))) + def productArityMethod(nargs: Int): Tree = { + val method = syntheticMethod(nme.productArity, 0, sym => PolyType(Nil, IntClass.tpe)) + typer typed { DEF(method) === LIT(nargs) } } def productElementMethod(accs: List[Symbol]): Tree = { - //val retTpe = lub(accs map (_.tpe.resultType)) - val method = syntheticMethod(nme.productElement, 0, - sym => MethodType(sym.newSyntheticValueParams(List(IntClass.tpe)), AnyClass.tpe/*retTpe*/)) - typer.typed(DefDef(method, Match(Ident(method.paramss.head.head), { - (for ((sym,i) <- accs.zipWithIndex) yield { - CaseDef(Literal(Constant(i)),EmptyTree, Ident(sym)) - }):::List(CaseDef(Ident(nme.WILDCARD), EmptyTree, - Throw(New(TypeTree(IndexOutOfBoundsExceptionClass.tpe), List(List( - Select(Ident(method.paramss.head.head), nme.toString_) - )))))) - }))) + val symToTpe = makeTypeConstructor(List(IntClass.tpe), AnyClass.tpe) + val method = syntheticMethod(nme.productElement, 0, symToTpe) + val arg = method ARG 0 + val default = List( DEFAULT ==> THROW(IndexOutOfBoundsExceptionClass, arg) ) + val cases = + for ((sym, i) <- accs.zipWithIndex) yield + CASE(LIT(i)) ==> Ident(sym) + + typer typed { + DEF(method) === { + arg MATCH { cases ::: default : _* } + } + } } def moduleToStringMethod: Tree = { - val method = syntheticMethod(nme.toString_, FINAL, sym => MethodType(List(), StringClass.tpe)) - typer.typed(DefDef(method, Literal(Constant(clazz.name.decode)))) + val method = syntheticMethod(nme.toString_, FINAL, makeNoArgConstructor(StringClass.tpe)) + typer typed { DEF(method) === LIT(clazz.name.decode) } } def forwardingMethod(name: Name): Tree = { - val target = getMember(ScalaRunTimeModule, "_" + name) - val paramtypes = - if (target.tpe.paramTypes.isEmpty) List() - else target.tpe.paramTypes.tail - val method = syntheticMethod( - name, 0, sym => MethodType(sym.newSyntheticValueParams(paramtypes), target.tpe.resultType)) - typer.typed(DefDef(method, - Apply(gen.mkAttributedRef(target), This(clazz) :: (method.paramss.head map Ident)))) + val target = getMember(ScalaRunTimeModule, "_" + name) + val paramtypes = target.tpe.paramTypes drop 1 + val method = syntheticMethod( + name, 0, makeTypeConstructor(paramtypes, target.tpe.resultType) + ) + + typer.typed { + DEF(method) === { + Apply(gen.mkAttributedRef(target), This(clazz) :: (method ARGNAMES)) + } + } } - def equalsSym = - syntheticMethod(nme.equals_, 0, - sym => MethodType(sym.newSyntheticValueParams(List(AnyClass.tpe)), BooleanClass.tpe)) + def equalsSym = syntheticMethod( + nme.equals_, 0, makeTypeConstructor(List(AnyClass.tpe), BooleanClass.tpe) + ) /** The equality method for case modules: * def equals(that: Any) = this eq that */ - def equalsModuleMethod: Tree = { + def equalsModuleMethod: Tree = localTyper typed { val method = equalsSym - val methodDef = - DefDef(method, - Apply( - Select(This(clazz), Object_eq), - List( - TypeApply( - Select( - Ident(method.paramss.head.head), - Any_asInstanceOf), - List(TypeTree(AnyRefClass.tpe)))))) - localTyper.typed(methodDef) + val that = method ARG 0 + + localTyper typed { + DEF(method) === { + (This(clazz) DOT Object_eq)(that AS AnyRefClass.tpe) + } + } } /** The equality method for case classes. The argument is an Any, @@ -151,120 +154,122 @@ trait SyntheticMethods { self: Analyzer => */ def equalsClassMethod: Tree = { val method = equalsSym - val methodDef = - DefDef( - method, { - val that = Ident(method.paramss.head.head) - val constrParamTypes = clazz.primaryConstructor.tpe.paramTypes - val hasVarArgs = !constrParamTypes.isEmpty && constrParamTypes.last.typeSymbol == RepeatedParamClass - val (pat, guard) = { - val guards = new ListBuffer[Tree] - val params = for ((acc, cpt) <- clazz.caseFieldAccessors zip constrParamTypes) yield { - val name = context.unit.fresh.newName(clazz.pos, acc.name+"$") - val isVarArg = cpt.typeSymbol == RepeatedParamClass - guards += Apply( - Select( - Ident(name), - if (isVarArg) nme.sameElements else nme.EQ), - List(Ident(acc))) - Bind(name, - if (isVarArg) Star(Ident(nme.WILDCARD)) - else Ident(nme.WILDCARD)) - } - ( Apply(Ident(clazz.name.toTermName), params), - if (guards.isEmpty) EmptyTree - else guards reduceLeft { (g1: Tree, g2: Tree) => - Apply(Select(g1, nme.AMPAMP), List(g2)) - } - ) - } - val eq_ = Apply(Select(This(clazz), nme.eq), List(that setType AnyRefClass.tpe)) - val match_ = Match(that, List( - CaseDef(pat, guard, Literal(Constant(true))), - CaseDef(Ident(nme.WILDCARD), EmptyTree, Literal(Constant(false))))) + val that = method ARG 0 + val constrParamTypes = clazz.primaryConstructor.tpe.paramTypes - gen.mkOr(eq_, match_) - } - ) - localTyper.typed(methodDef) + // returns (Apply, Bind) + def makeTrees(acc: Symbol, cpt: Type): (Tree, Bind) = { + val varName = context.unit.fresh.newName(clazz.pos, acc.name + "$") + val (eqMethod, binding) = + if (cpt.isVarargs) (nme.sameElements, Star(WILD)) + else (nme.EQ , WILD ) + + ((varName DOT eqMethod)(Ident(acc)), varName BIND binding) + } + + // Creates list of parameters and a guard for each + val (guards, params) = List.map2(clazz.caseFieldAccessors, constrParamTypes)(makeTrees) unzip + + // Pattern is classname applied to parameters, and guards are all logical and-ed + val (guard, pat) = (AND(guards : _*), clazz.name.toTermName APPLY params) + + localTyper typed { + DEF(method) === { + (This(clazz) EQREF that) OR (that MATCH( + (CASE(pat) IF guard) ==> TRUE , + DEFAULT ==> FALSE + )) + } + } } - def hasSerializableAnnotation(clazz: Symbol): Boolean = - clazz.hasAnnotation(definitions.SerializableAttr) + def hasSerializableAnnotation(clazz: Symbol) = + clazz hasAnnotation SerializableAttr + def isPublic(sym: Symbol) = + !sym.hasFlag(PRIVATE | PROTECTED) && sym.privateWithin == NoSymbol def readResolveMethod: Tree = { - // !!! the synthetic method "readResolve" should be private, - // but then it is renamed !!! - val method = newSyntheticMethod(nme.readResolve, PROTECTED, - sym => MethodType(List(), ObjectClass.tpe)) - typer.typed(DefDef(method, gen.mkAttributedRef(clazz.sourceModule))) + // !!! the synthetic method "readResolve" should be private, but then it is renamed !!! + val method = newSyntheticMethod(nme.readResolve, PROTECTED, makeNoArgConstructor(ObjectClass.tpe)) + typer typed { + DEF(method) === gen.mkAttributedRef(clazz.sourceModule) + } } def newAccessorMethod(tree: Tree): Tree = tree match { case DefDef(_, _, _, _, _, rhs) => var newAcc = tree.symbol.cloneSymbol newAcc.name = context.unit.fresh.newName(tree.symbol.pos, tree.symbol.name + "$") - newAcc.setFlag(SYNTHETIC).resetFlag(ACCESSOR | PARAMACCESSOR | PRIVATE) + newAcc setFlag SYNTHETIC resetFlag (ACCESSOR | PARAMACCESSOR | PRIVATE) newAcc = newAcc.owner.info.decls enter newAcc - val result = typer.typed(DefDef(newAcc, rhs.duplicate)) + val result = typer typed { DEF(newAcc) === rhs.duplicate } log("new accessor method " + result) result } val ts = new ListBuffer[Tree] - def isPublic(sym: Symbol) = - !sym.hasFlag(PRIVATE | PROTECTED) && sym.privateWithin == NoSymbol + if (!phase.erasedTypes) try { + if (clazz hasFlag Flags.CASE) { + val isTop = !(clazz.info.baseClasses.tail exists (_ hasFlag Flags.CASE)) + // case classes are implicitly declared serializable + clazz addAnnotation AnnotationInfo(SerializableAttr.tpe, Nil, Nil) - if (!phase.erasedTypes) { - try { - if (clazz hasFlag CASE) { - val isTop = !(clazz.info.baseClasses.tail exists (_ hasFlag CASE)) - // case classes are implicitly declared serializable - clazz.addAnnotation(AnnotationInfo(SerializableAttr.tpe, List(), List())) - - if (isTop) { - for (stat <- templ.body) { - if (stat.isDef && stat.symbol.isMethod && stat.symbol.hasFlag(CASEACCESSOR) && !isPublic(stat.symbol)) { - ts += newAccessorMethod(stat) - stat.symbol.resetFlag(CASEACCESSOR) - } + if (isTop) { + for (stat <- templ.body) { + if (stat.isDef && stat.symbol.isMethod && stat.symbol.hasFlag(CASEACCESSOR) && !isPublic(stat.symbol)) { + ts += newAccessorMethod(stat) + stat.symbol resetFlag CASEACCESSOR } } - if (clazz.isModuleClass) { - if (!hasOverridingImplementation(Object_toString)) ts += moduleToStringMethod - // if there's a synthetic method in a parent case class, override its equality - // with eq (see #883) - val otherEquals = clazz.info.nonPrivateMember(Object_equals.name) - if (otherEquals.owner != clazz && (otherEquals hasFlag SYNTHETICMETH)) ts += equalsModuleMethod - } else { - if (!hasOverridingImplementation(Object_hashCode)) ts += forwardingMethod(nme.hashCode_) - if (!hasOverridingImplementation(Object_toString)) ts += forwardingMethod(nme.toString_) - if (!hasOverridingImplementation(Object_equals)) ts += equalsClassMethod - } + } - if (!hasOverridingImplementation(Product_productPrefix)) ts += productPrefixMethod + // methods for case classes only + def classMethods = Map( + Object_hashCode -> (() => forwardingMethod(nme.hashCode_)), + Object_toString -> (() => forwardingMethod(nme.toString_)), + Object_equals -> (() => equalsClassMethod) + ) + // methods for case objects only + def objectMethods = Map( + Object_toString -> (() => moduleToStringMethod) + ) + // methods for both classes and objects + def everywhereMethods = { val accessors = clazz.caseFieldAccessors - if (!hasOverridingImplementation(Product_productArity)) - ts += productArityMethod(accessors.length) - if (!hasOverridingImplementation(Product_productElement)) - ts += productElementMethod(accessors) + Map( + Product_productPrefix -> (() => productPrefixMethod), + Product_productArity -> (() => productArityMethod(accessors.length)), + Product_productElement -> (() => productElementMethod(accessors)) + ) } - if (clazz.isModuleClass && hasSerializableAnnotation(clazz)) { - // If you serialize a singleton and then deserialize it twice, - // you will have two instances of your singleton, unless you implement - // the readResolve() method (see http://www.javaworld.com/javaworld/ - // jw-04-2003/jw-0425-designpatterns_p.html) - // question: should we do this for all serializable singletons, or (as currently done) - // only for those that carry a @serializable annotation? - if (!hasImplementation(nme.readResolve)) ts += readResolveMethod + if (clazz.isModuleClass) { + // if there's a synthetic method in a parent case class, override its equality + // with eq (see #883) + val otherEquals = clazz.info.nonPrivateMember(Object_equals.name) + if (otherEquals.owner != clazz && (otherEquals hasFlag SYNTHETICMETH)) ts += equalsModuleMethod } - } catch { - case ex: TypeError => - if (!reporter.hasErrors) throw ex + + val map = (if (clazz.isModuleClass) objectMethods else classMethods) ++ everywhereMethods + for ((m, impl) <- map ; if !hasOverridingImplementation(m)) + ts += impl() + } + + if (clazz.isModuleClass && hasSerializableAnnotation(clazz)) { + // If you serialize a singleton and then deserialize it twice, + // you will have two instances of your singleton, unless you implement + // the readResolve() method (see http://www.javaworld.com/javaworld/ + // jw-04-2003/jw-0425-designpatterns_p.html) + // question: should we do this for all serializable singletons, or (as currently done) + // only for those that carry a @serializable annotation? + if (!hasImplementation(nme.readResolve)) ts += readResolveMethod } + } catch { + case ex: TypeError => + if (!reporter.hasErrors) throw ex } + val synthetics = ts.toList treeCopy.Template( templ, templ.parents, templ.self, if (synthetics.isEmpty) templ.body else templ.body ::: synthetics) -- cgit v1.2.3