From 17eec0a71ba8713c559d641e3f43a1be726b037c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 14 Mar 2016 19:21:12 -0700 Subject: [SPARK-13664][SQL] Add a strategy for planning partitioned and bucketed scans of files This PR adds a new strategy, `FileSourceStrategy`, that can be used for planning scans of collections of files that might be partitioned or bucketed. Compared with the existing planning logic in `DataSourceStrategy` this version has the following desirable properties: - It removes the need to have `RDD`, `broadcastedHadoopConf` and other distributed concerns in the public API of `org.apache.spark.sql.sources.FileFormat` - Partition column appending is delegated to the format to avoid an extra copy / devectorization when appending partition columns - It minimizes the amount of data that is shipped to each executor (i.e. it does not send the whole list of files to every worker in the form of a hadoop conf) - it natively supports bucketing files into partitions, and thus does not require coalescing / creating a `UnionRDD` with the correct partitioning. - Small files are automatically coalesced into fewer tasks using an approximate bin-packing algorithm. Currently only a testing source is planned / tested using this strategy. In follow-up PRs we will port the existing formats to this API. A stub for `FileScanRDD` is also added, but most methods remain unimplemented. Other minor cleanups: - partition pruning is pushed into `FileCatalog` so both the new and old code paths can use this logic. This will also allow future implementations to use indexes or other tricks (i.e. a MySQL metastore) - The partitions from the `FileCatalog` now propagate information about file sizes all the way up to the planner so we can intelligently spread files out. - `Array` -> `Seq` in some internal APIs to avoid unnecessary `toArray` calls - Rename `Partition` to `PartitionDirectory` to differentiate partitions used earlier in pruning from those where we have already enumerated the files and their sizes. Author: Michael Armbrust Closes #11646 from marmbrus/fileStrategy. --- .../apache/spark/sql/execution/ExistingRDD.scala | 2 +- .../apache/spark/sql/execution/SparkPlanner.scala | 3 +- .../sql/execution/datasources/DataSource.scala | 32 +- .../execution/datasources/DataSourceStrategy.scala | 46 +-- .../sql/execution/datasources/FileScanRDD.scala | 57 ++++ .../execution/datasources/FileSourceStrategy.scala | 202 ++++++++++++ .../execution/datasources/PartitioningUtils.scala | 18 +- .../execution/datasources/csv/CSVRelation.scala | 2 +- .../execution/datasources/csv/DefaultSource.scala | 2 +- .../execution/datasources/json/JSONRelation.scala | 4 +- .../datasources/parquet/ParquetRelation.scala | 2 +- .../execution/datasources/text/DefaultSource.scala | 2 +- .../org/apache/spark/sql/internal/SQLConf.scala | 9 +- .../org/apache/spark/sql/sources/interfaces.scala | 112 ++++++- .../datasources/FileSourceStrategySuite.scala | 345 +++++++++++++++++++++ .../parquet/ParquetPartitionDiscoverySuite.scala | 2 +- 16 files changed, 763 insertions(+), 77 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala (limited to 'sql/core/src') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index d363cb000d..e97c6be7f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -151,7 +151,7 @@ private[sql] case class DataSourceScan( override val outputPartitioning = { val bucketSpec = relation match { // TODO: this should be closer to bucket planning. - case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled() => r.bucketSpec + case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled => r.bucketSpec case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index d1569a4ec2..292d366e72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkContext import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy} class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { val sparkContext: SparkContext = sqlContext.sparkContext @@ -29,6 +29,7 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { def strategies: Seq[Strategy] = sqlContext.experimental.extraStrategies ++ ( + FileSourceStrategy :: DataSourceStrategy :: DDLStrategy :: SpecialLimits :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 887f5469b5..e65a771202 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -143,7 +143,7 @@ case class DataSource( SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray - val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths) + val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths, None) val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sqlContext, @@ -208,7 +208,20 @@ case class DataSource( SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray - val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths) + // If they gave a schema, then we try and figure out the types of the partition columns + // from that schema. + val partitionSchema = userSpecifiedSchema.map { schema => + StructType( + partitionColumns.map { c => + // TODO: Case sensitivity. + schema + .find(_.name.toLowerCase() == c.toLowerCase()) + .getOrElse(throw new AnalysisException(s"Invalid partition column '$c'")) + }) + } + + val fileCatalog: FileCatalog = + new HDFSFileCatalog(sqlContext, options, globbedPaths, partitionSchema) val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sqlContext, @@ -220,22 +233,11 @@ case class DataSource( "It must be specified manually") } - // If they gave a schema, then we try and figure out the types of the partition columns - // from that schema. - val partitionSchema = userSpecifiedSchema.map { schema => - StructType( - partitionColumns.map { c => - // TODO: Case sensitivity. - schema - .find(_.name.toLowerCase() == c.toLowerCase()) - .getOrElse(throw new AnalysisException(s"Invalid partition column '$c'")) - }) - }.getOrElse(fileCatalog.partitionSpec(None).partitionColumns) HadoopFsRelation( sqlContext, fileCatalog, - partitionSchema = partitionSchema, + partitionSchema = fileCatalog.partitionSpec().partitionColumns, dataSchema = dataSchema.asNullable, bucketSpec = bucketSpec, format, @@ -296,7 +298,7 @@ case class DataSource( resolveRelation() .asInstanceOf[HadoopFsRelation] .location - .partitionSpec(None) + .partitionSpec() .partitionColumns .fieldNames .toSet) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 1adf3b6676..7f6671552e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -126,7 +126,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val partitionAndNormalColumnFilters = filters.toSet -- partitionFilters.toSet -- pushedFilters.toSet - val selectedPartitions = prunePartitions(partitionFilters, t.partitionSpec).toArray + val selectedPartitions = t.location.listFiles(partitionFilters) logInfo { val total = t.partitionSpec.partitions.length @@ -180,7 +180,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) t.bucketSpec match { - case Some(spec) if t.sqlContext.conf.bucketingEnabled() => + case Some(spec) if t.sqlContext.conf.bucketingEnabled => val scanBuilder: (Seq[Attribute], Array[Filter]) => RDD[InternalRow] = { (requiredColumns: Seq[Attribute], filters: Array[Filter]) => { val bucketed = @@ -200,7 +200,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { requiredColumns.map(_.name).toArray, filters, None, - bucketFiles.toArray, + bucketFiles, confBroadcast, t.options).coalesce(1) } @@ -233,7 +233,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { a.map(_.name).toArray, f, None, - t.location.allFiles().toArray, + t.location.allFiles(), confBroadcast, t.options)) :: Nil } @@ -255,7 +255,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { filters: Seq[Expression], buckets: Option[BitSet], partitionColumns: StructType, - partitions: Array[Partition], + partitions: Seq[Partition], options: Map[String, String]): SparkPlan = { val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation] @@ -272,14 +272,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { (requiredColumns: Seq[Attribute], filters: Array[Filter]) => { relation.bucketSpec match { - case Some(spec) if relation.sqlContext.conf.bucketingEnabled() => + case Some(spec) if relation.sqlContext.conf.bucketingEnabled => val requiredDataColumns = requiredColumns.filterNot(c => partitionColumnNames.contains(c.name)) // Builds RDD[Row]s for each selected partition. val perPartitionRows: Seq[(Int, RDD[InternalRow])] = partitions.flatMap { - case Partition(partitionValues, dir) => - val files = relation.location.getStatus(dir) + case Partition(partitionValues, files) => val bucketed = files.groupBy { f => BucketingUtils .getBucketId(f.getPath.getName) @@ -327,14 +326,14 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Builds RDD[Row]s for each selected partition. val perPartitionRows = partitions.map { - case Partition(partitionValues, dir) => + case Partition(partitionValues, files) => val dataRows = relation.fileFormat.buildInternalScan( relation.sqlContext, relation.dataSchema, requiredDataColumns.map(_.name).toArray, filters, buckets, - relation.location.getStatus(dir), + files, confBroadcast, options) @@ -525,33 +524,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets) } - protected def prunePartitions( - predicates: Seq[Expression], - partitionSpec: PartitionSpec): Seq[Partition] = { - val PartitionSpec(partitionColumns, partitions) = partitionSpec - val partitionColumnNames = partitionColumns.map(_.name).toSet - val partitionPruningPredicates = predicates.filter { - _.references.map(_.name).toSet.subsetOf(partitionColumnNames) - } - - if (partitionPruningPredicates.nonEmpty) { - val predicate = - partitionPruningPredicates - .reduceOption(expressions.And) - .getOrElse(Literal(true)) - - val boundPredicate = InterpretedPredicate.create(predicate.transform { - case a: AttributeReference => - val index = partitionColumns.indexWhere(a.name == _.name) - BoundReference(index, partitionColumns(index).dataType, nullable = true) - }) - - partitions.filter { case Partition(values, _) => boundPredicate(values) } - } else { - partitions - } - } - // Based on Public API. protected def pruneFilterProject( relation: LogicalRelation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala new file mode 100644 index 0000000000..e2cbbc34d9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow + +/** + * A single file that should be read, along with partition column values that + * need to be prepended to each row. The reading should start at the first + * valid record found after `offset`. + */ +case class PartitionedFile( + partitionValues: InternalRow, + filePath: String, + start: Long, + length: Long) + +/** + * A collection of files that should be read as a single task possibly from multiple partitioned + * directories. + * + * IMPLEMENT ME: This is just a placeholder for a future implementation. + * TODO: This currently does not take locality information about the files into account. + */ +case class FilePartition(val index: Int, files: Seq[PartitionedFile]) extends Partition + +class FileScanRDD( + @transient val sqlContext: SQLContext, + readFunction: (PartitionedFile) => Iterator[InternalRow], + @transient val filePartitions: Seq[FilePartition]) + extends RDD[InternalRow](sqlContext.sparkContext, Nil) { + + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + throw new NotImplementedError("Not Implemented Yet") + } + + override protected def getPartitions: Array[Partition] = Array.empty +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala new file mode 100644 index 0000000000..ef95d5d289 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import scala.collection.mutable.ArrayBuffer + +import org.apache.hadoop.fs.Path + +import org.apache.spark.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{DataSourceScan, SparkPlan} +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ + +/** + * A strategy for planning scans over collections of files that might be partitioned or bucketed + * by user specified columns. + * + * At a high level planning occurs in several phases: + * - Split filters by when they need to be evaluated. + * - Prune the schema of the data requested based on any projections present. Today this pruning + * is only done on top level columns, but formats should support pruning of nested columns as + * well. + * - Construct a reader function by passing filters and the schema into the FileFormat. + * - Using an partition pruning predicates, enumerate the list of files that should be read. + * - Split the files into tasks and construct a FileScanRDD. + * - Add any projection or filters that must be evaluated after the scan. + * + * Files are assigned into tasks using the following algorithm: + * - If the table is bucketed, group files by bucket id into the correct number of partitions. + * - If the table is not bucketed or bucketing is turned off: + * - If any file is larger than the threshold, split it into pieces based on that threshold + * - Sort the files by decreasing file size. + * - Assign the ordered files to buckets using the following algorithm. If the current partition + * is under the threshold with the addition of the next file, add it. If not, open a new bucket + * and add it. Proceed to the next file. + */ +private[sql] object FileSourceStrategy extends Strategy with Logging { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projects, filters, l@LogicalRelation(files: HadoopFsRelation, _, _)) + if files.fileFormat.toString == "TestFileFormat" => + // Filters on this relation fall into four categories based on where we can use them to avoid + // reading unneeded data: + // - partition keys only - used to prune directories to read + // - bucket keys only - optionally used to prune files to read + // - keys stored in the data only - optionally used to skip groups of data in files + // - filters that need to be evaluated again after the scan + val filterSet = ExpressionSet(filters) + + val partitionColumns = + AttributeSet(l.resolve(files.partitionSchema, files.sqlContext.analyzer.resolver)) + val partitionKeyFilters = + ExpressionSet(filters.filter(_.references.subsetOf(partitionColumns))) + logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") + + val bucketColumns = + AttributeSet( + files.bucketSpec + .map(_.bucketColumnNames) + .getOrElse(Nil) + .map(l.resolveQuoted(_, files.sqlContext.conf.resolver) + .getOrElse(sys.error("")))) + + // Partition keys are not available in the statistics of the files. + val dataFilters = filters.filter(_.references.intersect(partitionColumns).isEmpty) + + // Predicates with both partition keys and attributes need to be evaluated after the scan. + val afterScanFilters = filterSet -- partitionKeyFilters + logInfo(s"Post-Scan Filters: ${afterScanFilters.mkString(",")}") + + val selectedPartitions = files.location.listFiles(partitionKeyFilters.toSeq) + + val filterAttributes = AttributeSet(afterScanFilters) + val requiredExpressions: Seq[NamedExpression] = filterAttributes.toSeq ++ projects + val requiredAttributes = AttributeSet(requiredExpressions).map(_.name).toSet + + val prunedDataSchema = + StructType( + files.dataSchema.filter(f => requiredAttributes.contains(f.name))) + logInfo(s"Pruned Data Schema: ${prunedDataSchema.simpleString(5)}") + + val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) + logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}") + + val readFile = files.fileFormat.buildReader( + sqlContext = files.sqlContext, + partitionSchema = files.partitionSchema, + dataSchema = prunedDataSchema, + filters = pushedDownFilters, + options = files.options) + + val plannedPartitions = files.bucketSpec match { + case Some(bucketing) if files.sqlContext.conf.bucketingEnabled => + logInfo(s"Planning with ${bucketing.numBuckets} buckets") + val bucketed = + selectedPartitions + .flatMap { p => + p.files.map(f => PartitionedFile(p.values, f.getPath.toUri.toString, 0, f.getLen)) + }.groupBy { f => + BucketingUtils + .getBucketId(new Path(f.filePath).getName) + .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) + } + + (0 until bucketing.numBuckets).map { bucketId => + FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil)) + } + + case _ => + val maxSplitBytes = files.sqlContext.conf.filesMaxPartitionBytes + logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes") + + val splitFiles = selectedPartitions.flatMap { partition => + partition.files.flatMap { file => + assert(file.getLen != 0) + (0L to file.getLen by maxSplitBytes).map { offset => + val remaining = file.getLen - offset + val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining + PartitionedFile(partition.values, file.getPath.toUri.toString, offset, size) + } + } + }.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse) + + val partitions = new ArrayBuffer[FilePartition] + val currentFiles = new ArrayBuffer[PartitionedFile] + var currentSize = 0L + + /** Add the given file to the current partition. */ + def addFile(file: PartitionedFile): Unit = { + currentSize += file.length + currentFiles.append(file) + } + + /** Close the current partition and move to the next. */ + def closePartition(): Unit = { + if (currentFiles.nonEmpty) { + val newPartition = + FilePartition( + partitions.size, + currentFiles.toArray.toSeq) // Copy to a new Array. + partitions.append(newPartition) + } + currentFiles.clear() + currentSize = 0 + } + + // Assign files to partitions using "First Fit Decreasing" (FFD) + // TODO: consider adding a slop factor here? + splitFiles.foreach { file => + if (currentSize + file.length > maxSplitBytes) { + closePartition() + addFile(file) + } else { + addFile(file) + } + } + closePartition() + partitions + } + + val scan = + DataSourceScan( + l.output, + new FileScanRDD( + files.sqlContext, + readFile, + plannedPartitions), + files, + Map("format" -> files.fileFormat.toString)) + + val afterScanFilter = afterScanFilters.toSeq.reduceOption(expressions.And) + val withFilter = afterScanFilter.map(execution.Filter(_, scan)).getOrElse(scan) + val withProjections = if (projects.forall(_.isInstanceOf[AttributeReference])) { + withFilter + } else { + execution.Project(projects, withFilter) + } + + withProjections :: Nil + + case _ => Nil + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 18a460fc85..3ac2ff494f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -32,17 +32,23 @@ import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ -object Partition { - def apply(values: InternalRow, path: String): Partition = +object PartitionDirectory { + def apply(values: InternalRow, path: String): PartitionDirectory = apply(values, new Path(path)) } -private[sql] case class Partition(values: InternalRow, path: Path) +/** + * Holds a directory in a partitioned collection of files as well as as the partition values + * in the form of a Row. Before scanning, the files at `path` need to be enumerated. + */ +private[sql] case class PartitionDirectory(values: InternalRow, path: Path) -private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition]) +private[sql] case class PartitionSpec( + partitionColumns: StructType, + partitions: Seq[PartitionDirectory]) private[sql] object PartitionSpec { - val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[Partition]) + val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[PartitionDirectory]) } private[sql] object PartitioningUtils { @@ -133,7 +139,7 @@ private[sql] object PartitioningUtils { // Finally, we create `Partition`s based on paths and resolved partition values. val partitions = resolvedPartitionValues.zip(pathsWithPartitionValues).map { case (PartitionValues(_, literals), (path, _)) => - Partition(InternalRow.fromSeq(literals.map(_.value)), path) + PartitionDirectory(InternalRow.fromSeq(literals.map(_.value)), path) } PartitionSpec(StructType(fields), partitions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 0e6b9855c7..c96a508cf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -52,7 +52,7 @@ object CSVRelation extends Logging { tokenizedRDD: RDD[Array[String]], schema: StructType, requiredColumns: Array[String], - inputs: Array[FileStatus], + inputs: Seq[FileStatus], sqlContext: SQLContext, params: CSVOptions): RDD[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index 42c07c8a23..a5f94262ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -103,7 +103,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - inputFiles: Array[FileStatus], + inputFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] = { // TODO: Filter before calling buildInternalScan. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 05b44d1a2a..3fa5ebf1bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -95,7 +95,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - inputFiles: Array[FileStatus], + inputFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] = { // TODO: Filter files for all formats before calling buildInternalScan. @@ -115,7 +115,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { } } - private def createBaseRdd(sqlContext: SQLContext, inputPaths: Array[FileStatus]): RDD[String] = { + private def createBaseRdd(sqlContext: SQLContext, inputPaths: Seq[FileStatus]): RDD[String] = { val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration) val conf = job.getConfiguration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index f1060074d6..342034ca0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -274,7 +274,7 @@ private[sql] class DefaultSource extends FileFormat with DataSourceRegister with requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - allFiles: Array[FileStatus], + allFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] = { val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 2869a6a1ac..6af403dec5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -94,7 +94,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - inputFiles: Array[FileStatus], + inputFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] = { verifySchema(dataSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 59429d254e..cbdc37a2a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -504,6 +504,11 @@ object SQLConf { " method", isPublic = false) + val FILES_MAX_PARTITION_BYTES = longConf("spark.sql.files.maxPartitionBytes", + defaultValue = Some(128 * 1024 * 1024), // parquet.block.size + doc = "The maximum number of bytes to pack into a single partition when reading files.", + isPublic = true) + val EXCHANGE_REUSE_ENABLED = booleanConf("spark.sql.exchange.reuse", defaultValue = Some(true), doc = "When true, the planner will try to find out duplicated exchanges and re-use them", @@ -538,6 +543,8 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin /** ************************ Spark SQL Params/Hints ******************* */ + def filesMaxPartitionBytes: Long = getConf(FILES_MAX_PARTITION_BYTES) + def useCompression: Boolean = getConf(COMPRESS_CACHED) def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) @@ -605,7 +612,7 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin def parallelPartitionDiscoveryThreshold: Int = getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD) - def bucketingEnabled(): Boolean = getConf(SQLConf.BUCKETING_ENABLED) + def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED) // Do not use a value larger than 4000 as the default value of this property. // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 601f944fb6..95ffc33011 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -30,11 +30,11 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.FileRelation import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.streaming.{FileStreamSource, Sink, Source} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.collection.BitSet @@ -409,7 +409,7 @@ case class HadoopFsRelation( def partitionSchemaOption: Option[StructType] = if (partitionSchema.isEmpty) None else Some(partitionSchema) - def partitionSpec: PartitionSpec = location.partitionSpec(partitionSchemaOption) + def partitionSpec: PartitionSpec = location.partitionSpec() def refresh(): Unit = location.refresh() @@ -454,11 +454,41 @@ trait FileFormat { requiredColumns: Array[String], filters: Array[Filter], bucketSet: Option[BitSet], - inputFiles: Array[FileStatus], + inputFiles: Seq[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] + + /** + * Returns a function that can be used to read a single file in as an Iterator of InternalRow. + * + * @param partitionSchema The schema of the partition column row that will be present in each + * PartitionedFile. These columns should be prepended to the rows that + * are produced by the iterator. + * @param dataSchema The schema of the data that should be output for each row. This may be a + * subset of the columns that are present in the file if column pruning has + * occurred. + * @param filters A set of filters than can optionally be used to reduce the number of rows output + * @param options A set of string -> string configuration options. + * @return + */ + def buildReader( + sqlContext: SQLContext, + partitionSchema: StructType, + dataSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { + // TODO: Remove this default implementation when the other formats have been ported + // Until then we guard in [[FileSourceStrategy]] to only call this method on supported formats. + throw new UnsupportedOperationException(s"buildReader is not supported for $this") + } } +/** + * A collection of data files from a partitioned relation, along with the partition values in the + * form of an [[InternalRow]]. + */ +case class Partition(values: InternalRow, files: Seq[FileStatus]) + /** * An interface for objects capable of enumerating the files that comprise a relation as well * as the partitioning characteristics of those files. @@ -466,7 +496,18 @@ trait FileFormat { trait FileCatalog { def paths: Seq[Path] - def partitionSpec(schema: Option[StructType]): PartitionSpec + def partitionSpec(): PartitionSpec + + /** + * Returns all valid files grouped into partitions when the data is partitioned. If the data is + * unpartitioned, this will return a single partition with not partition values. + * + * @param filters the filters used to prune which partitions are returned. These filters must + * only refer to partition columns and this method will only return files + * where these predicates are guaranteed to evaluate to `true`. Thus, these + * filters will not need to be evaluated again on the returned data. + */ + def listFiles(filters: Seq[Expression]): Seq[Partition] def allFiles(): Seq[FileStatus] @@ -478,11 +519,17 @@ trait FileCatalog { /** * A file catalog that caches metadata gathered by scanning all the files present in `paths` * recursively. + * + * @param parameters as set of options to control discovery + * @param paths a list of paths to scan + * @param partitionSchema an optional partition schema that will be use to provide types for the + * discovered partitions */ class HDFSFileCatalog( val sqlContext: SQLContext, val parameters: Map[String, String], - val paths: Seq[Path]) + val paths: Seq[Path], + val partitionSchema: Option[StructType]) extends FileCatalog with Logging { private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) @@ -491,9 +538,9 @@ class HDFSFileCatalog( var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] var cachedPartitionSpec: PartitionSpec = _ - def partitionSpec(schema: Option[StructType]): PartitionSpec = { + def partitionSpec(): PartitionSpec = { if (cachedPartitionSpec == null) { - cachedPartitionSpec = inferPartitioning(schema) + cachedPartitionSpec = inferPartitioning(partitionSchema) } cachedPartitionSpec @@ -501,6 +548,53 @@ class HDFSFileCatalog( refresh() + override def listFiles(filters: Seq[Expression]): Seq[Partition] = { + if (partitionSpec().partitionColumns.isEmpty) { + Partition(InternalRow.empty, allFiles()) :: Nil + } else { + prunePartitions(filters, partitionSpec()).map { + case PartitionDirectory(values, path) => Partition(values, getStatus(path)) + } + } + } + + protected def prunePartitions( + predicates: Seq[Expression], + partitionSpec: PartitionSpec): Seq[PartitionDirectory] = { + val PartitionSpec(partitionColumns, partitions) = partitionSpec + val partitionColumnNames = partitionColumns.map(_.name).toSet + val partitionPruningPredicates = predicates.filter { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + + if (partitionPruningPredicates.nonEmpty) { + val predicate = + partitionPruningPredicates + .reduceOption(expressions.And) + .getOrElse(Literal(true)) + + val boundPredicate = InterpretedPredicate.create(predicate.transform { + case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }) + + val selected = partitions.filter { + case PartitionDirectory(values, _) => boundPredicate(values) + } + logInfo { + val total = partitions.length + val selectedSize = selected.length + val percentPruned = (1 - selectedSize.toDouble / total.toDouble) * 100 + s"Selected $selectedSize partitions out of $total, pruned $percentPruned% partitions." + } + + selected + } else { + partitions + } + } + def allFiles(): Seq[FileStatus] = leafFiles.values.toSeq def getStatus(path: Path): Array[FileStatus] = leafDirToChildrenFiles(path) @@ -560,7 +654,7 @@ class HDFSFileCatalog( PartitionSpec(userProvidedSchema, spec.partitions.map { part => part.copy(values = castPartitionValuesToUserSchema(part.values)) }) - case None => + case _ => PartitioningUtils.parsePartitions( leafDirs, PartitioningUtils.DEFAULT_PARTITION_NAME, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala new file mode 100644 index 0000000000..2f8129c5da --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.{File, FilenameFilter} + +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.mapreduce.Job + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper} +import org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.execution.{DataSourceScan, PhysicalRDD} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.collection.BitSet + +class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper { + import testImplicits._ + + test("unpartitioned table, single partition") { + val table = + createTable( + files = Seq( + "file1" -> 1, + "file2" -> 1, + "file3" -> 1, + "file4" -> 1, + "file5" -> 1, + "file6" -> 1, + "file7" -> 1, + "file8" -> 1, + "file9" -> 1, + "file10" -> 1)) + + checkScan(table.select('c1)) { partitions => + // 10 one byte files should fit in a single partition with 10 files. + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 10, "when checking partition 1") + // 1 byte files are too small to split so we should read the whole thing. + assert(partitions.head.files.head.start == 0) + assert(partitions.head.files.head.length == 1) + } + + checkPartitionSchema(StructType(Nil)) + checkDataSchema(StructType(Nil).add("c1", IntegerType)) + } + + test("unpartitioned table, multiple partitions") { + val table = + createTable( + files = Seq( + "file1" -> 5, + "file2" -> 5, + "file3" -> 5)) + + withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10") { + checkScan(table.select('c1)) { partitions => + // 5 byte files should be laid out [(5, 5), (5)] + assert(partitions.size == 2, "when checking partitions") + assert(partitions(0).files.size == 2, "when checking partition 1") + assert(partitions(1).files.size == 1, "when checking partition 2") + + // 5 byte files are too small to split so we should read the whole thing. + assert(partitions.head.files.head.start == 0) + assert(partitions.head.files.head.length == 5) + } + + checkPartitionSchema(StructType(Nil)) + checkDataSchema(StructType(Nil).add("c1", IntegerType)) + } + } + + test("Unpartitioned table, large file that gets split") { + val table = + createTable( + files = Seq( + "file1" -> 15, + "file2" -> 4)) + + withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10") { + checkScan(table.select('c1)) { partitions => + // Files should be laid out [(0-5), (5-10, 4)] + assert(partitions.size == 2, "when checking partitions") + assert(partitions(0).files.size == 1, "when checking partition 1") + assert(partitions(1).files.size == 2, "when checking partition 2") + + // Start by reading 10 bytes of the first file + assert(partitions.head.files.head.start == 0) + assert(partitions.head.files.head.length == 10) + + // Second partition reads the remaining 5 + assert(partitions(1).files.head.start == 10) + assert(partitions(1).files.head.length == 5) + } + + checkPartitionSchema(StructType(Nil)) + checkDataSchema(StructType(Nil).add("c1", IntegerType)) + } + } + + test("partitioned table") { + val table = + createTable( + files = Seq( + "p1=1/file1" -> 10, + "p1=2/file2" -> 10)) + + // Only one file should be read. + checkScan(table.where("p1 = 1")) { partitions => + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 1, "when files in partition 1") + } + // We don't need to reevaluate filters that are only on partitions. + checkDataFilters(Set.empty) + + // Only one file should be read. + checkScan(table.where("p1 = 1 AND c1 = 1 AND (p1 + c1) = 1")) { partitions => + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 1, "when checking files in partition 1") + assert(partitions.head.files.head.partitionValues.getInt(0) == 1, + "when checking partition values") + } + // Only the filters that do not contain the partition column should be pushed down + checkDataFilters(Set(IsNotNull("c1"), EqualTo("c1", 1))) + } + + test("partitioned table - after scan filters") { + val table = + createTable( + files = Seq( + "p1=1/file1" -> 10, + "p1=2/file2" -> 10)) + + val df = table.where("p1 = 1 AND (p1 + c1) = 2 AND c1 = 1") + // Filter on data only are advisory so we have to reevaluate. + assert(getPhysicalFilters(df) contains resolve(df, "c1 = 1")) + // Need to evalaute filters that are not pushed down. + assert(getPhysicalFilters(df) contains resolve(df, "(p1 + c1) = 2")) + // Don't reevaluate partition only filters. + assert(!(getPhysicalFilters(df) contains resolve(df, "p1 = 1"))) + } + + test("bucketed table") { + val table = + createTable( + files = Seq( + "p1=1/file1_0000" -> 1, + "p1=1/file2_0000" -> 1, + "p1=1/file3_0002" -> 1, + "p1=2/file4_0002" -> 1, + "p1=2/file5_0000" -> 1, + "p1=2/file6_0000" -> 1, + "p1=2/file7_0000" -> 1), + buckets = 3) + + // No partition pruning + checkScan(table) { partitions => + assert(partitions.size == 3) + assert(partitions(0).files.size == 5) + assert(partitions(1).files.size == 0) + assert(partitions(2).files.size == 2) + } + + // With partition pruning + checkScan(table.where("p1=2")) { partitions => + assert(partitions.size == 3) + assert(partitions(0).files.size == 3) + assert(partitions(1).files.size == 0) + assert(partitions(2).files.size == 1) + } + } + + // Helpers for checking the arguments passed to the FileFormat. + + protected val checkPartitionSchema = + checkArgument("partition schema", _.partitionSchema, _: StructType) + protected val checkDataSchema = + checkArgument("data schema", _.dataSchema, _: StructType) + protected val checkDataFilters = + checkArgument("data filters", _.filters.toSet, _: Set[Filter]) + + /** Helper for building checks on the arguments passed to the reader. */ + protected def checkArgument[T](name: String, arg: LastArguments.type => T, expected: T): Unit = { + if (arg(LastArguments) != expected) { + fail( + s""" + |Wrong $name + |expected: $expected + |actual: ${arg(LastArguments)} + """.stripMargin) + } + } + + /** Returns a resolved expression for `str` in the context of `df`. */ + def resolve(df: DataFrame, str: String): Expression = { + df.select(expr(str)).queryExecution.analyzed.expressions.head.children.head + } + + /** Returns a set with all the filters present in the physical plan. */ + def getPhysicalFilters(df: DataFrame): ExpressionSet = { + ExpressionSet( + df.queryExecution.executedPlan.collect { + case execution.Filter(f, _) => splitConjunctivePredicates(f) + }.flatten) + } + + /** Plans the query and calls the provided validation function with the planned partitioning. */ + def checkScan(df: DataFrame)(func: Seq[FilePartition] => Unit): Unit = { + val fileScan = df.queryExecution.executedPlan.collect { + case DataSourceScan(_, scan: FileScanRDD, _, _) => scan + }.headOption.getOrElse { + fail(s"No FileScan in query\n${df.queryExecution}") + } + + func(fileScan.filePartitions) + } + + /** + * Constructs a new table given a list of file names and sizes expressed in bytes. The table + * is written out in a temporary directory and any nested directories in the files names + * are automatically created. + * + * When `buckets` is > 0 the returned [[DataFrame]] will have metadata specifying that number of + * buckets. However, it is the responsibility of the caller to assign files to each bucket + * by appending the bucket id to the file names. + */ + def createTable( + files: Seq[(String, Int)], + buckets: Int = 0): DataFrame = { + val tempDir = Utils.createTempDir() + files.foreach { + case (name, size) => + val file = new File(tempDir, name) + assert(file.getParentFile.exists() || file.getParentFile.mkdirs()) + util.stringToFile(file, "*" * size) + } + + val df = sqlContext.read + .format(classOf[TestFileFormat].getName) + .load(tempDir.getCanonicalPath) + + if (buckets > 0) { + val bucketed = df.queryExecution.analyzed transform { + case l @ LogicalRelation(r: HadoopFsRelation, _, _) => + l.copy(relation = + r.copy(bucketSpec = Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil)))) + } + Dataset.newDataFrame(sqlContext, bucketed) + } else { + df + } + } +} + +/** Holds the last arguments passed to [[TestFileFormat]]. */ +object LastArguments { + var partitionSchema: StructType = _ + var dataSchema: StructType = _ + var filters: Seq[Filter] = _ + var options: Map[String, String] = _ +} + +/** A test [[FileFormat]] that records the arguments passed to buildReader, and returns nothing. */ +class TestFileFormat extends FileFormat { + + override def toString: String = "TestFileFormat" + + /** + * When possible, this method should return the schema of the given `files`. When the format + * does not support inference, or no valid files are given should return None. In these cases + * Spark will require that user specify the schema manually. + */ + override def inferSchema( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] = + Some( + StructType(Nil) + .add("c1", IntegerType) + .add("c2", IntegerType)) + + /** + * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can + * be put here. For example, user defined output committer can be configured here + * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. + */ + override def prepareWrite( + sqlContext: SQLContext, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory = { + throw new NotImplementedError("JUST FOR TESTING") + } + + override def buildInternalScan( + sqlContext: SQLContext, + dataSchema: StructType, + requiredColumns: Array[String], + filters: Array[Filter], + bucketSet: Option[BitSet], + inputFiles: Seq[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration], + options: Map[String, String]): RDD[InternalRow] = { + throw new NotImplementedError("JUST FOR TESTING") + } + + override def buildReader( + sqlContext: SQLContext, + partitionSchema: StructType, + dataSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { + + // Record the arguments so they can be checked in the test case. + LastArguments.partitionSchema = partitionSchema + LastArguments.dataSchema = dataSchema + LastArguments.filters = filters + LastArguments.options = options + + (file: PartitionedFile) => { Iterator.empty } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 026191528e..f875b54cd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.execution.datasources.{LogicalRelation, Partition, PartitioningUtils, PartitionSpec} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionDirectory => Partition, PartitioningUtils, PartitionSpec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.test.SharedSQLContext -- cgit v1.2.3