diff options
Diffstat (limited to 'sql')
2 files changed, 78 insertions, 47 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 34df89a163..d4569241e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -302,7 +302,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } else if (l == null || r == null) { false } else { - l == r + if (left.dataType != BinaryType) l == r + else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 72fec3b86e..188ecef9e7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -17,14 +17,11 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Date, Timestamp} - import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types.{IntegerType, BooleanType} +import org.apache.spark.sql.types.{Decimal, IntegerType, BooleanType} class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -66,12 +63,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { * Unknown Unknown */ // scalastyle:on - val notTrueTable = - (true, false) :: - (false, true) :: - (null, null) :: Nil test("3VL Not") { + val notTrueTable = + (true, false) :: + (false, true) :: + (null, null) :: Nil notTrueTable.foreach { case (v, answer) => checkEvaluation(Not(Literal.create(v, BooleanType)), answer) } @@ -126,8 +123,6 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { val two = Literal(2) val three = Literal(3) val nl = Literal(null) - val s = Seq(one, two) - val nullS = Seq(one, two, null) checkEvaluation(InSet(one, hS), true) checkEvaluation(InSet(two, hS), true) checkEvaluation(InSet(two, nS), true) @@ -137,43 +132,78 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) } + private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_)) + private val largeValues = Seq(2, Decimal(2), Array(2.toByte), "b").map(Literal(_)) - test("BinaryComparison") { - val row = create_row(1, 2, 3, null, 3, null) - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.int.at(2) - val c4 = 'a.int.at(3) - val c5 = 'a.int.at(4) - val c6 = 'a.int.at(5) + private val equalValues1 = smallValues + private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_)) - checkEvaluation(LessThan(c1, c4), null, row) - checkEvaluation(LessThan(c1, c2), true, row) - checkEvaluation(LessThan(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation(LessThan(Literal.create(null, IntegerType), c2), null, row) - checkEvaluation( - LessThan(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(c1 < c2, true, row) - checkEvaluation(c1 <= c2, true, row) - checkEvaluation(c1 > c2, false, row) - checkEvaluation(c1 >= c2, false, row) - checkEvaluation(c1 === c2, false, row) - checkEvaluation(c1 !== c2, true, row) - checkEvaluation(c4 <=> c1, false, row) - checkEvaluation(c1 <=> c4, false, row) - checkEvaluation(c4 <=> c6, true, row) - checkEvaluation(c3 <=> c5, true, row) - checkEvaluation(Literal(true) <=> Literal.create(null, BooleanType), false, row) - checkEvaluation(Literal.create(null, BooleanType) <=> Literal(true), false, row) - - val d1 = DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01")) - val d2 = DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-02")) - checkEvaluation(Literal(d1) < Literal(d2), true) - - val ts1 = new Timestamp(12) - val ts2 = new Timestamp(123) - checkEvaluation(Literal("ab") < Literal("abc"), true) - checkEvaluation(Literal(ts1) < Literal(ts2), true) + test("BinaryComparison: <") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) < largeValues(i), true) + checkEvaluation(equalValues1(i) < equalValues2(i), false) + checkEvaluation(largeValues(i) < smallValues(i), false) + } + } + + test("BinaryComparison: <=") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) <= largeValues(i), true) + checkEvaluation(equalValues1(i) <= equalValues2(i), true) + checkEvaluation(largeValues(i) <= smallValues(i), false) + } + } + + test("BinaryComparison: >") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) > largeValues(i), false) + checkEvaluation(equalValues1(i) > equalValues2(i), false) + checkEvaluation(largeValues(i) > smallValues(i), true) + } + } + + test("BinaryComparison: >=") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) >= largeValues(i), false) + checkEvaluation(equalValues1(i) >= equalValues2(i), true) + checkEvaluation(largeValues(i) >= smallValues(i), true) + } + } + + test("BinaryComparison: ===") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) === largeValues(i), false) + checkEvaluation(equalValues1(i) === equalValues2(i), true) + checkEvaluation(largeValues(i) === smallValues(i), false) + } + } + + test("BinaryComparison: <=>") { + for (i <- 0 until smallValues.length) { + checkEvaluation(smallValues(i) <=> largeValues(i), false) + checkEvaluation(equalValues1(i) <=> equalValues2(i), true) + checkEvaluation(largeValues(i) <=> smallValues(i), false) + } + } + + test("BinaryComparison: null test") { + val normalInt = Literal(1) + val nullInt = Literal.create(null, IntegerType) + + def nullTest(op: (Expression, Expression) => Expression): Unit = { + checkEvaluation(op(normalInt, nullInt), null) + checkEvaluation(op(nullInt, normalInt), null) + checkEvaluation(op(nullInt, nullInt), null) + } + + nullTest(LessThan) + nullTest(LessThanOrEqual) + nullTest(GreaterThan) + nullTest(GreaterThanOrEqual) + nullTest(EqualTo) + + checkEvaluation(normalInt <=> nullInt, false) + checkEvaluation(nullInt <=> normalInt, false) + checkEvaluation(nullInt <=> nullInt, true) } } |