diff options
author | Davies Liu <davies@databricks.com> | 2015-06-23 11:55:47 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-06-23 11:55:47 -0700 |
commit | 6f4cadf5ee81467d077febc53d36571dd232295d (patch) | |
tree | f5886b0ed4c10a4f7338ef73c128aced4f6d77cb /sql | |
parent | 7b1450b666f88452e7fe969a6d59e8b24842ea39 (diff) | |
download | spark-6f4cadf5ee81467d077febc53d36571dd232295d.tar.gz spark-6f4cadf5ee81467d077febc53d36571dd232295d.tar.bz2 spark-6f4cadf5ee81467d077febc53d36571dd232295d.zip |
[SPARK-8432] [SQL] fix hashCode() and equals() of BinaryType in Row
Also added more tests in LiteralExpressionSuite
Author: Davies Liu <davies@databricks.com>
Closes #6876 from davies/fix_hashcode and squashes the following commits:
429c2c0 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_hashcode
32d9811 [Davies Liu] fix test
a0626ed [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_hashcode
89c2432 [Davies Liu] fix style
bd20780 [Davies Liu] check with catalyst types
41caec6 [Davies Liu] change for to while
d96929b [Davies Liu] address comment
6ad2a90 [Davies Liu] fix style
5819d33 [Davies Liu] unify equals() and hashCode()
0fff25d [Davies Liu] fix style
53c38b1 [Davies Liu] fix hashCode() and equals() of BinaryType in Row
Diffstat (limited to 'sql')
8 files changed, 138 insertions, 128 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java index 611e02d8fb..6a2356f1f9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java @@ -155,27 +155,6 @@ public abstract class BaseRow extends InternalRow { throw new UnsupportedOperationException(); } - /** - * A generic version of Row.equals(Row), which is used for tests. - */ - @Override - public boolean equals(Object other) { - if (other instanceof Row) { - Row row = (Row) other; - int n = size(); - if (n != row.size()) { - return false; - } - for (int i = 0; i < n; i ++) { - if (isNullAt(i) != row.isNullAt(i) || (!isNullAt(i) && !get(i).equals(row.get(i)))) { - return false; - } - } - return true; - } - return false; - } - @Override public InternalRow copy() { final int n = size(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 8aaf5d7d89..e99d5c87a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import scala.util.hashing.MurmurHash3 - import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.types.StructType @@ -365,36 +363,6 @@ trait Row extends Serializable { false } - override def equals(that: Any): Boolean = that match { - case null => false - case that: Row => - if (this.length != that.length) { - return false - } - var i = 0 - val len = this.length - while (i < len) { - if (apply(i) != that.apply(i)) { - return false - } - i += 1 - } - true - case _ => false - } - - override def hashCode: Int = { - // Using Scala's Seq hash code implementation. - var n = 0 - var h = MurmurHash3.seqSeed - val len = length - while (n < len) { - h = MurmurHash3.mix(h, apply(n).##) - n += 1 - } - MurmurHash3.finalizeHash(h, n) - } - /* ---------------------- utility methods for Scala ---------------------- */ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index e3c2cc2433..d7b537a9fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.expressions._ /** * An abstract class for row used internal in Spark SQL, which only contain the columns as @@ -26,7 +26,70 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow */ abstract class InternalRow extends Row { // A default implementation to change the return type - override def copy(): InternalRow = {this} + override def copy(): InternalRow = this + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[Row]) { + return false + } + + val other = o.asInstanceOf[Row] + if (length != other.length) { + return false + } + + var i = 0 + while (i < length) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = apply(i) + val o2 = other.apply(i) + if (o1.isInstanceOf[Array[Byte]]) { + // handle equality of Array[Byte] + val b1 = o1.asInstanceOf[Array[Byte]] + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + } else if (o1 != o2) { + return false + } + } + i += 1 + } + true + } + + // Custom hashCode function that matches the efficient code generated version. + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + while (i < length) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + apply(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } } object InternalRow { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 2e20eda1a3..e362625469 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -127,6 +127,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { case FloatType => s"Float.floatToIntBits($col)" case DoubleType => s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" + case BinaryType => s"java.util.Arrays.hashCode($col)" case _ => s"$col.hashCode()" } s"isNullAt($i) ? 0 : ($nonNull)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 1098962ddc..0d4c9ace5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -121,58 +121,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow { } } - // TODO(davies): add getDate and getDecimal - - // Custom hashCode function that matches the efficient code generated version. - override def hashCode: Int = { - var result: Int = 37 - - var i = 0 - while (i < values.length) { - val update: Int = - if (isNullAt(i)) { - 0 - } else { - apply(i) match { - case b: Boolean => if (b) 0 else 1 - case b: Byte => b.toInt - case s: Short => s.toInt - case i: Int => i - case l: Long => (l ^ (l >>> 32)).toInt - case f: Float => java.lang.Float.floatToIntBits(f) - case d: Double => - val b = java.lang.Double.doubleToLongBits(d) - (b ^ (b >>> 32)).toInt - case other => other.hashCode() - } - } - result = 37 * result + update - i += 1 - } - result - } - - override def equals(o: Any): Boolean = o match { - case other: InternalRow => - if (values.length != other.length) { - return false - } - - var i = 0 - while (i < values.length) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - if (apply(i) != other.apply(i)) { - return false - } - i += 1 - } - true - - case _ => false - } - override def copy(): InternalRow = this } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 12d2da8b33..158f54af13 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -38,10 +38,23 @@ trait ExpressionEvalHelper { protected def checkEvaluation( expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - checkEvaluationWithoutCodegen(expression, expected, inputRow) - checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow) - checkEvaluationWithGeneratedProjection(expression, expected, inputRow) - checkEvaluationWithOptimization(expression, expected, inputRow) + val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) + checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) + checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) + checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow) + checkEvaluationWithOptimization(expression, catalystValue, inputRow) + } + + /** + * Check the equality between result of expression and expected value, it will handle + * Array[Byte]. + */ + protected def checkResult(result: Any, expected: Any): Boolean = { + (result, expected) match { + case (result: Array[Byte], expected: Array[Byte]) => + java.util.Arrays.equals(result, expected) + case _ => result == expected + } } protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { @@ -55,7 +68,7 @@ trait ExpressionEvalHelper { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - if (actual != expected) { + if (!checkResult(actual, expected)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect evaluation (codegen off): $expression, " + s"actual: $actual, " + @@ -83,7 +96,7 @@ trait ExpressionEvalHelper { } val actual = plan(inputRow).apply(0) - if (actual != expected) { + if (!checkResult(actual, expected)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") } @@ -109,7 +122,7 @@ trait ExpressionEvalHelper { } val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected))) + val expectedRow = new GenericRow(Array[Any](expected)) if (actual.hashCode() != expectedRow.hashCode()) { fail( s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index f44f55dfb9..d924ff7a10 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -18,12 +18,26 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types._ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { - // TODO: Add tests for all data types. + test("null") { + checkEvaluation(Literal.create(null, BooleanType), null) + checkEvaluation(Literal.create(null, ByteType), null) + checkEvaluation(Literal.create(null, ShortType), null) + checkEvaluation(Literal.create(null, IntegerType), null) + checkEvaluation(Literal.create(null, LongType), null) + checkEvaluation(Literal.create(null, FloatType), null) + checkEvaluation(Literal.create(null, LongType), null) + checkEvaluation(Literal.create(null, StringType), null) + checkEvaluation(Literal.create(null, BinaryType), null) + checkEvaluation(Literal.create(null, DecimalType()), null) + checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null) + checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null) + checkEvaluation(Literal.create(null, StructType(Seq.empty)), null) + } test("boolean literals") { checkEvaluation(Literal(true), true) @@ -31,25 +45,52 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } test("int literals") { - checkEvaluation(Literal(1), 1) - checkEvaluation(Literal(0L), 0L) + List(0, 1, Int.MinValue, Int.MaxValue).foreach { d => + checkEvaluation(Literal(d), d) + checkEvaluation(Literal(d.toLong), d.toLong) + checkEvaluation(Literal(d.toShort), d.toShort) + checkEvaluation(Literal(d.toByte), d.toByte) + } + checkEvaluation(Literal(Long.MinValue), Long.MinValue) + checkEvaluation(Literal(Long.MaxValue), Long.MaxValue) } test("double literals") { - List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { - d => { - checkEvaluation(Literal(d), d) - checkEvaluation(Literal(d.toFloat), d.toFloat) - } + List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d => + checkEvaluation(Literal(d), d) + checkEvaluation(Literal(d.toFloat), d.toFloat) } + checkEvaluation(Literal(Double.MinValue), Double.MinValue) + checkEvaluation(Literal(Double.MaxValue), Double.MaxValue) + checkEvaluation(Literal(Float.MinValue), Float.MinValue) + checkEvaluation(Literal(Float.MaxValue), Float.MaxValue) + } test("string literals") { + checkEvaluation(Literal(""), "") checkEvaluation(Literal("test"), "test") - checkEvaluation(Literal.create(null, StringType), null) + checkEvaluation(Literal("\0"), "\0") } test("sum two literals") { checkEvaluation(Add(Literal(1), Literal(1)), 2) } + + test("binary literals") { + checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0)) + checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2)) + } + + test("decimal") { + List(0.0, 1.2, 1.1111, 5).foreach { d => + checkEvaluation(Literal(Decimal(d)), Decimal(d)) + checkEvaluation(Literal(Decimal(d.toInt)), Decimal(d.toInt)) + checkEvaluation(Literal(Decimal(d.toLong)), Decimal(d.toLong)) + checkEvaluation(Literal(Decimal((d * 1000L).toLong, 10, 1)), + Decimal((d * 1000L).toLong, 10, 1)) + } + } + + // TODO(davies): add tests for ArrayType, MapType and StructType } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala index d363e63154..5dbb1d562c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala @@ -222,9 +222,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringLength(regEx), 5, create_row("abdef")) checkEvaluation(StringLength(regEx), 0, create_row("")) checkEvaluation(StringLength(regEx), null, create_row(null)) - // TODO currently bug in codegen, let's temporally disable this - // checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) + checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef")) } - - } |