aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2015-02-16 12:48:55 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-16 12:48:55 -0800
commit6f54dee66100e5e58f6649158db257eb5009bd6a (patch)
tree837a9fde3122cd1e7f26d485b43217db4a8ec7dd /sql
parentb4d7c7032d755de42951f92d9535287ef6230b9b (diff)
downloadspark-6f54dee66100e5e58f6649158db257eb5009bd6a.tar.gz
spark-6f54dee66100e5e58f6649158db257eb5009bd6a.tar.bz2
spark-6f54dee66100e5e58f6649158db257eb5009bd6a.zip
[SPARK-5296] [SQL] Add more filter types for data sources API
This PR adds the following filter types for data sources API: - `IsNull` - `IsNotNull` - `Not` - `And` - `Or` The code which converts Catalyst predicate expressions to data sources filters is very similar to filter conversion logics in `ParquetFilters` which converts Catalyst predicates to Parquet filter predicates. In this way we can support nested AND/OR/NOT predicates without changing current `BaseScan` type hierarchy. <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/4623) <!-- Reviewable:end --> Author: Cheng Lian <lian@databricks.com> This patch had conflicts when merged, resolved by Committer: Michael Armbrust <michael@databricks.com> Closes #4623 from liancheng/more-fiters and squashes the following commits: 1b296f4 [Cheng Lian] Add more filter types for data sources API
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala81
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala34
5 files changed, 103 insertions, 31 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index b42a52ebd2..1442250569 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -28,16 +28,16 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, NoRelation}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.execution._
+import org.apache.spark.sql.catalyst.{ScalaReflection, expressions}
+import org.apache.spark.sql.execution.{Filter, _}
import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
import org.apache.spark.sql.json._
-import org.apache.spark.sql.sources.{BaseRelation, DDLParser, DataSourceStrategy, LogicalRelation, _}
+import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.{Partition, SparkContext}
@@ -867,7 +867,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
val projectSet = AttributeSet(projectList.flatMap(_.references))
val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
- val filterCondition = prunePushedDownFilters(filterPredicates).reduceLeftOption(And)
+ val filterCondition =
+ prunePushedDownFilters(filterPredicates).reduceLeftOption(expressions.And)
// Right now we still use a projection even if the only evaluation is applying an alias
// to a column. Since this is a no-op, it could be avoided. However, using this
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index 9279f5a903..9bb34e2df9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -32,6 +32,7 @@ import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext}
+
import parquet.filter2.predicate.FilterApi
import parquet.format.converter.ParquetMetadataConverter
import parquet.hadoop.metadata.CompressionCodecName
@@ -42,6 +43,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD}
+import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.parquet.ParquetTypesConverter._
import org.apache.spark.sql.sources._
@@ -497,7 +499,8 @@ case class ParquetRelation2(
_.references.map(_.name).toSet.subsetOf(partitionColumnNames)
}
- val rawPredicate = partitionPruningPredicates.reduceOption(And).getOrElse(Literal(true))
+ val rawPredicate =
+ partitionPruningPredicates.reduceOption(expressions.And).getOrElse(Literal(true))
val boundPredicate = InterpretedPredicate(rawPredicate transform {
case a: AttributeReference =>
val index = partitionColumns.indexWhere(a.name == _.name)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index 624369afe8..a853385fda 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql.sources
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.{Row, Strategy, execution}
+import org.apache.spark.sql.{Row, Strategy, execution, sources}
/**
* A Strategy for planning scans over data sources defined using the sources API.
@@ -88,7 +88,7 @@ private[sql] object DataSourceStrategy extends Strategy {
val projectSet = AttributeSet(projectList.flatMap(_.references))
val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
- val filterCondition = filterPredicates.reduceLeftOption(And)
+ val filterCondition = filterPredicates.reduceLeftOption(expressions.And)
val pushedFilters = filterPredicates.map { _ transform {
case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes.
@@ -118,27 +118,60 @@ private[sql] object DataSourceStrategy extends Strategy {
}
}
- /** Turn Catalyst [[Expression]]s into data source [[Filter]]s. */
- protected[sql] def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect {
- case expressions.EqualTo(a: Attribute, expressions.Literal(v, _)) => EqualTo(a.name, v)
- case expressions.EqualTo(expressions.Literal(v, _), a: Attribute) => EqualTo(a.name, v)
-
- case expressions.GreaterThan(a: Attribute, expressions.Literal(v, _)) => GreaterThan(a.name, v)
- case expressions.GreaterThan(expressions.Literal(v, _), a: Attribute) => LessThan(a.name, v)
-
- case expressions.LessThan(a: Attribute, expressions.Literal(v, _)) => LessThan(a.name, v)
- case expressions.LessThan(expressions.Literal(v, _), a: Attribute) => GreaterThan(a.name, v)
-
- case expressions.GreaterThanOrEqual(a: Attribute, expressions.Literal(v, _)) =>
- GreaterThanOrEqual(a.name, v)
- case expressions.GreaterThanOrEqual(expressions.Literal(v, _), a: Attribute) =>
- LessThanOrEqual(a.name, v)
-
- case expressions.LessThanOrEqual(a: Attribute, expressions.Literal(v, _)) =>
- LessThanOrEqual(a.name, v)
- case expressions.LessThanOrEqual(expressions.Literal(v, _), a: Attribute) =>
- GreaterThanOrEqual(a.name, v)
+ /**
+ * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s,
+ * and convert them.
+ */
+ protected[sql] def selectFilters(filters: Seq[Expression]) = {
+ def translate(predicate: Expression): Option[Filter] = predicate match {
+ case expressions.EqualTo(a: Attribute, Literal(v, _)) =>
+ Some(sources.EqualTo(a.name, v))
+ case expressions.EqualTo(Literal(v, _), a: Attribute) =>
+ Some(sources.EqualTo(a.name, v))
+
+ case expressions.GreaterThan(a: Attribute, Literal(v, _)) =>
+ Some(sources.GreaterThan(a.name, v))
+ case expressions.GreaterThan(Literal(v, _), a: Attribute) =>
+ Some(sources.LessThan(a.name, v))
+
+ case expressions.LessThan(a: Attribute, Literal(v, _)) =>
+ Some(sources.LessThan(a.name, v))
+ case expressions.LessThan(Literal(v, _), a: Attribute) =>
+ Some(sources.GreaterThan(a.name, v))
+
+ case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) =>
+ Some(sources.GreaterThanOrEqual(a.name, v))
+ case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) =>
+ Some(sources.LessThanOrEqual(a.name, v))
+
+ case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) =>
+ Some(sources.LessThanOrEqual(a.name, v))
+ case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) =>
+ Some(sources.GreaterThanOrEqual(a.name, v))
+
+ case expressions.InSet(a: Attribute, set) =>
+ Some(sources.In(a.name, set.toArray))
+
+ case expressions.IsNull(a: Attribute) =>
+ Some(sources.IsNull(a.name))
+ case expressions.IsNotNull(a: Attribute) =>
+ Some(sources.IsNotNull(a.name))
+
+ case expressions.And(left, right) =>
+ (translate(left) ++ translate(right)).reduceOption(sources.And)
+
+ case expressions.Or(left, right) =>
+ for {
+ leftFilter <- translate(left)
+ rightFilter <- translate(right)
+ } yield sources.Or(leftFilter, rightFilter)
+
+ case expressions.Not(child) =>
+ translate(child).map(sources.Not)
+
+ case _ => None
+ }
- case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray)
+ filters.flatMap(translate)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
index 4a9fefc12b..1e4505e36d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
@@ -25,3 +25,8 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter
case class LessThan(attribute: String, value: Any) extends Filter
case class LessThanOrEqual(attribute: String, value: Any) extends Filter
case class In(attribute: String, values: Array[Any]) extends Filter
+case class IsNull(attribute: String) extends Filter
+case class IsNotNull(attribute: String) extends Filter
+case class And(left: Filter, right: Filter) extends Filter
+case class Or(left: Filter, right: Filter) extends Filter
+case class Not(child: Filter) extends Filter
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index 390538d35a..41cd35683c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -47,16 +47,22 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL
FiltersPushed.list = filters
- val filterFunctions = filters.collect {
+ def translateFilter(filter: Filter): Int => Boolean = filter match {
case EqualTo("a", v) => (a: Int) => a == v
case LessThan("a", v: Int) => (a: Int) => a < v
case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v
case GreaterThan("a", v: Int) => (a: Int) => a > v
case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v
case In("a", values) => (a: Int) => values.map(_.asInstanceOf[Int]).toSet.contains(a)
+ case IsNull("a") => (a: Int) => false // Int can't be null
+ case IsNotNull("a") => (a: Int) => true
+ case Not(pred) => (a: Int) => !translateFilter(pred)(a)
+ case And(left, right) => (a: Int) => translateFilter(left)(a) && translateFilter(right)(a)
+ case Or(left, right) => (a: Int) => translateFilter(left)(a) || translateFilter(right)(a)
+ case _ => (a: Int) => true
}
- def eval(a: Int) = !filterFunctions.map(_(a)).contains(false)
+ def eval(a: Int) = !filters.map(translateFilter(_)(a)).contains(false)
sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i =>
Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty)))
@@ -136,6 +142,26 @@ class FilteredScanSuite extends DataSourceTest {
"SELECT * FROM oneToTenFiltered WHERE b = 2",
Seq(1).map(i => Row(i, i * 2)).toSeq)
+ sqlTest(
+ "SELECT * FROM oneToTenFiltered WHERE a IS NULL",
+ Seq.empty[Row])
+
+ sqlTest(
+ "SELECT * FROM oneToTenFiltered WHERE a IS NOT NULL",
+ (1 to 10).map(i => Row(i, i * 2)).toSeq)
+
+ sqlTest(
+ "SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1",
+ (2 to 4).map(i => Row(i, i * 2)).toSeq)
+
+ sqlTest(
+ "SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8",
+ Seq(1, 2, 9, 10).map(i => Row(i, i * 2)).toSeq)
+
+ sqlTest(
+ "SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)",
+ (6 to 10).map(i => Row(i, i * 2)).toSeq)
+
testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1)
testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1)
testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1)
@@ -162,6 +188,10 @@ class FilteredScanSuite extends DataSourceTest {
testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0)
testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", 3)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5)
+
def testPushDown(sqlString: String, expectedCount: Int): Unit = {
test(s"PushDown Returns $expectedCount: $sqlString") {
val queryExecution = sql(sqlString).queryExecution