aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala131
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala129
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala47
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala65
6 files changed, 315 insertions, 68 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 65859865c8..7265d6a4de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -43,7 +43,8 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
l,
projects,
filters,
- (a, f) => toCatalystRDD(l, a, t.buildScan(a, f))) :: Nil
+ (requestedColumns, allPredicates, _) =>
+ toCatalystRDD(l, requestedColumns, t.buildScan(requestedColumns, allPredicates))) :: Nil
case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _)) =>
pruneFilterProject(
@@ -266,47 +267,81 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
relation,
projects,
filterPredicates,
- (requestedColumns, pushedFilters) => {
- scanBuilder(requestedColumns, selectFilters(pushedFilters).toArray)
+ (requestedColumns, _, pushedFilters) => {
+ scanBuilder(requestedColumns, pushedFilters.toArray)
})
}
- // Based on Catalyst expressions.
+ // Based on Catalyst expressions. The `scanBuilder` function accepts three arguments:
+ //
+ // 1. A `Seq[Attribute]`, containing all required column attributes. Used to handle relation
+ // traits that support column pruning (e.g. `PrunedScan` and `PrunedFilteredScan`).
+ //
+ // 2. A `Seq[Expression]`, containing all gathered Catalyst filter expressions, only used for
+ // `CatalystScan`.
+ //
+ // 3. A `Seq[Filter]`, containing all data source `Filter`s that are converted from (possibly a
+ // subset of) Catalyst filter expressions and can be handled by `relation`. Used to handle
+ // relation traits (`CatalystScan` excluded) that support filter push-down (e.g.
+ // `PrunedFilteredScan` and `HadoopFsRelation`).
+ //
+ // Note that 2 and 3 shouldn't be used together.
protected def pruneFilterProjectRaw(
- relation: LogicalRelation,
- projects: Seq[NamedExpression],
- filterPredicates: Seq[Expression],
- scanBuilder: (Seq[Attribute], Seq[Expression]) => RDD[InternalRow]) = {
+ relation: LogicalRelation,
+ projects: Seq[NamedExpression],
+ filterPredicates: Seq[Expression],
+ scanBuilder: (Seq[Attribute], Seq[Expression], Seq[Filter]) => RDD[InternalRow]) = {
val projectSet = AttributeSet(projects.flatMap(_.references))
val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
- val filterCondition = filterPredicates.reduceLeftOption(expressions.And)
- val pushedFilters = filterPredicates.map { _ transform {
+ val candidatePredicates = filterPredicates.map { _ transform {
case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes.
}}
+ val (unhandledPredicates, pushedFilters) =
+ selectFilters(relation.relation, candidatePredicates)
+
+ // A set of column attributes that are only referenced by pushed down filters. We can eliminate
+ // them from requested columns.
+ val handledSet = {
+ val handledPredicates = filterPredicates.filterNot(unhandledPredicates.contains)
+ val unhandledSet = AttributeSet(unhandledPredicates.flatMap(_.references))
+ AttributeSet(handledPredicates.flatMap(_.references)) --
+ (projectSet ++ unhandledSet).map(relation.attributeMap)
+ }
+
+ // Combines all Catalyst filter `Expression`s that are either not convertible to data source
+ // `Filter`s or cannot be handled by `relation`.
+ val filterCondition = unhandledPredicates.reduceLeftOption(expressions.And)
+
if (projects.map(_.toAttribute) == projects &&
projectSet.size == projects.size &&
filterSet.subsetOf(projectSet)) {
// When it is possible to just use column pruning to get the right projection and
// when the columns of this projection are enough to evaluate all filter conditions,
// just do a scan followed by a filter, with no extra project.
- val requestedColumns =
- projects.asInstanceOf[Seq[Attribute]] // Safe due to if above.
- .map(relation.attributeMap) // Match original case of attributes.
+ val requestedColumns = projects
+ // Safe due to if above.
+ .asInstanceOf[Seq[Attribute]]
+ // Match original case of attributes.
+ .map(relation.attributeMap)
+ // Don't request columns that are only referenced by pushed filters.
+ .filterNot(handledSet.contains)
val scan = execution.PhysicalRDD.createFromDataSource(
projects.map(_.toAttribute),
- scanBuilder(requestedColumns, pushedFilters),
+ scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation)
filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)
} else {
- val requestedColumns = (projectSet ++ filterSet).map(relation.attributeMap).toSeq
+ // Don't request columns that are only referenced by pushed filters.
+ val requestedColumns =
+ (projectSet ++ filterSet -- handledSet).map(relation.attributeMap).toSeq
val scan = execution.PhysicalRDD.createFromDataSource(
requestedColumns,
- scanBuilder(requestedColumns, pushedFilters),
+ scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation)
execution.Project(projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan))
}
@@ -334,11 +369,12 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
}
/**
- * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s,
- * and convert them.
+ * Tries to translate a Catalyst [[Expression]] into data source [[Filter]].
+ *
+ * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`.
*/
- protected[sql] def selectFilters(filters: Seq[Expression]) = {
- def translate(predicate: Expression): Option[Filter] = predicate match {
+ protected[sql] def translateFilter(predicate: Expression): Option[Filter] = {
+ predicate match {
case expressions.EqualTo(a: Attribute, Literal(v, t)) =>
Some(sources.EqualTo(a.name, convertToScala(v, t)))
case expressions.EqualTo(Literal(v, t), a: Attribute) =>
@@ -387,16 +423,16 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
Some(sources.IsNotNull(a.name))
case expressions.And(left, right) =>
- (translate(left) ++ translate(right)).reduceOption(sources.And)
+ (translateFilter(left) ++ translateFilter(right)).reduceOption(sources.And)
case expressions.Or(left, right) =>
for {
- leftFilter <- translate(left)
- rightFilter <- translate(right)
+ leftFilter <- translateFilter(left)
+ rightFilter <- translateFilter(right)
} yield sources.Or(leftFilter, rightFilter)
case expressions.Not(child) =>
- translate(child).map(sources.Not)
+ translateFilter(child).map(sources.Not)
case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) =>
Some(sources.StringStartsWith(a.name, v.toString))
@@ -409,7 +445,52 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
case _ => None
}
+ }
+
+ /**
+ * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s
+ * and can be handled by `relation`.
+ *
+ * @return A pair of `Seq[Expression]` and `Seq[Filter]`. The first element contains all Catalyst
+ * predicate [[Expression]]s that are either not convertible or cannot be handled by
+ * `relation`. The second element contains all converted data source [[Filter]]s that can
+ * be handled by `relation`.
+ */
+ protected[sql] def selectFilters(
+ relation: BaseRelation,
+ predicates: Seq[Expression]): (Seq[Expression], Seq[Filter]) = {
+
+ // For conciseness, all Catalyst filter expressions of type `expressions.Expression` below are
+ // called `predicate`s, while all data source filters of type `sources.Filter` are simply called
+ // `filter`s.
+
+ val translated: Seq[(Expression, Filter)] =
+ for {
+ predicate <- predicates
+ filter <- translateFilter(predicate)
+ } yield predicate -> filter
+
+ // A map from original Catalyst expressions to corresponding translated data source filters.
+ val translatedMap: Map[Expression, Filter] = translated.toMap
+
+ // Catalyst predicate expressions that cannot be translated to data source filters.
+ val unrecognizedPredicates = predicates.filterNot(translatedMap.contains)
+
+ // Data source filters that cannot be handled by `relation`
+ val unhandledFilters = relation.unhandledFilters(translatedMap.values.toArray).toSet
+
+ val (unhandled, handled) = translated.partition {
+ case (predicate, filter) =>
+ unhandledFilters.contains(filter)
+ }
+
+ // Catalyst predicate expressions that can be translated to data source filters, but cannot be
+ // handled by `relation`.
+ val (unhandledPredicates, _) = unhandled.unzip
+
+ // Translated data source filters that can be handled by `relation`
+ val (_, handledFilters) = handled.unzip
- filters.flatMap(translate)
+ (unrecognizedPredicates ++ unhandledPredicates, handledFilters)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 7a55351148..e296d631f0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -233,6 +233,15 @@ abstract class BaseRelation {
* @since 1.4.0
*/
def needConversion: Boolean = true
+
+ /**
+ * Given an array of [[Filter]]s, returns an array of [[Filter]]s that this data source relation
+ * cannot handle. Spark SQL will apply all returned [[Filter]]s against rows returned by this
+ * data source relation.
+ *
+ * @since 1.6.0
+ */
+ def unhandledFilters(filters: Array[Filter]): Array[Filter] = filters
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index f88ddc77a6..c24c9f025d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -59,7 +59,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
}.flatten
assert(analyzedPredicate.nonEmpty)
- val selectedFilters = DataSourceStrategy.selectFilters(analyzedPredicate)
+ val selectedFilters = analyzedPredicate.flatMap(DataSourceStrategy.translateFilter)
assert(selectedFilters.nonEmpty)
selectedFilters.foreach { pred =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index 68ce37c000..7541e72302 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.sources
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+
import scala.language.existentials
import org.apache.spark.rdd.RDD
@@ -44,16 +46,39 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL
StructField("b", IntegerType, nullable = false) ::
StructField("c", StringType, nullable = false) :: Nil)
+ override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
+ def unhandled(filter: Filter): Boolean = {
+ filter match {
+ case EqualTo(col, v) => col == "b"
+ case EqualNullSafe(col, v) => col == "b"
+ case LessThan(col, v: Int) => col == "b"
+ case LessThanOrEqual(col, v: Int) => col == "b"
+ case GreaterThan(col, v: Int) => col == "b"
+ case GreaterThanOrEqual(col, v: Int) => col == "b"
+ case In(col, values) => col == "b"
+ case IsNull(col) => col == "b"
+ case IsNotNull(col) => col == "b"
+ case Not(pred) => unhandled(pred)
+ case And(left, right) => unhandled(left) || unhandled(right)
+ case Or(left, right) => unhandled(left) || unhandled(right)
+ case _ => false
+ }
+ }
+
+ filters.filter(unhandled)
+ }
+
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
val rowBuilders = requiredColumns.map {
case "a" => (i: Int) => Seq(i)
case "b" => (i: Int) => Seq(i * 2)
case "c" => (i: Int) =>
val c = (i - 1 + 'a').toChar.toString
- Seq(c * 5 + c.toUpperCase() * 5)
+ Seq(c * 5 + c.toUpperCase * 5)
}
FiltersPushed.list = filters
+ ColumnsRequired.set = requiredColumns.toSet
// Predicate test on integer column
def translateFilterOnA(filter: Filter): Int => Boolean = filter match {
@@ -86,9 +111,8 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL
}
def eval(a: Int) = {
- val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 'a').toChar.toString.toUpperCase() * 5
- !filters.map(translateFilterOnA(_)(a)).contains(false) &&
- !filters.map(translateFilterOnC(_)(c)).contains(false)
+ val c = (a - 1 + 'a').toChar.toString * 5 + (a - 1 + 'a').toChar.toString.toUpperCase * 5
+ filters.forall(translateFilterOnA(_)(a)) && filters.forall(translateFilterOnC(_)(c))
}
sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i =>
@@ -101,6 +125,11 @@ object FiltersPushed {
var list: Seq[Filter] = Nil
}
+// Used together with `SimpleFilteredScan` to check pushed columns.
+object ColumnsRequired {
+ var set: Set[String] = Set.empty
+}
+
class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
protected override lazy val sql = caseInsensitiveContext.sql _
@@ -115,12 +144,15 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
| to '10'
|)
""".stripMargin)
+
+ // UDF for testing filter push-down
+ caseInsensitiveContext.udf.register("udf_gt3", (_: Int) > 3)
}
sqlTest(
"SELECT * FROM oneToTenFiltered",
(1 to 10).map(i => Row(i, i * 2, (i - 1 + 'a').toChar.toString * 5
- + (i - 1 + 'a').toChar.toString.toUpperCase() * 5)).toSeq)
+ + (i - 1 + 'a').toChar.toString.toUpperCase * 5)).toSeq)
sqlTest(
"SELECT a, b FROM oneToTenFiltered",
@@ -202,49 +234,64 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
"SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'",
Seq(Row(5, 5 * 2, "e" * 5 + "E" * 5)))
- testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1)
- testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1)
- testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1)
- testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 = a", 1)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1, Set("a", "b", "c"))
+ testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1, Set("a"))
+ testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1, Set("b"))
+ testPushDown("SELECT a, b FROM oneToTenFiltered WHERE A = 1", 1, Set("a", "b"))
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 1", 1, Set("a", "b", "c"))
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 = a", 1, Set("a", "b", "c"))
+
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1", 9, Set("a", "b", "c"))
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a >= 2", 9, Set("a", "b", "c"))
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1", 9)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a >= 2", 9)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 < a", 9, Set("a", "b", "c"))
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 <= a", 9, Set("a", "b", "c"))
- testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 < a", 9)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 <= a", 9)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 > a", 0, Set("a", "b", "c"))
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 >= a", 2, Set("a", "b", "c"))
- testPushDown("SELECT * FROM oneToTenFiltered WHERE 1 > a", 0)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE 2 >= a", 2)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 1", 0, Set("a", "b", "c"))
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a <= 2", 2, Set("a", "b", "c"))
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 1", 0)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a <= 2", 2)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8, Set("a", "b", "c"))
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a > 1 AND a < 10", 8)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3, Set("a", "b", "c"))
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a IN (1,3,5)", 3)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0, Set("a", "b", "c"))
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10, Set("a", "b", "c"))
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", 3, Set("a", "b", "c"))
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4, Set("a", "b", "c"))
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5, Set("a", "b", "c"))
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", 3)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4)
- testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5)
+ testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", 1, Set("a", "b", "c"))
+ testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'C%'", 0, Set("a", "b", "c"))
- testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'c%'", 1)
- testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like 'C%'", 0)
+ testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", 1, Set("a", "b", "c"))
+ testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%d'", 0, Set("a", "b", "c"))
- testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%D'", 1)
- testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%d'", 0)
+ testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1, Set("a", "b", "c"))
+ testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0, Set("a", "b", "c"))
- testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1)
- testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0)
+ testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1, Set("c"))
+ testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1, Set("c"))
- testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1)
- testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1)
+ // Columns only referenced by UDF filter must be required, as UDF filters can't be pushed down.
+ testPushDown("SELECT c FROM oneToTenFiltered WHERE udf_gt3(A)", 10, Set("a", "c"))
- def testPushDown(sqlString: String, expectedCount: Int): Unit = {
+ // A query with an unconvertible filter, an unhandled filter, and a handled filter.
+ testPushDown(
+ """SELECT a
+ | FROM oneToTenFiltered
+ | WHERE udf_gt3(b)
+ | AND b < 16
+ | AND c IN ('bbbbbBBBBB', 'cccccCCCCC', 'dddddDDDDD', 'foo')
+ """.stripMargin.split("\n").map(_.trim).mkString(" "), 3, Set("a", "b"))
+
+ def testPushDown(
+ sqlString: String,
+ expectedCount: Int,
+ requiredColumnNames: Set[String]): Unit = {
test(s"PushDown Returns $expectedCount: $sqlString") {
val queryExecution = sql(sqlString).queryExecution
val rawPlan = queryExecution.executedPlan.collect {
@@ -254,6 +301,17 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
case _ => fail(s"More than one PhysicalRDD found\n$queryExecution")
}
val rawCount = rawPlan.execute().count()
+ assert(ColumnsRequired.set === requiredColumnNames)
+
+ assert {
+ val table = caseInsensitiveContext.table("oneToTenFiltered")
+ val relation = table.queryExecution.logical.collectFirst {
+ case LogicalRelation(r, _) => r
+ }.get
+
+ // `relation` should be able to handle all pushed filters
+ relation.unhandledFilters(FiltersPushed.list.toArray).isEmpty
+ }
if (rawCount != expectedCount) {
fail(
@@ -264,4 +322,3 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
}
}
}
-
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
index a3a124488d..d945408341 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala
@@ -18,11 +18,16 @@
package org.apache.spark.sql.sources
import org.apache.hadoop.fs.Path
-
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.execution.PhysicalRDD
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
+ import testImplicits._
+
override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName
// We have a very limited number of supported types at here since it is just for a
@@ -64,4 +69,44 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest {
.load(file.getCanonicalPath))
}
}
+
+ private val writer = testDF.write.option("dataSchema", dataSchema.json).format(dataSourceName)
+ private val reader = sqlContext.read.option("dataSchema", dataSchema.json).format(dataSourceName)
+
+ test("unhandledFilters") {
+ withTempPath { dir =>
+
+ val path = dir.getCanonicalPath
+ writer.save(s"$path/p=0")
+ writer.save(s"$path/p=1")
+
+ val isOdd = udf((_: Int) % 2 == 1)
+ val df = reader.load(path)
+ .filter(
+ // This filter is inconvertible
+ isOdd('a) &&
+ // This filter is convertible but unhandled
+ 'a > 1 &&
+ // This filter is convertible and handled
+ 'b > "val_1" &&
+ // This filter references a partiiton column, won't be pushed down
+ 'p === 1
+ ).select('a, 'p)
+ val rawScan = df.queryExecution.executedPlan collect {
+ case p: PhysicalRDD => p
+ } match {
+ case Seq(p) => p
+ }
+
+ val outputSchema = new StructType().add("a", IntegerType).add("p", IntegerType)
+
+ assertResult(Set((2, 1), (3, 1))) {
+ rawScan.execute().collect()
+ .map { CatalystTypeConverters.convertToScala(_, outputSchema) }
+ .map { case Row(a, p) => (a, p) }.toSet
+ }
+
+ checkAnswer(df, Row(3, 1))
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
index aeaaa3e1c5..da09e1b00a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.sources
import java.text.NumberFormat
-import java.util.UUID
import com.google.common.base.Objects
import org.apache.hadoop.fs.{FileStatus, Path}
@@ -26,12 +25,12 @@ import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
-import org.apache.spark.rdd.RDD
import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
-import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions}
import org.apache.spark.sql.types.{DataType, StructType}
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{Row, SQLContext, sources}
/**
* A simple example [[HadoopFsRelationProvider]].
@@ -124,6 +123,53 @@ class SimpleTextRelation(
}
}
+ override def buildScan(
+ requiredColumns: Array[String],
+ filters: Array[Filter],
+ inputFiles: Array[FileStatus]): RDD[Row] = {
+
+ val fields = this.dataSchema.map(_.dataType)
+ val inputAttributes = this.dataSchema.toAttributes
+ val outputAttributes = requiredColumns.flatMap(name => inputAttributes.find(_.name == name))
+ val dataSchema = this.dataSchema
+
+ val inputPaths = inputFiles.map(_.getPath).mkString(",")
+ sparkContext.textFile(inputPaths).mapPartitions { iterator =>
+ // Constructs a filter predicate to simulate filter push-down
+ val predicate = {
+ val filterCondition: Expression = filters.collect {
+ // According to `unhandledFilters`, `SimpleTextRelation` only handles `GreaterThan` filter
+ case sources.GreaterThan(column, value) =>
+ val dataType = dataSchema(column).dataType
+ val literal = Literal.create(value, dataType)
+ val attribute = inputAttributes.find(_.name == column).get
+ expressions.GreaterThan(attribute, literal)
+ }.reduceOption(expressions.And).getOrElse(Literal(true))
+ InterpretedPredicate.create(filterCondition, inputAttributes)
+ }
+
+ // Uses a simple projection to simulate column pruning
+ val projection = new InterpretedMutableProjection(outputAttributes, inputAttributes)
+ val toScala = {
+ val requiredSchema = StructType.fromAttributes(outputAttributes)
+ CatalystTypeConverters.createToScalaConverter(requiredSchema)
+ }
+
+ iterator.map { record =>
+ new GenericInternalRow(record.split(",", -1).zip(fields).map {
+ case (v, dataType) =>
+ val value = if (v == "") null else v
+ // `Cast`ed values are always of internal types (e.g. UTF8String instead of String)
+ Cast(Literal(value), dataType).eval()
+ })
+ }.filter { row =>
+ predicate(row)
+ }.map { row =>
+ toScala(projection(row)).asInstanceOf[Row]
+ }
+ }
+ }
+
override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory {
job.setOutputFormatClass(classOf[TextOutputFormat[_, _]])
@@ -134,6 +180,15 @@ class SimpleTextRelation(
new SimpleTextOutputWriter(path, context)
}
}
+
+ // `SimpleTextRelation` only handles `GreaterThan` filter. This is used to test filter push-down
+ // and `BaseRelation.unhandledFilters()`.
+ override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
+ filters.filter {
+ case _: GreaterThan => false
+ case _ => true
+ }
+ }
}
/**