summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Phillips <paulp@improving.org>2009-07-09 18:06:38 +0000
committerPaul Phillips <paulp@improving.org>2009-07-09 18:06:38 +0000
commit79dc3b49f0ed25fcc7cb33fc8fe1c13a6fdc21b3 (patch)
tree82b7ad10b7c86db1ad1221cdfd4519ed664ada08
parentccfb3b9c1697379335992b085368557297c72e2d (diff)
downloadscala-79dc3b49f0ed25fcc7cb33fc8fe1c13a6fdc21b3.tar.gz
scala-79dc3b49f0ed25fcc7cb33fc8fe1c13a6fdc21b3.tar.bz2
scala-79dc3b49f0ed25fcc7cb33fc8fe1c13a6fdc21b3.zip
Implementation and test cases for canEqual meth...
Implementation and test cases for canEqual method in case classes. Now the autogenerated equality method inquires with the argument as to whether other.canEqual(this) before returning true.
-rw-r--r--src/compiler/scala/tools/nsc/symtab/Definitions.scala2
-rw-r--r--src/compiler/scala/tools/nsc/symtab/StdNames.scala1
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala32
-rw-r--r--src/library/scala/Product.scala6
-rw-r--r--test/files/run/canEqualCaseClasses.scala23
-rw-r--r--test/files/scalap/caseClass/result.test1
-rw-r--r--test/files/scalap/caseObject/result.test1
7 files changed, 58 insertions, 8 deletions
diff --git a/src/compiler/scala/tools/nsc/symtab/Definitions.scala b/src/compiler/scala/tools/nsc/symtab/Definitions.scala
index cec92d82c7..6f51638045 100644
--- a/src/compiler/scala/tools/nsc/symtab/Definitions.scala
+++ b/src/compiler/scala/tools/nsc/symtab/Definitions.scala
@@ -190,6 +190,8 @@ trait Definitions {
def Product_productArity = getMember(ProductRootClass, nme.productArity)
def Product_productElement = getMember(ProductRootClass, nme.productElement)
def Product_productPrefix = getMember(ProductRootClass, nme.productPrefix)
+ def Product_canEqual = getMember(ProductRootClass, nme.canEqual_)
+
val MaxProductArity = 22
/* <unapply> */
val ProductClass: Array[Symbol] = new Array(MaxProductArity + 1)
diff --git a/src/compiler/scala/tools/nsc/symtab/StdNames.scala b/src/compiler/scala/tools/nsc/symtab/StdNames.scala
index cc150d5809..65445b60d4 100644
--- a/src/compiler/scala/tools/nsc/symtab/StdNames.scala
+++ b/src/compiler/scala/tools/nsc/symtab/StdNames.scala
@@ -266,6 +266,7 @@ trait StdNames {
val box = newTermName("box")
val boxArray = newTermName("boxArray")
val forceBoxedArray = newTermName("forceBoxedArray")
+ val canEqual_ = newTermName("canEqual")
val checkInitialized = newTermName("checkInitialized")
val classOf = newTermName("classOf")
val copy = newTermName("copy")
diff --git a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala
index b48aed76ab..2fdd049bc5 100644
--- a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala
+++ b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala
@@ -118,7 +118,7 @@ trait SyntheticMethods extends ast.TreeDSL {
typer typed {
DEF(method) === {
- Apply(gen.mkAttributedRef(target), This(clazz) :: (method ARGNAMES))
+ Apply(REF(target), This(clazz) :: (method ARGNAMES))
}
}
}
@@ -141,6 +141,15 @@ trait SyntheticMethods extends ast.TreeDSL {
}
}
+ /** The canEqual method for case classes.
+ */
+ def canEqualMethod: Tree = {
+ val method = syntheticMethod(nme.canEqual_, 0, makeTypeConstructor(List(AnyClass.tpe), BooleanClass.tpe))
+ val that = method ARG 0
+
+ typer typed (DEF(method) === (that IS_OBJ clazz.tpe))
+ }
+
/** The equality method for case classes. The argument is an Any,
* but because of boxing it will always be an Object, so a check
* is neither necessary nor useful before the cast.
@@ -148,8 +157,8 @@ trait SyntheticMethods extends ast.TreeDSL {
* def equals(that: Any) =
* (this eq that.asInstanceOf[AnyRef]) ||
* (that match {
- * case this.C(this.arg_1, ..., this.arg_n) => true
- * case _ => false
+ * case x @ this.C(this.arg_1, ..., this.arg_n) => x canEqual this
+ * case _ => false
* })
*/
def equalsClassMethod: Tree = {
@@ -170,13 +179,21 @@ trait SyntheticMethods extends ast.TreeDSL {
// Creates list of parameters and a guard for each
val (guards, params) = List.map2(clazz.caseFieldAccessors, constrParamTypes)(makeTrees) unzip
+ // Verify with canEqual method before returning true.
+ def canEqualCheck() = {
+ val that: Tree = typer typed ((method ARG 0) AS clazz.tpe)
+ val canEqualOther: Symbol = clazz.info nonPrivateMember nme.canEqual_
+
+ (that DOT canEqualOther)(This(clazz))
+ }
+
// Pattern is classname applied to parameters, and guards are all logical and-ed
val (guard, pat) = (AND(guards: _*), clazz.name.toTermName APPLY params)
localTyper typed {
DEF(method) === {
(This(clazz) ANY_EQ that) OR (that MATCH(
- (CASE(pat) IF guard) ==> TRUE ,
+ (CASE(pat) IF guard) ==> canEqualCheck() ,
DEFAULT ==> FALSE
))
}
@@ -191,9 +208,7 @@ trait SyntheticMethods extends ast.TreeDSL {
def readResolveMethod: Tree = {
// !!! the synthetic method "readResolve" should be private, but then it is renamed !!!
val method = newSyntheticMethod(nme.readResolve, PROTECTED, makeNoArgConstructor(ObjectClass.tpe))
- typer typed {
- DEF(method) === gen.mkAttributedRef(clazz.sourceModule)
- }
+ typer typed (DEF(method) === REF(clazz.sourceModule))
}
def newAccessorMethod(tree: Tree): Tree = tree match {
@@ -240,7 +255,8 @@ trait SyntheticMethods extends ast.TreeDSL {
List(
Product_productPrefix -> (() => productPrefixMethod),
Product_productArity -> (() => productArityMethod(accessors.length)),
- Product_productElement -> (() => productElementMethod(accessors))
+ Product_productElement -> (() => productElementMethod(accessors)),
+ Product_canEqual -> (() => canEqualMethod)
)
}
diff --git a/src/library/scala/Product.scala b/src/library/scala/Product.scala
index 2e0a0b2816..bae88f52d7 100644
--- a/src/library/scala/Product.scala
+++ b/src/library/scala/Product.scala
@@ -49,4 +49,10 @@ trait Product extends AnyRef {
*/
def productPrefix = ""
+ /**
+ * An equality helper method to assist in maintaining reflexivity
+ * in the face of subtyping. For more, see
+ * http://www.artima.com/lejava/articles/equality.html
+ */
+ def canEqual(other: Any) = true
}
diff --git a/test/files/run/canEqualCaseClasses.scala b/test/files/run/canEqualCaseClasses.scala
new file mode 100644
index 0000000000..b1c3eb91de
--- /dev/null
+++ b/test/files/run/canEqualCaseClasses.scala
@@ -0,0 +1,23 @@
+object Test {
+ case class Foo(x: Int)
+ case class Bar(y: Int) extends Foo(y)
+ case class Nutty() {
+ override def canEqual(other: Any) = true
+ }
+
+ def assertEqual(x: AnyRef, y: AnyRef) =
+ assert((x == y) && (y == x) && (x.hashCode == y.hashCode))
+
+ def assertUnequal(x: AnyRef, y: AnyRef) =
+ assert((x != y) && (y != x))
+
+ def main(args: Array[String]): Unit = {
+ assertEqual(Foo(5), Foo(5))
+ assertEqual(Bar(5), Bar(5))
+ assertUnequal(Foo(5), Bar(5))
+
+ // in case there's an overriding implementation
+ assert(Nutty() canEqual (new AnyRef))
+ assert(!(Foo(5) canEqual (new AnyRef)))
+ }
+}
diff --git a/test/files/scalap/caseClass/result.test b/test/files/scalap/caseClass/result.test
index 3daa036d18..be64349cb1 100644
--- a/test/files/scalap/caseClass/result.test
+++ b/test/files/scalap/caseClass/result.test
@@ -11,4 +11,5 @@ case class CaseClass[A >: scala.Nothing <: scala.Seq[scala.Int]] extends java.la
override def productPrefix : java.lang.String = { /* compiled code */ }
override def productArity : scala.Int = { /* compiled code */ }
override def productElement(x$1 : scala.Int) : scala.Any = { /* compiled code */ }
+ override def canEqual(x$1 : scala.Any) : scala.Boolean = { /* compiled code */ }
} \ No newline at end of file
diff --git a/test/files/scalap/caseObject/result.test b/test/files/scalap/caseObject/result.test
index 2097c5a71d..f81e6ce3cb 100644
--- a/test/files/scalap/caseObject/result.test
+++ b/test/files/scalap/caseObject/result.test
@@ -4,5 +4,6 @@ case object CaseObject extends java.lang.Object with scala.ScalaObject with scal
override def productPrefix : java.lang.String = { /* compiled code */ }
override def productArity : scala.Int = { /* compiled code */ }
override def productElement(x$1 : scala.Int) : scala.Any = { /* compiled code */ }
+ override def canEqual(x$1 : scala.Any) : scala.Boolean = { /* compiled code */ }
protected def readResolve() : java.lang.Object = { /* compiled code */ }
} \ No newline at end of file