From 14c54f1876fcf91b5c10e80be2df5421c7328557 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Fri, 7 Nov 2014 11:56:40 -0800 Subject: [SPARK-4213][SQL] ParquetFilters - No support for LT, LTE, GT, GTE operators Following description is quoted from JIRA: When I issue a hql query against a HiveContext where my predicate uses a column of string type with one of LT, LTE, GT, or GTE operator, I get the following error: scala.MatchError: StringType (of class org.apache.spark.sql.catalyst.types.StringType$) Looking at the code in org.apache.spark.sql.parquet.ParquetFilters, StringType is absent from the corresponding functions for creating these filters. To reproduce, in a Hive 0.13.1 shell, I created the following table (at a specified DB): create table sparkbug ( id int, event string ) stored as parquet; Insert some sample data: insert into table sparkbug select 1, '2011-06-18' from limit 1; insert into table sparkbug select 2, '2012-01-01' from limit 1; Launch a spark shell and create a HiveContext to the metastore where the table above is located. import org.apache.spark.sql._ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.HiveContext val hc = new HiveContext(sc) hc.setConf("spark.sql.shuffle.partitions", "10") hc.setConf("spark.sql.hive.convertMetastoreParquet", "true") hc.setConf("spark.sql.parquet.compression.codec", "snappy") import hc._ hc.hql("select * from .sparkbug where event >= '2011-12-01'") A scala.MatchError will appear in the output. Author: Kousuke Saruta Closes #3083 from sarutak/SPARK-4213 and squashes the following commits: 4ab6e56 [Kousuke Saruta] WIP b6890c6 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into SPARK-4213 9a1fae7 [Kousuke Saruta] Fixed ParquetFilters so that compare Strings --- .../apache/spark/sql/parquet/ParquetFilters.scala | 335 ++++++++++++++++++++- .../spark/sql/parquet/ParquetQuerySuite.scala | 40 +++ 2 files changed, 364 insertions(+), 11 deletions(-) (limited to 'sql/core') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 517a5cf002..1e67799e83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -18,13 +18,15 @@ package org.apache.spark.sql.parquet import java.nio.ByteBuffer +import java.sql.{Date, Timestamp} import org.apache.hadoop.conf.Configuration +import parquet.common.schema.ColumnPath import parquet.filter2.compat.FilterCompat import parquet.filter2.compat.FilterCompat._ -import parquet.filter2.predicate.FilterPredicate -import parquet.filter2.predicate.FilterApi +import parquet.filter2.predicate.Operators.{Column, SupportsLtGt} +import parquet.filter2.predicate.{FilterApi, FilterPredicate} import parquet.filter2.predicate.FilterApi._ import parquet.io.api.Binary import parquet.column.ColumnReader @@ -33,9 +35,11 @@ import com.google.common.io.BaseEncoding import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.expressions.{Predicate => CatalystPredicate} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.parquet.ParquetColumns._ private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" @@ -50,15 +54,25 @@ private[sql] object ParquetFilters { if (filters.length > 0) FilterCompat.get(filters.reduce(FilterApi.and)) else null } - def createFilter(expression: Expression): Option[CatalystFilter] ={ + def createFilter(expression: Expression): Option[CatalystFilter] = { def createEqualityFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { case BooleanType => - ComparisonFilter.createBooleanFilter( + ComparisonFilter.createBooleanEqualityFilter( name, - literal.value.asInstanceOf[Boolean], + literal.value.asInstanceOf[Boolean], + predicate) + case ByteType => + new ComparisonFilter( + name, + FilterApi.eq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.eq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), predicate) case IntegerType => new ComparisonFilter( @@ -81,18 +95,49 @@ private[sql] object ParquetFilters { FilterApi.eq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) case StringType => - ComparisonFilter.createStringFilter( + ComparisonFilter.createStringEqualityFilter( name, literal.value.asInstanceOf[String], predicate) + case BinaryType => + ComparisonFilter.createBinaryEqualityFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.eq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.eq(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.eq(decimalColumn(name), literal.value.asInstanceOf[Decimal]), + predicate) } def createLessThanFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { + case ByteType => + new ComparisonFilter( + name, + FilterApi.lt(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.lt(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), + predicate) case IntegerType => - new ComparisonFilter( + new ComparisonFilter( name, FilterApi.lt(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) @@ -111,11 +156,47 @@ private[sql] object ParquetFilters { name, FilterApi.lt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) + case StringType => + ComparisonFilter.createStringLessThanFilter( + name, + literal.value.asInstanceOf[String], + predicate) + case BinaryType => + ComparisonFilter.createBinaryLessThanFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.lt(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.lt(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.lt(decimalColumn(name), literal.value.asInstanceOf[Decimal]), + predicate) } def createLessThanOrEqualFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { + case ByteType => + new ComparisonFilter( + name, + FilterApi.ltEq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.ltEq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), + predicate) case IntegerType => new ComparisonFilter( name, @@ -136,12 +217,48 @@ private[sql] object ParquetFilters { name, FilterApi.ltEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) + case StringType => + ComparisonFilter.createStringLessThanOrEqualFilter( + name, + literal.value.asInstanceOf[String], + predicate) + case BinaryType => + ComparisonFilter.createBinaryLessThanOrEqualFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.ltEq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.ltEq(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.ltEq(decimalColumn(name), literal.value.asInstanceOf[Decimal]), + predicate) } // TODO: combine these two types somehow? def createGreaterThanFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { + case ByteType => + new ComparisonFilter( + name, + FilterApi.gt(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.gt(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), + predicate) case IntegerType => new ComparisonFilter( name, @@ -162,11 +279,47 @@ private[sql] object ParquetFilters { name, FilterApi.gt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) + case StringType => + ComparisonFilter.createStringGreaterThanFilter( + name, + literal.value.asInstanceOf[String], + predicate) + case BinaryType => + ComparisonFilter.createBinaryGreaterThanFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.gt(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.gt(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.gt(decimalColumn(name), literal.value.asInstanceOf[Decimal]), + predicate) } def createGreaterThanOrEqualFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { + case ByteType => + new ComparisonFilter( + name, + FilterApi.gtEq(byteColumn(name), literal.value.asInstanceOf[java.lang.Byte]), + predicate) + case ShortType => + new ComparisonFilter( + name, + FilterApi.gtEq(shortColumn(name), literal.value.asInstanceOf[java.lang.Short]), + predicate) case IntegerType => new ComparisonFilter( name, @@ -187,6 +340,32 @@ private[sql] object ParquetFilters { name, FilterApi.gtEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) + case StringType => + ComparisonFilter.createStringGreaterThanOrEqualFilter( + name, + literal.value.asInstanceOf[String], + predicate) + case BinaryType => + ComparisonFilter.createBinaryGreaterThanOrEqualFilter( + name, + literal.value.asInstanceOf[Array[Byte]], + predicate) + case DateType => + new ComparisonFilter( + name, + FilterApi.gtEq(dateColumn(name), new WrappedDate(literal.value.asInstanceOf[Date])), + predicate) + case TimestampType => + new ComparisonFilter( + name, + FilterApi.gtEq(timestampColumn(name), + new WrappedTimestamp(literal.value.asInstanceOf[Timestamp])), + predicate) + case DecimalType.Unlimited => + new ComparisonFilter( + name, + FilterApi.gtEq(decimalColumn(name), literal.value.asInstanceOf[Decimal]), + predicate) } /** @@ -221,9 +400,9 @@ private[sql] object ParquetFilters { case _ => None } } - case p @ EqualTo(left: Literal, right: NamedExpression) => + case p @ EqualTo(left: Literal, right: NamedExpression) if left.dataType != NullType => Some(createEqualityFilter(right.name, left, p)) - case p @ EqualTo(left: NamedExpression, right: Literal) => + case p @ EqualTo(left: NamedExpression, right: Literal) if right.dataType != NullType => Some(createEqualityFilter(left.name, right, p)) case p @ LessThan(left: Literal, right: NamedExpression) => Some(createLessThanFilter(right.name, left, p)) @@ -363,7 +542,7 @@ private[parquet] case class AndFilter( } private[parquet] object ComparisonFilter { - def createBooleanFilter( + def createBooleanEqualityFilter( columnName: String, value: Boolean, predicate: CatalystPredicate): CatalystFilter = @@ -372,7 +551,7 @@ private[parquet] object ComparisonFilter { FilterApi.eq(booleanColumn(columnName), value.asInstanceOf[java.lang.Boolean]), predicate) - def createStringFilter( + def createStringEqualityFilter( columnName: String, value: String, predicate: CatalystPredicate): CatalystFilter = @@ -380,4 +559,138 @@ private[parquet] object ComparisonFilter { columnName, FilterApi.eq(binaryColumn(columnName), Binary.fromString(value)), predicate) + + def createStringLessThanFilter( + columnName: String, + value: String, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.lt(binaryColumn(columnName), Binary.fromString(value)), + predicate) + + def createStringLessThanOrEqualFilter( + columnName: String, + value: String, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.ltEq(binaryColumn(columnName), Binary.fromString(value)), + predicate) + + def createStringGreaterThanFilter( + columnName: String, + value: String, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.gt(binaryColumn(columnName), Binary.fromString(value)), + predicate) + + def createStringGreaterThanOrEqualFilter( + columnName: String, + value: String, + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.gtEq(binaryColumn(columnName), Binary.fromString(value)), + predicate) + + def createBinaryEqualityFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.eq(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) + + def createBinaryLessThanFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.lt(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) + + def createBinaryLessThanOrEqualFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.ltEq(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) + + def createBinaryGreaterThanFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.gt(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) + + def createBinaryGreaterThanOrEqualFilter( + columnName: String, + value: Array[Byte], + predicate: CatalystPredicate): CatalystFilter = + new ComparisonFilter( + columnName, + FilterApi.gtEq(binaryColumn(columnName), Binary.fromByteArray(value)), + predicate) +} + +private[spark] object ParquetColumns { + + def byteColumn(columnPath: String): ByteColumn = { + new ByteColumn(ColumnPath.fromDotString(columnPath)) + } + + final class ByteColumn(columnPath: ColumnPath) + extends Column[java.lang.Byte](columnPath, classOf[java.lang.Byte]) with SupportsLtGt + + def shortColumn(columnPath: String): ShortColumn = { + new ShortColumn(ColumnPath.fromDotString(columnPath)) + } + + final class ShortColumn(columnPath: ColumnPath) + extends Column[java.lang.Short](columnPath, classOf[java.lang.Short]) with SupportsLtGt + + + def dateColumn(columnPath: String): DateColumn = { + new DateColumn(ColumnPath.fromDotString(columnPath)) + } + + final class DateColumn(columnPath: ColumnPath) + extends Column[WrappedDate](columnPath, classOf[WrappedDate]) with SupportsLtGt + + def timestampColumn(columnPath: String): TimestampColumn = { + new TimestampColumn(ColumnPath.fromDotString(columnPath)) + } + + final class TimestampColumn(columnPath: ColumnPath) + extends Column[WrappedTimestamp](columnPath, classOf[WrappedTimestamp]) with SupportsLtGt + + def decimalColumn(columnPath: String): DecimalColumn = { + new DecimalColumn(ColumnPath.fromDotString(columnPath)) + } + + final class DecimalColumn(columnPath: ColumnPath) + extends Column[Decimal](columnPath, classOf[Decimal]) with SupportsLtGt + + final class WrappedDate(val date: Date) extends Comparable[WrappedDate] { + + override def compareTo(other: WrappedDate): Int = { + date.compareTo(other.date) + } + } + + final class WrappedTimestamp(val timestamp: Timestamp) extends Comparable[WrappedTimestamp] { + + override def compareTo(other: WrappedTimestamp): Int = { + timestamp.compareTo(other.timestamp) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 08d9da27f1..3cccafe92d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -619,6 +619,46 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA fail(s"optional Int value in result row $i should be ${6*i}") } } + + val query12 = sql("SELECT * FROM testfiltersource WHERE mystring >= \"50\"") + assert( + query12.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result12 = query12.collect() + assert(result12.size === 54) + assert(result12(0).getString(2) == "6") + assert(result12(4).getString(2) == "50") + assert(result12(53).getString(2) == "99") + + val query13 = sql("SELECT * FROM testfiltersource WHERE mystring > \"50\"") + assert( + query13.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result13 = query13.collect() + assert(result13.size === 53) + assert(result13(0).getString(2) == "6") + assert(result13(4).getString(2) == "51") + assert(result13(52).getString(2) == "99") + + val query14 = sql("SELECT * FROM testfiltersource WHERE mystring <= \"50\"") + assert( + query14.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result14 = query14.collect() + assert(result14.size === 148) + assert(result14(0).getString(2) == "0") + assert(result14(46).getString(2) == "50") + assert(result14(147).getString(2) == "200") + + val query15 = sql("SELECT * FROM testfiltersource WHERE mystring < \"50\"") + assert( + query15.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result15 = query15.collect() + assert(result15.size === 147) + assert(result15(0).getString(2) == "0") + assert(result15(46).getString(2) == "100") + assert(result15(146).getString(2) == "200") } test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { -- cgit v1.2.3