diff options
5 files changed, 308 insertions, 179 deletions
@@ -133,7 +133,7 @@ <!-- Version used for internal directory structure --> <hive.version.short>0.13.1</hive.version.short> <derby.version>10.10.1.1</derby.version> - <parquet.version>1.4.3</parquet.version> + <parquet.version>1.6.0rc3</parquet.version> <jblas.version>1.2.3</jblas.version> <jetty.version>8.1.14.v20131031</jetty.version> <chill.version>0.3.6</chill.version> diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index 7c83f1cad7..517a5cf002 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -21,8 +21,12 @@ import java.nio.ByteBuffer import org.apache.hadoop.conf.Configuration -import parquet.filter._ -import parquet.filter.ColumnPredicates._ +import parquet.filter2.compat.FilterCompat +import parquet.filter2.compat.FilterCompat._ +import parquet.filter2.predicate.FilterPredicate +import parquet.filter2.predicate.FilterApi +import parquet.filter2.predicate.FilterApi._ +import parquet.io.api.Binary import parquet.column.ColumnReader import com.google.common.io.BaseEncoding @@ -38,67 +42,74 @@ private[sql] object ParquetFilters { // set this to false if pushdown should be disabled val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.hints.parquetFilterPushdown" - def createRecordFilter(filterExpressions: Seq[Expression]): UnboundRecordFilter = { + def createRecordFilter(filterExpressions: Seq[Expression]): Filter = { val filters: Seq[CatalystFilter] = filterExpressions.collect { case (expression: Expression) if createFilter(expression).isDefined => createFilter(expression).get } - if (filters.length > 0) filters.reduce(AndRecordFilter.and) else null + if (filters.length > 0) FilterCompat.get(filters.reduce(FilterApi.and)) else null } - def createFilter(expression: Expression): Option[CatalystFilter] = { + def createFilter(expression: Expression): Option[CatalystFilter] ={ def createEqualityFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { case BooleanType => - ComparisonFilter.createBooleanFilter(name, literal.value.asInstanceOf[Boolean], predicate) + ComparisonFilter.createBooleanFilter( + name, + literal.value.asInstanceOf[Boolean], + predicate) case IntegerType => - ComparisonFilter.createIntFilter( + new ComparisonFilter( name, - (x: Int) => x == literal.value.asInstanceOf[Int], + FilterApi.eq(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) case LongType => - ComparisonFilter.createLongFilter( + new ComparisonFilter( name, - (x: Long) => x == literal.value.asInstanceOf[Long], + FilterApi.eq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), predicate) case DoubleType => - ComparisonFilter.createDoubleFilter( + new ComparisonFilter( name, - (x: Double) => x == literal.value.asInstanceOf[Double], + FilterApi.eq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), predicate) case FloatType => - ComparisonFilter.createFloatFilter( + new ComparisonFilter( name, - (x: Float) => x == literal.value.asInstanceOf[Float], + FilterApi.eq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) case StringType => - ComparisonFilter.createStringFilter(name, literal.value.asInstanceOf[String], predicate) + ComparisonFilter.createStringFilter( + name, + literal.value.asInstanceOf[String], + predicate) } + def createLessThanFilter( name: String, literal: Literal, predicate: CatalystPredicate) = literal.dataType match { case IntegerType => - ComparisonFilter.createIntFilter( - name, - (x: Int) => x < literal.value.asInstanceOf[Int], + new ComparisonFilter( + name, + FilterApi.lt(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) case LongType => - ComparisonFilter.createLongFilter( + new ComparisonFilter( name, - (x: Long) => x < literal.value.asInstanceOf[Long], + FilterApi.lt(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), predicate) case DoubleType => - ComparisonFilter.createDoubleFilter( + new ComparisonFilter( name, - (x: Double) => x < literal.value.asInstanceOf[Double], + FilterApi.lt(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), predicate) case FloatType => - ComparisonFilter.createFloatFilter( + new ComparisonFilter( name, - (x: Float) => x < literal.value.asInstanceOf[Float], + FilterApi.lt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) } def createLessThanOrEqualFilter( @@ -106,24 +117,24 @@ private[sql] object ParquetFilters { literal: Literal, predicate: CatalystPredicate) = literal.dataType match { case IntegerType => - ComparisonFilter.createIntFilter( + new ComparisonFilter( name, - (x: Int) => x <= literal.value.asInstanceOf[Int], + FilterApi.ltEq(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) case LongType => - ComparisonFilter.createLongFilter( + new ComparisonFilter( name, - (x: Long) => x <= literal.value.asInstanceOf[Long], + FilterApi.ltEq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), predicate) case DoubleType => - ComparisonFilter.createDoubleFilter( + new ComparisonFilter( name, - (x: Double) => x <= literal.value.asInstanceOf[Double], + FilterApi.ltEq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), predicate) case FloatType => - ComparisonFilter.createFloatFilter( + new ComparisonFilter( name, - (x: Float) => x <= literal.value.asInstanceOf[Float], + FilterApi.ltEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) } // TODO: combine these two types somehow? @@ -132,24 +143,24 @@ private[sql] object ParquetFilters { literal: Literal, predicate: CatalystPredicate) = literal.dataType match { case IntegerType => - ComparisonFilter.createIntFilter( + new ComparisonFilter( name, - (x: Int) => x > literal.value.asInstanceOf[Int], + FilterApi.gt(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) case LongType => - ComparisonFilter.createLongFilter( + new ComparisonFilter( name, - (x: Long) => x > literal.value.asInstanceOf[Long], + FilterApi.gt(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), predicate) case DoubleType => - ComparisonFilter.createDoubleFilter( + new ComparisonFilter( name, - (x: Double) => x > literal.value.asInstanceOf[Double], + FilterApi.gt(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), predicate) case FloatType => - ComparisonFilter.createFloatFilter( + new ComparisonFilter( name, - (x: Float) => x > literal.value.asInstanceOf[Float], + FilterApi.gt(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) } def createGreaterThanOrEqualFilter( @@ -157,23 +168,24 @@ private[sql] object ParquetFilters { literal: Literal, predicate: CatalystPredicate) = literal.dataType match { case IntegerType => - ComparisonFilter.createIntFilter( - name, (x: Int) => x >= literal.value.asInstanceOf[Int], + new ComparisonFilter( + name, + FilterApi.gtEq(intColumn(name), literal.value.asInstanceOf[Integer]), predicate) case LongType => - ComparisonFilter.createLongFilter( + new ComparisonFilter( name, - (x: Long) => x >= literal.value.asInstanceOf[Long], + FilterApi.gtEq(longColumn(name), literal.value.asInstanceOf[java.lang.Long]), predicate) case DoubleType => - ComparisonFilter.createDoubleFilter( + new ComparisonFilter( name, - (x: Double) => x >= literal.value.asInstanceOf[Double], + FilterApi.gtEq(doubleColumn(name), literal.value.asInstanceOf[java.lang.Double]), predicate) case FloatType => - ComparisonFilter.createFloatFilter( + new ComparisonFilter( name, - (x: Float) => x >= literal.value.asInstanceOf[Float], + FilterApi.gtEq(floatColumn(name), literal.value.asInstanceOf[java.lang.Float]), predicate) } @@ -209,25 +221,25 @@ private[sql] object ParquetFilters { case _ => None } } - case p @ EqualTo(left: Literal, right: NamedExpression) if !right.nullable => + case p @ EqualTo(left: Literal, right: NamedExpression) => Some(createEqualityFilter(right.name, left, p)) - case p @ EqualTo(left: NamedExpression, right: Literal) if !left.nullable => + case p @ EqualTo(left: NamedExpression, right: Literal) => Some(createEqualityFilter(left.name, right, p)) - case p @ LessThan(left: Literal, right: NamedExpression) if !right.nullable => + case p @ LessThan(left: Literal, right: NamedExpression) => Some(createLessThanFilter(right.name, left, p)) - case p @ LessThan(left: NamedExpression, right: Literal) if !left.nullable => + case p @ LessThan(left: NamedExpression, right: Literal) => Some(createLessThanFilter(left.name, right, p)) - case p @ LessThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable => + case p @ LessThanOrEqual(left: Literal, right: NamedExpression) => Some(createLessThanOrEqualFilter(right.name, left, p)) - case p @ LessThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable => + case p @ LessThanOrEqual(left: NamedExpression, right: Literal) => Some(createLessThanOrEqualFilter(left.name, right, p)) - case p @ GreaterThan(left: Literal, right: NamedExpression) if !right.nullable => + case p @ GreaterThan(left: Literal, right: NamedExpression) => Some(createGreaterThanFilter(right.name, left, p)) - case p @ GreaterThan(left: NamedExpression, right: Literal) if !left.nullable => + case p @ GreaterThan(left: NamedExpression, right: Literal) => Some(createGreaterThanFilter(left.name, right, p)) - case p @ GreaterThanOrEqual(left: Literal, right: NamedExpression) if !right.nullable => + case p @ GreaterThanOrEqual(left: Literal, right: NamedExpression) => Some(createGreaterThanOrEqualFilter(right.name, left, p)) - case p @ GreaterThanOrEqual(left: NamedExpression, right: Literal) if !left.nullable => + case p @ GreaterThanOrEqual(left: NamedExpression, right: Literal) => Some(createGreaterThanOrEqualFilter(left.name, right, p)) case _ => None } @@ -300,52 +312,54 @@ private[sql] object ParquetFilters { } abstract private[parquet] class CatalystFilter( - @transient val predicate: CatalystPredicate) extends UnboundRecordFilter + @transient val predicate: CatalystPredicate) extends FilterPredicate private[parquet] case class ComparisonFilter( val columnName: String, - private var filter: UnboundRecordFilter, + private var filter: FilterPredicate, @transient override val predicate: CatalystPredicate) extends CatalystFilter(predicate) { - override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = { - filter.bind(readers) + override def accept[R](visitor: FilterPredicate.Visitor[R]): R = { + filter.accept(visitor) } } private[parquet] case class OrFilter( - private var filter: UnboundRecordFilter, + private var filter: FilterPredicate, @transient val left: CatalystFilter, @transient val right: CatalystFilter, @transient override val predicate: Or) extends CatalystFilter(predicate) { def this(l: CatalystFilter, r: CatalystFilter) = this( - OrRecordFilter.or(l, r), + FilterApi.or(l, r), l, r, Or(l.predicate, r.predicate)) - override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = { - filter.bind(readers) + override def accept[R](visitor: FilterPredicate.Visitor[R]): R = { + filter.accept(visitor); } + } private[parquet] case class AndFilter( - private var filter: UnboundRecordFilter, + private var filter: FilterPredicate, @transient val left: CatalystFilter, @transient val right: CatalystFilter, @transient override val predicate: And) extends CatalystFilter(predicate) { def this(l: CatalystFilter, r: CatalystFilter) = this( - AndRecordFilter.and(l, r), + FilterApi.and(l, r), l, r, And(l.predicate, r.predicate)) - override def bind(readers: java.lang.Iterable[ColumnReader]): RecordFilter = { - filter.bind(readers) + override def accept[R](visitor: FilterPredicate.Visitor[R]): R = { + filter.accept(visitor); } + } private[parquet] object ComparisonFilter { @@ -355,13 +369,7 @@ private[parquet] object ComparisonFilter { predicate: CatalystPredicate): CatalystFilter = new ComparisonFilter( columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToBoolean( - new BooleanPredicateFunction { - def functionToApply(input: Boolean): Boolean = input == value - } - )), + FilterApi.eq(booleanColumn(columnName), value.asInstanceOf[java.lang.Boolean]), predicate) def createStringFilter( @@ -370,72 +378,6 @@ private[parquet] object ComparisonFilter { predicate: CatalystPredicate): CatalystFilter = new ComparisonFilter( columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToString ( - new ColumnPredicates.PredicateFunction[String] { - def functionToApply(input: String): Boolean = input == value - } - )), - predicate) - - def createIntFilter( - columnName: String, - func: Int => Boolean, - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToInteger( - new IntegerPredicateFunction { - def functionToApply(input: Int) = func(input) - } - )), - predicate) - - def createLongFilter( - columnName: String, - func: Long => Boolean, - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToLong( - new LongPredicateFunction { - def functionToApply(input: Long) = func(input) - } - )), - predicate) - - def createDoubleFilter( - columnName: String, - func: Double => Boolean, - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToDouble( - new DoublePredicateFunction { - def functionToApply(input: Double) = func(input) - } - )), - predicate) - - def createFloatFilter( - columnName: String, - func: Float => Boolean, - predicate: CatalystPredicate): CatalystFilter = - new ComparisonFilter( - columnName, - ColumnRecordFilter.column( - columnName, - ColumnPredicates.applyFunctionToFloat( - new FloatPredicateFunction { - def functionToApply(input: Float) = func(input) - } - )), + FilterApi.eq(binaryColumn(columnName), Binary.fromString(value)), predicate) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index ef995b3d1a..416bf56144 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -38,6 +38,7 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import parquet.hadoop._ import parquet.hadoop.api.{InitContext, ReadSupport} import parquet.hadoop.metadata.GlobalMetaData +import parquet.hadoop.api.ReadSupport.ReadContext import parquet.hadoop.util.ContextUtil import parquet.io.ParquetDecodingException import parquet.schema.MessageType @@ -77,6 +78,10 @@ case class ParquetTableScan( s"$normalOutput + $partOutput != $attributes, ${relation.output}") override def execute(): RDD[Row] = { + import parquet.filter2.compat.FilterCompat.FilterPredicateCompat + import parquet.filter2.compat.FilterCompat.Filter + import parquet.filter2.predicate.FilterPredicate + val sc = sqlContext.sparkContext val job = new Job(sc.hadoopConfiguration) ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) @@ -107,7 +112,13 @@ case class ParquetTableScan( // "spark.sql.hints.parquetFilterPushdown" to false inside SparkConf. if (columnPruningPred.length > 0 && sc.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { - ParquetFilters.serializeFilterExpressions(columnPruningPred, conf) + + // Set this in configuration of ParquetInputFormat, needed for RowGroupFiltering + val filter: Filter = ParquetFilters.createRecordFilter(columnPruningPred) + if (filter != null){ + val filterPredicate = filter.asInstanceOf[FilterPredicateCompat].getFilterPredicate() + ParquetInputFormat.setFilterPredicate(conf, filterPredicate) + } } // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata @@ -363,15 +374,17 @@ private[parquet] class FilteringParquetRowInputFormat override def createRecordReader( inputSplit: InputSplit, taskAttemptContext: TaskAttemptContext): RecordReader[Void, Row] = { + + import parquet.filter2.compat.FilterCompat.NoOpFilter + import parquet.filter2.compat.FilterCompat.Filter + val readSupport: ReadSupport[Row] = new RowReadSupport() - val filterExpressions = - ParquetFilters.deserializeFilterExpressions(ContextUtil.getConfiguration(taskAttemptContext)) - if (filterExpressions.length > 0) { - logInfo(s"Pushing down predicates for RecordFilter: ${filterExpressions.mkString(", ")}") + val filter = ParquetInputFormat.getFilter(ContextUtil.getConfiguration(taskAttemptContext)) + if (!filter.isInstanceOf[NoOpFilter]) { new ParquetRecordReader[Row]( readSupport, - ParquetFilters.createRecordFilter(filterExpressions)) + filter) } else { new ParquetRecordReader[Row](readSupport) } @@ -424,10 +437,8 @@ private[parquet] class FilteringParquetRowInputFormat configuration: Configuration, footers: JList[Footer]): JList[ParquetInputSplit] = { - import FilteringParquetRowInputFormat.blockLocationCache - - val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) - + // Use task side strategy by default + val taskSideMetaData = configuration.getBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true) val maxSplitSize: JLong = configuration.getLong("mapred.max.split.size", Long.MaxValue) val minSplitSize: JLong = Math.max(getFormatMinSplitSize(), configuration.getLong("mapred.min.split.size", 0L)) @@ -436,23 +447,67 @@ private[parquet] class FilteringParquetRowInputFormat s"maxSplitSize or minSplitSie should not be negative: maxSplitSize = $maxSplitSize;" + s" minSplitSize = $minSplitSize") } - val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] + + // Uses strict type checking by default val getGlobalMetaData = classOf[ParquetFileWriter].getDeclaredMethod("getGlobalMetaData", classOf[JList[Footer]]) getGlobalMetaData.setAccessible(true) val globalMetaData = getGlobalMetaData.invoke(null, footers).asInstanceOf[GlobalMetaData] - // if parquet file is empty, return empty splits. - if (globalMetaData == null) { - return splits - } + if (globalMetaData == null) { + val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] + return splits + } + val readContext = getReadSupport(configuration).init( new InitContext(configuration, globalMetaData.getKeyValueMetaData(), globalMetaData.getSchema())) + + if (taskSideMetaData){ + logInfo("Using Task Side Metadata Split Strategy") + return getTaskSideSplits(configuration, + footers, + maxSplitSize, + minSplitSize, + readContext) + } else { + logInfo("Using Client Side Metadata Split Strategy") + return getClientSideSplits(configuration, + footers, + maxSplitSize, + minSplitSize, + readContext) + } + + } + + def getClientSideSplits( + configuration: Configuration, + footers: JList[Footer], + maxSplitSize: JLong, + minSplitSize: JLong, + readContext: ReadContext): JList[ParquetInputSplit] = { + + import FilteringParquetRowInputFormat.blockLocationCache + import parquet.filter2.compat.FilterCompat; + import parquet.filter2.compat.FilterCompat.Filter; + import parquet.filter2.compat.RowGroupFilter; + + val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, false) + val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] + val filter: Filter = ParquetInputFormat.getFilter(configuration) + var rowGroupsDropped: Long = 0 + var totalRowGroups: Long = 0 + + // Ugly hack, stuck with it until PR: + // https://github.com/apache/incubator-parquet-mr/pull/17 + // is resolved val generateSplits = - classOf[ParquetInputFormat[_]].getDeclaredMethods.find(_.getName == "generateSplits").get + Class.forName("parquet.hadoop.ClientSideMetadataSplitStrategy") + .getDeclaredMethods.find(_.getName == "generateSplits").getOrElse( + sys.error(s"Failed to reflectively invoke ClientSideMetadataSplitStrategy.generateSplits")) generateSplits.setAccessible(true) for (footer <- footers) { @@ -461,29 +516,85 @@ private[parquet] class FilteringParquetRowInputFormat val status = fileStatuses.getOrElse(file, fs.getFileStatus(file)) val parquetMetaData = footer.getParquetMetadata val blocks = parquetMetaData.getBlocks - var blockLocations: Array[BlockLocation] = null - if (!cacheMetadata) { - blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) - } else { - blockLocations = blockLocationCache.get(status, new Callable[Array[BlockLocation]] { - def call(): Array[BlockLocation] = fs.getFileBlockLocations(status, 0, status.getLen) - }) - } + totalRowGroups = totalRowGroups + blocks.size + val filteredBlocks = RowGroupFilter.filterRowGroups( + filter, + blocks, + parquetMetaData.getFileMetaData.getSchema) + rowGroupsDropped = rowGroupsDropped + (blocks.size - filteredBlocks.size) + + if (!filteredBlocks.isEmpty){ + var blockLocations: Array[BlockLocation] = null + if (!cacheMetadata) { + blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) + } else { + blockLocations = blockLocationCache.get(status, new Callable[Array[BlockLocation]] { + def call(): Array[BlockLocation] = fs.getFileBlockLocations(status, 0, status.getLen) + }) + } + splits.addAll( + generateSplits.invoke( + null, + filteredBlocks, + blockLocations, + status, + readContext.getRequestedSchema.toString, + readContext.getReadSupportMetadata, + minSplitSize, + maxSplitSize).asInstanceOf[JList[ParquetInputSplit]]) + } + } + + if (rowGroupsDropped > 0 && totalRowGroups > 0){ + val percentDropped = ((rowGroupsDropped/totalRowGroups.toDouble) * 100).toInt + logInfo(s"Dropping $rowGroupsDropped row groups that do not pass filter predicate " + + s"($percentDropped %) !") + } + else { + logInfo("There were no row groups that could be dropped due to filter predicates") + } + splits + + } + + def getTaskSideSplits( + configuration: Configuration, + footers: JList[Footer], + maxSplitSize: JLong, + minSplitSize: JLong, + readContext: ReadContext): JList[ParquetInputSplit] = { + + val splits = mutable.ArrayBuffer.empty[ParquetInputSplit] + + // Ugly hack, stuck with it until PR: + // https://github.com/apache/incubator-parquet-mr/pull/17 + // is resolved + val generateSplits = + Class.forName("parquet.hadoop.TaskSideMetadataSplitStrategy") + .getDeclaredMethods.find(_.getName == "generateTaskSideMDSplits").getOrElse( + sys.error( + s"Failed to reflectively invoke TaskSideMetadataSplitStrategy.generateTaskSideMDSplits")) + generateSplits.setAccessible(true) + + for (footer <- footers) { + val file = footer.getFile + val fs = file.getFileSystem(configuration) + val status = fileStatuses.getOrElse(file, fs.getFileStatus(file)) + val blockLocations = fs.getFileBlockLocations(status, 0, status.getLen) splits.addAll( generateSplits.invoke( - null, - blocks, - blockLocations, - status, - parquetMetaData.getFileMetaData, - readContext.getRequestedSchema.toString, - readContext.getReadSupportMetadata, - minSplitSize, - maxSplitSize).asInstanceOf[JList[ParquetInputSplit]]) + null, + blockLocations, + status, + readContext.getRequestedSchema.toString, + readContext.getReadSupportMetadata, + minSplitSize, + maxSplitSize).asInstanceOf[JList[ParquetInputSplit]]) } splits - } + } + } private[parquet] object FilteringParquetRowInputFormat { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index 837ea7695d..c0918a40d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -92,6 +92,12 @@ private[sql] object ParquetTestData { required int64 mylong; required float myfloat; required double mydouble; + optional boolean myoptboolean; + optional int32 myoptint; + optional binary myoptstring (UTF8); + optional int64 myoptlong; + optional float myoptfloat; + optional double myoptdouble; } """ @@ -255,6 +261,19 @@ private[sql] object ParquetTestData { record.add(3, i.toLong) record.add(4, i.toFloat + 0.5f) record.add(5, i.toDouble + 0.5d) + if (i % 2 == 0) { + if (i % 3 == 0) { + record.add(6, true) + } else { + record.add(6, false) + } + record.add(7, i) + record.add(8, i.toString) + record.add(9, i.toLong) + record.add(10, i.toFloat + 0.5f) + record.add(11, i.toDouble + 0.5d) + } + writer.write(record) } writer.close() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 25e41ecf28..9979ab446d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -560,6 +560,63 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(stringResult.size === 1) assert(stringResult(0).getString(2) == "100", "stringvalue incorrect") assert(stringResult(0).getInt(1) === 100) + + val query7 = sql(s"SELECT * FROM testfiltersource WHERE myoptint < 40") + assert( + query7.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val optResult = query7.collect() + assert(optResult.size === 20) + for(i <- 0 until 20) { + if (optResult(i)(7) != i * 2) { + fail(s"optional Int value in result row $i should be ${2*4*i}") + } + } + for(myval <- Seq("myoptint", "myoptlong", "myoptdouble", "myoptfloat")) { + val query8 = sql(s"SELECT * FROM testfiltersource WHERE $myval < 150 AND $myval >= 100") + assert( + query8.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result8 = query8.collect() + assert(result8.size === 25) + assert(result8(0)(7) === 100) + assert(result8(24)(7) === 148) + val query9 = sql(s"SELECT * FROM testfiltersource WHERE $myval > 150 AND $myval <= 200") + assert( + query9.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result9 = query9.collect() + assert(result9.size === 25) + if (myval == "myoptint" || myval == "myoptlong") { + assert(result9(0)(7) === 152) + assert(result9(24)(7) === 200) + } else { + assert(result9(0)(7) === 150) + assert(result9(24)(7) === 198) + } + } + val query10 = sql("SELECT * FROM testfiltersource WHERE myoptstring = \"100\"") + assert( + query10.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result10 = query10.collect() + assert(result10.size === 1) + assert(result10(0).getString(8) == "100", "stringvalue incorrect") + assert(result10(0).getInt(7) === 100) + val query11 = sql(s"SELECT * FROM testfiltersource WHERE myoptboolean = true AND myoptint < 40") + assert( + query11.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], + "Top operator should be ParquetTableScan after pushdown") + val result11 = query11.collect() + assert(result11.size === 7) + for(i <- 0 until 6) { + if (!result11(i).getBoolean(6)) { + fail(s"optional Boolean value in result row $i not true") + } + if (result11(i).getInt(7) != i * 6) { + fail(s"optional Int value in result row $i should be ${6*i}") + } + } } test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") { |