summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/compiler/scala/tools/nsc/symtab/Definitions.scala2
-rw-r--r--src/compiler/scala/tools/nsc/symtab/StdNames.scala1
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala32
-rw-r--r--src/library/scala/Product.scala6
-rw-r--r--test/files/run/canEqualCaseClasses.scala23
-rw-r--r--test/files/scalap/caseClass/result.test1
-rw-r--r--test/files/scalap/caseObject/result.test1
7 files changed, 58 insertions, 8 deletions
diff --git a/src/compiler/scala/tools/nsc/symtab/Definitions.scala b/src/compiler/scala/tools/nsc/symtab/Definitions.scala
index cec92d82c7..6f51638045 100644
--- a/src/compiler/scala/tools/nsc/symtab/Definitions.scala
+++ b/src/compiler/scala/tools/nsc/symtab/Definitions.scala
@@ -190,6 +190,8 @@ trait Definitions {
def Product_productArity = getMember(ProductRootClass, nme.productArity)
def Product_productElement = getMember(ProductRootClass, nme.productElement)
def Product_productPrefix = getMember(ProductRootClass, nme.productPrefix)
+ def Product_canEqual = getMember(ProductRootClass, nme.canEqual_)
+
val MaxProductArity = 22
/* <unapply> */
val ProductClass: Array[Symbol] = new Array(MaxProductArity + 1)
diff --git a/src/compiler/scala/tools/nsc/symtab/StdNames.scala b/src/compiler/scala/tools/nsc/symtab/StdNames.scala
index cc150d5809..65445b60d4 100644
--- a/src/compiler/scala/tools/nsc/symtab/StdNames.scala
+++ b/src/compiler/scala/tools/nsc/symtab/StdNames.scala
@@ -266,6 +266,7 @@ trait StdNames {
val box = newTermName("box")
val boxArray = newTermName("boxArray")
val forceBoxedArray = newTermName("forceBoxedArray")
+ val canEqual_ = newTermName("canEqual")
val checkInitialized = newTermName("checkInitialized")
val classOf = newTermName("classOf")
val copy = newTermName("copy")
diff --git a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala
index b48aed76ab..2fdd049bc5 100644
--- a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala
+++ b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala
@@ -118,7 +118,7 @@ trait SyntheticMethods extends ast.TreeDSL {
typer typed {
DEF(method) === {
- Apply(gen.mkAttributedRef(target), This(clazz) :: (method ARGNAMES))
+ Apply(REF(target), This(clazz) :: (method ARGNAMES))
}
}
}
@@ -141,6 +141,15 @@ trait SyntheticMethods extends ast.TreeDSL {
}
}
+ /** The canEqual method for case classes.
+ */
+ def canEqualMethod: Tree = {
+ val method = syntheticMethod(nme.canEqual_, 0, makeTypeConstructor(List(AnyClass.tpe), BooleanClass.tpe))
+ val that = method ARG 0
+
+ typer typed (DEF(method) === (that IS_OBJ clazz.tpe))
+ }
+
/** The equality method for case classes. The argument is an Any,
* but because of boxing it will always be an Object, so a check
* is neither necessary nor useful before the cast.
@@ -148,8 +157,8 @@ trait SyntheticMethods extends ast.TreeDSL {
* def equals(that: Any) =
* (this eq that.asInstanceOf[AnyRef]) ||
* (that match {
- * case this.C(this.arg_1, ..., this.arg_n) => true
- * case _ => false
+ * case x @ this.C(this.arg_1, ..., this.arg_n) => x canEqual this
+ * case _ => false
* })
*/
def equalsClassMethod: Tree = {
@@ -170,13 +179,21 @@ trait SyntheticMethods extends ast.TreeDSL {
// Creates list of parameters and a guard for each
val (guards, params) = List.map2(clazz.caseFieldAccessors, constrParamTypes)(makeTrees) unzip
+ // Verify with canEqual method before returning true.
+ def canEqualCheck() = {
+ val that: Tree = typer typed ((method ARG 0) AS clazz.tpe)
+ val canEqualOther: Symbol = clazz.info nonPrivateMember nme.canEqual_
+
+ (that DOT canEqualOther)(This(clazz))
+ }
+
// Pattern is classname applied to parameters, and guards are all logical and-ed
val (guard, pat) = (AND(guards: _*), clazz.name.toTermName APPLY params)
localTyper typed {
DEF(method) === {
(This(clazz) ANY_EQ that) OR (that MATCH(
- (CASE(pat) IF guard) ==> TRUE ,
+ (CASE(pat) IF guard) ==> canEqualCheck() ,
DEFAULT ==> FALSE
))
}
@@ -191,9 +208,7 @@ trait SyntheticMethods extends ast.TreeDSL {
def readResolveMethod: Tree = {
// !!! the synthetic method "readResolve" should be private, but then it is renamed !!!
val method = newSyntheticMethod(nme.readResolve, PROTECTED, makeNoArgConstructor(ObjectClass.tpe))
- typer typed {
- DEF(method) === gen.mkAttributedRef(clazz.sourceModule)
- }
+ typer typed (DEF(method) === REF(clazz.sourceModule))
}
def newAccessorMethod(tree: Tree): Tree = tree match {
@@ -240,7 +255,8 @@ trait SyntheticMethods extends ast.TreeDSL {
List(
Product_productPrefix -> (() => productPrefixMethod),
Product_productArity -> (() => productArityMethod(accessors.length)),
- Product_productElement -> (() => productElementMethod(accessors))
+ Product_productElement -> (() => productElementMethod(accessors)),
+ Product_canEqual -> (() => canEqualMethod)
)
}
diff --git a/src/library/scala/Product.scala b/src/library/scala/Product.scala
index 2e0a0b2816..bae88f52d7 100644
--- a/src/library/scala/Product.scala
+++ b/src/library/scala/Product.scala
@@ -49,4 +49,10 @@ trait Product extends AnyRef {
*/
def productPrefix = ""
+ /**
+ * An equality helper method to assist in maintaining reflexivity
+ * in the face of subtyping. For more, see
+ * http://www.artima.com/lejava/articles/equality.html
+ */
+ def canEqual(other: Any) = true
}
diff --git a/test/files/run/canEqualCaseClasses.scala b/test/files/run/canEqualCaseClasses.scala
new file mode 100644
index 0000000000..b1c3eb91de
--- /dev/null
+++ b/test/files/run/canEqualCaseClasses.scala
@@ -0,0 +1,23 @@
+object Test {
+ case class Foo(x: Int)
+ case class Bar(y: Int) extends Foo(y)
+ case class Nutty() {
+ override def canEqual(other: Any) = true
+ }
+
+ def assertEqual(x: AnyRef, y: AnyRef) =
+ assert((x == y) && (y == x) && (x.hashCode == y.hashCode))
+
+ def assertUnequal(x: AnyRef, y: AnyRef) =
+ assert((x != y) && (y != x))
+
+ def main(args: Array[String]): Unit = {
+ assertEqual(Foo(5), Foo(5))
+ assertEqual(Bar(5), Bar(5))
+ assertUnequal(Foo(5), Bar(5))
+
+ // in case there's an overriding implementation
+ assert(Nutty() canEqual (new AnyRef))
+ assert(!(Foo(5) canEqual (new AnyRef)))
+ }
+}
diff --git a/test/files/scalap/caseClass/result.test b/test/files/scalap/caseClass/result.test
index 3daa036d18..be64349cb1 100644
--- a/test/files/scalap/caseClass/result.test
+++ b/test/files/scalap/caseClass/result.test
@@ -11,4 +11,5 @@ case class CaseClass[A >: scala.Nothing <: scala.Seq[scala.Int]] extends java.la
override def productPrefix : java.lang.String = { /* compiled code */ }
override def productArity : scala.Int = { /* compiled code */ }
override def productElement(x$1 : scala.Int) : scala.Any = { /* compiled code */ }
+ override def canEqual(x$1 : scala.Any) : scala.Boolean = { /* compiled code */ }
} \ No newline at end of file
diff --git a/test/files/scalap/caseObject/result.test b/test/files/scalap/caseObject/result.test
index 2097c5a71d..f81e6ce3cb 100644
--- a/test/files/scalap/caseObject/result.test
+++ b/test/files/scalap/caseObject/result.test
@@ -4,5 +4,6 @@ case object CaseObject extends java.lang.Object with scala.ScalaObject with scal
override def productPrefix : java.lang.String = { /* compiled code */ }
override def productArity : scala.Int = { /* compiled code */ }
override def productElement(x$1 : scala.Int) : scala.Any = { /* compiled code */ }
+ override def canEqual(x$1 : scala.Any) : scala.Boolean = { /* compiled code */ }
protected def readResolve() : java.lang.Object = { /* compiled code */ }
} \ No newline at end of file