From d0c5ee4c031a126cee4c552a34cf732716568076 Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Sat, 28 Jan 2012 15:37:07 -0800 Subject: More method synthesis unification. --- .../scala/reflect/internal/Definitions.scala | 37 +++-- src/compiler/scala/reflect/internal/Symbols.scala | 12 ++ .../scala/tools/nsc/transform/Erasure.scala | 2 +- src/compiler/scala/tools/nsc/transform/Mixin.scala | 2 +- .../tools/nsc/typechecker/MethodSynthesis.scala | 133 ++++++++++++++++ .../tools/nsc/typechecker/SyntheticMethods.scala | 167 ++------------------- 6 files changed, 182 insertions(+), 171 deletions(-) (limited to 'src/compiler') diff --git a/src/compiler/scala/reflect/internal/Definitions.scala b/src/compiler/scala/reflect/internal/Definitions.scala index 19726b694d..8114be20d5 100644 --- a/src/compiler/scala/reflect/internal/Definitions.scala +++ b/src/compiler/scala/reflect/internal/Definitions.scala @@ -421,24 +421,27 @@ trait Definitions extends reflect.api.StandardDefinitions { * information into the toString method. */ def manifestToType(m: OptManifest[_]): Type = m match { - case x: AnyValManifest[_] => - getClassIfDefined("scala." + x).tpe case m: ClassManifest[_] => - val name = m.erasure.getName - if (name endsWith nme.MODULE_SUFFIX_STRING) - getModuleIfDefined(name stripSuffix nme.MODULE_SUFFIX_STRING).tpe - else { - val sym = getClassIfDefined(name) - val args = m.typeArguments - - if (sym eq NoSymbol) NoType - else if (args.isEmpty) sym.tpe - else appliedType(sym.typeConstructor, args map manifestToType) - } + val sym = manifestToSymbol(m) + val args = m.typeArguments + + if ((sym eq NoSymbol) || args.isEmpty) sym.tpe + else appliedType(sym.typeConstructor, args map manifestToType) case _ => NoType } + def manifestToSymbol(m: ClassManifest[_]): Symbol = m match { + case x: scala.reflect.AnyValManifest[_] => + getMember(ScalaPackageClass, newTypeName("" + x)) + case _ => + val name = m.erasure.getName + if (name endsWith nme.MODULE_SUFFIX_STRING) + getModuleIfDefined(name stripSuffix nme.MODULE_SUFFIX_STRING) + else + getClassIfDefined(name) + } + // The given symbol represents either String.+ or StringAdd.+ def isStringAddition(sym: Symbol) = sym == String_+ || sym == StringAdd_+ def isArrowAssoc(sym: Symbol) = ArrowAssocClass.tpe.decls.toList contains sym @@ -586,6 +589,14 @@ trait Definitions extends reflect.api.StandardDefinitions { case _ => NoType } + /** To avoid unchecked warnings on polymorphic classes, translate + * a Foo[T] into a Foo[_] for use in the pattern matcher. + */ + def typeCaseType(clazz: Symbol) = clazz.tpe.normalize match { + case TypeRef(_, sym, args) if args.nonEmpty => newExistentialType(sym.typeParams, clazz.tpe) + case tp => tp + } + def seqType(arg: Type) = appliedType(SeqClass.typeConstructor, List(arg)) def arrayType(arg: Type) = appliedType(ArrayClass.typeConstructor, List(arg)) def byNameType(arg: Type) = appliedType(ByNameParamClass.typeConstructor, List(arg)) diff --git a/src/compiler/scala/reflect/internal/Symbols.scala b/src/compiler/scala/reflect/internal/Symbols.scala index e3355345f0..b3df2b0498 100644 --- a/src/compiler/scala/reflect/internal/Symbols.scala +++ b/src/compiler/scala/reflect/internal/Symbols.scala @@ -1832,6 +1832,18 @@ trait Symbols extends api.Symbols { self: SymbolTable => } } + /** Remove any access boundary and clear flags PROTECTED | PRIVATE. + */ + def makePublic = this setPrivateWithin NoSymbol resetFlag AccessFlags + + /** The first parameter to the first argument list of this method, + * or NoSymbol if inapplicable. + */ + def firstParam = info.params match { + case p :: _ => p + case _ => NoSymbol + } + /** change name by appending $$ * Do the same for any accessed symbols or setters/getters */ diff --git a/src/compiler/scala/tools/nsc/transform/Erasure.scala b/src/compiler/scala/tools/nsc/transform/Erasure.scala index b342b95742..fe479a5375 100644 --- a/src/compiler/scala/tools/nsc/transform/Erasure.scala +++ b/src/compiler/scala/tools/nsc/transform/Erasure.scala @@ -797,7 +797,7 @@ abstract class Erasure extends AddInterfaces // && (bridge.paramss.nonEmpty && bridge.paramss.head.nonEmpty && bridge.paramss.head.tail.isEmpty) // does the first argument list has exactly one argument -- for user-defined unapplies we can't be sure && !(atPhase(phase.next)(member.tpe <:< other.tpe))) { // no static guarantees (TODO: is the subtype test ever true?) import CODE._ - val typeTest = gen.mkIsInstanceOf(REF(bridge.paramss.head.head), member.tpe.params.head.tpe, any = true, wrapInApply = true) // any = true since we're before erasure (?), wrapInapply is true since we're after uncurry + val typeTest = gen.mkIsInstanceOf(REF(bridge.firstParam), member.tpe.params.head.tpe, any = true, wrapInApply = true) // any = true since we're before erasure (?), wrapInapply is true since we're after uncurry // println("unapp type test: "+ typeTest) IF (typeTest) THEN bridgingCall ELSE REF(NoneModule) } else bridgingCall diff --git a/src/compiler/scala/tools/nsc/transform/Mixin.scala b/src/compiler/scala/tools/nsc/transform/Mixin.scala index bd29336703..b3b7596f9a 100644 --- a/src/compiler/scala/tools/nsc/transform/Mixin.scala +++ b/src/compiler/scala/tools/nsc/transform/Mixin.scala @@ -1053,7 +1053,7 @@ abstract class Mixin extends InfoTransform with ast.TreeDSL { else accessedRef match { case Literal(_) => accessedRef case _ => - val init = Assign(accessedRef, Ident(sym.paramss.head.head)) + val init = Assign(accessedRef, Ident(sym.firstParam)) val getter = sym.getter(clazz) if (!needsInitFlag(getter)) init diff --git a/src/compiler/scala/tools/nsc/typechecker/MethodSynthesis.scala b/src/compiler/scala/tools/nsc/typechecker/MethodSynthesis.scala index c6ca9870c3..0c32ff32c0 100644 --- a/src/compiler/scala/tools/nsc/typechecker/MethodSynthesis.scala +++ b/src/compiler/scala/tools/nsc/typechecker/MethodSynthesis.scala @@ -17,6 +17,139 @@ trait MethodSynthesis { import global._ import definitions._ + import CODE._ + + object synthesisUtil { + type M[T] = Manifest[T] + type CM[T] = ClassManifest[T] + + def ValOrDefDef(sym: Symbol, body: Tree) = + if (sym.isLazy) ValDef(sym, body) + else DefDef(sym, body) + + def applyTypeInternal(manifests: List[M[_]]): Type = { + val symbols = manifests map manifestToSymbol + val container :: args = symbols + val tparams = container.typeConstructor.typeParams + + // Conservative at present - if manifests were more usable this could do a lot more. + require(symbols forall (_ ne NoSymbol), "Must find all manifests: " + symbols) + require(container.owner.isPackageClass, "Container must be a top-level class in a package: " + container) + require(tparams.size == args.size, "Arguments must match type constructor arity: " + tparams + ", " + args) + + typeRef(container.typeConstructor.prefix, container, args map (_.tpe)) + } + + def companionType[T](implicit m: M[T]) = + getRequiredModule(m.erasure.getName).tpe + + // Use these like `applyType[List, Int]` or `applyType[Map, Int, String]` + def applyType[CC](implicit m1: M[CC]): Type = + applyTypeInternal(List(m1)) + + def applyType[CC[X1], X1](implicit m1: M[CC[_]], m2: M[X1]): Type = + applyTypeInternal(List(m1, m2)) + + def applyType[CC[X1, X2], X1, X2](implicit m1: M[CC[_,_]], m2: M[X1], m3: M[X2]): Type = + applyTypeInternal(List(m1, m2, m3)) + + def applyType[CC[X1, X2, X3], X1, X2, X3](implicit m1: M[CC[_,_,_]], m2: M[X1], m3: M[X2], m4: M[X3]): Type = + applyTypeInternal(List(m1, m2, m3, m4)) + + def newMethodType[F](owner: Symbol)(implicit m: Manifest[F]): Type = { + val fnSymbol = manifestToSymbol(m) + assert(fnSymbol isSubClass FunctionClass(m.typeArguments.size - 1), (owner, m)) + val symbols = m.typeArguments map (m => manifestToSymbol(m)) + val formals = symbols.init map (_.typeConstructor) + val params = owner newSyntheticValueParams formals + + MethodType(params, symbols.last.typeConstructor) + } + } + import synthesisUtil._ + + class ClassMethodSynthesis(val clazz: Symbol, localTyper: Typer) { + private def isOverride(name: TermName) = + clazzMember(name).alternatives exists (sym => !sym.isDeferred && (sym.owner != clazz)) + + def newMethodFlags(name: TermName) = { + val overrideFlag = if (isOverride(name)) OVERRIDE else 0L + overrideFlag | SYNTHETIC + } + def newMethodFlags(method: Symbol) = { + val overrideFlag = if (isOverride(method.name)) OVERRIDE else 0L + (method.flags | overrideFlag | SYNTHETIC) & ~DEFERRED + } + + private def finishMethod(method: Symbol, f: Symbol => Tree): Tree = + logResult("finishMethod")(localTyper typed ValOrDefDef(method, f(method))) + + private def createInternal(name: Name, f: Symbol => Tree, info: Type): Tree = { + val m = clazz.newMethod(name.toTermName, clazz.pos.focus, newMethodFlags(name)) + finishMethod(m setInfoAndEnter info, f) + } + private def createInternal(name: Name, f: Symbol => Tree, infoFn: Symbol => Type): Tree = { + val m = clazz.newMethod(name.toTermName, clazz.pos.focus, newMethodFlags(name)) + finishMethod(m setInfoAndEnter infoFn(m), f) + } + private def cloneInternal(original: Symbol, f: Symbol => Tree, name: Name): Tree = { + val m = original.cloneSymbol(clazz, newMethodFlags(original)) setPos clazz.pos.focus + m.name = name + finishMethod(clazz.info.decls enter m, f) + } + + private def cloneInternal(original: Symbol, f: Symbol => Tree): Tree = + cloneInternal(original, f, original.name) + + def clazzMember(name: Name) = clazz.info nonPrivateMember name + def typeInClazz(sym: Symbol) = clazz.thisType memberType sym + + /** Function argument takes the newly created method symbol of + * the same type as `name` in clazz, and returns the tree to be + * added to the template. + */ + def overrideMethod(name: Name)(f: Symbol => Tree): Tree = + overrideMethod(clazzMember(name))(f) + + def overrideMethod(original: Symbol)(f: Symbol => Tree): Tree = + cloneInternal(original, sym => f(sym setFlag OVERRIDE)) + + def deriveMethod(original: Symbol, nameFn: Name => Name)(f: Symbol => Tree): Tree = + cloneInternal(original, f, nameFn(original.name)) + + def createMethod(name: Name, paramTypes: List[Type], returnType: Type)(f: Symbol => Tree): Tree = + createInternal(name, f, (m: Symbol) => MethodType(m newSyntheticValueParams paramTypes, returnType)) + + def createMethod(name: Name, returnType: Type)(f: Symbol => Tree): Tree = + createInternal(name, f, NullaryMethodType(returnType)) + + def createMethod(original: Symbol)(f: Symbol => Tree): Tree = + createInternal(original.name, f, original.info) + + def forwardMethod(original: Symbol, newMethod: Symbol)(transformArgs: List[Tree] => List[Tree]): Tree = + createMethod(original)(m => gen.mkMethodCall(newMethod, transformArgs(m.paramss.head map Ident))) + + def createSwitchMethod(name: Name, range: Seq[Int], returnType: Type)(f: Int => Tree) = { + createMethod(name, List(IntClass.tpe), returnType) { m => + val arg0 = Ident(m.firstParam) + val default = DEFAULT ==> THROW(IndexOutOfBoundsExceptionClass, arg0) + val cases = range.map(num => CASE(LIT(num)) ==> f(num)).toList :+ default + + Match(arg0, cases) + } + } + + // def foo() = constant + def constantMethod(name: Name, value: Any): Tree = { + val constant = Constant(value) + createMethod(name, Nil, constant.tpe)(_ => Literal(constant)) + } + // def foo = constant + def constantNullary(name: Name, value: Any): Tree = { + val constant = Constant(value) + createMethod(name, constant.tpe)(_ => Literal(constant)) + } + } /** There are two key methods in here. * diff --git a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala index 4e986dc5aa..4ea21b1c44 100644 --- a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala +++ b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala @@ -36,158 +36,13 @@ trait SyntheticMethods extends ast.TreeDSL { import definitions._ import CODE._ - private object util { - private type CM[T] = ClassManifest[T] - - def ValOrDefDef(sym: Symbol, body: Tree) = - if (sym.isLazy) ValDef(sym, body) - else DefDef(sym, body) - - /** To avoid unchecked warnings on polymorphic classes. - */ - def clazzTypeToTest(clazz: Symbol) = clazz.tpe.normalize match { - case TypeRef(_, sym, args) if args.nonEmpty => newExistentialType(sym.typeParams, clazz.tpe) - case tp => tp - } - - def makeMethodPublic(method: Symbol): Symbol = ( - method setPrivateWithin NoSymbol resetFlag AccessFlags - ) - - def methodArg(method: Symbol, idx: Int): Tree = Ident(method.paramss.head(idx)) - - private def applyTypeInternal(manifests: List[CM[_]]): Type = { - val symbols = manifests map manifestToSymbol - val container :: args = symbols - val tparams = container.typeConstructor.typeParams - - // Overly conservative at present - if manifests were more usable - // this could do a lot more. - require(symbols forall (_ ne NoSymbol), "Must find all manifests: " + symbols) - require(container.owner.isPackageClass, "Container must be a top-level class in a package: " + container) - require(tparams.size == args.size, "Arguments must match type constructor arity: " + tparams + ", " + args) - require(args forall (_.typeConstructor.typeParams.isEmpty), "Arguments must be unparameterized: " + args) - - typeRef(container.typeConstructor.prefix, container, args map (_.tpe)) - } - - def manifestToSymbol(m: CM[_]): Symbol = m match { - case x: scala.reflect.AnyValManifest[_] => getMember(ScalaPackageClass, newTermName("" + x)) - case _ => getClassIfDefined(m.erasure.getName) - } - def companionType[T](implicit m: CM[T]) = - getRequiredModule(m.erasure.getName).tpe - - // Use these like `applyType[List, Int]` or `applyType[Map, Int, String]` - def applyType[M](implicit m1: CM[M]): Type = - applyTypeInternal(List(m1)) - - def applyType[M[X1], X1](implicit m1: CM[M[_]], m2: CM[X1]): Type = - applyTypeInternal(List(m1, m2)) - - def applyType[M[X1, X2], X1, X2](implicit m1: CM[M[_,_]], m2: CM[X1], m3: CM[X2]): Type = - applyTypeInternal(List(m1, m2, m3)) - - def applyType[M[X1, X2, X3], X1, X2, X3](implicit m1: CM[M[_,_,_]], m2: CM[X1], m3: CM[X2], m4: CM[X3]): Type = - applyTypeInternal(List(m1, m2, m3, m4)) - } - import util._ - - class MethodSynthesis(val clazz: Symbol, localTyper: Typer) { - private def isOverride(method: Symbol) = - clazzMember(method.name).alternatives exists (sym => (sym != method) && !sym.isDeferred) - - private def setMethodFlags(method: Symbol): Symbol = { - val overrideFlag = if (isOverride(method)) OVERRIDE else 0L - - method setFlag (overrideFlag | SYNTHETIC) resetFlag DEFERRED - } - - private def finishMethod(method: Symbol, f: Symbol => Tree): Tree = { - setMethodFlags(method) - clazz.info.decls enter method - logResult("finishMethod")(localTyper typed ValOrDefDef(method, f(method))) - } - - private def createInternal(name: Name, f: Symbol => Tree, info: Type): Tree = { - val m = clazz.newMethod(name.toTermName, clazz.pos.focus) - m setInfo info - finishMethod(m, f) - } - private def createInternal(name: Name, f: Symbol => Tree, infoFn: Symbol => Type): Tree = { - val m = clazz.newMethod(name.toTermName, clazz.pos.focus) - m setInfo infoFn(m) - finishMethod(m, f) - } - private def cloneInternal(original: Symbol, f: Symbol => Tree, name: Name): Tree = { - val m = original.cloneSymbol(clazz) setPos clazz.pos.focus - m.name = name - finishMethod(m, f) - } - - private def cloneInternal(original: Symbol, f: Symbol => Tree): Tree = - cloneInternal(original, f, original.name) - - def clazzMember(name: Name) = clazz.info nonPrivateMember name match { - case NoSymbol => log("In " + clazz + ", " + name + " not found: " + clazz.info) ; NoSymbol - case sym => sym - } - def typeInClazz(sym: Symbol) = clazz.thisType memberType sym - - /** Function argument takes the newly created method symbol of - * the same type as `name` in clazz, and returns the tree to be - * added to the template. - */ - def overrideMethod(name: Name)(f: Symbol => Tree): Tree = - overrideMethod(clazzMember(name))(f) - - def overrideMethod(original: Symbol)(f: Symbol => Tree): Tree = - cloneInternal(original, sym => f(sym setFlag OVERRIDE)) - - def deriveMethod(original: Symbol, nameFn: Name => Name)(f: Symbol => Tree): Tree = - cloneInternal(original, f, nameFn(original.name)) - - def createMethod(name: Name, paramTypes: List[Type], returnType: Type)(f: Symbol => Tree): Tree = - createInternal(name, f, (m: Symbol) => MethodType(m newSyntheticValueParams paramTypes, returnType)) - - def createMethod(name: Name, returnType: Type)(f: Symbol => Tree): Tree = - createInternal(name, f, NullaryMethodType(returnType)) - - def createMethod(original: Symbol)(f: Symbol => Tree): Tree = - createInternal(original.name, f, original.info) - - def forwardMethod(original: Symbol, newMethod: Symbol)(transformArgs: List[Tree] => List[Tree]): Tree = - createMethod(original)(m => gen.mkMethodCall(newMethod, transformArgs(m.paramss.head map Ident))) - - def createSwitchMethod(name: Name, range: Seq[Int], returnType: Type)(f: Int => Tree) = { - createMethod(name, List(IntClass.tpe), returnType) { m => - val arg0 = methodArg(m, 0) - val default = DEFAULT ==> THROW(IndexOutOfBoundsExceptionClass, arg0) - val cases = range.map(num => CASE(LIT(num)) ==> f(num)).toList :+ default - - Match(arg0, cases) - } - } - - // def foo() = constant - def constantMethod(name: Name, value: Any): Tree = { - val constant = Constant(value) - createMethod(name, Nil, constant.tpe)(_ => Literal(constant)) - } - // def foo = constant - def constantNullary(name: Name, value: Any): Tree = { - val constant = Constant(value) - createMethod(name, constant.tpe)(_ => Literal(constant)) - } - } - /** Add the synthetic methods to case classes. */ def addSyntheticMethods(templ: Template, clazz0: Symbol, context: Context): Template = { if (phase.erasedTypes) return templ - val synthesizer = new MethodSynthesis( + val synthesizer = new ClassMethodSynthesis( clazz0, newTyper( if (reporter.hasErrors) context makeSilent false else context ) ) @@ -212,11 +67,12 @@ trait SyntheticMethods extends ast.TreeDSL { // like Manifests and Arrays which are not robust and infer things // which they shouldn't. val accessorLub = ( - if (opt.experimental) + if (opt.experimental) { global.weakLub(accessors map (_.tpe.finalResultType))._1 match { case RefinedType(parents, decls) if !decls.isEmpty => intersectionType(parents) case tp => tp } + } else AnyClass.tpe ) @@ -258,11 +114,10 @@ trait SyntheticMethods extends ast.TreeDSL { /** The canEqual method for case classes. * def canEqual(that: Any) = that.isInstanceOf[This] */ - def canEqualMethod: Tree = { - createMethod(nme.canEqual_, List(AnyClass.tpe), BooleanClass.tpe)(m => - methodArg(m, 0) IS_OBJ clazzTypeToTest(clazz) - ) - } + def canEqualMethod: Tree = ( + createMethod(nme.canEqual_, List(AnyClass.tpe), BooleanClass.tpe)(m => + Ident(m.firstParam) IS_OBJ typeCaseType(clazz)) + ) /** The equality method for case classes. * 0 args: @@ -276,8 +131,8 @@ trait SyntheticMethods extends ast.TreeDSL { * } */ def equalsClassMethod: Tree = createMethod(nme.equals_, List(AnyClass.tpe), BooleanClass.tpe) { m => - val arg0 = methodArg(m, 0) - val thatTest = gen.mkIsInstanceOf(arg0, clazzTypeToTest(clazz), true, false) + val arg0 = Ident(m.firstParam) + val thatTest = gen.mkIsInstanceOf(arg0, typeCaseType(clazz), true, false) val thatCast = gen.mkCast(arg0, clazz.tpe) def argsBody: Tree = { @@ -331,7 +186,7 @@ trait SyntheticMethods extends ast.TreeDSL { Object_hashCode -> (() => constantMethod(nme.hashCode_, clazz.name.decode.hashCode)), Object_toString -> (() => constantMethod(nme.toString_, clazz.name.decode)) // Not needed, as reference equality is the default. - // Object_equals -> (() => createMethod(Object_equals)(m => This(clazz) ANY_EQ methodArg(m, 0))) + // Object_equals -> (() => createMethod(Object_equals)(m => This(clazz) ANY_EQ Ident(m.firstParam))) ) /** If you serialize a singleton and then deserialize it twice, @@ -381,7 +236,7 @@ trait SyntheticMethods extends ast.TreeDSL { for (ddef @ DefDef(_, _, _, _, _, _) <- templ.body ; if isRewrite(ddef.symbol)) { val original = ddef.symbol val newAcc = deriveMethod(ddef.symbol, name => context.unit.freshTermName(name + "$")) { newAcc => - makeMethodPublic(newAcc) + newAcc.makePublic newAcc resetFlag (ACCESSOR | PARAMACCESSOR) ddef.rhs.duplicate } -- cgit v1.2.3