aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorAndre Schumacher <andre.schumacher@iki.fi>2014-05-16 13:41:41 -0700
committerReynold Xin <rxin@apache.org>2014-05-16 13:41:41 -0700
commit40d6acd6ba2feccc600301f5c47d4f90157138b1 (patch)
treef7575f51808dd473f41329c46f3122ce71140167 /sql
parent032d6632ad4ab88c97c9e568b63169a114220a02 (diff)
downloadspark-40d6acd6ba2feccc600301f5c47d4f90157138b1.tar.gz
spark-40d6acd6ba2feccc600301f5c47d4f90157138b1.tar.bz2
spark-40d6acd6ba2feccc600301f5c47d4f90157138b1.zip
SPARK-1487 [SQL] Support record filtering via predicate pushdown in Parquet
Simple filter predicates such as LessThan, GreaterThan, etc., where one side is a literal and the other one a NamedExpression are now pushed down to the underlying ParquetTableScan. Here are some results for a microbenchmark with a simple schema of six fields of different types where most records failed the test: | Uncompressed | Compressed -------------| ------------- | ------------- File size | 10 GB | 2 GB Speedup | 2 | 1.8 Since mileage may vary I added a new option to SparkConf: `org.apache.spark.sql.parquet.filter.pushdown` Default value would be `true` and setting it to `false` disables the pushdown. When most rows are expected to pass the filter or when there are few fields performance can be better when pushdown is disabled. The default should fit situations with a reasonable number of (possibly nested) fields where not too many records on average pass the filter. Because of an issue with Parquet ([see here](https://github.com/Parquet/parquet-mr/issues/371])) currently only predicates on non-nullable attributes are pushed down. If one would know that for a given table no optional fields have missing values one could also allow overriding this. Author: Andre Schumacher <andre.schumacher@iki.fi> Closes #511 from AndreSchumacher/parquet_filter and squashes the following commits: 16bfe83 [Andre Schumacher] Removing leftovers from merge during rebase 7b304ca [Andre Schumacher] Fixing formatting c36d5cb [Andre Schumacher] Scalastyle 3da98db [Andre Schumacher] Second round of review feedback 7a78265 [Andre Schumacher] Fixing broken formatting in ParquetFilter a86553b [Andre Schumacher] First round of code review feedback b0f7806 [Andre Schumacher] Optimizing imports in ParquetTestData 85fea2d [Andre Schumacher] Adding SparkConf setting to disable filter predicate pushdown f0ad3cf [Andre Schumacher] Undoing changes not needed for this PR 210e9cb [Andre Schumacher] Adding disjunctive filter predicates a93a588 [Andre Schumacher] Adding unit test for filtering 6d22666 [Andre Schumacher] Extending ParquetFilters 93e8192 [Andre Schumacher] First commit Parquet record filtering
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala436
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala90
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala90
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala135
5 files changed, 731 insertions, 51 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f763106da4..394a59700d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -140,12 +140,35 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
InsertIntoParquetTable(relation, planLater(child), overwrite=true)(sparkContext) :: Nil
case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil
- case PhysicalOperation(projectList, filters, relation: ParquetRelation) =>
- // TODO: Should be pushing down filters as well.
+ case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => {
+ val remainingFilters =
+ if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) {
+ filters.filter {
+ // Note: filters cannot be pushed down to Parquet if they contain more complex
+ // expressions than simple "Attribute cmp Literal" comparisons. Here we remove
+ // all filters that have been pushed down. Note that a predicate such as
+ // "(A AND B) OR C" can result in "A OR C" being pushed down.
+ filter =>
+ val recordFilter = ParquetFilters.createFilter(filter)
+ if (!recordFilter.isDefined) {
+ // First case: the pushdown did not result in any record filter.
+ true
+ } else {
+ // Second case: a record filter was created; here we are conservative in
+ // the sense that even if "A" was pushed and we check for "A AND B" we
+ // still want to keep "A AND B" in the higher-level filter, not just "B".
+ !ParquetFilters.findExpression(recordFilter.get, filter).isDefined
+ }
+ }
+ } else {
+ filters
+ }
pruneFilterProject(
projectList,
- filters,
- ParquetTableScan(_, relation, None)(sparkContext)) :: Nil
+ remainingFilters,
+ ParquetTableScan(_, relation, filters)(sparkContext)) :: Nil
+ }
+
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
new file mode 100644
index 0000000000..052b0a9196
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
@@ -0,0 +1,436 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.parquet
+
+import org.apache.hadoop.conf.Configuration
+
+import parquet.filter._
+import parquet.filter.ColumnPredicates._
+import parquet.column.ColumnReader
+
+import com.google.common.io.BaseEncoding
+
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.expressions.{Predicate => CatalystPredicate}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkSqlSerializer
+
+object ParquetFilters {
+ val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter"
+ // set this to false if pushdown should be disabled
+ val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.hints.parquetFilterPushdown"
+
+ def createRecordFilter(filterExpressions: Seq[Expression]): UnboundRecordFilter = {
+ val filters: Seq[CatalystFilter] = filterExpressions.collect {
+ case (expression: Expression) if createFilter(expression).isDefined =>
+ createFilter(expression).get
+ }
+ if (filters.length > 0) filters.reduce(AndRecordFilter.and) else null
+ }
+
+ def createFilter(expression: Expression): Option[CatalystFilter] = {
+ def createEqualityFilter(
+ name: String,
+ literal: Literal,
+ predicate: CatalystPredicate) = literal.dataType match {
+ case BooleanType =>
+ ComparisonFilter.createBooleanFilter(name, literal.value.asInstanceOf[Boolean], predicate)
+ case IntegerType =>
+ ComparisonFilter.createIntFilter(
+ name,
+ (x: Int) => x == literal.value.asInstanceOf[Int],
+ predicate)
+ case LongType =>
+ ComparisonFilter.createLongFilter(
+ name,
+ (x: Long) => x == literal.value.asInstanceOf[Long],
+ predicate)
+ case DoubleType =>
+ ComparisonFilter.createDoubleFilter(
+ name,
+ (x: Double) => x == literal.value.asInstanceOf[Double],
+ predicate)
+ case FloatType =>
+ ComparisonFilter.createFloatFilter(
+ name,
+ (x: Float) => x == literal.value.asInstanceOf[Float],
+ predicate)
+ case StringType =>
+ ComparisonFilter.createStringFilter(name, literal.value.asInstanceOf[String], predicate)
+ }
+ def createLessThanFilter(
+ name: String,
+ literal: Literal,
+ predicate: CatalystPredicate) = literal.dataType match {
+ case IntegerType =>
+ ComparisonFilter.createIntFilter(
+ name,
+ (x: Int) => x < literal.value.asInstanceOf[Int],
+ predicate)
+ case LongType =>
+ ComparisonFilter.createLongFilter(
+ name,
+ (x: Long) => x < literal.value.asInstanceOf[Long],
+ predicate)
+ case DoubleType =>
+ ComparisonFilter.createDoubleFilter(
+ name,
+ (x: Double) => x < literal.value.asInstanceOf[Double],
+ predicate)
+ case FloatType =>
+ ComparisonFilter.createFloatFilter(
+ name,
+ (x: Float) => x < literal.value.asInstanceOf[Float],
+ predicate)
+ }
+ def createLessThanOrEqualFilter(
+ name: String,
+ literal: Literal,
+ predicate: CatalystPredicate) = literal.dataType match {
+ case IntegerType =>
+ ComparisonFilter.createIntFilter(
+ name,
+ (x: Int) => x <= literal.value.asInstanceOf[Int],
+ predicate)
+ case LongType =>
+ ComparisonFilter.createLongFilter(
+ name,
+ (x: Long) => x <= literal.value.asInstanceOf[Long],
+ predicate)
+ case DoubleType =>
+ ComparisonFilter.createDoubleFilter(
+ name,
+ (x: Double) => x <= literal.value.asInstanceOf[Double],
+ predicate)
+ case FloatType =>
+ ComparisonFilter.createFloatFilter(
+ name,
+ (x: Float) => x <= literal.value.asInstanceOf[Float],
+ predicate)
+ }
+ // TODO: combine these two types somehow?
+ def createGreaterThanFilter(
+ name: String,
+ literal: Literal,
+ predicate: CatalystPredicate) = literal.dataType match {
+ case IntegerType =>
+ ComparisonFilter.createIntFilter(
+ name,
+ (x: Int) => x > literal.value.asInstanceOf[Int],
+ predicate)
+ case LongType =>
+ ComparisonFilter.createLongFilter(
+ name,
+ (x: Long) => x > literal.value.asInstanceOf[Long],
+ predicate)
+ case DoubleType =>
+ ComparisonFilter.createDoubleFilter(
+ name,
+ (x: Double) => x > literal.value.asInstanceOf[Double],
+ predicate)
+ case FloatType =>
+ ComparisonFilter.createFloatFilter(
+ name,
+ (x: Float) => x > literal.value.asInstanceOf[Float],
+ predicate)
+ }
+ def createGreaterThanOrEqualFilter(
+ name: String,
+ literal: Literal,
+ predicate: CatalystPredicate) = literal.dataType match {
+ case IntegerType =>
+ ComparisonFilter.createIntFilter(
+ name, (x: Int) => x >= literal.value.asInstanceOf[Int],
+ predicate)
+ case LongType =>
+ ComparisonFilter.createLongFilter(
+ name,
+ (x: Long) => x >= literal.value.asInstanceOf[Long],
+ predicate)
+ case DoubleType =>
+ ComparisonFilter.createDoubleFilter(
+ name,
+ (x: Double) => x >= literal.value.asInstanceOf[Double],
+ predicate)
+ case FloatType =>
+ ComparisonFilter.createFloatFilter(
+ name,
+ (x: Float) => x >= literal.value.asInstanceOf[Float],
+ predicate)
+ }
+
+ /**
+ * TODO: we currently only filter on non-nullable (Parquet REQUIRED) attributes until
+ * https://github.com/Parquet/parquet-mr/issues/371
+ * has been resolved.
+ */
+ expression match {
+ case p @ Or(left: Expression, right: Expression)
+ if createFilter(left).isDefined && createFilter(right).isDefined => {
+ // If either side of this Or-predicate is empty then this means
+ // it contains a more complex comparison than between attribute and literal
+ // (e.g., it contained a CAST). The only safe thing to do is then to disregard
+ // this disjunction, which could be contained in a conjunction. If it stands
+ // alone then it is also safe to drop it, since a Null return value of this
+ // function is interpreted as having no filters at all.
+ val leftFilter = createFilter(left).get
+ val rightFilter = createFilter(right).get
+ Some(new OrFilter(leftFilter, rightFilter))
+ }
+ case p @ And(left: Expression, right: Expression) => {
+ // This treats nested conjunctions; since either side of the conjunction
+ // may contain more complex filter expressions we may actually generate
+ // strictly weaker filter predicates in the process.
+ val leftFilter = createFilter(left)
+ val rightFilter = createFilter(right)
+ (leftFilter, rightFilter) match {
+ case (None, Some(filter)) => Some(filter)
+ case (Some(filter), None) => Some(filter)
+ case (_, _) =>
+ Some(new AndFilter(leftFilter.get, rightFilter.get))
+ }
+ }
+ case p @ Equals(left: Literal, right: NamedExpression) if !right.nullable =>
+ Some(createEqualityFilter(right.name, left, p))
+ case p @ Equals(left: NamedExpression, right: Literal) if !left.nullable =>
+ Some(createEqualityFilter(left.name, right, p))
+ case p @ LessThan(left: Literal, right: NamedExpression) if !right.nullable =>
+ Some(createLessThanFilter(right.name, left, p))
+ case p @ LessThan(left: NamedExpression, right: Literal) if !left.nullable =>
+ Some(createLessThanFilter(left.name, right, p))
+ case p @ LessThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable =>
+ Some(createLessThanOrEqualFilter(right.name, left, p))
+ case p @ LessThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable =>
+ Some(createLessThanOrEqualFilter(left.name, right, p))
+ case p @ GreaterThan(left: Literal, right: NamedExpression) if !right.nullable =>
+ Some(createGreaterThanFilter(right.name, left, p))
+ case p @ GreaterThan(left: NamedExpression, right: Literal) if !left.nullable =>
+ Some(createGreaterThanFilter(left.name, right, p))
+ case p @ GreaterThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable =>
+ Some(createGreaterThanOrEqualFilter(right.name, left, p))
+ case p @ GreaterThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable =>
+ Some(createGreaterThanOrEqualFilter(left.name, right, p))
+ case _ => None
+ }
+ }
+
+ /**
+ * Note: Inside the Hadoop API we only have access to `Configuration`, not to
+ * [[org.apache.spark.SparkContext]], so we cannot use broadcasts to convey
+ * the actual filter predicate.
+ */
+ def serializeFilterExpressions(filters: Seq[Expression], conf: Configuration): Unit = {
+ if (filters.length > 0) {
+ val serialized: Array[Byte] = SparkSqlSerializer.serialize(filters)
+ val encoded: String = BaseEncoding.base64().encode(serialized)
+ conf.set(PARQUET_FILTER_DATA, encoded)
+ }
+ }
+
+ /**
+ * Note: Inside the Hadoop API we only have access to `Configuration`, not to
+ * [[org.apache.spark.SparkContext]], so we cannot use broadcasts to convey
+ * the actual filter predicate.
+ */
+ def deserializeFilterExpressions(conf: Configuration): Seq[Expression] = {
+ val data = conf.get(PARQUET_FILTER_DATA)
+ if (data != null) {
+ val decoded: Array[Byte] = BaseEncoding.base64().decode(data)
+ SparkSqlSerializer.deserialize(decoded)
+ } else {
+ Seq()
+ }
+ }
+
+ /**
+ * Try to find the given expression in the tree of filters in order to
+ * determine whether it is safe to remove it from the higher level filters. Note
+ * that strictly speaking we could stop the search whenever an expression is found
+ * that contains this expression as subexpression (e.g., when searching for "a"
+ * and "(a or c)" is found) but we don't care about optimizations here since the
+ * filter tree is assumed to be small.
+ *
+ * @param filter The [[org.apache.spark.sql.parquet.CatalystFilter]] to expand
+ * and search
+ * @param expression The expression to look for
+ * @return An optional [[org.apache.spark.sql.parquet.CatalystFilter]] that
+ * contains the expression.
+ */
+ def findExpression(
+ filter: CatalystFilter,
+ expression: Expression): Option[CatalystFilter] = filter match {
+ case f @ OrFilter(_, leftFilter, rightFilter, _) =>
+ if (f.predicate == expression) {
+ Some(f)
+ } else {
+ val left = findExpression(leftFilter, expression)
+ if (left.isDefined) left else findExpression(rightFilter, expression)
+ }
+ case f @ AndFilter(_, leftFilter, rightFilter, _) =>
+ if (f.predicate == expression) {
+ Some(f)
+ } else {
+ val left = findExpression(leftFilter, expression)
+ if (left.isDefined) left else findExpression(rightFilter, expression)
+ }
+ case f @ ComparisonFilter(_, _, predicate) =>
+ if (predicate == expression) Some(f) else None
+ case _ => None
+ }
+}
+
+abstract private[parquet] class CatalystFilter(
+ @transient val predicate: CatalystPredicate) extends UnboundRecordFilter
+
+private[parquet] case class ComparisonFilter(
+ val columnName: String,
+ private var filter: UnboundRecordFilter,
+ @transient override val predicate: CatalystPredicate)
+ extends CatalystFilter(predicate) {
+ override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = {
+ filter.bind(readers)
+ }
+}
+
+private[parquet] case class OrFilter(
+ private var filter: UnboundRecordFilter,
+ @transient val left: CatalystFilter,
+ @transient val right: CatalystFilter,
+ @transient override val predicate: Or)
+ extends CatalystFilter(predicate) {
+ def this(l: CatalystFilter, r: CatalystFilter) =
+ this(
+ OrRecordFilter.or(l, r),
+ l,
+ r,
+ Or(l.predicate, r.predicate))
+
+ override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = {
+ filter.bind(readers)
+ }
+}
+
+private[parquet] case class AndFilter(
+ private var filter: UnboundRecordFilter,
+ @transient val left: CatalystFilter,
+ @transient val right: CatalystFilter,
+ @transient override val predicate: And)
+ extends CatalystFilter(predicate) {
+ def this(l: CatalystFilter, r: CatalystFilter) =
+ this(
+ AndRecordFilter.and(l, r),
+ l,
+ r,
+ And(l.predicate, r.predicate))
+
+ override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = {
+ filter.bind(readers)
+ }
+}
+
+private[parquet] object ComparisonFilter {
+ def createBooleanFilter(
+ columnName: String,
+ value: Boolean,
+ predicate: CatalystPredicate): CatalystFilter =
+ new ComparisonFilter(
+ columnName,
+ ColumnRecordFilter.column(
+ columnName,
+ ColumnPredicates.applyFunctionToBoolean(
+ new BooleanPredicateFunction {
+ def functionToApply(input: Boolean): Boolean = input == value
+ }
+ )),
+ predicate)
+
+ def createStringFilter(
+ columnName: String,
+ value: String,
+ predicate: CatalystPredicate): CatalystFilter =
+ new ComparisonFilter(
+ columnName,
+ ColumnRecordFilter.column(
+ columnName,
+ ColumnPredicates.applyFunctionToString (
+ new ColumnPredicates.PredicateFunction[String] {
+ def functionToApply(input: String): Boolean = input == value
+ }
+ )),
+ predicate)
+
+ def createIntFilter(
+ columnName: String,
+ func: Int => Boolean,
+ predicate: CatalystPredicate): CatalystFilter =
+ new ComparisonFilter(
+ columnName,
+ ColumnRecordFilter.column(
+ columnName,
+ ColumnPredicates.applyFunctionToInteger(
+ new IntegerPredicateFunction {
+ def functionToApply(input: Int) = func(input)
+ }
+ )),
+ predicate)
+
+ def createLongFilter(
+ columnName: String,
+ func: Long => Boolean,
+ predicate: CatalystPredicate): CatalystFilter =
+ new ComparisonFilter(
+ columnName,
+ ColumnRecordFilter.column(
+ columnName,
+ ColumnPredicates.applyFunctionToLong(
+ new LongPredicateFunction {
+ def functionToApply(input: Long) = func(input)
+ }
+ )),
+ predicate)
+
+ def createDoubleFilter(
+ columnName: String,
+ func: Double => Boolean,
+ predicate: CatalystPredicate): CatalystFilter =
+ new ComparisonFilter(
+ columnName,
+ ColumnRecordFilter.column(
+ columnName,
+ ColumnPredicates.applyFunctionToDouble(
+ new DoublePredicateFunction {
+ def functionToApply(input: Double) = func(input)
+ }
+ )),
+ predicate)
+
+ def createFloatFilter(
+ columnName: String,
+ func: Float => Boolean,
+ predicate: CatalystPredicate): CatalystFilter =
+ new ComparisonFilter(
+ columnName,
+ ColumnRecordFilter.column(
+ columnName,
+ ColumnPredicates.applyFunctionToFloat(
+ new FloatPredicateFunction {
+ def functionToApply(input: Float) = func(input)
+ }
+ )),
+ predicate)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index f825ca3c02..65ba1246fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -27,26 +27,27 @@ import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat, FileOutputCommitter}
-import parquet.hadoop.{ParquetInputFormat, ParquetOutputFormat}
+import parquet.hadoop.{ParquetRecordReader, ParquetInputFormat, ParquetOutputFormat}
+import parquet.hadoop.api.ReadSupport
import parquet.hadoop.util.ContextUtil
import parquet.io.InvalidRecordException
import parquet.schema.MessageType
-import org.apache.spark.{SerializableWritable, SparkContext, TaskContext}
+import org.apache.spark.{Logging, SerializableWritable, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
/**
* Parquet table scan operator. Imports the file that backs the given
- * [[ParquetRelation]] as a RDD[Row].
+ * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[Row]``.
*/
case class ParquetTableScan(
// note: output cannot be transient, see
// https://issues.apache.org/jira/browse/SPARK-1367
output: Seq[Attribute],
relation: ParquetRelation,
- columnPruningPred: Option[Expression])(
+ columnPruningPred: Seq[Expression])(
@transient val sc: SparkContext)
extends LeafNode {
@@ -62,18 +63,30 @@ case class ParquetTableScan(
for (path <- fileList if !path.getName.startsWith("_")) {
NewFileInputFormat.addInputPath(job, path)
}
+
+ // Store Parquet schema in `Configuration`
conf.set(
RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA,
ParquetTypesConverter.convertFromAttributes(output).toString)
- // TODO: think about adding record filters
- /* Comments regarding record filters: it would be nice to push down as much filtering
- to Parquet as possible. However, currently it seems we cannot pass enough information
- to materialize an (arbitrary) Catalyst [[Predicate]] inside Parquet's
- ``FilteredRecordReader`` (via Configuration, for example). Simple
- filter-rows-by-column-values however should be supported.
- */
- sc.newAPIHadoopRDD(conf, classOf[ParquetInputFormat[Row]], classOf[Void], classOf[Row])
- .map(_._2)
+
+ // Store record filtering predicate in `Configuration`
+ // Note 1: the input format ignores all predicates that cannot be expressed
+ // as simple column predicate filters in Parquet. Here we just record
+ // the whole pruning predicate.
+ // Note 2: you can disable filter predicate pushdown by setting
+ // "spark.sql.hints.parquetFilterPushdown" to false inside SparkConf.
+ if (columnPruningPred.length > 0 &&
+ sc.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) {
+ ParquetFilters.serializeFilterExpressions(columnPruningPred, conf)
+ }
+
+ sc.newAPIHadoopRDD(
+ conf,
+ classOf[org.apache.spark.sql.parquet.FilteringParquetRowInputFormat],
+ classOf[Void],
+ classOf[Row])
+ .map(_._2)
+ .filter(_ != null) // Parquet's record filters may produce null values
}
override def otherCopyArgs = sc :: Nil
@@ -184,10 +197,19 @@ case class InsertIntoParquetTable(
override def otherCopyArgs = sc :: Nil
- // based on ``saveAsNewAPIHadoopFile`` in [[PairRDDFunctions]]
- // TODO: Maybe PairRDDFunctions should use Product2 instead of Tuple2?
- // .. then we could use the default one and could use [[MutablePair]]
- // instead of ``Tuple2``
+ /**
+ * Stores the given Row RDD as a Hadoop file.
+ *
+ * Note: We cannot use ``saveAsNewAPIHadoopFile`` from [[org.apache.spark.rdd.PairRDDFunctions]]
+ * together with [[org.apache.spark.util.MutablePair]] because ``PairRDDFunctions`` uses
+ * ``Tuple2`` and not ``Product2``. Also, we want to allow appending files to an existing
+ * directory and need to determine which was the largest written file index before starting to
+ * write.
+ *
+ * @param rdd The [[org.apache.spark.rdd.RDD]] to writer
+ * @param path The directory to write to.
+ * @param conf A [[org.apache.hadoop.conf.Configuration]].
+ */
private def saveAsHadoopFile(
rdd: RDD[Row],
path: String,
@@ -244,8 +266,10 @@ case class InsertIntoParquetTable(
}
}
-// TODO: this will be able to append to directories it created itself, not necessarily
-// to imported ones
+/**
+ * TODO: this will be able to append to directories it created itself, not necessarily
+ * to imported ones.
+ */
private[parquet] class AppendingParquetOutputFormat(offset: Int)
extends parquet.hadoop.ParquetOutputFormat[Row] {
// override to accept existing directories as valid output directory
@@ -262,6 +286,30 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int)
}
}
+/**
+ * We extend ParquetInputFormat in order to have more control over which
+ * RecordFilter we want to use.
+ */
+private[parquet] class FilteringParquetRowInputFormat
+ extends parquet.hadoop.ParquetInputFormat[Row] with Logging {
+ override def createRecordReader(
+ inputSplit: InputSplit,
+ taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = {
+ val readSupport: ReadSupport[Row] = new RowReadSupport()
+
+ val filterExpressions =
+ ParquetFilters.deserializeFilterExpressions(ContextUtil.getConfiguration(taskAttemptContext))
+ if (filterExpressions.length > 0) {
+ logInfo(s"Pushing down predicates for RecordFilter: ${filterExpressions.mkString(", ")}")
+ new ParquetRecordReader[Row](
+ readSupport,
+ ParquetFilters.createRecordFilter(filterExpressions))
+ } else {
+ new ParquetRecordReader[Row](readSupport)
+ }
+ }
+}
+
private[parquet] object FileSystemHelper {
def listFiles(pathStr: String, conf: Configuration): Seq[Path] = {
val origPath = new Path(pathStr)
@@ -278,7 +326,9 @@ private[parquet] object FileSystemHelper {
fs.listStatus(path).map(_.getPath)
}
- // finds the maximum taskid in the output file names at the given path
+ /**
+ * Finds the maximum taskid in the output file names at the given path.
+ */
def findMaxTaskId(pathStr: String, conf: Configuration): Int = {
val files = FileSystemHelper.listFiles(pathStr, conf)
// filename pattern is part-r-<int>.parquet
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
index f37976f731..46c7172985 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
@@ -19,15 +19,34 @@ package org.apache.spark.sql.parquet
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
-import org.apache.hadoop.mapreduce.Job
+import parquet.example.data.{GroupWriter, Group}
+import parquet.example.data.simple.SimpleGroup
import parquet.hadoop.ParquetWriter
-import parquet.hadoop.util.ContextUtil
+import parquet.hadoop.api.WriteSupport
+import parquet.hadoop.api.WriteSupport.WriteContext
+import parquet.io.api.RecordConsumer
import parquet.schema.{MessageType, MessageTypeParser}
-import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.util.Utils
+// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport
+// with an empty configuration (it is after all not intended to be used in this way?)
+// and members are private so we need to make our own in order to pass the schema
+// to the writer.
+private class TestGroupWriteSupport(schema: MessageType) extends WriteSupport[Group] {
+ var groupWriter: GroupWriter = null
+ override def prepareForWrite(recordConsumer: RecordConsumer): Unit = {
+ groupWriter = new GroupWriter(recordConsumer, schema)
+ }
+ override def init(configuration: Configuration): WriteContext = {
+ new WriteContext(schema, new java.util.HashMap[String, String]())
+ }
+ override def write(record: Group) {
+ groupWriter.write(record)
+ }
+}
+
private[sql] object ParquetTestData {
val testSchema =
@@ -43,7 +62,7 @@ private[sql] object ParquetTestData {
// field names for test assertion error messages
val testSchemaFieldNames = Seq(
"myboolean:Boolean",
- "mtint:Int",
+ "myint:Int",
"mystring:String",
"mylong:Long",
"myfloat:Float",
@@ -58,6 +77,18 @@ private[sql] object ParquetTestData {
|}
""".stripMargin
+ val testFilterSchema =
+ """
+ |message myrecord {
+ |required boolean myboolean;
+ |required int32 myint;
+ |required binary mystring;
+ |required int64 mylong;
+ |required float myfloat;
+ |required double mydouble;
+ |}
+ """.stripMargin
+
// field names for test assertion error messages
val subTestSchemaFieldNames = Seq(
"myboolean:Boolean",
@@ -65,36 +96,57 @@ private[sql] object ParquetTestData {
)
val testDir = Utils.createTempDir()
+ val testFilterDir = Utils.createTempDir()
lazy val testData = new ParquetRelation(testDir.toURI.toString)
def writeFile() = {
testDir.delete
val path: Path = new Path(new Path(testDir.toURI), new Path("part-r-0.parquet"))
- val job = new Job()
- val configuration: Configuration = ContextUtil.getConfiguration(job)
val schema: MessageType = MessageTypeParser.parseMessageType(testSchema)
+ val writeSupport = new TestGroupWriteSupport(schema)
+ val writer = new ParquetWriter[Group](path, writeSupport)
- val writeSupport = new RowWriteSupport()
- writeSupport.setSchema(schema, configuration)
- val writer = new ParquetWriter(path, writeSupport)
for(i <- 0 until 15) {
- val data = new Array[Any](6)
+ val record = new SimpleGroup(schema)
if (i % 3 == 0) {
- data.update(0, true)
+ record.add(0, true)
} else {
- data.update(0, false)
+ record.add(0, false)
}
if (i % 5 == 0) {
- data.update(1, 5)
+ record.add(1, 5)
+ }
+ record.add(2, "abc")
+ record.add(3, i.toLong << 33)
+ record.add(4, 2.5F)
+ record.add(5, 4.5D)
+ writer.write(record)
+ }
+ writer.close()
+ }
+
+ def writeFilterFile(records: Int = 200) = {
+ // for microbenchmark use: records = 300000000
+ testFilterDir.delete
+ val path: Path = new Path(new Path(testFilterDir.toURI), new Path("part-r-0.parquet"))
+ val schema: MessageType = MessageTypeParser.parseMessageType(testFilterSchema)
+ val writeSupport = new TestGroupWriteSupport(schema)
+ val writer = new ParquetWriter[Group](path, writeSupport)
+
+ for(i <- 0 to records) {
+ val record = new SimpleGroup(schema)
+ if (i % 4 == 0) {
+ record.add(0, true)
} else {
- data.update(1, null) // optional
+ record.add(0, false)
}
- data.update(2, "abc")
- data.update(3, i.toLong << 33)
- data.update(4, 2.5F)
- data.update(5, 4.5D)
- writer.write(new GenericRow(data.toArray))
+ record.add(1, i)
+ record.add(2, i.toString)
+ record.add(3, i.toLong)
+ record.add(4, i.toFloat + 0.5f)
+ record.add(5, i.toDouble + 0.5d)
+ writer.write(record)
}
writer.close()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index ff1677eb8a..65f4c17aee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -17,25 +17,25 @@
package org.apache.spark.sql.parquet
-import java.io.File
-
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.hadoop.mapreduce.Job
import parquet.hadoop.ParquetFileWriter
-import parquet.schema.MessageTypeParser
import parquet.hadoop.util.ContextUtil
+import parquet.schema.MessageTypeParser
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.util.getTempFilePath
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.TestData
+import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.sql.catalyst.expressions.Equals
+import org.apache.spark.sql.catalyst.types.IntegerType
import org.apache.spark.util.Utils
-import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, DataType}
-import org.apache.spark.sql.{parquet, SchemaRDD}
// Implicits
import org.apache.spark.sql.test.TestSQLContext._
@@ -64,12 +64,16 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
override def beforeAll() {
ParquetTestData.writeFile()
+ ParquetTestData.writeFilterFile()
testRDD = parquetFile(ParquetTestData.testDir.toString)
testRDD.registerAsTable("testsource")
+ parquetFile(ParquetTestData.testFilterDir.toString)
+ .registerAsTable("testfiltersource")
}
override def afterAll() {
Utils.deleteRecursively(ParquetTestData.testDir)
+ Utils.deleteRecursively(ParquetTestData.testFilterDir)
// here we should also unregister the table??
}
@@ -120,7 +124,7 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
val scanner = new ParquetTableScan(
ParquetTestData.testData.output,
ParquetTestData.testData,
- None)(TestSQLContext.sparkContext)
+ Seq())(TestSQLContext.sparkContext)
val projected = scanner.pruneColumns(ParquetTypesConverter
.convertToAttributes(MessageTypeParser
.parseMessageType(ParquetTestData.subTestSchema)))
@@ -196,7 +200,6 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
assert(true)
}
-
test("insert (appending) to same table via Scala API") {
sql("INSERT INTO testsource SELECT * FROM testsource").collect()
val double_rdd = sql("SELECT * FROM testsource").collect()
@@ -239,5 +242,121 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
Utils.deleteRecursively(file)
assert(true)
}
+
+ test("create RecordFilter for simple predicates") {
+ val attribute1 = new AttributeReference("first", IntegerType, false)()
+ val predicate1 = new Equals(attribute1, new Literal(1, IntegerType))
+ val filter1 = ParquetFilters.createFilter(predicate1)
+ assert(filter1.isDefined)
+ assert(filter1.get.predicate == predicate1, "predicates do not match")
+ assert(filter1.get.isInstanceOf[ComparisonFilter])
+ val cmpFilter1 = filter1.get.asInstanceOf[ComparisonFilter]
+ assert(cmpFilter1.columnName == "first", "column name incorrect")
+
+ val predicate2 = new LessThan(attribute1, new Literal(4, IntegerType))
+ val filter2 = ParquetFilters.createFilter(predicate2)
+ assert(filter2.isDefined)
+ assert(filter2.get.predicate == predicate2, "predicates do not match")
+ assert(filter2.get.isInstanceOf[ComparisonFilter])
+ val cmpFilter2 = filter2.get.asInstanceOf[ComparisonFilter]
+ assert(cmpFilter2.columnName == "first", "column name incorrect")
+
+ val predicate3 = new And(predicate1, predicate2)
+ val filter3 = ParquetFilters.createFilter(predicate3)
+ assert(filter3.isDefined)
+ assert(filter3.get.predicate == predicate3, "predicates do not match")
+ assert(filter3.get.isInstanceOf[AndFilter])
+
+ val predicate4 = new Or(predicate1, predicate2)
+ val filter4 = ParquetFilters.createFilter(predicate4)
+ assert(filter4.isDefined)
+ assert(filter4.get.predicate == predicate4, "predicates do not match")
+ assert(filter4.get.isInstanceOf[OrFilter])
+
+ val attribute2 = new AttributeReference("second", IntegerType, false)()
+ val predicate5 = new GreaterThan(attribute1, attribute2)
+ val badfilter = ParquetFilters.createFilter(predicate5)
+ assert(badfilter.isDefined === false)
+ }
+
+ test("test filter by predicate pushdown") {
+ for(myval <- Seq("myint", "mylong", "mydouble", "myfloat")) {
+ println(s"testing field $myval")
+ val query1 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100")
+ assert(
+ query1.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+ "Top operator should be ParquetTableScan after pushdown")
+ val result1 = query1.collect()
+ assert(result1.size === 50)
+ assert(result1(0)(1) === 100)
+ assert(result1(49)(1) === 149)
+ val query2 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200")
+ assert(
+ query2.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+ "Top operator should be ParquetTableScan after pushdown")
+ val result2 = query2.collect()
+ assert(result2.size === 50)
+ if (myval == "myint" || myval == "mylong") {
+ assert(result2(0)(1) === 151)
+ assert(result2(49)(1) === 200)
+ } else {
+ assert(result2(0)(1) === 150)
+ assert(result2(49)(1) === 199)
+ }
+ }
+ for(myval <- Seq("myint", "mylong")) {
+ val query3 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190 OR $myval < 10")
+ assert(
+ query3.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+ "Top operator should be ParquetTableScan after pushdown")
+ val result3 = query3.collect()
+ assert(result3.size === 20)
+ assert(result3(0)(1) === 0)
+ assert(result3(9)(1) === 9)
+ assert(result3(10)(1) === 191)
+ assert(result3(19)(1) === 200)
+ }
+ for(myval <- Seq("mydouble", "myfloat")) {
+ val result4 =
+ if (myval == "mydouble") {
+ val query4 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10.0")
+ assert(
+ query4.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+ "Top operator should be ParquetTableScan after pushdown")
+ query4.collect()
+ } else {
+ // CASTs are problematic. Here myfloat will be casted to a double and it seems there is
+ // currently no way to specify float constants in SqlParser?
+ sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10").collect()
+ }
+ assert(result4.size === 20)
+ assert(result4(0)(1) === 0)
+ assert(result4(9)(1) === 9)
+ assert(result4(10)(1) === 191)
+ assert(result4(19)(1) === 200)
+ }
+ val query5 = sql(s"SELECT * FROM testfiltersource WHERE myboolean = true AND myint < 40")
+ assert(
+ query5.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+ "Top operator should be ParquetTableScan after pushdown")
+ val booleanResult = query5.collect()
+ assert(booleanResult.size === 10)
+ for(i <- 0 until 10) {
+ if (!booleanResult(i).getBoolean(0)) {
+ fail(s"Boolean value in result row $i not true")
+ }
+ if (booleanResult(i).getInt(1) != i * 4) {
+ fail(s"Int value in result row $i should be ${4*i}")
+ }
+ }
+ val query6 = sql("SELECT * FROM testfiltersource WHERE mystring = \"100\"")
+ assert(
+ query6.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+ "Top operator should be ParquetTableScan after pushdown")
+ val stringResult = query6.collect()
+ assert(stringResult.size === 1)
+ assert(stringResult(0).getString(2) == "100", "stringvalue incorrect")
+ assert(stringResult(0).getInt(1) === 100)
+ }
}