diff options
author | Paul Phillips <paulp@improving.org> | 2009-07-09 18:06:38 +0000 |
---|---|---|
committer | Paul Phillips <paulp@improving.org> | 2009-07-09 18:06:38 +0000 |
commit | 79dc3b49f0ed25fcc7cb33fc8fe1c13a6fdc21b3 (patch) | |
tree | 82b7ad10b7c86db1ad1221cdfd4519ed664ada08 | |
parent | ccfb3b9c1697379335992b085368557297c72e2d (diff) | |
download | scala-79dc3b49f0ed25fcc7cb33fc8fe1c13a6fdc21b3.tar.gz scala-79dc3b49f0ed25fcc7cb33fc8fe1c13a6fdc21b3.tar.bz2 scala-79dc3b49f0ed25fcc7cb33fc8fe1c13a6fdc21b3.zip |
Implementation and test cases for canEqual meth...
Implementation and test cases for canEqual method in case classes. Now
the autogenerated equality method inquires with the argument as to
whether other.canEqual(this) before returning true.
-rw-r--r-- | src/compiler/scala/tools/nsc/symtab/Definitions.scala | 2 | ||||
-rw-r--r-- | src/compiler/scala/tools/nsc/symtab/StdNames.scala | 1 | ||||
-rw-r--r-- | src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala | 32 | ||||
-rw-r--r-- | src/library/scala/Product.scala | 6 | ||||
-rw-r--r-- | test/files/run/canEqualCaseClasses.scala | 23 | ||||
-rw-r--r-- | test/files/scalap/caseClass/result.test | 1 | ||||
-rw-r--r-- | test/files/scalap/caseObject/result.test | 1 |
7 files changed, 58 insertions, 8 deletions
diff --git a/src/compiler/scala/tools/nsc/symtab/Definitions.scala b/src/compiler/scala/tools/nsc/symtab/Definitions.scala index cec92d82c7..6f51638045 100644 --- a/src/compiler/scala/tools/nsc/symtab/Definitions.scala +++ b/src/compiler/scala/tools/nsc/symtab/Definitions.scala @@ -190,6 +190,8 @@ trait Definitions { def Product_productArity = getMember(ProductRootClass, nme.productArity) def Product_productElement = getMember(ProductRootClass, nme.productElement) def Product_productPrefix = getMember(ProductRootClass, nme.productPrefix) + def Product_canEqual = getMember(ProductRootClass, nme.canEqual_) + val MaxProductArity = 22 /* <unapply> */ val ProductClass: Array[Symbol] = new Array(MaxProductArity + 1) diff --git a/src/compiler/scala/tools/nsc/symtab/StdNames.scala b/src/compiler/scala/tools/nsc/symtab/StdNames.scala index cc150d5809..65445b60d4 100644 --- a/src/compiler/scala/tools/nsc/symtab/StdNames.scala +++ b/src/compiler/scala/tools/nsc/symtab/StdNames.scala @@ -266,6 +266,7 @@ trait StdNames { val box = newTermName("box") val boxArray = newTermName("boxArray") val forceBoxedArray = newTermName("forceBoxedArray") + val canEqual_ = newTermName("canEqual") val checkInitialized = newTermName("checkInitialized") val classOf = newTermName("classOf") val copy = newTermName("copy") diff --git a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala index b48aed76ab..2fdd049bc5 100644 --- a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala +++ b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala @@ -118,7 +118,7 @@ trait SyntheticMethods extends ast.TreeDSL { typer typed { DEF(method) === { - Apply(gen.mkAttributedRef(target), This(clazz) :: (method ARGNAMES)) + Apply(REF(target), This(clazz) :: (method ARGNAMES)) } } } @@ -141,6 +141,15 @@ trait SyntheticMethods extends ast.TreeDSL { } } + /** The canEqual method for case classes. + */ + def canEqualMethod: Tree = { + val method = syntheticMethod(nme.canEqual_, 0, makeTypeConstructor(List(AnyClass.tpe), BooleanClass.tpe)) + val that = method ARG 0 + + typer typed (DEF(method) === (that IS_OBJ clazz.tpe)) + } + /** The equality method for case classes. The argument is an Any, * but because of boxing it will always be an Object, so a check * is neither necessary nor useful before the cast. @@ -148,8 +157,8 @@ trait SyntheticMethods extends ast.TreeDSL { * def equals(that: Any) = * (this eq that.asInstanceOf[AnyRef]) || * (that match { - * case this.C(this.arg_1, ..., this.arg_n) => true - * case _ => false + * case x @ this.C(this.arg_1, ..., this.arg_n) => x canEqual this + * case _ => false * }) */ def equalsClassMethod: Tree = { @@ -170,13 +179,21 @@ trait SyntheticMethods extends ast.TreeDSL { // Creates list of parameters and a guard for each val (guards, params) = List.map2(clazz.caseFieldAccessors, constrParamTypes)(makeTrees) unzip + // Verify with canEqual method before returning true. + def canEqualCheck() = { + val that: Tree = typer typed ((method ARG 0) AS clazz.tpe) + val canEqualOther: Symbol = clazz.info nonPrivateMember nme.canEqual_ + + (that DOT canEqualOther)(This(clazz)) + } + // Pattern is classname applied to parameters, and guards are all logical and-ed val (guard, pat) = (AND(guards: _*), clazz.name.toTermName APPLY params) localTyper typed { DEF(method) === { (This(clazz) ANY_EQ that) OR (that MATCH( - (CASE(pat) IF guard) ==> TRUE , + (CASE(pat) IF guard) ==> canEqualCheck() , DEFAULT ==> FALSE )) } @@ -191,9 +208,7 @@ trait SyntheticMethods extends ast.TreeDSL { def readResolveMethod: Tree = { // !!! the synthetic method "readResolve" should be private, but then it is renamed !!! val method = newSyntheticMethod(nme.readResolve, PROTECTED, makeNoArgConstructor(ObjectClass.tpe)) - typer typed { - DEF(method) === gen.mkAttributedRef(clazz.sourceModule) - } + typer typed (DEF(method) === REF(clazz.sourceModule)) } def newAccessorMethod(tree: Tree): Tree = tree match { @@ -240,7 +255,8 @@ trait SyntheticMethods extends ast.TreeDSL { List( Product_productPrefix -> (() => productPrefixMethod), Product_productArity -> (() => productArityMethod(accessors.length)), - Product_productElement -> (() => productElementMethod(accessors)) + Product_productElement -> (() => productElementMethod(accessors)), + Product_canEqual -> (() => canEqualMethod) ) } diff --git a/src/library/scala/Product.scala b/src/library/scala/Product.scala index 2e0a0b2816..bae88f52d7 100644 --- a/src/library/scala/Product.scala +++ b/src/library/scala/Product.scala @@ -49,4 +49,10 @@ trait Product extends AnyRef { */ def productPrefix = "" + /** + * An equality helper method to assist in maintaining reflexivity + * in the face of subtyping. For more, see + * http://www.artima.com/lejava/articles/equality.html + */ + def canEqual(other: Any) = true } diff --git a/test/files/run/canEqualCaseClasses.scala b/test/files/run/canEqualCaseClasses.scala new file mode 100644 index 0000000000..b1c3eb91de --- /dev/null +++ b/test/files/run/canEqualCaseClasses.scala @@ -0,0 +1,23 @@ +object Test { + case class Foo(x: Int) + case class Bar(y: Int) extends Foo(y) + case class Nutty() { + override def canEqual(other: Any) = true + } + + def assertEqual(x: AnyRef, y: AnyRef) = + assert((x == y) && (y == x) && (x.hashCode == y.hashCode)) + + def assertUnequal(x: AnyRef, y: AnyRef) = + assert((x != y) && (y != x)) + + def main(args: Array[String]): Unit = { + assertEqual(Foo(5), Foo(5)) + assertEqual(Bar(5), Bar(5)) + assertUnequal(Foo(5), Bar(5)) + + // in case there's an overriding implementation + assert(Nutty() canEqual (new AnyRef)) + assert(!(Foo(5) canEqual (new AnyRef))) + } +} diff --git a/test/files/scalap/caseClass/result.test b/test/files/scalap/caseClass/result.test index 3daa036d18..be64349cb1 100644 --- a/test/files/scalap/caseClass/result.test +++ b/test/files/scalap/caseClass/result.test @@ -11,4 +11,5 @@ case class CaseClass[A >: scala.Nothing <: scala.Seq[scala.Int]] extends java.la override def productPrefix : java.lang.String = { /* compiled code */ } override def productArity : scala.Int = { /* compiled code */ } override def productElement(x$1 : scala.Int) : scala.Any = { /* compiled code */ } + override def canEqual(x$1 : scala.Any) : scala.Boolean = { /* compiled code */ } }
\ No newline at end of file diff --git a/test/files/scalap/caseObject/result.test b/test/files/scalap/caseObject/result.test index 2097c5a71d..f81e6ce3cb 100644 --- a/test/files/scalap/caseObject/result.test +++ b/test/files/scalap/caseObject/result.test @@ -4,5 +4,6 @@ case object CaseObject extends java.lang.Object with scala.ScalaObject with scal override def productPrefix : java.lang.String = { /* compiled code */ } override def productArity : scala.Int = { /* compiled code */ } override def productElement(x$1 : scala.Int) : scala.Any = { /* compiled code */ } + override def canEqual(x$1 : scala.Any) : scala.Boolean = { /* compiled code */ } protected def readResolve() : java.lang.Object = { /* compiled code */ } }
\ No newline at end of file |