aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala2
2 files changed, 7 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index b86825902a..b87016d5a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -65,6 +65,11 @@ case class InMemoryTableScanExec(
case EqualTo(l: Literal, a: AttributeReference) =>
statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
+ case EqualNullSafe(a: AttributeReference, l: Literal) =>
+ statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
+ case EqualNullSafe(l: Literal, a: AttributeReference) =>
+ statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
+
case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l
case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
index b99cd67a63..9d862cfdec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
@@ -85,6 +85,8 @@ class PartitionBatchPruningSuite
// Comparisons
checkBatchPruning("SELECT key FROM pruningData WHERE key = 1", 1, 1)(Seq(1))
checkBatchPruning("SELECT key FROM pruningData WHERE 1 = key", 1, 1)(Seq(1))
+ checkBatchPruning("SELECT key FROM pruningData WHERE key <=> 1", 1, 1)(Seq(1))
+ checkBatchPruning("SELECT key FROM pruningData WHERE 1 <=> key", 1, 1)(Seq(1))
checkBatchPruning("SELECT key FROM pruningData WHERE key < 12", 1, 2)(1 to 11)
checkBatchPruning("SELECT key FROM pruningData WHERE key <= 11", 1, 2)(1 to 11)
checkBatchPruning("SELECT key FROM pruningData WHERE key > 88", 1, 2)(89 to 100)