aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-07-02 10:06:38 -0700
committerDavies Liu <davies@databricks.com>2015-07-02 10:06:38 -0700
commitafa021e03f0a1a326be2ed742332845b77f94c55 (patch)
tree2a704afb247c38dfd2b3e92e4891d2d2720a96d6 /sql
parent5b3338130dfd9db92c4894a348839a62ebb57ef3 (diff)
downloadspark-afa021e03f0a1a326be2ed742332845b77f94c55.tar.gz
spark-afa021e03f0a1a326be2ed742332845b77f94c55.tar.bz2
spark-afa021e03f0a1a326be2ed742332845b77f94c55.zip
[SPARK-8747] [SQL] fix EqualNullSafe for binary type
also improve tests for binary comparison. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7143 from cloud-fan/binary and squashes the following commits: 28a5b76 [Wenchen Fan] improve test 04ef4b0 [Wenchen Fan] fix equalNullSafe
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala122
2 files changed, 78 insertions, 47 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 34df89a163..d4569241e7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -302,7 +302,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
} else if (l == null || r == null) {
false
} else {
- l == r
+ if (left.dataType != BinaryType) l == r
+ else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 72fec3b86e..188ecef9e7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -17,14 +17,11 @@
package org.apache.spark.sql.catalyst.expressions
-import java.sql.{Date, Timestamp}
-
import scala.collection.immutable.HashSet
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.types.{IntegerType, BooleanType}
+import org.apache.spark.sql.types.{Decimal, IntegerType, BooleanType}
class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -66,12 +63,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
* Unknown Unknown
*/
// scalastyle:on
- val notTrueTable =
- (true, false) ::
- (false, true) ::
- (null, null) :: Nil
test("3VL Not") {
+ val notTrueTable =
+ (true, false) ::
+ (false, true) ::
+ (null, null) :: Nil
notTrueTable.foreach { case (v, answer) =>
checkEvaluation(Not(Literal.create(v, BooleanType)), answer)
}
@@ -126,8 +123,6 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
val two = Literal(2)
val three = Literal(3)
val nl = Literal(null)
- val s = Seq(one, two)
- val nullS = Seq(one, two, null)
checkEvaluation(InSet(one, hS), true)
checkEvaluation(InSet(two, hS), true)
checkEvaluation(InSet(two, nS), true)
@@ -137,43 +132,78 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
}
+ private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_))
+ private val largeValues = Seq(2, Decimal(2), Array(2.toByte), "b").map(Literal(_))
- test("BinaryComparison") {
- val row = create_row(1, 2, 3, null, 3, null)
- val c1 = 'a.int.at(0)
- val c2 = 'a.int.at(1)
- val c3 = 'a.int.at(2)
- val c4 = 'a.int.at(3)
- val c5 = 'a.int.at(4)
- val c6 = 'a.int.at(5)
+ private val equalValues1 = smallValues
+ private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_))
- checkEvaluation(LessThan(c1, c4), null, row)
- checkEvaluation(LessThan(c1, c2), true, row)
- checkEvaluation(LessThan(c1, Literal.create(null, IntegerType)), null, row)
- checkEvaluation(LessThan(Literal.create(null, IntegerType), c2), null, row)
- checkEvaluation(
- LessThan(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row)
-
- checkEvaluation(c1 < c2, true, row)
- checkEvaluation(c1 <= c2, true, row)
- checkEvaluation(c1 > c2, false, row)
- checkEvaluation(c1 >= c2, false, row)
- checkEvaluation(c1 === c2, false, row)
- checkEvaluation(c1 !== c2, true, row)
- checkEvaluation(c4 <=> c1, false, row)
- checkEvaluation(c1 <=> c4, false, row)
- checkEvaluation(c4 <=> c6, true, row)
- checkEvaluation(c3 <=> c5, true, row)
- checkEvaluation(Literal(true) <=> Literal.create(null, BooleanType), false, row)
- checkEvaluation(Literal.create(null, BooleanType) <=> Literal(true), false, row)
-
- val d1 = DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01"))
- val d2 = DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-02"))
- checkEvaluation(Literal(d1) < Literal(d2), true)
-
- val ts1 = new Timestamp(12)
- val ts2 = new Timestamp(123)
- checkEvaluation(Literal("ab") < Literal("abc"), true)
- checkEvaluation(Literal(ts1) < Literal(ts2), true)
+ test("BinaryComparison: <") {
+ for (i <- 0 until smallValues.length) {
+ checkEvaluation(smallValues(i) < largeValues(i), true)
+ checkEvaluation(equalValues1(i) < equalValues2(i), false)
+ checkEvaluation(largeValues(i) < smallValues(i), false)
+ }
+ }
+
+ test("BinaryComparison: <=") {
+ for (i <- 0 until smallValues.length) {
+ checkEvaluation(smallValues(i) <= largeValues(i), true)
+ checkEvaluation(equalValues1(i) <= equalValues2(i), true)
+ checkEvaluation(largeValues(i) <= smallValues(i), false)
+ }
+ }
+
+ test("BinaryComparison: >") {
+ for (i <- 0 until smallValues.length) {
+ checkEvaluation(smallValues(i) > largeValues(i), false)
+ checkEvaluation(equalValues1(i) > equalValues2(i), false)
+ checkEvaluation(largeValues(i) > smallValues(i), true)
+ }
+ }
+
+ test("BinaryComparison: >=") {
+ for (i <- 0 until smallValues.length) {
+ checkEvaluation(smallValues(i) >= largeValues(i), false)
+ checkEvaluation(equalValues1(i) >= equalValues2(i), true)
+ checkEvaluation(largeValues(i) >= smallValues(i), true)
+ }
+ }
+
+ test("BinaryComparison: ===") {
+ for (i <- 0 until smallValues.length) {
+ checkEvaluation(smallValues(i) === largeValues(i), false)
+ checkEvaluation(equalValues1(i) === equalValues2(i), true)
+ checkEvaluation(largeValues(i) === smallValues(i), false)
+ }
+ }
+
+ test("BinaryComparison: <=>") {
+ for (i <- 0 until smallValues.length) {
+ checkEvaluation(smallValues(i) <=> largeValues(i), false)
+ checkEvaluation(equalValues1(i) <=> equalValues2(i), true)
+ checkEvaluation(largeValues(i) <=> smallValues(i), false)
+ }
+ }
+
+ test("BinaryComparison: null test") {
+ val normalInt = Literal(1)
+ val nullInt = Literal.create(null, IntegerType)
+
+ def nullTest(op: (Expression, Expression) => Expression): Unit = {
+ checkEvaluation(op(normalInt, nullInt), null)
+ checkEvaluation(op(nullInt, normalInt), null)
+ checkEvaluation(op(nullInt, nullInt), null)
+ }
+
+ nullTest(LessThan)
+ nullTest(LessThanOrEqual)
+ nullTest(GreaterThan)
+ nullTest(GreaterThanOrEqual)
+ nullTest(EqualTo)
+
+ checkEvaluation(normalInt <=> nullInt, false)
+ checkEvaluation(nullInt <=> normalInt, false)
+ checkEvaluation(nullInt <=> nullInt, true)
}
}