From f7f5b50848c8c0efd30abe17467d1ef443760ec5 Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Fri, 5 Aug 2011 17:29:28 +0000 Subject: Rewrote the case class synthetic equals method ... Rewrote the case class synthetic equals method to be more efficient and to cause fewer problems for compiler hackers who are always saying stuff like "the only place this comes up is case class equals..." No review. --- src/compiler/scala/tools/nsc/ast/TreeGen.scala | 15 +++- .../scala/tools/nsc/transform/UnCurry.scala | 13 +-- .../tools/nsc/typechecker/SyntheticMethods.scala | 92 ++++++++++------------ test/files/pos/polymorphic-case-class.flags | 1 + test/files/pos/polymorphic-case-class.scala | 2 + 5 files changed, 59 insertions(+), 64 deletions(-) create mode 100644 test/files/pos/polymorphic-case-class.flags create mode 100644 test/files/pos/polymorphic-case-class.scala diff --git a/src/compiler/scala/tools/nsc/ast/TreeGen.scala b/src/compiler/scala/tools/nsc/ast/TreeGen.scala index 10762bfa98..acca7dac7d 100644 --- a/src/compiler/scala/tools/nsc/ast/TreeGen.scala +++ b/src/compiler/scala/tools/nsc/ast/TreeGen.scala @@ -32,7 +32,20 @@ abstract class TreeGen extends reflect.internal.TreeGen { } // wrap the given expression in a SoftReference so it can be gc-ed - def mkSoftRef(expr: Tree): Tree = New(TypeTree(SoftReferenceClass.tpe), List(List(expr))) + def mkSoftRef(expr: Tree): Tree = atPos(expr.pos) { + New(SoftReferenceClass, expr) + } + // annotate the expression with @unchecked + def mkUnchecked(expr: Tree): Tree = atPos(expr.pos) { + // This can't be "Annotated(New(UncheckedClass), expr)" because annotations + // are very pick about things and it crashes the compiler with "unexpected new". + Annotated(New(scalaDot(UncheckedClass.name), List(Nil)), expr) + } + // if it's a Match, mark the selector unchecked; otherwise nothing. + def mkUncheckedMatch(tree: Tree) = tree match { + case Match(selector, cases) => atPos(tree.pos)(Match(mkUnchecked(selector), cases)) + case _ => tree + } def mkCached(cvar: Symbol, expr: Tree): Tree = { val cvarRef = mkUnattributedRef(cvar) diff --git a/src/compiler/scala/tools/nsc/transform/UnCurry.scala b/src/compiler/scala/tools/nsc/transform/UnCurry.scala index 7a0cde8630..1e7df19afe 100644 --- a/src/compiler/scala/tools/nsc/transform/UnCurry.scala +++ b/src/compiler/scala/tools/nsc/transform/UnCurry.scala @@ -263,17 +263,8 @@ abstract class UnCurry extends InfoTransform fun.vparams foreach (_.symbol.owner = applyMethod) new ChangeOwnerTraverser(fun.symbol, applyMethod) traverse fun.body - - def mkUnchecked(tree: Tree) = { - def newUnchecked(expr: Tree) = Annotated(New(gen.scalaDot(UncheckedClass.name), List(Nil)), expr) - tree match { - case Match(selector, cases) => atPos(tree.pos) { Match(newUnchecked(selector), cases) } - case _ => tree - } - } - def applyMethodDef() = { - val body = if (isPartial) mkUnchecked(fun.body) else fun.body + val body = if (isPartial) gen.mkUncheckedMatch(fun.body) else fun.body DefDef(Modifiers(FINAL), nme.apply, Nil, List(fun.vparams), TypeTree(restpe), body) setSymbol applyMethod } def isDefinedAtMethodDef() = { @@ -291,7 +282,7 @@ abstract class UnCurry extends InfoTransform substTree(CaseDef(cdef.pat.duplicate, cdef.guard.duplicate, Literal(Constant(true)))) def defaultCase = CaseDef(Ident(nme.WILDCARD), EmptyTree, Literal(Constant(false))) - DefDef(m, mkUnchecked( + DefDef(m, gen.mkUncheckedMatch( if (cases exists treeInfo.isDefaultCase) Literal(Constant(true)) else Match(substTree(selector.duplicate), (cases map transformCase) :+ defaultCase) )) diff --git a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala index 4972c6295a..9613866986 100644 --- a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala +++ b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala @@ -44,7 +44,7 @@ trait SyntheticMethods extends ast.TreeDSL { if (reporter.hasErrors) context makeSilent false else context ) def accessorTypes = clazz.caseFieldAccessors map (_.tpe.finalResultType) - def accessorLub = global.weakLub(accessorTypes)._1 + def accessorLub = global.weakLub(accessorTypes)._1 def hasOverridingImplementation(meth: Symbol): Boolean = { val sym = clazz.info nonPrivateMember meth.name @@ -144,71 +144,59 @@ trait SyntheticMethods extends ast.TreeDSL { newTypedMethod(method)(This(clazz) ANY_EQ (method ARG 0)) } + /** To avoid unchecked warnings on polymorphic classes. + */ + def clazzTypeToTest = clazz.tpe.normalize match { + case TypeRef(pre, sym, args) if args.nonEmpty => ExistentialType(sym.typeParams, clazz.tpe) + case tp => tp + } + /** The canEqual method for case classes. * def canEqual(that: Any) = that.isInstanceOf[This] */ def canEqualMethod: Tree = { val method = makeEqualityMethod(nme.canEqual_) - newTypedMethod(method)((method ARG 0) IS_OBJ clazz.tpe) + newTypedMethod(method)((method ARG 0) IS_OBJ clazzTypeToTest) } - /** The equality method for case classes. The argument is an Any, - * but because of boxing it will always be an Object, so a check - * is neither necessary nor useful before the cast. - * - * def equals(that: Any) = - * (this eq that.asInstanceOf[AnyRef]) || - * (that match { - * case x @ this.C(this.arg_1, ..., this.arg_n) => x canEqual this - * case _ => false - * }) + /** The equality method for case classes. + * 0 args: + * def equals(that: Any) = that.isInstanceOf[this.C] && that.asInstanceOf[this.C].canEqual(this) + * 1+ args: + * def equals(that: Any) = (this eq that.asInstanceOf[AnyRef]) || { + * (that.isInstanceOf[this.C]) && { + * val x$1 = that.asInstanceOf[this.C] + * (this.arg_1 == x$1.arg_1) && (this.arg_2 == x$1.arg_2) && ... && (x$1 canEqual this) + * } + * } */ def equalsClassMethod: Tree = { - val method = makeEqualityMethod(nme.equals_) - val that = method ARG 0 - val constrParamTypes = clazz.primaryConstructor.tpe.paramTypes - - // returns (Apply, Bind) - def makeTrees(acc: Symbol, cpt: Type): (Tree, Bind) = { - val varName = context.unit.freshTermName(acc.name + "$") - val isRepeated = isRepeatedParamType(cpt) - val binding = Bind( - varName, - if (isRepeated) Star(WILD.empty) else WILD.empty - ) - val eqMethod: Tree = - if (isRepeated) gen.mkRuntimeCall(nme.sameElements, List(Ident(varName), Ident(acc))) - else (Ident(varName) DOT nme.EQ)(Ident(acc)) - - (eqMethod, binding) - } - - // Creates list of parameters and a guard for each - val (guards, params) = (clazz.caseFieldAccessors, constrParamTypes).zipped map makeTrees unzip - - // Verify with canEqual method before returning true. - def canEqualCheck() = { - val that: Tree = (method ARG 0) AS clazz.tpe - val canEqualOther: Symbol = clazz.info nonPrivateMember nme.canEqual_ + val method = makeEqualityMethod(nme.equals_) + val that = Ident(method.paramss.head.head) + val typeTest = gen.mkIsInstanceOf(that, clazzTypeToTest, true, false) + val typeCast = that AS_ATTR clazz.tpe + + def argsBody = { + val valName = context.unit.freshTermName(clazz.name + "$") + val valSym = method.newValue(method.pos, valName) setInfo clazz.tpe setFlag SYNTHETIC + val valDef = ValDef(valSym, typeCast) + val canEqualCheck = gen.mkMethodCall(valSym, nme.canEqual_, Nil, List(This(clazz))) + val pairwise = clazz.caseFieldAccessors map { accessor => + (Select(This(clazz), accessor) MEMBER_== Select(Ident(valSym), accessor)) + } - typer typed { (that DOT canEqualOther)(This(clazz)) } + ( + (This(clazz) ANY_EQ that) OR + (typeTest AND Block(valDef, AND(pairwise :+ canEqualCheck: _*))) + ) } - // Pattern is classname applied to parameters, and guards are all logical and-ed - val guard = AND(guards: _*) - // You might think this next line could be written - // val pat = Ident(clazz.companionModule) APPLY params - // However this prevents the compiler from bootstrapping, - // possibly due to a bug in how case classes are handled when - // read from classfiles as opposed to from source. - val pat = Ident(clazz.name.toTermName) APPLY params - newTypedMethod(method) { - (This(clazz) ANY_EQ that) OR (that MATCH( - (CASE(pat) IF guard) ==> canEqualCheck() , - DEFAULT ==> FALSE - )) + if (clazz.caseFieldAccessors.isEmpty) + typeTest AND ((typeCast DOT nme.canEqual_)(This(clazz))) + else + argsBody } } diff --git a/test/files/pos/polymorphic-case-class.flags b/test/files/pos/polymorphic-case-class.flags new file mode 100644 index 0000000000..464cc20ea6 --- /dev/null +++ b/test/files/pos/polymorphic-case-class.flags @@ -0,0 +1 @@ +-Xfatal-warnings -unchecked \ No newline at end of file diff --git a/test/files/pos/polymorphic-case-class.scala b/test/files/pos/polymorphic-case-class.scala new file mode 100644 index 0000000000..5ed5eeddc9 --- /dev/null +++ b/test/files/pos/polymorphic-case-class.scala @@ -0,0 +1,2 @@ +// no unchecked warnings +case class Bippy[T, -U, +V](x: T, z: V) { } -- cgit v1.2.3