diff options
author | Joan <joan@goyeau.com> | 2016-04-22 12:24:12 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-04-22 12:24:12 +0100 |
commit | bf95b8da2774620cd62fa36bd8bf37725ad3fc7d (patch) | |
tree | b257a13641f72ed5b0b0eff34ef0bf64374c7c1d /sql | |
parent | e09ab5da8b02da98d7b2496d549c1d53cceb8728 (diff) | |
download | spark-bf95b8da2774620cd62fa36bd8bf37725ad3fc7d.tar.gz spark-bf95b8da2774620cd62fa36bd8bf37725ad3fc7d.tar.bz2 spark-bf95b8da2774620cd62fa36bd8bf37725ad3fc7d.zip |
[SPARK-6429] Implement hashCode and equals together
## What changes were proposed in this pull request?
Implement some `hashCode` and `equals` together in order to enable the scalastyle.
This is a first batch, I will continue to implement them but I wanted to know your thoughts.
Author: Joan <joan@goyeau.com>
Closes #12157 from joan38/SPARK-6429-HashCode-Equals.
Diffstat (limited to 'sql')
13 files changed, 57 insertions, 18 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 8bdf9b29c9..b77f93373e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -60,6 +60,8 @@ object AttributeSet { class AttributeSet private (val baseSet: Set[AttributeEquals]) extends Traversable[Attribute] with Serializable { + override def hashCode: Int = baseSet.hashCode() + /** Returns true if the members of this AttributeSet and other are the same. */ override def equals(other: Any): Boolean = other match { case otherSet: AttributeSet => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 607c7c877c..d0ad7a05a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -35,7 +35,8 @@ class EquivalentExpressions { case other: Expr => e.semanticEquals(other.e) case _ => false } - override val hashCode: Int = e.semanticHash() + + override def hashCode: Int = e.semanticHash() } // For each expression, the set of equivalent expressions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e9dda588de..7e3683e482 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import java.util.Objects import org.json4s.JsonAST._ @@ -170,6 +171,8 @@ case class Literal protected (value: Any, dataType: DataType) override def toString: String = if (value != null) value.toString else "null" + override def hashCode(): Int = 31 * (31 * Objects.hashCode(dataType)) + Objects.hashCode(value) + override def equals(other: Any): Boolean = other match { case o: Literal => dataType.equals(o.dataType) && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index c083f12724..8b38838537 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.util.UUID +import java.util.{Objects, UUID} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute @@ -175,6 +175,11 @@ case class Alias(child: Expression, name: String)( exprId :: qualifier :: explicitMetadata :: isGenerated :: Nil } + override def hashCode(): Int = { + val state = Seq(name, exprId, child, qualifier, explicitMetadata) + state.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } + override def equals(other: Any): Boolean = other match { case a: Alias => name == a.name && exprId == a.exprId && child == a.child && qualifier == a.qualifier && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index fb7251d71b..71a9b9f808 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.util.Objects + import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ @@ -83,6 +85,8 @@ abstract class UserDefinedType[UserType >: Null] extends DataType with Serializa override def sql: String = sqlType.sql + override def hashCode(): Int = getClass.hashCode() + override def equals(other: Any): Boolean = other match { case that: UserDefinedType[_] => this.acceptsType(that) case _ => false @@ -115,7 +119,9 @@ private[sql] class PythonUserDefinedType( } override def equals(other: Any): Boolean = other match { - case that: PythonUserDefinedType => this.pyUDT.equals(that.pyUDT) + case that: PythonUserDefinedType => pyUDT == that.pyUDT case _ => false } + + override def hashCode(): Int = Objects.hashCode(pyUDT) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 18752014ea..c3b20e2cc0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -35,6 +35,9 @@ import org.apache.spark.sql.types.{ArrayType, Decimal, ObjectType, StructType} case class RepeatedStruct(s: Seq[PrimitiveData]) case class NestedArray(a: Array[Array[Int]]) { + override def hashCode(): Int = + java.util.Arrays.deepHashCode(a.asInstanceOf[Array[AnyRef]]) + override def equals(other: Any): Boolean = other match { case NestedArray(otherArray) => java.util.Arrays.deepEquals( @@ -64,15 +67,21 @@ case class SpecificCollection(l: List[Int]) /** For testing Kryo serialization based encoder. */ class KryoSerializable(val value: Int) { - override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[KryoSerializable].value + override def hashCode(): Int = value + + override def equals(other: Any): Boolean = other match { + case that: KryoSerializable => this.value == that.value + case _ => false } } /** For testing Java serialization based encoder. */ class JavaSerializable(val value: Int) extends Serializable { - override def equals(other: Any): Boolean = { - this.value == other.asInstanceOf[JavaSerializable].value + override def hashCode(): Int = value + + override def equals(other: Any): Boolean = other match { + case that: JavaSerializable => this.value == that.value + case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index 42891287a3..e81cd28ea3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -33,7 +33,10 @@ private final class ShuffledRowRDDPartition( val startPreShufflePartitionIndex: Int, val endPreShufflePartitionIndex: Int) extends Partition { override val index: Int = postShufflePartitionIndex + override def hashCode(): Int = postShufflePartitionIndex + + override def equals(other: Any): Boolean = super.equals(other) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index 34db10f822..61ec7ed2b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -44,6 +44,8 @@ class DefaultSource extends FileFormat with DataSourceRegister { override def toString: String = "CSV" + override def hashCode(): Int = getClass.hashCode() + override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource] override def inferSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 7364a1dc06..7773ff550f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -154,6 +154,9 @@ class DefaultSource extends FileFormat with DataSourceRegister { } override def toString: String = "JSON" + + override def hashCode(): Int = getClass.hashCode() + override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index bfe7aefe41..38c0084952 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -60,6 +60,8 @@ private[sql] class DefaultSource override def toString: String = "ParquetFormat" + override def hashCode(): Int = getClass.hashCode() + override def equals(other: Any): Boolean = other.isInstanceOf[DefaultSource] override def prepareWrite( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 92c31eac95..930adabc48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -82,12 +82,12 @@ private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetr override def value: Long = _value // Needed for SQLListenerSuite - override def equals(other: Any): Boolean = { - other match { - case o: LongSQLMetricValue => value == o.value - case _ => false - } + override def equals(other: Any): Boolean = other match { + case o: LongSQLMetricValue => value == o.value + case _ => false } + + override def hashCode(): Int = _value.hashCode() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 695a5ad78a..a73e427295 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -27,6 +27,9 @@ import org.apache.spark.sql.types._ */ @SQLUserDefinedType(udt = classOf[ExamplePointUDT]) private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializable { + + override def hashCode(): Int = 31 * (31 * x.hashCode()) + y.hashCode() + override def equals(other: Any): Boolean = other match { case that: ExamplePoint => this.x == that.x && this.y == that.y case _ => false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index acc9f48d7e..a49aaa8b73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -37,9 +37,10 @@ object UDT { @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { + override def hashCode(): Int = java.util.Arrays.hashCode(data) + override def equals(other: Any): Boolean = other match { - case v: MyDenseVector => - java.util.Arrays.equals(this.data, v.data) + case v: MyDenseVector => java.util.Arrays.equals(this.data, v.data) case _ => false } } @@ -63,10 +64,9 @@ object UDT { private[spark] override def asNullable: MyDenseVectorUDT = this - override def equals(other: Any): Boolean = other match { - case _: MyDenseVectorUDT => true - case _ => false - } + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other.isInstanceOf[MyDenseVectorUDT] } } |