aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2014-11-17 16:55:12 -0800
committerMichael Armbrust <michael@databricks.com>2014-11-17 16:55:12 -0800
commit36b0956a3eadc7343ed0d25c79a6ce0496eaaccd (patch)
tree47fba8e9a00b21b20b77342a9a45f9d0f9969489 /sql
parentef7c464effa1510b24bd8e665e4df6c4839b0c87 (diff)
downloadspark-36b0956a3eadc7343ed0d25c79a6ce0496eaaccd.tar.gz
spark-36b0956a3eadc7343ed0d25c79a6ce0496eaaccd.tar.bz2
spark-36b0956a3eadc7343ed0d25c79a6ce0496eaaccd.zip
[SPARK-4453][SPARK-4213][SQL] Simplifies Parquet filter generation code
While reviewing PR #3083 and #3161, I noticed that Parquet record filter generation code can be simplified significantly according to the clue stated in [SPARK-4453](https://issues.apache.org/jira/browse/SPARK-4213). This PR addresses both SPARK-4453 and SPARK-4213 with this simplification. While generating `ParquetTableScan` operator, we need to remove all Catalyst predicates that have already been pushed down to Parquet. Originally, we first generate the record filter, and then call `findExpression` to traverse the generated filter to find out all pushed down predicates [[1](https://github.com/apache/spark/blob/64c6b9bad559c21f25cd9fbe37c8813cdab939f2/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala#L213-L228)]. In this way, we have to introduce the `CatalystFilter` class hierarchy to bind the Catalyst predicates together with their generated Parquet filter, and complicate the code base a lot. The basic idea of this PR is that, we don't need `findExpression` after filter generation, because we already know a predicate can be pushed down if we can successfully generate its corresponding Parquet filter. SPARK-4213 is fixed by returning `None` for any unsupported predicate type. <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/3317) <!-- Reviewable:end --> Author: Cheng Lian <lian@databricks.com> Closes #3317 from liancheng/simplify-parquet-filters and squashes the following commits: d6a9499 [Cheng Lian] Fixes import styling issue 43760e8 [Cheng Lian] Simplifies Parquet filter generation logic
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala693
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala77
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala58
5 files changed, 161 insertions, 693 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index fc90a54a58..7634d392d4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.util.Metadata
object NamedExpression {
private val curId = new java.util.concurrent.atomic.AtomicLong()
def newExprId = ExprId(curId.getAndIncrement())
+ def unapply(expr: NamedExpression): Option[(String, DataType)] = Some(expr.name, expr.dataType)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 7ef1f9f2c5..1225d18857 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -209,22 +209,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) =>
val prunePushedDownFilters =
if (sqlContext.parquetFilterPushDown) {
- (filters: Seq[Expression]) => {
- filters.filter { filter =>
- // Note: filters cannot be pushed down to Parquet if they contain more complex
- // expressions than simple "Attribute cmp Literal" comparisons. Here we remove
- // all filters that have been pushed down. Note that a predicate such as
- // "(A AND B) OR C" can result in "A OR C" being pushed down.
- val recordFilter = ParquetFilters.createFilter(filter)
- if (!recordFilter.isDefined) {
- // First case: the pushdown did not result in any record filter.
- true
- } else {
- // Second case: a record filter was created; here we are conservative in
- // the sense that even if "A" was pushed and we check for "A AND B" we
- // still want to keep "A AND B" in the higher-level filter, not just "B".
- !ParquetFilters.findExpression(recordFilter.get, filter).isDefined
- }
+ (predicates: Seq[Expression]) => {
+ // Note: filters cannot be pushed down to Parquet if they contain more complex
+ // expressions than simple "Attribute cmp Literal" comparisons. Here we remove all
+ // filters that have been pushed down. Note that a predicate such as "(A AND B) OR C"
+ // can result in "A OR C" being pushed down. Here we are conservative in the sense
+ // that even if "A" was pushed and we check for "A AND B" we still want to keep
+ // "A AND B" in the higher-level filter, not just "B".
+ predicates.map(p => p -> ParquetFilters.createFilter(p)).collect {
+ case (predicate, None) => predicate
}
}
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
index 9a3f6d388d..3a9e1499e2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala
@@ -18,406 +18,152 @@
package org.apache.spark.sql.parquet
import java.nio.ByteBuffer
-import java.sql.{Date, Timestamp}
+import com.google.common.io.BaseEncoding
import org.apache.hadoop.conf.Configuration
-
-import parquet.common.schema.ColumnPath
import parquet.filter2.compat.FilterCompat
import parquet.filter2.compat.FilterCompat._
-import parquet.filter2.predicate.Operators.{Column, SupportsLtGt}
-import parquet.filter2.predicate.{FilterApi, FilterPredicate}
import parquet.filter2.predicate.FilterApi._
+import parquet.filter2.predicate.{FilterApi, FilterPredicate}
import parquet.io.api.Binary
-import parquet.column.ColumnReader
-
-import com.google.common.io.BaseEncoding
import org.apache.spark.SparkEnv
-import org.apache.spark.sql.catalyst.types._
-import org.apache.spark.sql.catalyst.types.decimal.Decimal
-import org.apache.spark.sql.catalyst.expressions.{Predicate => CatalystPredicate}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.SparkSqlSerializer
-import org.apache.spark.sql.parquet.ParquetColumns._
+import org.apache.spark.sql.catalyst.types._
private[sql] object ParquetFilters {
val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter"
- def createRecordFilter(filterExpressions: Seq[Expression]): Filter = {
- val filters: Seq[CatalystFilter] = filterExpressions.collect {
- case (expression: Expression) if createFilter(expression).isDefined =>
- createFilter(expression).get
- }
- if (filters.length > 0) FilterCompat.get(filters.reduce(FilterApi.and)) else null
+ def createRecordFilter(filterExpressions: Seq[Expression]): Option[Filter] = {
+ filterExpressions.flatMap(createFilter).reduceOption(FilterApi.and).map(FilterCompat.get)
}
- def createFilter(expression: Expression): Option[CatalystFilter] = {
- def createEqualityFilter(
- name: String,
- literal: Literal,
- predicate: CatalystPredicate) = literal.dataType match {
+ def createFilter(predicate: Expression): Option[FilterPredicate] = {
+ val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
case BooleanType =>
- ComparisonFilter.createBooleanEqualityFilter(
- name,
- literal.value.asInstanceOf[Boolean],
- predicate)
- case ByteType =>
- new ComparisonFilter(
- name,
- FilterApi.eq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]),
- predicate)
- case ShortType =>
- new ComparisonFilter(
- name,
- FilterApi.eq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]),
- predicate)
+ (n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean])
case IntegerType =>
- new ComparisonFilter(
- name,
- FilterApi.eq(intColumn(name), literal.value.asInstanceOf[Integer]),
- predicate)
+ (n: String, v: Any) => FilterApi.eq(intColumn(n), v.asInstanceOf[Integer])
case LongType =>
- new ComparisonFilter(
- name,
- FilterApi.eq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]),
- predicate)
- case DoubleType =>
- new ComparisonFilter(
- name,
- FilterApi.eq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]),
- predicate)
+ (n: String, v: Any) => FilterApi.eq(longColumn(n), v.asInstanceOf[java.lang.Long])
case FloatType =>
- new ComparisonFilter(
- name,
- FilterApi.eq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]),
- predicate)
+ (n: String, v: Any) => FilterApi.eq(floatColumn(n), v.asInstanceOf[java.lang.Float])
+ case DoubleType =>
+ (n: String, v: Any) => FilterApi.eq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
- ComparisonFilter.createStringEqualityFilter(
- name,
- literal.value.asInstanceOf[String],
- predicate)
+ (n: String, v: Any) =>
+ FilterApi.eq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
case BinaryType =>
- ComparisonFilter.createBinaryEqualityFilter(
- name,
- literal.value.asInstanceOf[Array[Byte]],
- predicate)
- case DateType =>
- new ComparisonFilter(
- name,
- FilterApi.eq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])),
- predicate)
- case TimestampType =>
- new ComparisonFilter(
- name,
- FilterApi.eq(timestampColumn(name),
- new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])),
- predicate)
- case DecimalType.Unlimited =>
- new ComparisonFilter(
- name,
- FilterApi.eq(decimalColumn(name), literal.value.asInstanceOf[Decimal]),
- predicate)
+ (n: String, v: Any) =>
+ FilterApi.eq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
}
- def createLessThanFilter(
- name: String,
- literal: Literal,
- predicate: CatalystPredicate) = literal.dataType match {
- case ByteType =>
- new ComparisonFilter(
- name,
- FilterApi.lt(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]),
- predicate)
- case ShortType =>
- new ComparisonFilter(
- name,
- FilterApi.lt(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]),
- predicate)
+ val makeLt: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
case IntegerType =>
- new ComparisonFilter(
- name,
- FilterApi.lt(intColumn(name), literal.value.asInstanceOf[Integer]),
- predicate)
+ (n: String, v: Any) => FilterApi.lt(intColumn(n), v.asInstanceOf[Integer])
case LongType =>
- new ComparisonFilter(
- name,
- FilterApi.lt(longColumn(name), literal.value.asInstanceOf[java.lang.Long]),
- predicate)
- case DoubleType =>
- new ComparisonFilter(
- name,
- FilterApi.lt(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]),
- predicate)
+ (n: String, v: Any) => FilterApi.lt(longColumn(n), v.asInstanceOf[java.lang.Long])
case FloatType =>
- new ComparisonFilter(
- name,
- FilterApi.lt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]),
- predicate)
+ (n: String, v: Any) => FilterApi.lt(floatColumn(n), v.asInstanceOf[java.lang.Float])
+ case DoubleType =>
+ (n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
- ComparisonFilter.createStringLessThanFilter(
- name,
- literal.value.asInstanceOf[String],
- predicate)
+ (n: String, v: Any) =>
+ FilterApi.lt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
case BinaryType =>
- ComparisonFilter.createBinaryLessThanFilter(
- name,
- literal.value.asInstanceOf[Array[Byte]],
- predicate)
- case DateType =>
- new ComparisonFilter(
- name,
- FilterApi.lt(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])),
- predicate)
- case TimestampType =>
- new ComparisonFilter(
- name,
- FilterApi.lt(timestampColumn(name),
- new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])),
- predicate)
- case DecimalType.Unlimited =>
- new ComparisonFilter(
- name,
- FilterApi.lt(decimalColumn(name), literal.value.asInstanceOf[Decimal]),
- predicate)
+ (n: String, v: Any) =>
+ FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
}
- def createLessThanOrEqualFilter(
- name: String,
- literal: Literal,
- predicate: CatalystPredicate) = literal.dataType match {
- case ByteType =>
- new ComparisonFilter(
- name,
- FilterApi.ltEq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]),
- predicate)
- case ShortType =>
- new ComparisonFilter(
- name,
- FilterApi.ltEq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]),
- predicate)
+
+ val makeLtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
case IntegerType =>
- new ComparisonFilter(
- name,
- FilterApi.ltEq(intColumn(name), literal.value.asInstanceOf[Integer]),
- predicate)
+ (n: String, v: Any) => FilterApi.ltEq(intColumn(n), v.asInstanceOf[java.lang.Integer])
case LongType =>
- new ComparisonFilter(
- name,
- FilterApi.ltEq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]),
- predicate)
- case DoubleType =>
- new ComparisonFilter(
- name,
- FilterApi.ltEq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]),
- predicate)
+ (n: String, v: Any) => FilterApi.ltEq(longColumn(n), v.asInstanceOf[java.lang.Long])
case FloatType =>
- new ComparisonFilter(
- name,
- FilterApi.ltEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]),
- predicate)
+ (n: String, v: Any) => FilterApi.ltEq(floatColumn(n), v.asInstanceOf[java.lang.Float])
+ case DoubleType =>
+ (n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
- ComparisonFilter.createStringLessThanOrEqualFilter(
- name,
- literal.value.asInstanceOf[String],
- predicate)
+ (n: String, v: Any) =>
+ FilterApi.ltEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
case BinaryType =>
- ComparisonFilter.createBinaryLessThanOrEqualFilter(
- name,
- literal.value.asInstanceOf[Array[Byte]],
- predicate)
- case DateType =>
- new ComparisonFilter(
- name,
- FilterApi.ltEq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])),
- predicate)
- case TimestampType =>
- new ComparisonFilter(
- name,
- FilterApi.ltEq(timestampColumn(name),
- new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])),
- predicate)
- case DecimalType.Unlimited =>
- new ComparisonFilter(
- name,
- FilterApi.ltEq(decimalColumn(name), literal.value.asInstanceOf[Decimal]),
- predicate)
+ (n: String, v: Any) =>
+ FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
}
- // TODO: combine these two types somehow?
- def createGreaterThanFilter(
- name: String,
- literal: Literal,
- predicate: CatalystPredicate) = literal.dataType match {
- case ByteType =>
- new ComparisonFilter(
- name,
- FilterApi.gt(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]),
- predicate)
- case ShortType =>
- new ComparisonFilter(
- name,
- FilterApi.gt(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]),
- predicate)
+
+ val makeGt: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
case IntegerType =>
- new ComparisonFilter(
- name,
- FilterApi.gt(intColumn(name), literal.value.asInstanceOf[Integer]),
- predicate)
+ (n: String, v: Any) => FilterApi.gt(intColumn(n), v.asInstanceOf[java.lang.Integer])
case LongType =>
- new ComparisonFilter(
- name,
- FilterApi.gt(longColumn(name), literal.value.asInstanceOf[java.lang.Long]),
- predicate)
- case DoubleType =>
- new ComparisonFilter(
- name,
- FilterApi.gt(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]),
- predicate)
+ (n: String, v: Any) => FilterApi.gt(longColumn(n), v.asInstanceOf[java.lang.Long])
case FloatType =>
- new ComparisonFilter(
- name,
- FilterApi.gt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]),
- predicate)
+ (n: String, v: Any) => FilterApi.gt(floatColumn(n), v.asInstanceOf[java.lang.Float])
+ case DoubleType =>
+ (n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
- ComparisonFilter.createStringGreaterThanFilter(
- name,
- literal.value.asInstanceOf[String],
- predicate)
+ (n: String, v: Any) =>
+ FilterApi.gt(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
case BinaryType =>
- ComparisonFilter.createBinaryGreaterThanFilter(
- name,
- literal.value.asInstanceOf[Array[Byte]],
- predicate)
- case DateType =>
- new ComparisonFilter(
- name,
- FilterApi.gt(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])),
- predicate)
- case TimestampType =>
- new ComparisonFilter(
- name,
- FilterApi.gt(timestampColumn(name),
- new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])),
- predicate)
- case DecimalType.Unlimited =>
- new ComparisonFilter(
- name,
- FilterApi.gt(decimalColumn(name), literal.value.asInstanceOf[Decimal]),
- predicate)
+ (n: String, v: Any) =>
+ FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
}
- def createGreaterThanOrEqualFilter(
- name: String,
- literal: Literal,
- predicate: CatalystPredicate) = literal.dataType match {
- case ByteType =>
- new ComparisonFilter(
- name,
- FilterApi.gtEq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]),
- predicate)
- case ShortType =>
- new ComparisonFilter(
- name,
- FilterApi.gtEq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]),
- predicate)
+
+ val makeGtEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
case IntegerType =>
- new ComparisonFilter(
- name,
- FilterApi.gtEq(intColumn(name), literal.value.asInstanceOf[Integer]),
- predicate)
+ (n: String, v: Any) => FilterApi.gtEq(intColumn(n), v.asInstanceOf[java.lang.Integer])
case LongType =>
- new ComparisonFilter(
- name,
- FilterApi.gtEq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]),
- predicate)
- case DoubleType =>
- new ComparisonFilter(
- name,
- FilterApi.gtEq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]),
- predicate)
+ (n: String, v: Any) => FilterApi.gtEq(longColumn(n), v.asInstanceOf[java.lang.Long])
case FloatType =>
- new ComparisonFilter(
- name,
- FilterApi.gtEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]),
- predicate)
+ (n: String, v: Any) => FilterApi.gtEq(floatColumn(n), v.asInstanceOf[java.lang.Float])
+ case DoubleType =>
+ (n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
- ComparisonFilter.createStringGreaterThanOrEqualFilter(
- name,
- literal.value.asInstanceOf[String],
- predicate)
+ (n: String, v: Any) =>
+ FilterApi.gtEq(binaryColumn(n), Binary.fromString(v.asInstanceOf[String]))
case BinaryType =>
- ComparisonFilter.createBinaryGreaterThanOrEqualFilter(
- name,
- literal.value.asInstanceOf[Array[Byte]],
- predicate)
- case DateType =>
- new ComparisonFilter(
- name,
- FilterApi.gtEq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])),
- predicate)
- case TimestampType =>
- new ComparisonFilter(
- name,
- FilterApi.gtEq(timestampColumn(name),
- new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])),
- predicate)
- case DecimalType.Unlimited =>
- new ComparisonFilter(
- name,
- FilterApi.gtEq(decimalColumn(name), literal.value.asInstanceOf[Decimal]),
- predicate)
+ (n: String, v: Any) =>
+ FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
}
- /**
- * TODO: we currently only filter on non-nullable (Parquet REQUIRED) attributes until
- * https://github.com/Parquet/parquet-mr/issues/371
- * has been resolved.
- */
- expression match {
- case p @ Or(left: Expression, right: Expression)
- if createFilter(left).isDefined && createFilter(right).isDefined => {
- // If either side of this Or-predicate is empty then this means
- // it contains a more complex comparison than between attribute and literal
- // (e.g., it contained a CAST). The only safe thing to do is then to disregard
- // this disjunction, which could be contained in a conjunction. If it stands
- // alone then it is also safe to drop it, since a Null return value of this
- // function is interpreted as having no filters at all.
- val leftFilter = createFilter(left).get
- val rightFilter = createFilter(right).get
- Some(new OrFilter(leftFilter, rightFilter))
- }
- case p @ And(left: Expression, right: Expression) => {
- // This treats nested conjunctions; since either side of the conjunction
- // may contain more complex filter expressions we may actually generate
- // strictly weaker filter predicates in the process.
- val leftFilter = createFilter(left)
- val rightFilter = createFilter(right)
- (leftFilter, rightFilter) match {
- case (None, Some(filter)) => Some(filter)
- case (Some(filter), None) => Some(filter)
- case (Some(leftF), Some(rightF)) =>
- Some(new AndFilter(leftF, rightF))
- case _ => None
- }
- }
- case p @ EqualTo(left: Literal, right: NamedExpression) if left.dataType != NullType =>
- Some(createEqualityFilter(right.name, left, p))
- case p @ EqualTo(left: NamedExpression, right: Literal) if right.dataType != NullType =>
- Some(createEqualityFilter(left.name, right, p))
- case p @ LessThan(left: Literal, right: NamedExpression) =>
- Some(createLessThanFilter(right.name, left, p))
- case p @ LessThan(left: NamedExpression, right: Literal) =>
- Some(createLessThanFilter(left.name, right, p))
- case p @ LessThanOrEqual(left: Literal, right: NamedExpression) =>
- Some(createLessThanOrEqualFilter(right.name, left, p))
- case p @ LessThanOrEqual(left: NamedExpression, right: Literal) =>
- Some(createLessThanOrEqualFilter(left.name, right, p))
- case p @ GreaterThan(left: Literal, right: NamedExpression) =>
- Some(createGreaterThanFilter(right.name, left, p))
- case p @ GreaterThan(left: NamedExpression, right: Literal) =>
- Some(createGreaterThanFilter(left.name, right, p))
- case p @ GreaterThanOrEqual(left: Literal, right: NamedExpression) =>
- Some(createGreaterThanOrEqualFilter(right.name, left, p))
- case p @ GreaterThanOrEqual(left: NamedExpression, right: Literal) =>
- Some(createGreaterThanOrEqualFilter(left.name, right, p))
+ predicate match {
+ case EqualTo(NamedExpression(name, _), Literal(value, dataType)) if dataType != NullType =>
+ makeEq.lift(dataType).map(_(name, value))
+ case EqualTo(Literal(value, dataType), NamedExpression(name, _)) if dataType != NullType =>
+ makeEq.lift(dataType).map(_(name, value))
+
+ case LessThan(NamedExpression(name, _), Literal(value, dataType)) =>
+ makeLt.lift(dataType).map(_(name, value))
+ case LessThan(Literal(value, dataType), NamedExpression(name, _)) =>
+ makeLt.lift(dataType).map(_(name, value))
+
+ case LessThanOrEqual(NamedExpression(name, _), Literal(value, dataType)) =>
+ makeLtEq.lift(dataType).map(_(name, value))
+ case LessThanOrEqual(Literal(value, dataType), NamedExpression(name, _)) =>
+ makeLtEq.lift(dataType).map(_(name, value))
+
+ case GreaterThan(NamedExpression(name, _), Literal(value, dataType)) =>
+ makeGt.lift(dataType).map(_(name, value))
+ case GreaterThan(Literal(value, dataType), NamedExpression(name, _)) =>
+ makeGt.lift(dataType).map(_(name, value))
+
+ case GreaterThanOrEqual(NamedExpression(name, _), Literal(value, dataType)) =>
+ makeGtEq.lift(dataType).map(_(name, value))
+ case GreaterThanOrEqual(Literal(value, dataType), NamedExpression(name, _)) =>
+ makeGtEq.lift(dataType).map(_(name, value))
+
+ case And(lhs, rhs) =>
+ (createFilter(lhs) ++ createFilter(rhs)).reduceOption(FilterApi.and)
+
+ case Or(lhs, rhs) =>
+ for {
+ lhsFilter <- createFilter(lhs)
+ rhsFilter <- createFilter(rhs)
+ } yield FilterApi.or(lhsFilter, rhsFilter)
+
+ case Not(pred) =>
+ createFilter(pred).map(FilterApi.not)
+
case _ => None
}
}
@@ -428,7 +174,7 @@ private[sql] object ParquetFilters {
* the actual filter predicate.
*/
def serializeFilterExpressions(filters: Seq[Expression], conf: Configuration): Unit = {
- if (filters.length > 0) {
+ if (filters.nonEmpty) {
val serialized: Array[Byte] =
SparkEnv.get.closureSerializer.newInstance().serialize(filters).array()
val encoded: String = BaseEncoding.base64().encode(serialized)
@@ -450,245 +196,4 @@ private[sql] object ParquetFilters {
Seq()
}
}
-
- /**
- * Try to find the given expression in the tree of filters in order to
- * determine whether it is safe to remove it from the higher level filters. Note
- * that strictly speaking we could stop the search whenever an expression is found
- * that contains this expression as subexpression (e.g., when searching for "a"
- * and "(a or c)" is found) but we don't care about optimizations here since the
- * filter tree is assumed to be small.
- *
- * @param filter The [[org.apache.spark.sql.parquet.CatalystFilter]] to expand
- * and search
- * @param expression The expression to look for
- * @return An optional [[org.apache.spark.sql.parquet.CatalystFilter]] that
- * contains the expression.
- */
- def findExpression(
- filter: CatalystFilter,
- expression: Expression): Option[CatalystFilter] = filter match {
- case f @ OrFilter(_, leftFilter, rightFilter, _) =>
- if (f.predicate == expression) {
- Some(f)
- } else {
- val left = findExpression(leftFilter, expression)
- if (left.isDefined) left else findExpression(rightFilter, expression)
- }
- case f @ AndFilter(_, leftFilter, rightFilter, _) =>
- if (f.predicate == expression) {
- Some(f)
- } else {
- val left = findExpression(leftFilter, expression)
- if (left.isDefined) left else findExpression(rightFilter, expression)
- }
- case f @ ComparisonFilter(_, _, predicate) =>
- if (predicate == expression) Some(f) else None
- case _ => None
- }
-}
-
-abstract private[parquet] class CatalystFilter(
- @transient val predicate: CatalystPredicate) extends FilterPredicate
-
-private[parquet] case class ComparisonFilter(
- val columnName: String,
- private var filter: FilterPredicate,
- @transient override val predicate: CatalystPredicate)
- extends CatalystFilter(predicate) {
- override def accept[R](visitor: FilterPredicate.Visitor[R]): R = {
- filter.accept(visitor)
- }
-}
-
-private[parquet] case class OrFilter(
- private var filter: FilterPredicate,
- @transient val left: CatalystFilter,
- @transient val right: CatalystFilter,
- @transient override val predicate: Or)
- extends CatalystFilter(predicate) {
- def this(l: CatalystFilter, r: CatalystFilter) =
- this(
- FilterApi.or(l, r),
- l,
- r,
- Or(l.predicate, r.predicate))
-
- override def accept[R](visitor: FilterPredicate.Visitor[R]): R = {
- filter.accept(visitor);
- }
-
-}
-
-private[parquet] case class AndFilter(
- private var filter: FilterPredicate,
- @transient val left: CatalystFilter,
- @transient val right: CatalystFilter,
- @transient override val predicate: And)
- extends CatalystFilter(predicate) {
- def this(l: CatalystFilter, r: CatalystFilter) =
- this(
- FilterApi.and(l, r),
- l,
- r,
- And(l.predicate, r.predicate))
-
- override def accept[R](visitor: FilterPredicate.Visitor[R]): R = {
- filter.accept(visitor);
- }
-
-}
-
-private[parquet] object ComparisonFilter {
- def createBooleanEqualityFilter(
- columnName: String,
- value: Boolean,
- predicate: CatalystPredicate): CatalystFilter =
- new ComparisonFilter(
- columnName,
- FilterApi.eq(booleanColumn(columnName), value.asInstanceOf[java.lang.Boolean]),
- predicate)
-
- def createStringEqualityFilter(
- columnName: String,
- value: String,
- predicate: CatalystPredicate): CatalystFilter =
- new ComparisonFilter(
- columnName,
- FilterApi.eq(binaryColumn(columnName), Binary.fromString(value)),
- predicate)
-
- def createStringLessThanFilter(
- columnName: String,
- value: String,
- predicate: CatalystPredicate): CatalystFilter =
- new ComparisonFilter(
- columnName,
- FilterApi.lt(binaryColumn(columnName), Binary.fromString(value)),
- predicate)
-
- def createStringLessThanOrEqualFilter(
- columnName: String,
- value: String,
- predicate: CatalystPredicate): CatalystFilter =
- new ComparisonFilter(
- columnName,
- FilterApi.ltEq(binaryColumn(columnName), Binary.fromString(value)),
- predicate)
-
- def createStringGreaterThanFilter(
- columnName: String,
- value: String,
- predicate: CatalystPredicate): CatalystFilter =
- new ComparisonFilter(
- columnName,
- FilterApi.gt(binaryColumn(columnName), Binary.fromString(value)),
- predicate)
-
- def createStringGreaterThanOrEqualFilter(
- columnName: String,
- value: String,
- predicate: CatalystPredicate): CatalystFilter =
- new ComparisonFilter(
- columnName,
- FilterApi.gtEq(binaryColumn(columnName), Binary.fromString(value)),
- predicate)
-
- def createBinaryEqualityFilter(
- columnName: String,
- value: Array[Byte],
- predicate: CatalystPredicate): CatalystFilter =
- new ComparisonFilter(
- columnName,
- FilterApi.eq(binaryColumn(columnName), Binary.fromByteArray(value)),
- predicate)
-
- def createBinaryLessThanFilter(
- columnName: String,
- value: Array[Byte],
- predicate: CatalystPredicate): CatalystFilter =
- new ComparisonFilter(
- columnName,
- FilterApi.lt(binaryColumn(columnName), Binary.fromByteArray(value)),
- predicate)
-
- def createBinaryLessThanOrEqualFilter(
- columnName: String,
- value: Array[Byte],
- predicate: CatalystPredicate): CatalystFilter =
- new ComparisonFilter(
- columnName,
- FilterApi.ltEq(binaryColumn(columnName), Binary.fromByteArray(value)),
- predicate)
-
- def createBinaryGreaterThanFilter(
- columnName: String,
- value: Array[Byte],
- predicate: CatalystPredicate): CatalystFilter =
- new ComparisonFilter(
- columnName,
- FilterApi.gt(binaryColumn(columnName), Binary.fromByteArray(value)),
- predicate)
-
- def createBinaryGreaterThanOrEqualFilter(
- columnName: String,
- value: Array[Byte],
- predicate: CatalystPredicate): CatalystFilter =
- new ComparisonFilter(
- columnName,
- FilterApi.gtEq(binaryColumn(columnName), Binary.fromByteArray(value)),
- predicate)
-}
-
-private[spark] object ParquetColumns {
-
- def byteColumn(columnPath: String): ByteColumn = {
- new ByteColumn(ColumnPath.fromDotString(columnPath))
- }
-
- final class ByteColumn(columnPath: ColumnPath)
- extends Column[java.lang.Byte](columnPath, classOf[java.lang.Byte]) with SupportsLtGt
-
- def shortColumn(columnPath: String): ShortColumn = {
- new ShortColumn(ColumnPath.fromDotString(columnPath))
- }
-
- final class ShortColumn(columnPath: ColumnPath)
- extends Column[java.lang.Short](columnPath, classOf[java.lang.Short]) with SupportsLtGt
-
-
- def dateColumn(columnPath: String): DateColumn = {
- new DateColumn(ColumnPath.fromDotString(columnPath))
- }
-
- final class DateColumn(columnPath: ColumnPath)
- extends Column[WrappedDate](columnPath, classOf[WrappedDate]) with SupportsLtGt
-
- def timestampColumn(columnPath: String): TimestampColumn = {
- new TimestampColumn(ColumnPath.fromDotString(columnPath))
- }
-
- final class TimestampColumn(columnPath: ColumnPath)
- extends Column[WrappedTimestamp](columnPath, classOf[WrappedTimestamp]) with SupportsLtGt
-
- def decimalColumn(columnPath: String): DecimalColumn = {
- new DecimalColumn(ColumnPath.fromDotString(columnPath))
- }
-
- final class DecimalColumn(columnPath: ColumnPath)
- extends Column[Decimal](columnPath, classOf[Decimal]) with SupportsLtGt
-
- final class WrappedDate(val date: Date) extends Comparable[WrappedDate] {
-
- override def compareTo(other: WrappedDate): Int = {
- date.compareTo(other.date)
- }
- }
-
- final class WrappedTimestamp(val timestamp: Timestamp) extends Comparable[WrappedTimestamp] {
-
- override def compareTo(other: WrappedTimestamp): Int = {
- timestamp.compareTo(other.timestamp)
- }
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index f6bed5016f..5d0643a64a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -23,8 +23,6 @@ import java.text.SimpleDateFormat
import java.util.concurrent.{Callable, TimeUnit}
import java.util.{ArrayList, Collections, Date, List => JList}
-import org.apache.spark.annotation.DeveloperApi
-
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.util.Try
@@ -34,22 +32,20 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path}
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
-import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
-import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
-
+import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat => NewFileOutputFormat}
import parquet.hadoop._
+import parquet.hadoop.api.ReadSupport.ReadContext
import parquet.hadoop.api.{InitContext, ReadSupport}
import parquet.hadoop.metadata.GlobalMetaData
-import parquet.hadoop.api.ReadSupport.ReadContext
import parquet.hadoop.util.ContextUtil
import parquet.io.ParquetDecodingException
import parquet.schema.MessageType
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.SQLConf
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row, _}
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
import org.apache.spark.{Logging, SerializableWritable, TaskContext}
@@ -82,8 +78,6 @@ case class ParquetTableScan(
override def execute(): RDD[Row] = {
import parquet.filter2.compat.FilterCompat.FilterPredicateCompat
- import parquet.filter2.compat.FilterCompat.Filter
- import parquet.filter2.predicate.FilterPredicate
val sc = sqlContext.sparkContext
val job = new Job(sc.hadoopConfiguration)
@@ -111,14 +105,11 @@ case class ParquetTableScan(
// Note 1: the input format ignores all predicates that cannot be expressed
// as simple column predicate filters in Parquet. Here we just record
// the whole pruning predicate.
- if (columnPruningPred.length > 0) {
+ ParquetFilters
+ .createRecordFilter(columnPruningPred)
+ .map(_.asInstanceOf[FilterPredicateCompat].getFilterPredicate)
// Set this in configuration of ParquetInputFormat, needed for RowGroupFiltering
- val filter: Filter = ParquetFilters.createRecordFilter(columnPruningPred)
- if (filter != null){
- val filterPredicate = filter.asInstanceOf[FilterPredicateCompat].getFilterPredicate
- ParquetInputFormat.setFilterPredicate(conf, filterPredicate)
- }
- }
+ .foreach(ParquetInputFormat.setFilterPredicate(conf, _))
// Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata
conf.set(
@@ -317,7 +308,7 @@ case class InsertIntoParquetTable(
}
writer.close(hadoopContext)
committer.commitTask(hadoopContext)
- return 1
+ 1
}
val jobFormat = new AppendingParquetOutputFormat(taskIdOffset)
/* apparently we need a TaskAttemptID to construct an OutputCommitter;
@@ -375,9 +366,8 @@ private[parquet] class FilteringParquetRowInputFormat
override def createRecordReader(
inputSplit: InputSplit,
taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = {
-
+
import parquet.filter2.compat.FilterCompat.NoOpFilter
- import parquet.filter2.compat.FilterCompat.Filter
val readSupport: ReadSupport[Row] = new RowReadSupport()
@@ -392,7 +382,7 @@ private[parquet] class FilteringParquetRowInputFormat
}
override def getFooters(jobContext: JobContext): JList[Footer] = {
- import FilteringParquetRowInputFormat.footerCache
+ import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.footerCache
if (footers eq null) {
val conf = ContextUtil.getConfiguration(jobContext)
@@ -442,13 +432,13 @@ private[parquet] class FilteringParquetRowInputFormat
val taskSideMetaData = configuration.getBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true)
val maxSplitSize: JLong = configuration.getLong("mapred.max.split.size", Long.MaxValue)
val minSplitSize: JLong =
- Math.max(getFormatMinSplitSize(), configuration.getLong("mapred.min.split.size", 0L))
+ Math.max(getFormatMinSplitSize, configuration.getLong("mapred.min.split.size", 0L))
if (maxSplitSize < 0 || minSplitSize < 0) {
throw new ParquetDecodingException(
s"maxSplitSize or minSplitSie should not be negative: maxSplitSize = $maxSplitSize;" +
s" minSplitSize = $minSplitSize")
}
-
+
// Uses strict type checking by default
val getGlobalMetaData =
classOf[ParquetFileWriter].getDeclaredMethod("getGlobalMetaData", classOf[JList[Footer]])
@@ -458,29 +448,29 @@ private[parquet] class FilteringParquetRowInputFormat
if (globalMetaData == null) {
val splits = mutable.ArrayBuffer.empty[ParquetInputSplit]
return splits
- }
-
+ }
+
val readContext = getReadSupport(configuration).init(
new InitContext(configuration,
- globalMetaData.getKeyValueMetaData(),
- globalMetaData.getSchema()))
-
+ globalMetaData.getKeyValueMetaData,
+ globalMetaData.getSchema))
+
if (taskSideMetaData){
logInfo("Using Task Side Metadata Split Strategy")
- return getTaskSideSplits(configuration,
+ getTaskSideSplits(configuration,
footers,
maxSplitSize,
minSplitSize,
readContext)
} else {
logInfo("Using Client Side Metadata Split Strategy")
- return getClientSideSplits(configuration,
+ getClientSideSplits(configuration,
footers,
maxSplitSize,
minSplitSize,
readContext)
}
-
+
}
def getClientSideSplits(
@@ -489,12 +479,11 @@ private[parquet] class FilteringParquetRowInputFormat
maxSplitSize: JLong,
minSplitSize: JLong,
readContext: ReadContext): JList[ParquetInputSplit] = {
-
- import FilteringParquetRowInputFormat.blockLocationCache
- import parquet.filter2.compat.FilterCompat;
- import parquet.filter2.compat.FilterCompat.Filter;
- import parquet.filter2.compat.RowGroupFilter;
-
+
+ import parquet.filter2.compat.FilterCompat.Filter
+ import parquet.filter2.compat.RowGroupFilter
+ import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.blockLocationCache
+
val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true)
val splits = mutable.ArrayBuffer.empty[ParquetInputSplit]
@@ -503,7 +492,7 @@ private[parquet] class FilteringParquetRowInputFormat
var totalRowGroups: Long = 0
// Ugly hack, stuck with it until PR:
- // https://github.com/apache/incubator-parquet-mr/pull/17
+ // https://github.com/apache/incubator-parquet-mr/pull/17
// is resolved
val generateSplits =
Class.forName("parquet.hadoop.ClientSideMetadataSplitStrategy")
@@ -523,7 +512,7 @@ private[parquet] class FilteringParquetRowInputFormat
blocks,
parquetMetaData.getFileMetaData.getSchema)
rowGroupsDropped = rowGroupsDropped + (blocks.size - filteredBlocks.size)
-
+
if (!filteredBlocks.isEmpty){
var blockLocations: Array[BlockLocation] = null
if (!cacheMetadata) {
@@ -566,7 +555,7 @@ private[parquet] class FilteringParquetRowInputFormat
readContext: ReadContext): JList[ParquetInputSplit] = {
val splits = mutable.ArrayBuffer.empty[ParquetInputSplit]
-
+
// Ugly hack, stuck with it until PR:
// https://github.com/apache/incubator-parquet-mr/pull/17
// is resolved
@@ -576,7 +565,7 @@ private[parquet] class FilteringParquetRowInputFormat
sys.error(
s"Failed to reflectively invoke TaskSideMetadataSplitStrategy.generateTaskSideMDSplits"))
generateSplits.setAccessible(true)
-
+
for (footer <- footers) {
val file = footer.getFile
val fs = file.getFileSystem(configuration)
@@ -594,7 +583,7 @@ private[parquet] class FilteringParquetRowInputFormat
}
splits
- }
+ }
}
@@ -636,11 +625,9 @@ private[parquet] object FileSystemHelper {
files.map(_.getName).map {
case nameP(taskid) => taskid.toInt
case hiddenFileP() => 0
- case other: String => {
+ case other: String =>
sys.error("ERROR: attempting to append to set of Parquet files and found file" +
s"that does not match name pattern: $other")
- 0
- }
case _ => 0
}.reduceLeft((a, b) => if (a < b) b else a)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 80a3e0b4c9..d31a9d8418 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -17,11 +17,13 @@
package org.apache.spark.sql.parquet
+import _root_.parquet.filter2.predicate.{FilterPredicate, Operators}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.mapreduce.Job
import org.scalatest.{BeforeAndAfterAll, FunSuiteLike}
import parquet.hadoop.ParquetFileWriter
import parquet.hadoop.util.ContextUtil
+
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.IntegerType
@@ -447,44 +449,24 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(true)
}
- test("create RecordFilter for simple predicates") {
- val attribute1 = new AttributeReference("first", IntegerType, false)()
- val predicate1 = new EqualTo(attribute1, new Literal(1, IntegerType))
- val filter1 = ParquetFilters.createFilter(predicate1)
- assert(filter1.isDefined)
- assert(filter1.get.predicate == predicate1, "predicates do not match")
- assert(filter1.get.isInstanceOf[ComparisonFilter])
- val cmpFilter1 = filter1.get.asInstanceOf[ComparisonFilter]
- assert(cmpFilter1.columnName == "first", "column name incorrect")
-
- val predicate2 = new LessThan(attribute1, new Literal(4, IntegerType))
- val filter2 = ParquetFilters.createFilter(predicate2)
- assert(filter2.isDefined)
- assert(filter2.get.predicate == predicate2, "predicates do not match")
- assert(filter2.get.isInstanceOf[ComparisonFilter])
- val cmpFilter2 = filter2.get.asInstanceOf[ComparisonFilter]
- assert(cmpFilter2.columnName == "first", "column name incorrect")
-
- val predicate3 = new And(predicate1, predicate2)
- val filter3 = ParquetFilters.createFilter(predicate3)
- assert(filter3.isDefined)
- assert(filter3.get.predicate == predicate3, "predicates do not match")
- assert(filter3.get.isInstanceOf[AndFilter])
-
- val predicate4 = new Or(predicate1, predicate2)
- val filter4 = ParquetFilters.createFilter(predicate4)
- assert(filter4.isDefined)
- assert(filter4.get.predicate == predicate4, "predicates do not match")
- assert(filter4.get.isInstanceOf[OrFilter])
-
- val attribute2 = new AttributeReference("second", IntegerType, false)()
- val predicate5 = new GreaterThan(attribute1, attribute2)
- val badfilter = ParquetFilters.createFilter(predicate5)
- assert(badfilter.isDefined === false)
-
- val predicate6 = And(GreaterThan(attribute1, attribute2), GreaterThan(attribute1, attribute2))
- val badfilter2 = ParquetFilters.createFilter(predicate6)
- assert(badfilter2.isDefined === false)
+ test("make RecordFilter for simple predicates") {
+ def checkFilter[T <: FilterPredicate](predicate: Expression, defined: Boolean = true): Unit = {
+ val filter = ParquetFilters.createFilter(predicate)
+ if (defined) {
+ assert(filter.isDefined)
+ assert(filter.get.isInstanceOf[T])
+ } else {
+ assert(filter.isEmpty)
+ }
+ }
+
+ checkFilter[Operators.Eq[Integer]]('a.int === 1)
+ checkFilter[Operators.Lt[Integer]]('a.int < 4)
+ checkFilter[Operators.And]('a.int === 1 && 'a.int < 4)
+ checkFilter[Operators.Or]('a.int === 1 || 'a.int < 4)
+
+ checkFilter('a.int > 'b.int, defined = false)
+ checkFilter(('a.int > 'b.int) && ('a.int > 'b.int), defined = false)
}
test("test filter by predicate pushdown") {