aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoan <joan@goyeau.com>2016-04-22 12:24:12 +0100
committerSean Owen <sowen@cloudera.com>2016-04-22 12:24:12 +0100
commitbf95b8da2774620cd62fa36bd8bf37725ad3fc7d (patch)
treeb257a13641f72ed5b0b0eff34ef0bf64374c7c1d
parente09ab5da8b02da98d7b2496d549c1d53cceb8728 (diff)
downloadspark-bf95b8da2774620cd62fa36bd8bf37725ad3fc7d.tar.gz
spark-bf95b8da2774620cd62fa36bd8bf37725ad3fc7d.tar.bz2
spark-bf95b8da2774620cd62fa36bd8bf37725ad3fc7d.zip
[SPARK-6429] Implement hashCode and equals together
## What changes were proposed in this pull request? Implement some `hashCode` and `equals` together in order to enable the scalastyle. This is a first batch, I will continue to implement them but I wanted to know your thoughts. Author: Joan <joan@goyeau.com> Closes #12157 from joan38/SPARK-6429-HashCode-Equals.
-rw-r--r--core/src/main/scala/org/apache/spark/Partition.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala4
-rw-r--r--mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala2
-rw-r--r--mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala4
-rw-r--r--project/MimaExcludes.scala4
-rw-r--r--scalastyle-config.xml2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala8
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala12
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala10
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala1
32 files changed, 136 insertions, 40 deletions
diff --git a/core/src/main/scala/org/apache/spark/Partition.scala b/core/src/main/scala/org/apache/spark/Partition.scala
index dd3f28e419..e10660793d 100644
--- a/core/src/main/scala/org/apache/spark/Partition.scala
+++ b/core/src/main/scala/org/apache/spark/Partition.scala
@@ -28,4 +28,6 @@ trait Partition extends Serializable {
// A better default implementation of HashCode
override def hashCode(): Int = index
+
+ override def equals(other: Any): Boolean = super.equals(other)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 7bc1eb0436..2381f54ee3 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -58,10 +58,10 @@ private[spark] case class NarrowCoGroupSplitDep(
* narrowDeps should always be equal to the number of parents.
*/
private[spark] class CoGroupPartition(
- idx: Int, val narrowDeps: Array[Option[NarrowCoGroupSplitDep]])
+ override val index: Int, val narrowDeps: Array[Option[NarrowCoGroupSplitDep]])
extends Partition with Serializable {
- override val index: Int = idx
- override def hashCode(): Int = idx
+ override def hashCode(): Int = index
+ override def equals(other: Any): Boolean = super.equals(other)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 6b1e15572c..b22134af45 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -53,14 +53,14 @@ import org.apache.spark.util.{NextIterator, SerializableConfiguration, ShutdownH
/**
* A Spark split class that wraps around a Hadoop InputSplit.
*/
-private[spark] class HadoopPartition(rddId: Int, idx: Int, s: InputSplit)
+private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: InputSplit)
extends Partition {
val inputSplit = new SerializableWritable[InputSplit](s)
- override def hashCode(): Int = 41 * (41 + rddId) + idx
+ override def hashCode(): Int = 31 * (31 + rddId) + index
- override val index: Int = idx
+ override def equals(other: Any): Boolean = super.equals(other)
/**
* Get any environment variables that should be added to the users environment when running pipes
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index a71c191b31..ad7c2216a0 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -45,7 +45,10 @@ private[spark] class NewHadoopPartition(
extends Partition {
val serializableHadoopSplit = new SerializableWritable(rawSplit)
- override def hashCode(): Int = 41 * (41 + rddId) + index
+
+ override def hashCode(): Int = 31 * (31 + rddId) + index
+
+ override def equals(other: Any): Boolean = super.equals(other)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
index 0abba15bec..b6366f3e68 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
@@ -31,12 +31,13 @@ import org.apache.spark.util.Utils
private[spark]
class PartitionerAwareUnionRDDPartition(
@transient val rdds: Seq[RDD[_]],
- val idx: Int
+ override val index: Int
) extends Partition {
- var parents = rdds.map(_.partitions(idx)).toArray
+ var parents = rdds.map(_.partitions(index)).toArray
- override val index = idx
- override def hashCode(): Int = idx
+ override def hashCode(): Int = index
+
+ override def equals(other: Any): Boolean = super.equals(other)
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index 800b42505d..29d5d74650 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -25,7 +25,10 @@ import org.apache.spark.serializer.Serializer
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
override val index: Int = idx
- override def hashCode(): Int = idx
+
+ override def hashCode(): Int = index
+
+ override def equals(other: Any): Boolean = super.equals(other)
}
/**
diff --git a/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala b/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala
index d8d818ceed..8386869237 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/CustomShuffledRDD.scala
@@ -18,6 +18,7 @@
package org.apache.spark.scheduler
import java.util.Arrays
+import java.util.Objects
import org.apache.spark._
import org.apache.spark.rdd.RDD
@@ -53,6 +54,9 @@ class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: A
parentPartitionMapping(parent.getPartition(key))
}
+ override def hashCode(): Int =
+ 31 * Objects.hashCode(parent) + Arrays.hashCode(partitionStartIndices)
+
override def equals(other: Any): Boolean = other match {
case c: CoalescedPartitioner =>
c.parent == parent && Arrays.equals(c.partitionStartIndices, partitionStartIndices)
@@ -66,6 +70,8 @@ private[spark] class CustomShuffledRDDPartition(
extends Partition {
override def hashCode(): Int = index
+
+ override def equals(other: Any): Boolean = super.equals(other)
}
/**
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index 27d063630b..57a8231200 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -476,6 +476,9 @@ object KryoTest {
class ClassWithNoArgConstructor {
var x: Int = 0
+
+ override def hashCode(): Int = x
+
override def equals(other: Any): Boolean = other match {
case c: ClassWithNoArgConstructor => x == c.x
case _ => false
@@ -483,6 +486,8 @@ object KryoTest {
}
class ClassWithoutNoArgConstructor(val x: Int) {
+ override def hashCode(): Int = x
+
override def equals(other: Any): Boolean = other match {
case c: ClassWithoutNoArgConstructor => x == c.x
case _ => false
diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
index 932704c1a3..4920b7ee8b 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
@@ -124,6 +124,8 @@ class ClosureCleanerSuite extends SparkFunSuite {
// A non-serializable class we create in closures to make sure that we aren't
// keeping references to unneeded variables from our outer closures.
class NonSerializable(val id: Int = -1) {
+ override def hashCode(): Int = id
+
override def equals(other: Any): Boolean = {
other match {
case o: NonSerializable => id == o.id
diff --git a/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala b/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala
index c787b5f066..ea22db3555 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala
@@ -22,4 +22,8 @@ package org.apache.spark.util.collection
*/
case class FixedHashObject(v: Int, h: Int) extends Serializable {
override def hashCode(): Int = h
+ override def equals(other: Any): Boolean = other match {
+ case that: FixedHashObject => v == that.v && h == that.h
+ case _ => false
+ }
}
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala
index baa04fb0fd..8204b5af02 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala
@@ -458,6 +458,8 @@ class SparseMatrix (
rowIndices: Array[Int],
values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false)
+ override def hashCode(): Int = toBreeze.hashCode()
+
override def equals(o: Any): Boolean = o match {
case m: Matrix => toBreeze == m.toBreeze
case _ => false
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
index fd4ce9adb8..4275a22ae0 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
@@ -476,6 +476,8 @@ class DenseVector (val values: Array[Double]) extends Vector {
}
}
+ override def equals(other: Any): Boolean = super.equals(other)
+
override def hashCode(): Int = {
var result: Int = 31 + size
var i = 0
@@ -602,6 +604,8 @@ class SparseVector (
}
}
+ override def equals(other: Any): Boolean = super.equals(other)
+
override def hashCode(): Int = {
var result: Int = 31 + size
val end = values.length
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
index 9d895b8fac..5d11ed0971 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.tree
+import java.util.Objects
+
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
@@ -112,12 +114,15 @@ final class CategoricalSplit private[ml] (
}
}
- override def equals(o: Any): Boolean = {
- o match {
- case other: CategoricalSplit => featureIndex == other.featureIndex &&
- isLeft == other.isLeft && categories == other.categories
- case _ => false
- }
+ override def hashCode(): Int = {
+ val state = Seq(featureIndex, isLeft, categories)
+ state.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
+ }
+
+ override def equals(o: Any): Boolean = o match {
+ case other: CategoricalSplit => featureIndex == other.featureIndex &&
+ isLeft == other.isLeft && categories == other.categories
+ case _ => false
}
override private[tree] def toOld: OldSplit = {
@@ -181,6 +186,11 @@ final class ContinuousSplit private[ml] (override val featureIndex: Int, val thr
}
}
+ override def hashCode(): Int = {
+ val state = Seq(featureIndex, threshold)
+ state.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
+ }
+
override private[tree] def toOld: OldSplit = {
OldSplit(featureIndex, threshold, OldFeatureType.Continuous, List.empty[Double])
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index bb5d6d9d51..90fa4fbbc6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -606,6 +606,8 @@ class SparseMatrix @Since("1.3.0") (
case _ => false
}
+ override def hashCode(): Int = toBreeze.hashCode
+
private[mllib] def toBreeze: BM[Double] = {
if (!isTransposed) {
new BSM[Double](values, numRows, numCols, colPtrs, rowIndices)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 5ec83e8d5c..6e3da6b701 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -628,6 +628,8 @@ class DenseVector @Since("1.0.0") (
}
}
+ override def equals(other: Any): Boolean = super.equals(other)
+
override def hashCode(): Int = {
var result: Int = 31 + size
var i = 0
@@ -775,6 +777,8 @@ class SparseVector @Since("1.0.0") (
}
}
+ override def equals(other: Any): Boolean = super.equals(other)
+
override def hashCode(): Int = {
var result: Int = 31 + size
val end = values.length
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index c98a39dc0c..27838167fd 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -114,6 +114,10 @@ object MimaExcludes {
"org.apache.spark.api.java.function.FlatMapGroupsFunction.call")
) ++
Seq(
+ // [SPARK-6429] Implement hashCode and equals together
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.Partition.org$apache$spark$Partition$$super=uals")
+ ) ++
+ Seq(
// SPARK-4819 replace Guava Optional
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getCheckpointDir"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getSparkHome"),
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index e39400e2d1..270104f85b 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -262,7 +262,7 @@ This file is divided into 3 sections:
</check>
<!-- Should turn this on, but we have a few places that need to be fixed first -->
- <check level="error" class="org.scalastyle.scalariform.EqualsHashCodeChecker" enabled="false"></check>
+ <check level="error" class="org.scalastyle.scalariform.EqualsHashCodeChecker" enabled="true"></check>
<!-- ================================================================================ -->
<!-- rules we don't want -->
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
index 8bdf9b29c9..b77f93373e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -60,6 +60,8 @@ object AttributeSet {
class AttributeSet private (val baseSet: Set[AttributeEquals])
extends Traversable[Attribute] with Serializable {
+ override def hashCode: Int = baseSet.hashCode()
+
/** Returns true if the members of this AttributeSet and other are the same. */
override def equals(other: Any): Boolean = other match {
case otherSet: AttributeSet =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index 607c7c877c..d0ad7a05a0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -35,7 +35,8 @@ class EquivalentExpressions {
case other: Expr => e.semanticEquals(other.e)
case _ => false
}
- override val hashCode: Int = e.semanticHash()
+
+ override def hashCode: Int = e.semanticHash()
}
// For each expression, the set of equivalent expressions.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index e9dda588de..7e3683e482 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
+import java.util.Objects
import org.json4s.JsonAST._
@@ -170,6 +171,8 @@ case class Literal protected (value: Any, dataType: DataType)
override def toString: String = if (value != null) value.toString else "null"
+ override def hashCode(): Int = 31 * (31 * Objects.hashCode(dataType)) + Objects.hashCode(value)
+
override def equals(other: Any): Boolean = other match {
case o: Literal =>
dataType.equals(o.dataType) &&
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index c083f12724..8b38838537 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import java.util.UUID
+import java.util.{Objects, UUID}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
@@ -175,6 +175,11 @@ case class Alias(child: Expression, name: String)(
exprId :: qualifier :: explicitMetadata :: isGenerated :: Nil
}
+ override def hashCode(): Int = {
+ val state = Seq(name, exprId, child, qualifier, explicitMetadata)
+ state.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
+ }
+
override def equals(other: Any): Boolean = other match {
case a: Alias =>
name == a.name && exprId == a.exprId && child == a.child && qualifier == a.qualifier &&
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index fb7251d71b..71a9b9f808 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.types
+import java.util.Objects
+
import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
@@ -83,6 +85,8 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa
override def sql: String = sqlType.sql
+ override def hashCode(): Int = getClass.hashCode()
+
override def equals(other: Any): Boolean = other match {
case that: UserDefinedType[_] => this.acceptsType(that)
case _ => false
@@ -115,7 +119,9 @@ private[sql] class PythonUserDefinedType(
}
override def equals(other: Any): Boolean = other match {
- case that: PythonUserDefinedType => this.pyUDT.equals(that.pyUDT)
+ case that: PythonUserDefinedType => pyUDT == that.pyUDT
case _ => false
}
+
+ override def hashCode(): Int = Objects.hashCode(pyUDT)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 18752014ea..c3b20e2cc0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -35,6 +35,9 @@ import org.apache.spark.sql.types.{ArrayType, Decimal, ObjectType, StructType}
case class RepeatedStruct(s: Seq[PrimitiveData])
case class NestedArray(a: Array[Array[Int]]) {
+ override def hashCode(): Int =
+ java.util.Arrays.deepHashCode(a.asInstanceOf[Array[AnyRef]])
+
override def equals(other: Any): Boolean = other match {
case NestedArray(otherArray) =>
java.util.Arrays.deepEquals(
@@ -64,15 +67,21 @@ case class SpecificCollection(l: List[Int])
/** For testing Kryo serialization based encoder. */
class KryoSerializable(val value: Int) {
- override def equals(other: Any): Boolean = {
- this.value == other.asInstanceOf[KryoSerializable].value
+ override def hashCode(): Int = value
+
+ override def equals(other: Any): Boolean = other match {
+ case that: KryoSerializable => this.value == that.value
+ case _ => false
}
}
/** For testing Java serialization based encoder. */
class JavaSerializable(val value: Int) extends Serializable {
- override def equals(other: Any): Boolean = {
- this.value == other.asInstanceOf[JavaSerializable].value
+ override def hashCode(): Int = value
+
+ override def equals(other: Any): Boolean = other match {
+ case that: JavaSerializable => this.value == that.value
+ case _ => false
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
index 42891287a3..e81cd28ea3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
@@ -33,7 +33,10 @@ private final class ShuffledRowRDDPartition(
val startPreShufflePartitionIndex: Int,
val endPreShufflePartitionIndex: Int) extends Partition {
override val index: Int = postShufflePartitionIndex
+
override def hashCode(): Int = postShufflePartitionIndex
+
+ override def equals(other: Any): Boolean = super.equals(other)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
index 34db10f822..61ec7ed2b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
@@ -44,6 +44,8 @@ class DefaultSource extends FileFormat with DataSourceRegister {
override def toString: String = "CSV"
+ override def hashCode(): Int = getClass.hashCode()
+
override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource]
override def inferSchema(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
index 7364a1dc06..7773ff550f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
@@ -154,6 +154,9 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
override def toString: String = "JSON"
+
+ override def hashCode(): Int = getClass.hashCode()
+
override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index bfe7aefe41..38c0084952 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -60,6 +60,8 @@ private[sql] class DefaultSource
override def toString: String = "ParquetFormat"
+ override def hashCode(): Int = getClass.hashCode()
+
override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource]
override def prepareWrite(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 92c31eac95..930adabc48 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -82,12 +82,12 @@ private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetr
override def value: Long = _value
// Needed for SQLListenerSuite
- override def equals(other: Any): Boolean = {
- other match {
- case o: LongSQLMetricValue => value == o.value
- case _ => false
- }
+ override def equals(other: Any): Boolean = other match {
+ case o: LongSQLMetricValue => value == o.value
+ case _ => false
}
+
+ override def hashCode(): Int = _value.hashCode()
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
index 695a5ad78a..a73e427295 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala
@@ -27,6 +27,9 @@ import org.apache.spark.sql.types._
*/
@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable {
+
+ override def hashCode(): Int = 31 * (31 * x.hashCode()) + y.hashCode()
+
override def equals(other: Any): Boolean = other match {
case that: ExamplePoint => this.x == that.x && this.y == that.y
case _ => false
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index acc9f48d7e..a49aaa8b73 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -37,9 +37,10 @@ object UDT {
@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT])
private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable {
+ override def hashCode(): Int = java.util.Arrays.hashCode(data)
+
override def equals(other: Any): Boolean = other match {
- case v: MyDenseVector =>
- java.util.Arrays.equals(this.data, v.data)
+ case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data)
case _ => false
}
}
@@ -63,10 +64,9 @@ object UDT {
private[spark] override def asNullable: MyDenseVectorUDT = this
- override def equals(other: Any): Boolean = other match {
- case _: MyDenseVectorUDT => true
- case _ => false
- }
+ override def hashCode(): Int = getClass.hashCode()
+
+ override def equals(other: Any): Boolean = other.isInstanceOf[MyDenseVectorUDT]
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
index 8119d808ff..58b7031d5e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
@@ -84,15 +84,19 @@ private[streaming] object MapWithStateRDDRecord {
* RDD, and a partitioned keyed-data RDD
*/
private[streaming] class MapWithStateRDDPartition(
- idx: Int,
+ override val index: Int,
@transient private var prevStateRDD: RDD[_],
@transient private var partitionedDataRDD: RDD[_]) extends Partition {
private[rdd] var previousSessionRDDPartition: Partition = null
private[rdd] var partitionedDataRDDPartition: Partition = null
- override def index: Int = idx
- override def hashCode(): Int = idx
+ override def hashCode(): Int = index
+
+ override def equals(other: Any): Boolean = other match {
+ case that: MapWithStateRDDPartition => index == that.index
+ case _ => false
+ }
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
index 784c6525e5..6a861d6f66 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala
@@ -85,6 +85,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter
}
class MockSplitInfo(host: String) extends SplitInfo(null, host, null, 1, null) {
+ override def hashCode(): Int = 0
override def equals(other: Any): Boolean = false
}