aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala67
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala52
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala27
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala61
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala5
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java6
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java2
10 files changed, 139 insertions, 135 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java
index 611e02d8fb..6a2356f1f9 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java
@@ -155,27 +155,6 @@ public abstract class BaseRow extends InternalRow {
throw new UnsupportedOperationException();
}
- /**
- * A generic version of Row.equals(Row), which is used for tests.
- */
- @Override
- public boolean equals(Object other) {
- if (other instanceof Row) {
- Row row = (Row) other;
- int n = size();
- if (n != row.size()) {
- return false;
- }
- for (int i = 0; i < n; i ++) {
- if (isNullAt(i) != row.isNullAt(i) || (!isNullAt(i) && !get(i).equals(row.get(i)))) {
- return false;
- }
- }
- return true;
- }
- return false;
- }
-
@Override
public InternalRow copy() {
final int n = size();
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index 8aaf5d7d89..e99d5c87a4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql
-import scala.util.hashing.MurmurHash3
-
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types.StructType
@@ -365,36 +363,6 @@ trait Row extends Serializable {
false
}
- override def equals(that: Any): Boolean = that match {
- case null => false
- case that: Row =>
- if (this.length != that.length) {
- return false
- }
- var i = 0
- val len = this.length
- while (i < len) {
- if (apply(i) != that.apply(i)) {
- return false
- }
- i += 1
- }
- true
- case _ => false
- }
-
- override def hashCode: Int = {
- // Using Scala's Seq hash code implementation.
- var n = 0
- var h = MurmurHash3.seqSeed
- val len = length
- while (n < len) {
- h = MurmurHash3.mix(h, apply(n).##)
- n += 1
- }
- MurmurHash3.finalizeHash(h, n)
- }
-
/* ---------------------- utility methods for Scala ---------------------- */
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index e3c2cc2433..d7b537a9fe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.expressions.GenericRow
+import org.apache.spark.sql.catalyst.expressions._
/**
* An abstract class for row used internal in Spark SQL, which only contain the columns as
@@ -26,7 +26,70 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow
*/
abstract class InternalRow extends Row {
// A default implementation to change the return type
- override def copy(): InternalRow = {this}
+ override def copy(): InternalRow = this
+
+ override def equals(o: Any): Boolean = {
+ if (!o.isInstanceOf[Row]) {
+ return false
+ }
+
+ val other = o.asInstanceOf[Row]
+ if (length != other.length) {
+ return false
+ }
+
+ var i = 0
+ while (i < length) {
+ if (isNullAt(i) != other.isNullAt(i)) {
+ return false
+ }
+ if (!isNullAt(i)) {
+ val o1 = apply(i)
+ val o2 = other.apply(i)
+ if (o1.isInstanceOf[Array[Byte]]) {
+ // handle equality of Array[Byte]
+ val b1 = o1.asInstanceOf[Array[Byte]]
+ if (!o2.isInstanceOf[Array[Byte]] ||
+ !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
+ return false
+ }
+ } else if (o1 != o2) {
+ return false
+ }
+ }
+ i += 1
+ }
+ true
+ }
+
+ // Custom hashCode function that matches the efficient code generated version.
+ override def hashCode: Int = {
+ var result: Int = 37
+ var i = 0
+ while (i < length) {
+ val update: Int =
+ if (isNullAt(i)) {
+ 0
+ } else {
+ apply(i) match {
+ case b: Boolean => if (b) 0 else 1
+ case b: Byte => b.toInt
+ case s: Short => s.toInt
+ case i: Int => i
+ case l: Long => (l ^ (l >>> 32)).toInt
+ case f: Float => java.lang.Float.floatToIntBits(f)
+ case d: Double =>
+ val b = java.lang.Double.doubleToLongBits(d)
+ (b ^ (b >>> 32)).toInt
+ case a: Array[Byte] => java.util.Arrays.hashCode(a)
+ case other => other.hashCode()
+ }
+ }
+ result = 37 * result + update
+ i += 1
+ }
+ result
+ }
}
object InternalRow {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
index 2e20eda1a3..e362625469 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala
@@ -127,6 +127,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
case FloatType => s"Float.floatToIntBits($col)"
case DoubleType =>
s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))"
+ case BinaryType => s"java.util.Arrays.hashCode($col)"
case _ => s"$col.hashCode()"
}
s"isNullAt($i) ? 0 : ($nonNull)"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 1098962ddc..0d4c9ace5e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -121,58 +121,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow {
}
}
- // TODO(davies): add getDate and getDecimal
-
- // Custom hashCode function that matches the efficient code generated version.
- override def hashCode: Int = {
- var result: Int = 37
-
- var i = 0
- while (i < values.length) {
- val update: Int =
- if (isNullAt(i)) {
- 0
- } else {
- apply(i) match {
- case b: Boolean => if (b) 0 else 1
- case b: Byte => b.toInt
- case s: Short => s.toInt
- case i: Int => i
- case l: Long => (l ^ (l >>> 32)).toInt
- case f: Float => java.lang.Float.floatToIntBits(f)
- case d: Double =>
- val b = java.lang.Double.doubleToLongBits(d)
- (b ^ (b >>> 32)).toInt
- case other => other.hashCode()
- }
- }
- result = 37 * result + update
- i += 1
- }
- result
- }
-
- override def equals(o: Any): Boolean = o match {
- case other: InternalRow =>
- if (values.length != other.length) {
- return false
- }
-
- var i = 0
- while (i < values.length) {
- if (isNullAt(i) != other.isNullAt(i)) {
- return false
- }
- if (apply(i) != other.apply(i)) {
- return false
- }
- i += 1
- }
- true
-
- case _ => false
- }
-
override def copy(): InternalRow = this
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 12d2da8b33..158f54af13 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -38,10 +38,23 @@ trait ExpressionEvalHelper {
protected def checkEvaluation(
expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
- checkEvaluationWithoutCodegen(expression, expected, inputRow)
- checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow)
- checkEvaluationWithGeneratedProjection(expression, expected, inputRow)
- checkEvaluationWithOptimization(expression, expected, inputRow)
+ val catalystValue = CatalystTypeConverters.convertToCatalyst(expected)
+ checkEvaluationWithoutCodegen(expression, catalystValue, inputRow)
+ checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow)
+ checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow)
+ checkEvaluationWithOptimization(expression, catalystValue, inputRow)
+ }
+
+ /**
+ * Check the equality between result of expression and expected value, it will handle
+ * Array[Byte].
+ */
+ protected def checkResult(result: Any, expected: Any): Boolean = {
+ (result, expected) match {
+ case (result: Array[Byte], expected: Array[Byte]) =>
+ java.util.Arrays.equals(result, expected)
+ case _ => result == expected
+ }
}
protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = {
@@ -55,7 +68,7 @@ trait ExpressionEvalHelper {
val actual = try evaluate(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
- if (actual != expected) {
+ if (!checkResult(actual, expected)) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect evaluation (codegen off): $expression, " +
s"actual: $actual, " +
@@ -83,7 +96,7 @@ trait ExpressionEvalHelper {
}
val actual = plan(inputRow).apply(0)
- if (actual != expected) {
+ if (!checkResult(actual, expected)) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
}
@@ -109,7 +122,7 @@ trait ExpressionEvalHelper {
}
val actual = plan(inputRow)
- val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected)))
+ val expectedRow = new GenericRow(Array[Any](expected))
if (actual.hashCode() != expectedRow.hashCode()) {
fail(
s"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
index f44f55dfb9..d924ff7a10 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
@@ -18,12 +18,26 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.types._
class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
- // TODO: Add tests for all data types.
+ test("null") {
+ checkEvaluation(Literal.create(null, BooleanType), null)
+ checkEvaluation(Literal.create(null, ByteType), null)
+ checkEvaluation(Literal.create(null, ShortType), null)
+ checkEvaluation(Literal.create(null, IntegerType), null)
+ checkEvaluation(Literal.create(null, LongType), null)
+ checkEvaluation(Literal.create(null, FloatType), null)
+ checkEvaluation(Literal.create(null, LongType), null)
+ checkEvaluation(Literal.create(null, StringType), null)
+ checkEvaluation(Literal.create(null, BinaryType), null)
+ checkEvaluation(Literal.create(null, DecimalType()), null)
+ checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null)
+ checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null)
+ checkEvaluation(Literal.create(null, StructType(Seq.empty)), null)
+ }
test("boolean literals") {
checkEvaluation(Literal(true), true)
@@ -31,25 +45,52 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("int literals") {
- checkEvaluation(Literal(1), 1)
- checkEvaluation(Literal(0L), 0L)
+ List(0, 1, Int.MinValue, Int.MaxValue).foreach { d =>
+ checkEvaluation(Literal(d), d)
+ checkEvaluation(Literal(d.toLong), d.toLong)
+ checkEvaluation(Literal(d.toShort), d.toShort)
+ checkEvaluation(Literal(d.toByte), d.toByte)
+ }
+ checkEvaluation(Literal(Long.MinValue), Long.MinValue)
+ checkEvaluation(Literal(Long.MaxValue), Long.MaxValue)
}
test("double literals") {
- List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach {
- d => {
- checkEvaluation(Literal(d), d)
- checkEvaluation(Literal(d.toFloat), d.toFloat)
- }
+ List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d =>
+ checkEvaluation(Literal(d), d)
+ checkEvaluation(Literal(d.toFloat), d.toFloat)
}
+ checkEvaluation(Literal(Double.MinValue), Double.MinValue)
+ checkEvaluation(Literal(Double.MaxValue), Double.MaxValue)
+ checkEvaluation(Literal(Float.MinValue), Float.MinValue)
+ checkEvaluation(Literal(Float.MaxValue), Float.MaxValue)
+
}
test("string literals") {
+ checkEvaluation(Literal(""), "")
checkEvaluation(Literal("test"), "test")
- checkEvaluation(Literal.create(null, StringType), null)
+ checkEvaluation(Literal("\0"), "\0")
}
test("sum two literals") {
checkEvaluation(Add(Literal(1), Literal(1)), 2)
}
+
+ test("binary literals") {
+ checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0))
+ checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2))
+ }
+
+ test("decimal") {
+ List(0.0, 1.2, 1.1111, 5).foreach { d =>
+ checkEvaluation(Literal(Decimal(d)), Decimal(d))
+ checkEvaluation(Literal(Decimal(d.toInt)), Decimal(d.toInt))
+ checkEvaluation(Literal(Decimal(d.toLong)), Decimal(d.toLong))
+ checkEvaluation(Literal(Decimal((d * 1000L).toLong, 10, 1)),
+ Decimal((d * 1000L).toLong, 10, 1))
+ }
+ }
+
+ // TODO(davies): add tests for ArrayType, MapType and StructType
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
index d363e63154..5dbb1d562c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
@@ -222,9 +222,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringLength(regEx), 5, create_row("abdef"))
checkEvaluation(StringLength(regEx), 0, create_row(""))
checkEvaluation(StringLength(regEx), null, create_row(null))
- // TODO currently bug in codegen, let's temporally disable this
- // checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
+ checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
}
-
-
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 9871a70a40..9302b47292 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -17,10 +17,10 @@
package org.apache.spark.unsafe.types;
+import javax.annotation.Nonnull;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
-import javax.annotation.Nonnull;
import org.apache.spark.unsafe.PlatformDependent;
@@ -202,10 +202,6 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
public boolean equals(final Object other) {
if (other instanceof UTF8String) {
return Arrays.equals(bytes, ((UTF8String) other).getBytes());
- } else if (other instanceof String) {
- // Used only in unit tests.
- String s = (String) other;
- return bytes.length >= s.length() && length() == s.length() && toString().equals(s);
} else {
return false;
}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index 80c179a1b5..796cdc9dbe 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -28,8 +28,6 @@ public class UTF8StringSuite {
Assert.assertEquals(UTF8String.fromString(str).length(), len);
Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).length(), len);
- Assert.assertEquals(UTF8String.fromString(str), str);
- Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")), str);
Assert.assertEquals(UTF8String.fromString(str).toString(), str);
Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")).toString(), str);
Assert.assertEquals(UTF8String.fromBytes(str.getBytes("utf8")), UTF8String.fromString(str));