aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala31
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala436
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala90
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala90
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala135
5 files changed, 731 insertions, 51 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f763106da4..394a59700d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -140,12 +140,35 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
InsertIntoParquetTable(relation, planLater(child), overwrite=true)(sparkContext) :: Nil
case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) =>
InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil
- case PhysicalOperation(projectList, filters, relation: ParquetRelation) =>
- // TODO: Should be pushing down filters as well.
+ case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => {
+ val remainingFilters =
+ if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) {
+ filters.filter {
+ // Note: filters cannot be pushed down to Parquet if they contain more complex
+ // expressions than simple "Attribute cmp Literal" comparisons. Here we remove
+ // all filters that have been pushed down. Note that a predicate such as
+ // "(A AND B) OR C" can result in "A OR C" being pushed down.
+ filter =>
+ val recordFilter = ParquetFilters.createFilter(filter)
+ if (!recordFilter.isDefined) {
+ // First case: the pushdown did not result in any record filter.
+ true
+ } else {
+ // Second case: a record filter was created; here we are conservative in
+ // the sense that even if "A" was pushed and we check for "A AND B" we
+ // still want to keep "A AND B" in the higher-level filter, not just "B".
+ !ParquetFilters.findExpression(recordFilter.get, filter).isDefined
+ }
+ }
+ } else {
+ filters
+ }
pruneFilterProject(
projectList,
- filters,
- ParquetTableScan(_, relation, None)(sparkContext)) :: Nil
+ remainingFilters,
+ ParquetTableScan(_, relation, filters)(sparkContext)) :: Nil
+ }
+
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
new file mode 100644
index 0000000000..052b0a9196
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
@@ -0,0 +1,436 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.parquet
+
+import org.apache.hadoop.conf.Configuration
+
+import parquet.filter._
+import parquet.filter.ColumnPredicates._
+import parquet.column.ColumnReader
+
+import com.google.common.io.BaseEncoding
+
+import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.expressions.{Predicate => CatalystPredicate}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkSqlSerializer
+
+object ParquetFilters {
+ val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter"
+ // set this to false if pushdown should be disabled
+ val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.hints.parquetFilterPushdown"
+
+ def createRecordFilter(filterExpressions: Seq[Expression]): UnboundRecordFilter = {
+ val filters: Seq[CatalystFilter] = filterExpressions.collect {
+ case (expression: Expression) if createFilter(expression).isDefined =>
+ createFilter(expression).get
+ }
+ if (filters.length > 0) filters.reduce(AndRecordFilter.and) else null
+ }
+
+ def createFilter(expression: Expression): Option[CatalystFilter] = {
+ def createEqualityFilter(
+ name: String,
+ literal: Literal,
+ predicate: CatalystPredicate) = literal.dataType match {
+ case BooleanType =>
+ ComparisonFilter.createBooleanFilter(name, literal.value.asInstanceOf[Boolean], predicate)
+ case IntegerType =>
+ ComparisonFilter.createIntFilter(
+ name,
+ (x: Int) => x == literal.value.asInstanceOf[Int],
+ predicate)
+ case LongType =>
+ ComparisonFilter.createLongFilter(
+ name,
+ (x: Long) => x == literal.value.asInstanceOf[Long],
+ predicate)
+ case DoubleType =>
+ ComparisonFilter.createDoubleFilter(
+ name,
+ (x: Double) => x == literal.value.asInstanceOf[Double],
+ predicate)
+ case FloatType =>
+ ComparisonFilter.createFloatFilter(
+ name,
+ (x: Float) => x == literal.value.asInstanceOf[Float],
+ predicate)
+ case StringType =>
+ ComparisonFilter.createStringFilter(name, literal.value.asInstanceOf[String], predicate)
+ }
+ def createLessThanFilter(
+ name: String,
+ literal: Literal,
+ predicate: CatalystPredicate) = literal.dataType match {
+ case IntegerType =>
+ ComparisonFilter.createIntFilter(
+ name,
+ (x: Int) => x < literal.value.asInstanceOf[Int],
+ predicate)
+ case LongType =>
+ ComparisonFilter.createLongFilter(
+ name,
+ (x: Long) => x < literal.value.asInstanceOf[Long],
+ predicate)
+ case DoubleType =>
+ ComparisonFilter.createDoubleFilter(
+ name,
+ (x: Double) => x < literal.value.asInstanceOf[Double],
+ predicate)
+ case FloatType =>
+ ComparisonFilter.createFloatFilter(
+ name,
+ (x: Float) => x < literal.value.asInstanceOf[Float],
+ predicate)
+ }
+ def createLessThanOrEqualFilter(
+ name: String,
+ literal: Literal,
+ predicate: CatalystPredicate) = literal.dataType match {
+ case IntegerType =>
+ ComparisonFilter.createIntFilter(
+ name,
+ (x: Int) => x <= literal.value.asInstanceOf[Int],
+ predicate)
+ case LongType =>
+ ComparisonFilter.createLongFilter(
+ name,
+ (x: Long) => x <= literal.value.asInstanceOf[Long],
+ predicate)
+ case DoubleType =>
+ ComparisonFilter.createDoubleFilter(
+ name,
+ (x: Double) => x <= literal.value.asInstanceOf[Double],
+ predicate)
+ case FloatType =>
+ ComparisonFilter.createFloatFilter(
+ name,
+ (x: Float) => x <= literal.value.asInstanceOf[Float],
+ predicate)
+ }
+ // TODO: combine these two types somehow?
+ def createGreaterThanFilter(
+ name: String,
+ literal: Literal,
+ predicate: CatalystPredicate) = literal.dataType match {
+ case IntegerType =>
+ ComparisonFilter.createIntFilter(
+ name,
+ (x: Int) => x > literal.value.asInstanceOf[Int],
+ predicate)
+ case LongType =>
+ ComparisonFilter.createLongFilter(
+ name,
+ (x: Long) => x > literal.value.asInstanceOf[Long],
+ predicate)
+ case DoubleType =>
+ ComparisonFilter.createDoubleFilter(
+ name,
+ (x: Double) => x > literal.value.asInstanceOf[Double],
+ predicate)
+ case FloatType =>
+ ComparisonFilter.createFloatFilter(
+ name,
+ (x: Float) => x > literal.value.asInstanceOf[Float],
+ predicate)
+ }
+ def createGreaterThanOrEqualFilter(
+ name: String,
+ literal: Literal,
+ predicate: CatalystPredicate) = literal.dataType match {
+ case IntegerType =>
+ ComparisonFilter.createIntFilter(
+ name, (x: Int) => x >= literal.value.asInstanceOf[Int],
+ predicate)
+ case LongType =>
+ ComparisonFilter.createLongFilter(
+ name,
+ (x: Long) => x >= literal.value.asInstanceOf[Long],
+ predicate)
+ case DoubleType =>
+ ComparisonFilter.createDoubleFilter(
+ name,
+ (x: Double) => x >= literal.value.asInstanceOf[Double],
+ predicate)
+ case FloatType =>
+ ComparisonFilter.createFloatFilter(
+ name,
+ (x: Float) => x >= literal.value.asInstanceOf[Float],
+ predicate)
+ }
+
+ /**
+ * TODO: we currently only filter on non-nullable (Parquet REQUIRED) attributes until
+ * https://github.com/Parquet/parquet-mr/issues/371
+ * has been resolved.
+ */
+ expression match {
+ case p @ Or(left: Expression, right: Expression)
+ if createFilter(left).isDefined && createFilter(right).isDefined => {
+ // If either side of this Or-predicate is empty then this means
+ // it contains a more complex comparison than between attribute and literal
+ // (e.g., it contained a CAST). The only safe thing to do is then to disregard
+ // this disjunction, which could be contained in a conjunction. If it stands
+ // alone then it is also safe to drop it, since a Null return value of this
+ // function is interpreted as having no filters at all.
+ val leftFilter = createFilter(left).get
+ val rightFilter = createFilter(right).get
+ Some(new OrFilter(leftFilter, rightFilter))
+ }
+ case p @ And(left: Expression, right: Expression) => {
+ // This treats nested conjunctions; since either side of the conjunction
+ // may contain more complex filter expressions we may actually generate
+ // strictly weaker filter predicates in the process.
+ val leftFilter = createFilter(left)
+ val rightFilter = createFilter(right)
+ (leftFilter, rightFilter) match {
+ case (None, Some(filter)) => Some(filter)
+ case (Some(filter), None) => Some(filter)
+ case (_, _) =>
+ Some(new AndFilter(leftFilter.get, rightFilter.get))
+ }
+ }
+ case p @ Equals(left: Literal, right: NamedExpression) if !right.nullable =>
+ Some(createEqualityFilter(right.name, left, p))
+ case p @ Equals(left: NamedExpression, right: Literal) if !left.nullable =>
+ Some(createEqualityFilter(left.name, right, p))
+ case p @ LessThan(left: Literal, right: NamedExpression) if !right.nullable =>
+ Some(createLessThanFilter(right.name, left, p))
+ case p @ LessThan(left: NamedExpression, right: Literal) if !left.nullable =>
+ Some(createLessThanFilter(left.name, right, p))
+ case p @ LessThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable =>
+ Some(createLessThanOrEqualFilter(right.name, left, p))
+ case p @ LessThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable =>
+ Some(createLessThanOrEqualFilter(left.name, right, p))
+ case p @ GreaterThan(left: Literal, right: NamedExpression) if !right.nullable =>
+ Some(createGreaterThanFilter(right.name, left, p))
+ case p @ GreaterThan(left: NamedExpression, right: Literal) if !left.nullable =>
+ Some(createGreaterThanFilter(left.name, right, p))
+ case p @ GreaterThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable =>
+ Some(createGreaterThanOrEqualFilter(right.name, left, p))
+ case p @ GreaterThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable =>
+ Some(createGreaterThanOrEqualFilter(left.name, right, p))
+ case _ => None
+ }
+ }
+
+ /**
+ * Note: Inside the Hadoop API we only have access to `Configuration`, not to
+ * [[org.apache.spark.SparkContext]], so we cannot use broadcasts to convey
+ * the actual filter predicate.
+ */
+ def serializeFilterExpressions(filters: Seq[Expression], conf: Configuration): Unit = {
+ if (filters.length > 0) {
+ val serialized: Array[Byte] = SparkSqlSerializer.serialize(filters)
+ val encoded: String = BaseEncoding.base64().encode(serialized)
+ conf.set(PARQUET_FILTER_DATA, encoded)
+ }
+ }
+
+ /**
+ * Note: Inside the Hadoop API we only have access to `Configuration`, not to
+ * [[org.apache.spark.SparkContext]], so we cannot use broadcasts to convey
+ * the actual filter predicate.
+ */
+ def deserializeFilterExpressions(conf: Configuration): Seq[Expression] = {
+ val data = conf.get(PARQUET_FILTER_DATA)
+ if (data != null) {
+ val decoded: Array[Byte] = BaseEncoding.base64().decode(data)
+ SparkSqlSerializer.deserialize(decoded)
+ } else {
+ Seq()
+ }
+ }
+
+ /**
+ * Try to find the given expression in the tree of filters in order to
+ * determine whether it is safe to remove it from the higher level filters. Note
+ * that strictly speaking we could stop the search whenever an expression is found
+ * that contains this expression as subexpression (e.g., when searching for "a"
+ * and "(a or c)" is found) but we don't care about optimizations here since the
+ * filter tree is assumed to be small.
+ *
+ * @param filter The [[org.apache.spark.sql.parquet.CatalystFilter]] to expand
+ * and search
+ * @param expression The expression to look for
+ * @return An optional [[org.apache.spark.sql.parquet.CatalystFilter]] that
+ * contains the expression.
+ */
+ def findExpression(
+ filter: CatalystFilter,
+ expression: Expression): Option[CatalystFilter] = filter match {
+ case f @ OrFilter(_, leftFilter, rightFilter, _) =>
+ if (f.predicate == expression) {
+ Some(f)
+ } else {
+ val left = findExpression(leftFilter, expression)
+ if (left.isDefined) left else findExpression(rightFilter, expression)
+ }
+ case f @ AndFilter(_, leftFilter, rightFilter, _) =>
+ if (f.predicate == expression) {
+ Some(f)
+ } else {
+ val left = findExpression(leftFilter, expression)
+ if (left.isDefined) left else findExpression(rightFilter, expression)
+ }
+ case f @ ComparisonFilter(_, _, predicate) =>
+ if (predicate == expression) Some(f) else None
+ case _ => None
+ }
+}
+
+abstract private[parquet] class CatalystFilter(
+ @transient val predicate: CatalystPredicate) extends UnboundRecordFilter
+
+private[parquet] case class ComparisonFilter(
+ val columnName: String,
+ private var filter: UnboundRecordFilter,
+ @transient override val predicate: CatalystPredicate)
+ extends CatalystFilter(predicate) {
+ override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = {
+ filter.bind(readers)
+ }
+}
+
+private[parquet] case class OrFilter(
+ private var filter: UnboundRecordFilter,
+ @transient val left: CatalystFilter,
+ @transient val right: CatalystFilter,
+ @transient override val predicate: Or)
+ extends CatalystFilter(predicate) {
+ def this(l: CatalystFilter, r: CatalystFilter) =
+ this(
+ OrRecordFilter.or(l, r),
+ l,
+ r,
+ Or(l.predicate, r.predicate))
+
+ override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = {
+ filter.bind(readers)
+ }
+}
+
+private[parquet] case class AndFilter(
+ private var filter: UnboundRecordFilter,
+ @transient val left: CatalystFilter,
+ @transient val right: CatalystFilter,
+ @transient override val predicate: And)
+ extends CatalystFilter(predicate) {
+ def this(l: CatalystFilter, r: CatalystFilter) =
+ this(
+ AndRecordFilter.and(l, r),
+ l,
+ r,
+ And(l.predicate, r.predicate))
+
+ override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = {
+ filter.bind(readers)
+ }
+}
+
+private[parquet] object ComparisonFilter {
+ def createBooleanFilter(
+ columnName: String,
+ value: Boolean,
+ predicate: CatalystPredicate): CatalystFilter =
+ new ComparisonFilter(
+ columnName,
+ ColumnRecordFilter.column(
+ columnName,
+ ColumnPredicates.applyFunctionToBoolean(
+ new BooleanPredicateFunction {
+ def functionToApply(input: Boolean): Boolean = input == value
+ }
+ )),
+ predicate)
+
+ def createStringFilter(
+ columnName: String,
+ value: String,
+ predicate: CatalystPredicate): CatalystFilter =
+ new ComparisonFilter(
+ columnName,
+ ColumnRecordFilter.column(
+ columnName,
+ ColumnPredicates.applyFunctionToString (
+ new ColumnPredicates.PredicateFunction[String] {
+ def functionToApply(input: String): Boolean = input == value
+ }
+ )),
+ predicate)
+
+ def createIntFilter(
+ columnName: String,
+ func: Int => Boolean,
+ predicate: CatalystPredicate): CatalystFilter =
+ new ComparisonFilter(
+ columnName,
+ ColumnRecordFilter.column(
+ columnName,
+ ColumnPredicates.applyFunctionToInteger(
+ new IntegerPredicateFunction {
+ def functionToApply(input: Int) = func(input)
+ }
+ )),
+ predicate)
+
+ def createLongFilter(
+ columnName: String,
+ func: Long => Boolean,
+ predicate: CatalystPredicate): CatalystFilter =
+ new ComparisonFilter(
+ columnName,
+ ColumnRecordFilter.column(
+ columnName,
+ ColumnPredicates.applyFunctionToLong(
+ new LongPredicateFunction {
+ def functionToApply(input: Long) = func(input)
+ }
+ )),
+ predicate)
+
+ def createDoubleFilter(
+ columnName: String,
+ func: Double => Boolean,
+ predicate: CatalystPredicate): CatalystFilter =
+ new ComparisonFilter(
+ columnName,
+ ColumnRecordFilter.column(
+ columnName,
+ ColumnPredicates.applyFunctionToDouble(
+ new DoublePredicateFunction {
+ def functionToApply(input: Double) = func(input)
+ }
+ )),
+ predicate)
+
+ def createFloatFilter(
+ columnName: String,
+ func: Float => Boolean,
+ predicate: CatalystPredicate): CatalystFilter =
+ new ComparisonFilter(
+ columnName,
+ ColumnRecordFilter.column(
+ columnName,
+ ColumnPredicates.applyFunctionToFloat(
+ new FloatPredicateFunction {
+ def functionToApply(input: Float) = func(input)
+ }
+ )),
+ predicate)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index f825ca3c02..65ba1246fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -27,26 +27,27 @@ import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat, FileOutputCommitter}
-import parquet.hadoop.{ParquetInputFormat, ParquetOutputFormat}
+import parquet.hadoop.{ParquetRecordReader, ParquetInputFormat, ParquetOutputFormat}
+import parquet.hadoop.api.ReadSupport
import parquet.hadoop.util.ContextUtil
import parquet.io.InvalidRecordException
import parquet.schema.MessageType
-import org.apache.spark.{SerializableWritable, SparkContext, TaskContext}
+import org.apache.spark.{Logging, SerializableWritable, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
/**
* Parquet table scan operator. Imports the file that backs the given
- * [[ParquetRelation]] as a RDD[Row].
+ * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[Row]``.
*/
case class ParquetTableScan(
// note: output cannot be transient, see
// https://issues.apache.org/jira/browse/SPARK-1367
output: Seq[Attribute],
relation: ParquetRelation,
- columnPruningPred: Option[Expression])(
+ columnPruningPred: Seq[Expression])(
@transient val sc: SparkContext)
extends LeafNode {
@@ -62,18 +63,30 @@ case class ParquetTableScan(
for (path <- fileList if !path.getName.startsWith("_")) {
NewFileInputFormat.addInputPath(job, path)
}
+
+ // Store Parquet schema in `Configuration`
conf.set(
RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA,
ParquetTypesConverter.convertFromAttributes(output).toString)
- // TODO: think about adding record filters
- /* Comments regarding record filters: it would be nice to push down as much filtering
- to Parquet as possible. However, currently it seems we cannot pass enough information
- to materialize an (arbitrary) Catalyst [[Predicate]] inside Parquet's
- ``FilteredRecordReader`` (via Configuration, for example). Simple
- filter-rows-by-column-values however should be supported.
- */
- sc.newAPIHadoopRDD(conf, classOf[ParquetInputFormat[Row]], classOf[Void], classOf[Row])
- .map(_._2)
+
+ // Store record filtering predicate in `Configuration`
+ // Note 1: the input format ignores all predicates that cannot be expressed
+ // as simple column predicate filters in Parquet. Here we just record
+ // the whole pruning predicate.
+ // Note 2: you can disable filter predicate pushdown by setting
+ // "spark.sql.hints.parquetFilterPushdown" to false inside SparkConf.
+ if (columnPruningPred.length > 0 &&
+ sc.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) {
+ ParquetFilters.serializeFilterExpressions(columnPruningPred, conf)
+ }
+
+ sc.newAPIHadoopRDD(
+ conf,
+ classOf[org.apache.spark.sql.parquet.FilteringParquetRowInputFormat],
+ classOf[Void],
+ classOf[Row])
+ .map(_._2)
+ .filter(_ != null) // Parquet's record filters may produce null values
}
override def otherCopyArgs = sc :: Nil
@@ -184,10 +197,19 @@ case class InsertIntoParquetTable(
override def otherCopyArgs = sc :: Nil
- // based on ``saveAsNewAPIHadoopFile`` in [[PairRDDFunctions]]
- // TODO: Maybe PairRDDFunctions should use Product2 instead of Tuple2?
- // .. then we could use the default one and could use [[MutablePair]]
- // instead of ``Tuple2``
+ /**
+ * Stores the given Row RDD as a Hadoop file.
+ *
+ * Note: We cannot use ``saveAsNewAPIHadoopFile`` from [[org.apache.spark.rdd.PairRDDFunctions]]
+ * together with [[org.apache.spark.util.MutablePair]] because ``PairRDDFunctions`` uses
+ * ``Tuple2`` and not ``Product2``. Also, we want to allow appending files to an existing
+ * directory and need to determine which was the largest written file index before starting to
+ * write.
+ *
+ * @param rdd The [[org.apache.spark.rdd.RDD]] to writer
+ * @param path The directory to write to.
+ * @param conf A [[org.apache.hadoop.conf.Configuration]].
+ */
private def saveAsHadoopFile(
rdd: RDD[Row],
path: String,
@@ -244,8 +266,10 @@ case class InsertIntoParquetTable(
}
}
-// TODO: this will be able to append to directories it created itself, not necessarily
-// to imported ones
+/**
+ * TODO: this will be able to append to directories it created itself, not necessarily
+ * to imported ones.
+ */
private[parquet] class AppendingParquetOutputFormat(offset: Int)
extends parquet.hadoop.ParquetOutputFormat[Row] {
// override to accept existing directories as valid output directory
@@ -262,6 +286,30 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int)
}
}
+/**
+ * We extend ParquetInputFormat in order to have more control over which
+ * RecordFilter we want to use.
+ */
+private[parquet] class FilteringParquetRowInputFormat
+ extends parquet.hadoop.ParquetInputFormat[Row] with Logging {
+ override def createRecordReader(
+ inputSplit: InputSplit,
+ taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = {
+ val readSupport: ReadSupport[Row] = new RowReadSupport()
+
+ val filterExpressions =
+ ParquetFilters.deserializeFilterExpressions(ContextUtil.getConfiguration(taskAttemptContext))
+ if (filterExpressions.length > 0) {
+ logInfo(s"Pushing down predicates for RecordFilter: ${filterExpressions.mkString(", ")}")
+ new ParquetRecordReader[Row](
+ readSupport,
+ ParquetFilters.createRecordFilter(filterExpressions))
+ } else {
+ new ParquetRecordReader[Row](readSupport)
+ }
+ }
+}
+
private[parquet] object FileSystemHelper {
def listFiles(pathStr: String, conf: Configuration): Seq[Path] = {
val origPath = new Path(pathStr)
@@ -278,7 +326,9 @@ private[parquet] object FileSystemHelper {
fs.listStatus(path).map(_.getPath)
}
- // finds the maximum taskid in the output file names at the given path
+ /**
+ * Finds the maximum taskid in the output file names at the given path.
+ */
def findMaxTaskId(pathStr: String, conf: Configuration): Int = {
val files = FileSystemHelper.listFiles(pathStr, conf)
// filename pattern is part-r-<int>.parquet
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
index f37976f731..46c7172985 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala
@@ -19,15 +19,34 @@ package org.apache.spark.sql.parquet
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
-import org.apache.hadoop.mapreduce.Job
+import parquet.example.data.{GroupWriter, Group}
+import parquet.example.data.simple.SimpleGroup
import parquet.hadoop.ParquetWriter
-import parquet.hadoop.util.ContextUtil
+import parquet.hadoop.api.WriteSupport
+import parquet.hadoop.api.WriteSupport.WriteContext
+import parquet.io.api.RecordConsumer
import parquet.schema.{MessageType, MessageTypeParser}
-import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.util.Utils
+// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport
+// with an empty configuration (it is after all not intended to be used in this way?)
+// and members are private so we need to make our own in order to pass the schema
+// to the writer.
+private class TestGroupWriteSupport(schema: MessageType) extends WriteSupport[Group] {
+ var groupWriter: GroupWriter = null
+ override def prepareForWrite(recordConsumer: RecordConsumer): Unit = {
+ groupWriter = new GroupWriter(recordConsumer, schema)
+ }
+ override def init(configuration: Configuration): WriteContext = {
+ new WriteContext(schema, new java.util.HashMap[String, String]())
+ }
+ override def write(record: Group) {
+ groupWriter.write(record)
+ }
+}
+
private[sql] object ParquetTestData {
val testSchema =
@@ -43,7 +62,7 @@ private[sql] object ParquetTestData {
// field names for test assertion error messages
val testSchemaFieldNames = Seq(
"myboolean:Boolean",
- "mtint:Int",
+ "myint:Int",
"mystring:String",
"mylong:Long",
"myfloat:Float",
@@ -58,6 +77,18 @@ private[sql] object ParquetTestData {
|}
""".stripMargin
+ val testFilterSchema =
+ """
+ |message myrecord {
+ |required boolean myboolean;
+ |required int32 myint;
+ |required binary mystring;
+ |required int64 mylong;
+ |required float myfloat;
+ |required double mydouble;
+ |}
+ """.stripMargin
+
// field names for test assertion error messages
val subTestSchemaFieldNames = Seq(
"myboolean:Boolean",
@@ -65,36 +96,57 @@ private[sql] object ParquetTestData {
)
val testDir = Utils.createTempDir()
+ val testFilterDir = Utils.createTempDir()
lazy val testData = new ParquetRelation(testDir.toURI.toString)
def writeFile() = {
testDir.delete
val path: Path = new Path(new Path(testDir.toURI), new Path("part-r-0.parquet"))
- val job = new Job()
- val configuration: Configuration = ContextUtil.getConfiguration(job)
val schema: MessageType = MessageTypeParser.parseMessageType(testSchema)
+ val writeSupport = new TestGroupWriteSupport(schema)
+ val writer = new ParquetWriter[Group](path, writeSupport)
- val writeSupport = new RowWriteSupport()
- writeSupport.setSchema(schema, configuration)
- val writer = new ParquetWriter(path, writeSupport)
for(i <- 0 until 15) {
- val data = new Array[Any](6)
+ val record = new SimpleGroup(schema)
if (i % 3 == 0) {
- data.update(0, true)
+ record.add(0, true)
} else {
- data.update(0, false)
+ record.add(0, false)
}
if (i % 5 == 0) {
- data.update(1, 5)
+ record.add(1, 5)
+ }
+ record.add(2, "abc")
+ record.add(3, i.toLong << 33)
+ record.add(4, 2.5F)
+ record.add(5, 4.5D)
+ writer.write(record)
+ }
+ writer.close()
+ }
+
+ def writeFilterFile(records: Int = 200) = {
+ // for microbenchmark use: records = 300000000
+ testFilterDir.delete
+ val path: Path = new Path(new Path(testFilterDir.toURI), new Path("part-r-0.parquet"))
+ val schema: MessageType = MessageTypeParser.parseMessageType(testFilterSchema)
+ val writeSupport = new TestGroupWriteSupport(schema)
+ val writer = new ParquetWriter[Group](path, writeSupport)
+
+ for(i <- 0 to records) {
+ val record = new SimpleGroup(schema)
+ if (i % 4 == 0) {
+ record.add(0, true)
} else {
- data.update(1, null) // optional
+ record.add(0, false)
}
- data.update(2, "abc")
- data.update(3, i.toLong << 33)
- data.update(4, 2.5F)
- data.update(5, 4.5D)
- writer.write(new GenericRow(data.toArray))
+ record.add(1, i)
+ record.add(2, i.toString)
+ record.add(3, i.toLong)
+ record.add(4, i.toFloat + 0.5f)
+ record.add(5, i.toDouble + 0.5d)
+ writer.write(record)
}
writer.close()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index ff1677eb8a..65f4c17aee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -17,25 +17,25 @@
package org.apache.spark.sql.parquet
-import java.io.File
-
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.hadoop.mapreduce.Job
import parquet.hadoop.ParquetFileWriter
-import parquet.schema.MessageTypeParser
import parquet.hadoop.util.ContextUtil
+import parquet.schema.MessageTypeParser
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.util.getTempFilePath
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.TestData
+import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.sql.catalyst.expressions.Equals
+import org.apache.spark.sql.catalyst.types.IntegerType
import org.apache.spark.util.Utils
-import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, DataType}
-import org.apache.spark.sql.{parquet, SchemaRDD}
// Implicits
import org.apache.spark.sql.test.TestSQLContext._
@@ -64,12 +64,16 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
override def beforeAll() {
ParquetTestData.writeFile()
+ ParquetTestData.writeFilterFile()
testRDD = parquetFile(ParquetTestData.testDir.toString)
testRDD.registerAsTable("testsource")
+ parquetFile(ParquetTestData.testFilterDir.toString)
+ .registerAsTable("testfiltersource")
}
override def afterAll() {
Utils.deleteRecursively(ParquetTestData.testDir)
+ Utils.deleteRecursively(ParquetTestData.testFilterDir)
// here we should also unregister the table??
}
@@ -120,7 +124,7 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
val scanner = new ParquetTableScan(
ParquetTestData.testData.output,
ParquetTestData.testData,
- None)(TestSQLContext.sparkContext)
+ Seq())(TestSQLContext.sparkContext)
val projected = scanner.pruneColumns(ParquetTypesConverter
.convertToAttributes(MessageTypeParser
.parseMessageType(ParquetTestData.subTestSchema)))
@@ -196,7 +200,6 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
assert(true)
}
-
test("insert (appending) to same table via Scala API") {
sql("INSERT INTO testsource SELECT * FROM testsource").collect()
val double_rdd = sql("SELECT * FROM testsource").collect()
@@ -239,5 +242,121 @@ class ParquetQuerySuite extends QueryTest with FunSuite with BeforeAndAfterAll {
Utils.deleteRecursively(file)
assert(true)
}
+
+ test("create RecordFilter for simple predicates") {
+ val attribute1 = new AttributeReference("first", IntegerType, false)()
+ val predicate1 = new Equals(attribute1, new Literal(1, IntegerType))
+ val filter1 = ParquetFilters.createFilter(predicate1)
+ assert(filter1.isDefined)
+ assert(filter1.get.predicate == predicate1, "predicates do not match")
+ assert(filter1.get.isInstanceOf[ComparisonFilter])
+ val cmpFilter1 = filter1.get.asInstanceOf[ComparisonFilter]
+ assert(cmpFilter1.columnName == "first", "column name incorrect")
+
+ val predicate2 = new LessThan(attribute1, new Literal(4, IntegerType))
+ val filter2 = ParquetFilters.createFilter(predicate2)
+ assert(filter2.isDefined)
+ assert(filter2.get.predicate == predicate2, "predicates do not match")
+ assert(filter2.get.isInstanceOf[ComparisonFilter])
+ val cmpFilter2 = filter2.get.asInstanceOf[ComparisonFilter]
+ assert(cmpFilter2.columnName == "first", "column name incorrect")
+
+ val predicate3 = new And(predicate1, predicate2)
+ val filter3 = ParquetFilters.createFilter(predicate3)
+ assert(filter3.isDefined)
+ assert(filter3.get.predicate == predicate3, "predicates do not match")
+ assert(filter3.get.isInstanceOf[AndFilter])
+
+ val predicate4 = new Or(predicate1, predicate2)
+ val filter4 = ParquetFilters.createFilter(predicate4)
+ assert(filter4.isDefined)
+ assert(filter4.get.predicate == predicate4, "predicates do not match")
+ assert(filter4.get.isInstanceOf[OrFilter])
+
+ val attribute2 = new AttributeReference("second", IntegerType, false)()
+ val predicate5 = new GreaterThan(attribute1, attribute2)
+ val badfilter = ParquetFilters.createFilter(predicate5)
+ assert(badfilter.isDefined === false)
+ }
+
+ test("test filter by predicate pushdown") {
+ for(myval <- Seq("myint", "mylong", "mydouble", "myfloat")) {
+ println(s"testing field $myval")
+ val query1 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100")
+ assert(
+ query1.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+ "Top operator should be ParquetTableScan after pushdown")
+ val result1 = query1.collect()
+ assert(result1.size === 50)
+ assert(result1(0)(1) === 100)
+ assert(result1(49)(1) === 149)
+ val query2 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200")
+ assert(
+ query2.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+ "Top operator should be ParquetTableScan after pushdown")
+ val result2 = query2.collect()
+ assert(result2.size === 50)
+ if (myval == "myint" || myval == "mylong") {
+ assert(result2(0)(1) === 151)
+ assert(result2(49)(1) === 200)
+ } else {
+ assert(result2(0)(1) === 150)
+ assert(result2(49)(1) === 199)
+ }
+ }
+ for(myval <- Seq("myint", "mylong")) {
+ val query3 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190 OR $myval < 10")
+ assert(
+ query3.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+ "Top operator should be ParquetTableScan after pushdown")
+ val result3 = query3.collect()
+ assert(result3.size === 20)
+ assert(result3(0)(1) === 0)
+ assert(result3(9)(1) === 9)
+ assert(result3(10)(1) === 191)
+ assert(result3(19)(1) === 200)
+ }
+ for(myval <- Seq("mydouble", "myfloat")) {
+ val result4 =
+ if (myval == "mydouble") {
+ val query4 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10.0")
+ assert(
+ query4.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+ "Top operator should be ParquetTableScan after pushdown")
+ query4.collect()
+ } else {
+ // CASTs are problematic. Here myfloat will be casted to a double and it seems there is
+ // currently no way to specify float constants in SqlParser?
+ sql(s"SELECT * FROM testfiltersource WHERE $myval > 190.5 OR $myval < 10").collect()
+ }
+ assert(result4.size === 20)
+ assert(result4(0)(1) === 0)
+ assert(result4(9)(1) === 9)
+ assert(result4(10)(1) === 191)
+ assert(result4(19)(1) === 200)
+ }
+ val query5 = sql(s"SELECT * FROM testfiltersource WHERE myboolean = true AND myint < 40")
+ assert(
+ query5.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+ "Top operator should be ParquetTableScan after pushdown")
+ val booleanResult = query5.collect()
+ assert(booleanResult.size === 10)
+ for(i <- 0 until 10) {
+ if (!booleanResult(i).getBoolean(0)) {
+ fail(s"Boolean value in result row $i not true")
+ }
+ if (booleanResult(i).getInt(1) != i * 4) {
+ fail(s"Int value in result row $i should be ${4*i}")
+ }
+ }
+ val query6 = sql("SELECT * FROM testfiltersource WHERE mystring = \"100\"")
+ assert(
+ query6.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
+ "Top operator should be ParquetTableScan after pushdown")
+ val stringResult = query6.collect()
+ assert(stringResult.size === 1)
+ assert(stringResult(0).getString(2) == "100", "stringvalue incorrect")
+ assert(stringResult(0).getInt(1) === 100)
+ }
}