diff options
-rw-r--r-- | src/compiler/scala/tools/nsc/symtab/Definitions.scala | 2 | ||||
-rw-r--r-- | src/compiler/scala/tools/nsc/symtab/StdNames.scala | 1 | ||||
-rw-r--r-- | src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala | 32 | ||||
-rw-r--r-- | src/library/scala/Product.scala | 6 | ||||
-rw-r--r-- | test/files/run/canEqualCaseClasses.scala | 23 | ||||
-rw-r--r-- | test/files/scalap/caseClass/result.test | 1 | ||||
-rw-r--r-- | test/files/scalap/caseObject/result.test | 1 |
7 files changed, 58 insertions, 8 deletions
diff --git a/src/compiler/scala/tools/nsc/symtab/Definitions.scala b/src/compiler/scala/tools/nsc/symtab/Definitions.scala index cec92d82c7..6f51638045 100644 --- a/src/compiler/scala/tools/nsc/symtab/Definitions.scala +++ b/src/compiler/scala/tools/nsc/symtab/Definitions.scala @@ -190,6 +190,8 @@ trait Definitions { def Product_productArity = getMember(ProductRootClass, nme.productArity) def Product_productElement = getMember(ProductRootClass, nme.productElement) def Product_productPrefix = getMember(ProductRootClass, nme.productPrefix) + def Product_canEqual = getMember(ProductRootClass, nme.canEqual_) + val MaxProductArity = 22 /* <unapply> */ val ProductClass: Array[Symbol] = new Array(MaxProductArity + 1) diff --git a/src/compiler/scala/tools/nsc/symtab/StdNames.scala b/src/compiler/scala/tools/nsc/symtab/StdNames.scala index cc150d5809..65445b60d4 100644 --- a/src/compiler/scala/tools/nsc/symtab/StdNames.scala +++ b/src/compiler/scala/tools/nsc/symtab/StdNames.scala @@ -266,6 +266,7 @@ trait StdNames { val box = newTermName("box") val boxArray = newTermName("boxArray") val forceBoxedArray = newTermName("forceBoxedArray") + val canEqual_ = newTermName("canEqual") val checkInitialized = newTermName("checkInitialized") val classOf = newTermName("classOf") val copy = newTermName("copy") diff --git a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala index b48aed76ab..2fdd049bc5 100644 --- a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala +++ b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala @@ -118,7 +118,7 @@ trait SyntheticMethods extends ast.TreeDSL { typer typed { DEF(method) === { - Apply(gen.mkAttributedRef(target), This(clazz) :: (method ARGNAMES)) + Apply(REF(target), This(clazz) :: (method ARGNAMES)) } } } @@ -141,6 +141,15 @@ trait SyntheticMethods extends ast.TreeDSL { } } + /** The canEqual method for case classes. + */ + def canEqualMethod: Tree = { + val method = syntheticMethod(nme.canEqual_, 0, makeTypeConstructor(List(AnyClass.tpe), BooleanClass.tpe)) + val that = method ARG 0 + + typer typed (DEF(method) === (that IS_OBJ clazz.tpe)) + } + /** The equality method for case classes. The argument is an Any, * but because of boxing it will always be an Object, so a check * is neither necessary nor useful before the cast. @@ -148,8 +157,8 @@ trait SyntheticMethods extends ast.TreeDSL { * def equals(that: Any) = * (this eq that.asInstanceOf[AnyRef]) || * (that match { - * case this.C(this.arg_1, ..., this.arg_n) => true - * case _ => false + * case x @ this.C(this.arg_1, ..., this.arg_n) => x canEqual this + * case _ => false * }) */ def equalsClassMethod: Tree = { @@ -170,13 +179,21 @@ trait SyntheticMethods extends ast.TreeDSL { // Creates list of parameters and a guard for each val (guards, params) = List.map2(clazz.caseFieldAccessors, constrParamTypes)(makeTrees) unzip + // Verify with canEqual method before returning true. + def canEqualCheck() = { + val that: Tree = typer typed ((method ARG 0) AS clazz.tpe) + val canEqualOther: Symbol = clazz.info nonPrivateMember nme.canEqual_ + + (that DOT canEqualOther)(This(clazz)) + } + // Pattern is classname applied to parameters, and guards are all logical and-ed val (guard, pat) = (AND(guards: _*), clazz.name.toTermName APPLY params) localTyper typed { DEF(method) === { (This(clazz) ANY_EQ that) OR (that MATCH( - (CASE(pat) IF guard) ==> TRUE , + (CASE(pat) IF guard) ==> canEqualCheck() , DEFAULT ==> FALSE )) } @@ -191,9 +208,7 @@ trait SyntheticMethods extends ast.TreeDSL { def readResolveMethod: Tree = { // !!! the synthetic method "readResolve" should be private, but then it is renamed !!! val method = newSyntheticMethod(nme.readResolve, PROTECTED, makeNoArgConstructor(ObjectClass.tpe)) - typer typed { - DEF(method) === gen.mkAttributedRef(clazz.sourceModule) - } + typer typed (DEF(method) === REF(clazz.sourceModule)) } def newAccessorMethod(tree: Tree): Tree = tree match { @@ -240,7 +255,8 @@ trait SyntheticMethods extends ast.TreeDSL { List( Product_productPrefix -> (() => productPrefixMethod), Product_productArity -> (() => productArityMethod(accessors.length)), - Product_productElement -> (() => productElementMethod(accessors)) + Product_productElement -> (() => productElementMethod(accessors)), + Product_canEqual -> (() => canEqualMethod) ) } diff --git a/src/library/scala/Product.scala b/src/library/scala/Product.scala index 2e0a0b2816..bae88f52d7 100644 --- a/src/library/scala/Product.scala +++ b/src/library/scala/Product.scala @@ -49,4 +49,10 @@ trait Product extends AnyRef { */ def productPrefix = "" + /** + * An equality helper method to assist in maintaining reflexivity + * in the face of subtyping. For more, see + * http://www.artima.com/lejava/articles/equality.html + */ + def canEqual(other: Any) = true } diff --git a/test/files/run/canEqualCaseClasses.scala b/test/files/run/canEqualCaseClasses.scala new file mode 100644 index 0000000000..b1c3eb91de --- /dev/null +++ b/test/files/run/canEqualCaseClasses.scala @@ -0,0 +1,23 @@ +object Test { + case class Foo(x: Int) + case class Bar(y: Int) extends Foo(y) + case class Nutty() { + override def canEqual(other: Any) = true + } + + def assertEqual(x: AnyRef, y: AnyRef) = + assert((x == y) && (y == x) && (x.hashCode == y.hashCode)) + + def assertUnequal(x: AnyRef, y: AnyRef) = + assert((x != y) && (y != x)) + + def main(args: Array[String]): Unit = { + assertEqual(Foo(5), Foo(5)) + assertEqual(Bar(5), Bar(5)) + assertUnequal(Foo(5), Bar(5)) + + // in case there's an overriding implementation + assert(Nutty() canEqual (new AnyRef)) + assert(!(Foo(5) canEqual (new AnyRef))) + } +} diff --git a/test/files/scalap/caseClass/result.test b/test/files/scalap/caseClass/result.test index 3daa036d18..be64349cb1 100644 --- a/test/files/scalap/caseClass/result.test +++ b/test/files/scalap/caseClass/result.test @@ -11,4 +11,5 @@ case class CaseClass[A >: scala.Nothing <: scala.Seq[scala.Int]] extends java.la override def productPrefix : java.lang.String = { /* compiled code */ } override def productArity : scala.Int = { /* compiled code */ } override def productElement(x$1 : scala.Int) : scala.Any = { /* compiled code */ } + override def canEqual(x$1 : scala.Any) : scala.Boolean = { /* compiled code */ } }
\ No newline at end of file diff --git a/test/files/scalap/caseObject/result.test b/test/files/scalap/caseObject/result.test index 2097c5a71d..f81e6ce3cb 100644 --- a/test/files/scalap/caseObject/result.test +++ b/test/files/scalap/caseObject/result.test @@ -4,5 +4,6 @@ case object CaseObject extends java.lang.Object with scala.ScalaObject with scal override def productPrefix : java.lang.String = { /* compiled code */ } override def productArity : scala.Int = { /* compiled code */ } override def productElement(x$1 : scala.Int) : scala.Any = { /* compiled code */ } + override def canEqual(x$1 : scala.Any) : scala.Boolean = { /* compiled code */ } protected def readResolve() : java.lang.Object = { /* compiled code */ } }
\ No newline at end of file |