aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala23
2 files changed, 26 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 183e4947b6..67a410f539 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -79,6 +79,10 @@ private[sql] case class InMemoryTableScanExec(
case IsNull(a: Attribute) => statsFor(a).nullCount > 0
case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0
+
+ case In(a: AttributeReference, list: Seq[Expression]) if list.forall(_.isInstanceOf[Literal]) =>
+ list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] &&
+ l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _)
}
val partitionFilters: Seq[Expression] = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
index 7ca8e047f0..b99cd67a63 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
@@ -65,11 +65,18 @@ class PartitionBatchPruningSuite
}, 5).toDF()
pruningData.createOrReplaceTempView("pruningData")
spark.catalog.cacheTable("pruningData")
+
+ val pruningStringData = sparkContext.makeRDD((100 to 200).map { key =>
+ StringData(key.toString)
+ }, 5).toDF()
+ pruningStringData.createOrReplaceTempView("pruningStringData")
+ spark.catalog.cacheTable("pruningStringData")
}
override protected def afterEach(): Unit = {
try {
spark.catalog.uncacheTable("pruningData")
+ spark.catalog.uncacheTable("pruningStringData")
} finally {
super.afterEach()
}
@@ -110,9 +117,23 @@ class PartitionBatchPruningSuite
88 to 100
}
- // With unsupported predicate
+ // Support `IN` predicate
+ checkBatchPruning("SELECT key FROM pruningData WHERE key IN (1)", 1, 1)(Seq(1))
+ checkBatchPruning("SELECT key FROM pruningData WHERE key IN (1, 2)", 1, 1)(Seq(1, 2))
+ checkBatchPruning("SELECT key FROM pruningData WHERE key IN (1, 11)", 1, 2)(Seq(1, 11))
+ checkBatchPruning("SELECT key FROM pruningData WHERE key IN (1, 21, 41, 61, 81)", 5, 5)(
+ Seq(1, 21, 41, 61, 81))
+ checkBatchPruning("SELECT CAST(s AS INT) FROM pruningStringData WHERE s = '100'", 1, 1)(Seq(100))
+ checkBatchPruning("SELECT CAST(s AS INT) FROM pruningStringData WHERE s < '102'", 1, 1)(
+ Seq(100, 101))
+ checkBatchPruning(
+ "SELECT CAST(s AS INT) FROM pruningStringData WHERE s IN ('99', '150', '201')", 1, 1)(
+ Seq(150))
+
+ // With unsupported `InSet` predicate
{
val seq = (1 to 30).mkString(", ")
+ checkBatchPruning(s"SELECT key FROM pruningData WHERE key IN ($seq)", 5, 10)(1 to 30)
checkBatchPruning(s"SELECT key FROM pruningData WHERE NOT (key IN ($seq))", 5, 10)(31 to 100)
checkBatchPruning(s"SELECT key FROM pruningData WHERE NOT (key IN ($seq)) AND key > 88", 1, 2) {
89 to 100