From 36b0956a3eadc7343ed0d25c79a6ce0496eaaccd Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 17 Nov 2014 16:55:12 -0800 Subject: [SPARK-4453][SPARK-4213][SQL] Simplifies Parquet filter generation code While reviewing PR #3083 and #3161, I noticed that Parquet record filter generation code can be simplified significantly according to the clue stated in [SPARK-4453](https://issues.apache.org/jira/browse/SPARK-4213). This PR addresses both SPARK-4453 and SPARK-4213 with this simplification. While generating `ParquetTableScan` operator, we need to remove all Catalyst predicates that have already been pushed down to Parquet. Originally, we first generate the record filter, and then call `findExpression` to traverse the generated filter to find out all pushed down predicates [[1](https://github.com/apache/spark/blob/64c6b9bad559c21f25cd9fbe37c8813cdab939f2/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala#L213-L228)]. In this way, we have to introduce the `CatalystFilter` class hierarchy to bind the Catalyst predicates together with their generated Parquet filter, and complicate the code base a lot. The basic idea of this PR is that, we don't need `findExpression` after filter generation, because we already know a predicate can be pushed down if we can successfully generate its corresponding Parquet filter. SPARK-4213 is fixed by returning `None` for any unsupported predicate type. [Review on Reviewable](https://reviewable.io/reviews/apache/spark/3317) Author: Cheng Lian Closes #3317 from liancheng/simplify-parquet-filters and squashes the following commits: d6a9499 [Cheng Lian] Fixes import styling issue 43760e8 [Cheng Lian] Simplifies Parquet filter generation logic --- .../spark/sql/execution/SparkStrategies.scala | 25 +- .../apache/spark/sql/parquet/ParquetFilters.scala | 693 +++------------------ .../spark/sql/parquet/ParquetTableOperations.scala | 77 +-- .../spark/sql/parquet/ParquetQuerySuite.scala | 58 +- 4 files changed, 160 insertions(+), 693 deletions(-) (limited to 'sql/core') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7ef1f9f2c5..1225d18857 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -209,22 +209,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => val prunePushedDownFilters = if (sqlContext.parquetFilterPushDown) { - (filters: Seq[Expression]) => { - filters.filter { filter => - // Note: filters cannot be pushed down to Parquet if they contain more complex - // expressions than simple "Attribute cmp Literal" comparisons. Here we remove - // all filters that have been pushed down. Note that a predicate such as - // "(A AND B) OR C" can result in "A OR C" being pushed down. - val recordFilter = ParquetFilters.createFilter(filter) - if (!recordFilter.isDefined) { - // First case: the pushdown did not result in any record filter. - true - } else { - // Second case: a record filter was created; here we are conservative in - // the sense that even if "A" was pushed and we check for "A AND B" we - // still want to keep "A AND B" in the higher-level filter, not just "B". - !ParquetFilters.findExpression(recordFilter.get, filter).isDefined - } + (predicates: Seq[Expression]) => { + // Note: filters cannot be pushed down to Parquet if they contain more complex + // expressions than simple "Attribute cmp Literal" comparisons. Here we remove all + // filters that have been pushed down. Note that a predicate such as "(A AND B) OR C" + // can result in "A OR C" being pushed down. Here we are conservative in the sense + // that even if "A" was pushed and we check for "A AND B" we still want to keep + // "A AND B" in the higher-level filter, not just "B". + predicates.map(p => p -> ParquetFilters.createFilter(p)).collect { + case (predicate, None) => predicate } } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 9a3f6d388d..3a9e1499e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -18,406 +18,152 @@ package org.apache.spark.sql.parquet import java.nio.ByteBuffer -import java.sql.{Date, Timestamp} +import com.google.common.io.BaseEncoding import org.apache.hadoop.conf.Configuration - -import parquet.common.schema.ColumnPath import parquet.filter2.compat.FilterCompat import parquet.filter2.compat.FilterCompat._ -import parquet.filter2.predicate.Operators.{Column, SupportsLtGt} -import parquet.filter2.predicate.{FilterApi, FilterPredicate} import parquet.filter2.predicate.FilterApi._ +import parquet.filter2.predicate.{FilterApi, FilterPredicate} import parquet.io.api.Binary -import parquet.column.ColumnReader - -import com.google.common.io.BaseEncoding import org.apache.spark.SparkEnv -import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.types.decimal.Decimal -import org.apache.spark.sql.catalyst.expressions.{Predicate => CatalystPredicate} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.sql.parquet.ParquetColumns._ +import org.apache.spark.sql.catalyst.types._ private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" - def createRecordFilter(filterExpressions: Seq[Expression]): Filter = { - val filters: Seq[CatalystFilter] = filterExpressions.collect { - case (expression: Expression) if createFilter(expression).isDefined => - createFilter(expression).get - } - if (filters.length > 0) FilterCompat.get(filters.reduce(FilterApi.and)) else null + def createRecordFilter(filterExpressions: Seq[Expression]): Option[Filter] = { + filterExpressions.flatMap(createFilter).reduceOption(FilterApi.and).map(FilterCompat.get) } - def createFilter(expression: Expression): Option[CatalystFilter] = { - def createEqualityFilter( - name: String, - literal: Literal, - predicate: CatalystPredicate) = literal.dataType match { + def createFilter(predicate: Expression): Option[FilterPredicate] = { + val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { case BooleanType => - ComparisonFilter.createBooleanEqualityFilter( - name, - literal.value.asInstanceOf[Boolean], - predicate) - case ByteType => - new ComparisonFilter( - name, - FilterApi.eq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), - predicate) - case ShortType => - new ComparisonFilter( - name, - FilterApi.eq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), - predicate) + (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean]) case IntegerType => - new ComparisonFilter( - name, - FilterApi.eq(intColumn(name), literal.value.asInstanceOf[Integer]), - predicate) + (n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer]) case LongType => - new ComparisonFilter( - name, - FilterApi.eq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), - predicate) - case DoubleType => - new ComparisonFilter( - name, - FilterApi.eq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), - predicate) + (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long]) case FloatType => - new ComparisonFilter( - name, - FilterApi.eq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), - predicate) + (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => - ComparisonFilter.createStringEqualityFilter( - name, - literal.value.asInstanceOf[String], - predicate) + (n: String, v: Any) => + FilterApi.eq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) case BinaryType => - ComparisonFilter.createBinaryEqualityFilter( - name, - literal.value.asInstanceOf[Array[Byte]], - predicate) - case DateType => - new ComparisonFilter( - name, - FilterApi.eq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), - predicate) - case TimestampType => - new ComparisonFilter( - name, - FilterApi.eq(timestampColumn(name), - new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), - predicate) - case DecimalType.Unlimited => - new ComparisonFilter( - name, - FilterApi.eq(decimalColumn(name), literal.value.asInstanceOf[Decimal]), - predicate) + (n: String, v: Any) => + FilterApi.eq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) } - def createLessThanFilter( - name: String, - literal: Literal, - predicate: CatalystPredicate) = literal.dataType match { - case ByteType => - new ComparisonFilter( - name, - FilterApi.lt(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), - predicate) - case ShortType => - new ComparisonFilter( - name, - FilterApi.lt(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), - predicate) + val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { case IntegerType => - new ComparisonFilter( - name, - FilterApi.lt(intColumn(name), literal.value.asInstanceOf[Integer]), - predicate) + (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer]) case LongType => - new ComparisonFilter( - name, - FilterApi.lt(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), - predicate) - case DoubleType => - new ComparisonFilter( - name, - FilterApi.lt(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), - predicate) + (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long]) case FloatType => - new ComparisonFilter( - name, - FilterApi.lt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), - predicate) + (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => - ComparisonFilter.createStringLessThanFilter( - name, - literal.value.asInstanceOf[String], - predicate) + (n: String, v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) case BinaryType => - ComparisonFilter.createBinaryLessThanFilter( - name, - literal.value.asInstanceOf[Array[Byte]], - predicate) - case DateType => - new ComparisonFilter( - name, - FilterApi.lt(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), - predicate) - case TimestampType => - new ComparisonFilter( - name, - FilterApi.lt(timestampColumn(name), - new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), - predicate) - case DecimalType.Unlimited => - new ComparisonFilter( - name, - FilterApi.lt(decimalColumn(name), literal.value.asInstanceOf[Decimal]), - predicate) + (n: String, v: Any) => + FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) } - def createLessThanOrEqualFilter( - name: String, - literal: Literal, - predicate: CatalystPredicate) = literal.dataType match { - case ByteType => - new ComparisonFilter( - name, - FilterApi.ltEq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), - predicate) - case ShortType => - new ComparisonFilter( - name, - FilterApi.ltEq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), - predicate) + + val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { case IntegerType => - new ComparisonFilter( - name, - FilterApi.ltEq(intColumn(name), literal.value.asInstanceOf[Integer]), - predicate) + (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) case LongType => - new ComparisonFilter( - name, - FilterApi.ltEq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), - predicate) - case DoubleType => - new ComparisonFilter( - name, - FilterApi.ltEq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), - predicate) + (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long]) case FloatType => - new ComparisonFilter( - name, - FilterApi.ltEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), - predicate) + (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => - ComparisonFilter.createStringLessThanOrEqualFilter( - name, - literal.value.asInstanceOf[String], - predicate) + (n: String, v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) case BinaryType => - ComparisonFilter.createBinaryLessThanOrEqualFilter( - name, - literal.value.asInstanceOf[Array[Byte]], - predicate) - case DateType => - new ComparisonFilter( - name, - FilterApi.ltEq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), - predicate) - case TimestampType => - new ComparisonFilter( - name, - FilterApi.ltEq(timestampColumn(name), - new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), - predicate) - case DecimalType.Unlimited => - new ComparisonFilter( - name, - FilterApi.ltEq(decimalColumn(name), literal.value.asInstanceOf[Decimal]), - predicate) + (n: String, v: Any) => + FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) } - // TODO: combine these two types somehow? - def createGreaterThanFilter( - name: String, - literal: Literal, - predicate: CatalystPredicate) = literal.dataType match { - case ByteType => - new ComparisonFilter( - name, - FilterApi.gt(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), - predicate) - case ShortType => - new ComparisonFilter( - name, - FilterApi.gt(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), - predicate) + + val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = { case IntegerType => - new ComparisonFilter( - name, - FilterApi.gt(intColumn(name), literal.value.asInstanceOf[Integer]), - predicate) + (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[java.lang.Integer]) case LongType => - new ComparisonFilter( - name, - FilterApi.gt(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), - predicate) - case DoubleType => - new ComparisonFilter( - name, - FilterApi.gt(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), - predicate) + (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long]) case FloatType => - new ComparisonFilter( - name, - FilterApi.gt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), - predicate) + (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => - ComparisonFilter.createStringGreaterThanFilter( - name, - literal.value.asInstanceOf[String], - predicate) + (n: String, v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) case BinaryType => - ComparisonFilter.createBinaryGreaterThanFilter( - name, - literal.value.asInstanceOf[Array[Byte]], - predicate) - case DateType => - new ComparisonFilter( - name, - FilterApi.gt(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), - predicate) - case TimestampType => - new ComparisonFilter( - name, - FilterApi.gt(timestampColumn(name), - new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), - predicate) - case DecimalType.Unlimited => - new ComparisonFilter( - name, - FilterApi.gt(decimalColumn(name), literal.value.asInstanceOf[Decimal]), - predicate) + (n: String, v: Any) => + FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) } - def createGreaterThanOrEqualFilter( - name: String, - literal: Literal, - predicate: CatalystPredicate) = literal.dataType match { - case ByteType => - new ComparisonFilter( - name, - FilterApi.gtEq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), - predicate) - case ShortType => - new ComparisonFilter( - name, - FilterApi.gtEq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), - predicate) + + val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = { case IntegerType => - new ComparisonFilter( - name, - FilterApi.gtEq(intColumn(name), literal.value.asInstanceOf[Integer]), - predicate) + (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[java.lang.Integer]) case LongType => - new ComparisonFilter( - name, - FilterApi.gtEq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), - predicate) - case DoubleType => - new ComparisonFilter( - name, - FilterApi.gtEq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), - predicate) + (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long]) case FloatType => - new ComparisonFilter( - name, - FilterApi.gtEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), - predicate) + (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float]) + case DoubleType => + (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double]) case StringType => - ComparisonFilter.createStringGreaterThanOrEqualFilter( - name, - literal.value.asInstanceOf[String], - predicate) + (n: String, v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String])) case BinaryType => - ComparisonFilter.createBinaryGreaterThanOrEqualFilter( - name, - literal.value.asInstanceOf[Array[Byte]], - predicate) - case DateType => - new ComparisonFilter( - name, - FilterApi.gtEq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), - predicate) - case TimestampType => - new ComparisonFilter( - name, - FilterApi.gtEq(timestampColumn(name), - new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), - predicate) - case DecimalType.Unlimited => - new ComparisonFilter( - name, - FilterApi.gtEq(decimalColumn(name), literal.value.asInstanceOf[Decimal]), - predicate) + (n: String, v: Any) => + FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]])) } - /** - * TODO: we currently only filter on non-nullable (Parquet REQUIRED) attributes until - * https://github.com/Parquet/parquet-mr/issues/371 - * has been resolved. - */ - expression match { - case p @ Or(left: Expression, right: Expression) - if createFilter(left).isDefined && createFilter(right).isDefined => { - // If either side of this Or-predicate is empty then this means - // it contains a more complex comparison than between attribute and literal - // (e.g., it contained a CAST). The only safe thing to do is then to disregard - // this disjunction, which could be contained in a conjunction. If it stands - // alone then it is also safe to drop it, since a Null return value of this - // function is interpreted as having no filters at all. - val leftFilter = createFilter(left).get - val rightFilter = createFilter(right).get - Some(new OrFilter(leftFilter, rightFilter)) - } - case p @ And(left: Expression, right: Expression) => { - // This treats nested conjunctions; since either side of the conjunction - // may contain more complex filter expressions we may actually generate - // strictly weaker filter predicates in the process. - val leftFilter = createFilter(left) - val rightFilter = createFilter(right) - (leftFilter, rightFilter) match { - case (None, Some(filter)) => Some(filter) - case (Some(filter), None) => Some(filter) - case (Some(leftF), Some(rightF)) => - Some(new AndFilter(leftF, rightF)) - case _ => None - } - } - case p @ EqualTo(left: Literal, right: NamedExpression) if left.dataType != NullType => - Some(createEqualityFilter(right.name, left, p)) - case p @ EqualTo(left: NamedExpression, right: Literal) if right.dataType != NullType => - Some(createEqualityFilter(left.name, right, p)) - case p @ LessThan(left: Literal, right: NamedExpression) => - Some(createLessThanFilter(right.name, left, p)) - case p @ LessThan(left: NamedExpression, right: Literal) => - Some(createLessThanFilter(left.name, right, p)) - case p @ LessThanOrEqual(left: Literal, right: NamedExpression) => - Some(createLessThanOrEqualFilter(right.name, left, p)) - case p @ LessThanOrEqual(left: NamedExpression, right: Literal) => - Some(createLessThanOrEqualFilter(left.name, right, p)) - case p @ GreaterThan(left: Literal, right: NamedExpression) => - Some(createGreaterThanFilter(right.name, left, p)) - case p @ GreaterThan(left: NamedExpression, right: Literal) => - Some(createGreaterThanFilter(left.name, right, p)) - case p @ GreaterThanOrEqual(left: Literal, right: NamedExpression) => - Some(createGreaterThanOrEqualFilter(right.name, left, p)) - case p @ GreaterThanOrEqual(left: NamedExpression, right: Literal) => - Some(createGreaterThanOrEqualFilter(left.name, right, p)) + predicate match { + case EqualTo(NamedExpression(name, _), Literal(value, dataType)) if dataType != NullType => + makeEq.lift(dataType).map(_(name, value)) + case EqualTo(Literal(value, dataType), NamedExpression(name, _)) if dataType != NullType => + makeEq.lift(dataType).map(_(name, value)) + + case LessThan(NamedExpression(name, _), Literal(value, dataType)) => + makeLt.lift(dataType).map(_(name, value)) + case LessThan(Literal(value, dataType), NamedExpression(name, _)) => + makeLt.lift(dataType).map(_(name, value)) + + case LessThanOrEqual(NamedExpression(name, _), Literal(value, dataType)) => + makeLtEq.lift(dataType).map(_(name, value)) + case LessThanOrEqual(Literal(value, dataType), NamedExpression(name, _)) => + makeLtEq.lift(dataType).map(_(name, value)) + + case GreaterThan(NamedExpression(name, _), Literal(value, dataType)) => + makeGt.lift(dataType).map(_(name, value)) + case GreaterThan(Literal(value, dataType), NamedExpression(name, _)) => + makeGt.lift(dataType).map(_(name, value)) + + case GreaterThanOrEqual(NamedExpression(name, _), Literal(value, dataType)) => + makeGtEq.lift(dataType).map(_(name, value)) + case GreaterThanOrEqual(Literal(value, dataType), NamedExpression(name, _)) => + makeGtEq.lift(dataType).map(_(name, value)) + + case And(lhs, rhs) => + (createFilter(lhs) ++ createFilter(rhs)).reduceOption(FilterApi.and) + + case Or(lhs, rhs) => + for { + lhsFilter <- createFilter(lhs) + rhsFilter <- createFilter(rhs) + } yield FilterApi.or(lhsFilter, rhsFilter) + + case Not(pred) => + createFilter(pred).map(FilterApi.not) + case _ => None } } @@ -428,7 +174,7 @@ private[sql] object ParquetFilters { * the actual filter predicate. */ def serializeFilterExpressions(filters: Seq[Expression], conf: Configuration): Unit = { - if (filters.length > 0) { + if (filters.nonEmpty) { val serialized: Array[Byte] = SparkEnv.get.closureSerializer.newInstance().serialize(filters).array() val encoded: String = BaseEncoding.base64().encode(serialized) @@ -450,245 +196,4 @@ private[sql] object ParquetFilters { Seq() } } - - /** - * Try to find the given expression in the tree of filters in order to - * determine whether it is safe to remove it from the higher level filters. Note - * that strictly speaking we could stop the search whenever an expression is found - * that contains this expression as subexpression (e.g., when searching for "a" - * and "(a or c)" is found) but we don't care about optimizations here since the - * filter tree is assumed to be small. - * - * @param filter The [[org.apache.spark.sql.parquet.CatalystFilter]] to expand - * and search - * @param expression The expression to look for - * @return An optional [[org.apache.spark.sql.parquet.CatalystFilter]] that - * contains the expression. - */ - def findExpression( - filter: CatalystFilter, - expression: Expression): Option[CatalystFilter] = filter match { - case f @ OrFilter(_, leftFilter, rightFilter, _) => - if (f.predicate == expression) { - Some(f) - } else { - val left = findExpression(leftFilter, expression) - if (left.isDefined) left else findExpression(rightFilter, expression) - } - case f @ AndFilter(_, leftFilter, rightFilter, _) => - if (f.predicate == expression) { - Some(f) - } else { - val left = findExpression(leftFilter, expression) - if (left.isDefined) left else findExpression(rightFilter, expression) - } - case f @ ComparisonFilter(_, _, predicate) => - if (predicate == expression) Some(f) else None - case _ => None - } -} - -abstract private[parquet] class CatalystFilter( - @transient val predicate: CatalystPredicate) extends FilterPredicate - -private[parquet] case class ComparisonFilter( - val columnName: String, - private var filter: FilterPredicate, - @transient override val predicate: CatalystPredicate) - extends CatalystFilter(predicate) { - override def accept[R](visitor: FilterPredicate.Visitor[R]): R = { - filter.accept(visitor) - } -} - -private[parquet] case class OrFilter( - private var filter: FilterPredicate, - @transient val left: CatalystFilter, - @transient val right: CatalystFilter, - @transient override val predicate: Or) - extends CatalystFilter(predicate) { - def this(l: CatalystFilter, r: CatalystFilter) = - this( - FilterApi.or(l, r), - l, - r, - Or(l.predicate, r.predicate)) - - override def accept[R](visitor: FilterPredicate.Visitor[R]): R = { - filter.accept(visitor); - } - -} - -private[parquet] case class AndFilter( - private var filter: FilterPredicate, - @transient val left: CatalystFilter, - @transient val right: CatalystFilter, - @transient override val predicate: And) - extends CatalystFilter(predicate) { - def this(l: CatalystFilter, r: CatalystFilter) = - this( - FilterApi.and(l, r), - l, - r, - And(l.predicate, r.predicate)) - - override def accept[R](visitor: FilterPredicate.Visitor[R]): R = { - filter.accept(visitor); - } - -} - -private[parquet] object ComparisonFilter { - def createBooleanEqualityFilter( - columnName: String, - value: Boolean, - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - FilterApi.eq(booleanColumn(columnName), value.asInstanceOf[java.lang.Boolean]), - predicate) - - def createStringEqualityFilter( - columnName: String, - value: String, - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - FilterApi.eq(binaryColumn(columnName), Binary.fromString(value)), - predicate) - - def createStringLessThanFilter( - columnName: String, - value: String, - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - FilterApi.lt(binaryColumn(columnName), Binary.fromString(value)), - predicate) - - def createStringLessThanOrEqualFilter( - columnName: String, - value: String, - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - FilterApi.ltEq(binaryColumn(columnName), Binary.fromString(value)), - predicate) - - def createStringGreaterThanFilter( - columnName: String, - value: String, - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - FilterApi.gt(binaryColumn(columnName), Binary.fromString(value)), - predicate) - - def createStringGreaterThanOrEqualFilter( - columnName: String, - value: String, - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - FilterApi.gtEq(binaryColumn(columnName), Binary.fromString(value)), - predicate) - - def createBinaryEqualityFilter( - columnName: String, - value: Array[Byte], - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - FilterApi.eq(binaryColumn(columnName), Binary.fromByteArray(value)), - predicate) - - def createBinaryLessThanFilter( - columnName: String, - value: Array[Byte], - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - FilterApi.lt(binaryColumn(columnName), Binary.fromByteArray(value)), - predicate) - - def createBinaryLessThanOrEqualFilter( - columnName: String, - value: Array[Byte], - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - FilterApi.ltEq(binaryColumn(columnName), Binary.fromByteArray(value)), - predicate) - - def createBinaryGreaterThanFilter( - columnName: String, - value: Array[Byte], - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - FilterApi.gt(binaryColumn(columnName), Binary.fromByteArray(value)), - predicate) - - def createBinaryGreaterThanOrEqualFilter( - columnName: String, - value: Array[Byte], - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - FilterApi.gtEq(binaryColumn(columnName), Binary.fromByteArray(value)), - predicate) -} - -private[spark] object ParquetColumns { - - def byteColumn(columnPath: String): ByteColumn = { - new ByteColumn(ColumnPath.fromDotString(columnPath)) - } - - final class ByteColumn(columnPath: ColumnPath) - extends Column[java.lang.Byte](columnPath, classOf[java.lang.Byte]) with SupportsLtGt - - def shortColumn(columnPath: String): ShortColumn = { - new ShortColumn(ColumnPath.fromDotString(columnPath)) - } - - final class ShortColumn(columnPath: ColumnPath) - extends Column[java.lang.Short](columnPath, classOf[java.lang.Short]) with SupportsLtGt - - - def dateColumn(columnPath: String): DateColumn = { - new DateColumn(ColumnPath.fromDotString(columnPath)) - } - - final class DateColumn(columnPath: ColumnPath) - extends Column[WrappedDate](columnPath, classOf[WrappedDate]) with SupportsLtGt - - def timestampColumn(columnPath: String): TimestampColumn = { - new TimestampColumn(ColumnPath.fromDotString(columnPath)) - } - - final class TimestampColumn(columnPath: ColumnPath) - extends Column[WrappedTimestamp](columnPath, classOf[WrappedTimestamp]) with SupportsLtGt - - def decimalColumn(columnPath: String): DecimalColumn = { - new DecimalColumn(ColumnPath.fromDotString(columnPath)) - } - - final class DecimalColumn(columnPath: ColumnPath) - extends Column[Decimal](columnPath, classOf[Decimal]) with SupportsLtGt - - final class WrappedDate(val date: Date) extends Comparable[WrappedDate] { - - override def compareTo(other: WrappedDate): Int = { - date.compareTo(other.date) - } - } - - final class WrappedTimestamp(val timestamp: Timestamp) extends Comparable[WrappedTimestamp] { - - override def compareTo(other: WrappedTimestamp): Int = { - timestamp.compareTo(other.timestamp) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index f6bed5016f..5d0643a64a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -23,8 +23,6 @@ import java.text.SimpleDateFormat import java.util.concurrent.{Callable, TimeUnit} import java.util.{ArrayList, Collections, Date, List => JList} -import org.apache.spark.annotation.DeveloperApi - import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Try @@ -34,22 +32,20 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} -import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter - +import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat => NewFileOutputFormat} import parquet.hadoop._ +import parquet.hadoop.api.ReadSupport.ReadContext import parquet.hadoop.api.{InitContext, ReadSupport} import parquet.hadoop.metadata.GlobalMetaData -import parquet.hadoop.api.ReadSupport.ReadContext import parquet.hadoop.util.ContextUtil import parquet.io.ParquetDecodingException import parquet.schema.MessageType +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row, _} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} import org.apache.spark.{Logging, SerializableWritable, TaskContext} @@ -82,8 +78,6 @@ case class ParquetTableScan( override def execute(): RDD[Row] = { import parquet.filter2.compat.FilterCompat.FilterPredicateCompat - import parquet.filter2.compat.FilterCompat.Filter - import parquet.filter2.predicate.FilterPredicate val sc = sqlContext.sparkContext val job = new Job(sc.hadoopConfiguration) @@ -111,14 +105,11 @@ case class ParquetTableScan( // Note 1: the input format ignores all predicates that cannot be expressed // as simple column predicate filters in Parquet. Here we just record // the whole pruning predicate. - if (columnPruningPred.length > 0) { + ParquetFilters + .createRecordFilter(columnPruningPred) + .map(_.asInstanceOf[FilterPredicateCompat].getFilterPredicate) // Set this in configuration of ParquetInputFormat, needed for RowGroupFiltering - val filter: Filter = ParquetFilters.createRecordFilter(columnPruningPred) - if (filter != null){ - val filterPredicate = filter.asInstanceOf[FilterPredicateCompat].getFilterPredicate - ParquetInputFormat.setFilterPredicate(conf, filterPredicate) - } - } + .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata conf.set( @@ -317,7 +308,7 @@ case class InsertIntoParquetTable( } writer.close(hadoopContext) committer.commitTask(hadoopContext) - return 1 + 1 } val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) /* apparently we need a TaskAttemptID to construct an OutputCommitter; @@ -375,9 +366,8 @@ private[parquet] class FilteringParquetRowInputFormat override def createRecordReader( inputSplit: InputSplit, taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = { - + import parquet.filter2.compat.FilterCompat.NoOpFilter - import parquet.filter2.compat.FilterCompat.Filter val readSupport: ReadSupport[Row] = new RowReadSupport() @@ -392,7 +382,7 @@ private[parquet] class FilteringParquetRowInputFormat } override def getFooters(jobContext: JobContext): JList[Footer] = { - import FilteringParquetRowInputFormat.footerCache + import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.footerCache if (footers eq null) { val conf = ContextUtil.getConfiguration(jobContext) @@ -442,13 +432,13 @@ private[parquet] class FilteringParquetRowInputFormat val taskSideMetaData = configuration.getBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true) val maxSplitSize: JLong = configuration.getLong("mapred.max.split.size", Long.MaxValue) val minSplitSize: JLong = - Math.max(getFormatMinSplitSize(), configuration.getLong("mapred.min.split.size", 0L)) + Math.max(getFormatMinSplitSize, configuration.getLong("mapred.min.split.size", 0L)) if (maxSplitSize < 0 || minSplitSize < 0) { throw new ParquetDecodingException( s"maxSplitSize or minSplitSie should not be negative: maxSplitSize = $maxSplitSize;" + s" minSplitSize = $minSplitSize") } - + // Uses strict type checking by default val getGlobalMetaData = classOf[ParquetFileWriter].getDeclaredMethod("getGlobalMetaData", classOf[JList[Footer]]) @@ -458,29 +448,29 @@ private[parquet] class FilteringParquetRowInputFormat if (globalMetaData == null) { val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] return splits - } - + } + val readContext = getReadSupport(configuration).init( new InitContext(configuration, - globalMetaData.getKeyValueMetaData(), - globalMetaData.getSchema())) - + globalMetaData.getKeyValueMetaData, + globalMetaData.getSchema)) + if (taskSideMetaData){ logInfo("Using Task Side Metadata Split Strategy") - return getTaskSideSplits(configuration, + getTaskSideSplits(configuration, footers, maxSplitSize, minSplitSize, readContext) } else { logInfo("Using Client Side Metadata Split Strategy") - return getClientSideSplits(configuration, + getClientSideSplits(configuration, footers, maxSplitSize, minSplitSize, readContext) } - + } def getClientSideSplits( @@ -489,12 +479,11 @@ private[parquet] class FilteringParquetRowInputFormat maxSplitSize: JLong, minSplitSize: JLong, readContext: ReadContext): JList[ParquetInputSplit] = { - - import FilteringParquetRowInputFormat.blockLocationCache - import parquet.filter2.compat.FilterCompat; - import parquet.filter2.compat.FilterCompat.Filter; - import parquet.filter2.compat.RowGroupFilter; - + + import parquet.filter2.compat.FilterCompat.Filter + import parquet.filter2.compat.RowGroupFilter + import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.blockLocationCache + val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true) val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] @@ -503,7 +492,7 @@ private[parquet] class FilteringParquetRowInputFormat var totalRowGroups: Long = 0 // Ugly hack, stuck with it until PR: - // https://github.com/apache/incubator-parquet-mr/pull/17 + // https://github.com/apache/incubator-parquet-mr/pull/17 // is resolved val generateSplits = Class.forName("parquet.hadoop.ClientSideMetadataSplitStrategy") @@ -523,7 +512,7 @@ private[parquet] class FilteringParquetRowInputFormat blocks, parquetMetaData.getFileMetaData.getSchema) rowGroupsDropped = rowGroupsDropped + (blocks.size - filteredBlocks.size) - + if (!filteredBlocks.isEmpty){ var blockLocations: Array[BlockLocation] = null if (!cacheMetadata) { @@ -566,7 +555,7 @@ private[parquet] class FilteringParquetRowInputFormat readContext: ReadContext): JList[ParquetInputSplit] = { val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] - + // Ugly hack, stuck with it until PR: // https://github.com/apache/incubator-parquet-mr/pull/17 // is resolved @@ -576,7 +565,7 @@ private[parquet] class FilteringParquetRowInputFormat sys.error( s"Failed to reflectively invoke TaskSideMetadataSplitStrategy.generateTaskSideMDSplits")) generateSplits.setAccessible(true) - + for (footer <- footers) { val file = footer.getFile val fs = file.getFileSystem(configuration) @@ -594,7 +583,7 @@ private[parquet] class FilteringParquetRowInputFormat } splits - } + } } @@ -636,11 +625,9 @@ private[parquet] object FileSystemHelper { files.map(_.getName).map { case nameP(taskid) => taskid.toInt case hiddenFileP() => 0 - case other: String => { + case other: String => sys.error("ERROR: attempting to append to set of Parquet files and found file" + s"that does not match name pattern: $other") - 0 - } case _ => 0 }.reduceLeft((a, b) => if (a < b) b else a) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 80a3e0b4c9..d31a9d8418 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.parquet +import _root_.parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.Job import org.scalatest.{BeforeAndAfterAll, FunSuiteLike} import parquet.hadoop.ParquetFileWriter import parquet.hadoop.util.ContextUtil + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.IntegerType @@ -447,44 +449,24 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(true) } - test("create RecordFilter for simple predicates") { - val attribute1 = new AttributeReference("first", IntegerType, false)() - val predicate1 = new EqualTo(attribute1, new Literal(1, IntegerType)) - val filter1 = ParquetFilters.createFilter(predicate1) - assert(filter1.isDefined) - assert(filter1.get.predicate == predicate1, "predicates do not match") - assert(filter1.get.isInstanceOf[ComparisonFilter]) - val cmpFilter1 = filter1.get.asInstanceOf[ComparisonFilter] - assert(cmpFilter1.columnName == "first", "column name incorrect") - - val predicate2 = new LessThan(attribute1, new Literal(4, IntegerType)) - val filter2 = ParquetFilters.createFilter(predicate2) - assert(filter2.isDefined) - assert(filter2.get.predicate == predicate2, "predicates do not match") - assert(filter2.get.isInstanceOf[ComparisonFilter]) - val cmpFilter2 = filter2.get.asInstanceOf[ComparisonFilter] - assert(cmpFilter2.columnName == "first", "column name incorrect") - - val predicate3 = new And(predicate1, predicate2) - val filter3 = ParquetFilters.createFilter(predicate3) - assert(filter3.isDefined) - assert(filter3.get.predicate == predicate3, "predicates do not match") - assert(filter3.get.isInstanceOf[AndFilter]) - - val predicate4 = new Or(predicate1, predicate2) - val filter4 = ParquetFilters.createFilter(predicate4) - assert(filter4.isDefined) - assert(filter4.get.predicate == predicate4, "predicates do not match") - assert(filter4.get.isInstanceOf[OrFilter]) - - val attribute2 = new AttributeReference("second", IntegerType, false)() - val predicate5 = new GreaterThan(attribute1, attribute2) - val badfilter = ParquetFilters.createFilter(predicate5) - assert(badfilter.isDefined === false) - - val predicate6 = And(GreaterThan(attribute1, attribute2), GreaterThan(attribute1, attribute2)) - val badfilter2 = ParquetFilters.createFilter(predicate6) - assert(badfilter2.isDefined === false) + test("make RecordFilter for simple predicates") { + def checkFilter[T <: FilterPredicate](predicate: Expression, defined: Boolean = true): Unit = { + val filter = ParquetFilters.createFilter(predicate) + if (defined) { + assert(filter.isDefined) + assert(filter.get.isInstanceOf[T]) + } else { + assert(filter.isEmpty) + } + } + + checkFilter[Operators.Eq[Integer]]('a.int === 1) + checkFilter[Operators.Lt[Integer]]('a.int < 4) + checkFilter[Operators.And]('a.int === 1 && 'a.int < 4) + checkFilter[Operators.Or]('a.int === 1 || 'a.int < 4) + + checkFilter('a.int > 'b.int, defined = false) + checkFilter(('a.int > 'b.int) && ('a.int > 'b.int), defined = false) } test("test filter by predicate pushdown") { -- cgit v1.2.3