aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2016-03-14 19:21:12 -0700
committerMichael Armbrust <michael@databricks.com>2016-03-14 19:21:12 -0700
commit17eec0a71ba8713c559d641e3f43a1be726b037c (patch)
tree6f2a6c5a7aef585ef58bb2d6fba4f63bc58f167a /sql/core/src
parent992142b87ed5b507493e4f9fac3f72ba14fafbbc (diff)
downloadspark-17eec0a71ba8713c559d641e3f43a1be726b037c.tar.gz
spark-17eec0a71ba8713c559d641e3f43a1be726b037c.tar.bz2
spark-17eec0a71ba8713c559d641e3f43a1be726b037c.zip
[SPARK-13664][SQL] Add a strategy for planning partitioned and bucketed scans of files
This PR adds a new strategy, `FileSourceStrategy`, that can be used for planning scans of collections of files that might be partitioned or bucketed. Compared with the existing planning logic in `DataSourceStrategy` this version has the following desirable properties: - It removes the need to have `RDD`, `broadcastedHadoopConf` and other distributed concerns in the public API of `org.apache.spark.sql.sources.FileFormat` - Partition column appending is delegated to the format to avoid an extra copy / devectorization when appending partition columns - It minimizes the amount of data that is shipped to each executor (i.e. it does not send the whole list of files to every worker in the form of a hadoop conf) - it natively supports bucketing files into partitions, and thus does not require coalescing / creating a `UnionRDD` with the correct partitioning. - Small files are automatically coalesced into fewer tasks using an approximate bin-packing algorithm. Currently only a testing source is planned / tested using this strategy. In follow-up PRs we will port the existing formats to this API. A stub for `FileScanRDD` is also added, but most methods remain unimplemented. Other minor cleanups: - partition pruning is pushed into `FileCatalog` so both the new and old code paths can use this logic. This will also allow future implementations to use indexes or other tricks (i.e. a MySQL metastore) - The partitions from the `FileCatalog` now propagate information about file sizes all the way up to the planner so we can intelligently spread files out. - `Array` -> `Seq` in some internal APIs to avoid unnecessary `toArray` calls - Rename `Partition` to `PartitionDirectory` to differentiate partitions used earlier in pruning from those where we have already enumerated the files and their sizes. Author: Michael Armbrust <michael@databricks.com> Closes #11646 from marmbrus/fileStrategy.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala32
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala46
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala57
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala202
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala112
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala345
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala2
16 files changed, 763 insertions, 77 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index d363cb000d..e97c6be7f1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -151,7 +151,7 @@ private[sql] case class DataSourceScan(
override val outputPartitioning = {
val bucketSpec = relation match {
// TODO: this should be closer to bucket planning.
- case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled() => r.bucketSpec
+ case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled => r.bucketSpec
case _ => None
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index d1569a4ec2..292d366e72 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.datasources.DataSourceStrategy
+import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy}
class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies {
val sparkContext: SparkContext = sqlContext.sparkContext
@@ -29,6 +29,7 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies {
def strategies: Seq[Strategy] =
sqlContext.experimental.extraStrategies ++ (
+ FileSourceStrategy ::
DataSourceStrategy ::
DDLStrategy ::
SpecialLimits ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 887f5469b5..e65a771202 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -143,7 +143,7 @@ case class DataSource(
SparkHadoopUtil.get.globPathIfNecessary(qualified)
}.toArray
- val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths)
+ val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths, None)
val dataSchema = userSpecifiedSchema.orElse {
format.inferSchema(
sqlContext,
@@ -208,7 +208,20 @@ case class DataSource(
SparkHadoopUtil.get.globPathIfNecessary(qualified)
}.toArray
- val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths)
+ // If they gave a schema, then we try and figure out the types of the partition columns
+ // from that schema.
+ val partitionSchema = userSpecifiedSchema.map { schema =>
+ StructType(
+ partitionColumns.map { c =>
+ // TODO: Case sensitivity.
+ schema
+ .find(_.name.toLowerCase() == c.toLowerCase())
+ .getOrElse(throw new AnalysisException(s"Invalid partition column '$c'"))
+ })
+ }
+
+ val fileCatalog: FileCatalog =
+ new HDFSFileCatalog(sqlContext, options, globbedPaths, partitionSchema)
val dataSchema = userSpecifiedSchema.orElse {
format.inferSchema(
sqlContext,
@@ -220,22 +233,11 @@ case class DataSource(
"It must be specified manually")
}
- // If they gave a schema, then we try and figure out the types of the partition columns
- // from that schema.
- val partitionSchema = userSpecifiedSchema.map { schema =>
- StructType(
- partitionColumns.map { c =>
- // TODO: Case sensitivity.
- schema
- .find(_.name.toLowerCase() == c.toLowerCase())
- .getOrElse(throw new AnalysisException(s"Invalid partition column '$c'"))
- })
- }.getOrElse(fileCatalog.partitionSpec(None).partitionColumns)
HadoopFsRelation(
sqlContext,
fileCatalog,
- partitionSchema = partitionSchema,
+ partitionSchema = fileCatalog.partitionSpec().partitionColumns,
dataSchema = dataSchema.asNullable,
bucketSpec = bucketSpec,
format,
@@ -296,7 +298,7 @@ case class DataSource(
resolveRelation()
.asInstanceOf[HadoopFsRelation]
.location
- .partitionSpec(None)
+ .partitionSpec()
.partitionColumns
.fieldNames
.toSet)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 1adf3b6676..7f6671552e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -126,7 +126,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
val partitionAndNormalColumnFilters =
filters.toSet -- partitionFilters.toSet -- pushedFilters.toSet
- val selectedPartitions = prunePartitions(partitionFilters, t.partitionSpec).toArray
+ val selectedPartitions = t.location.listFiles(partitionFilters)
logInfo {
val total = t.partitionSpec.partitions.length
@@ -180,7 +180,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf))
t.bucketSpec match {
- case Some(spec) if t.sqlContext.conf.bucketingEnabled() =>
+ case Some(spec) if t.sqlContext.conf.bucketingEnabled =>
val scanBuilder: (Seq[Attribute], Array[Filter]) => RDD[InternalRow] = {
(requiredColumns: Seq[Attribute], filters: Array[Filter]) => {
val bucketed =
@@ -200,7 +200,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
requiredColumns.map(_.name).toArray,
filters,
None,
- bucketFiles.toArray,
+ bucketFiles,
confBroadcast,
t.options).coalesce(1)
}
@@ -233,7 +233,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
a.map(_.name).toArray,
f,
None,
- t.location.allFiles().toArray,
+ t.location.allFiles(),
confBroadcast,
t.options)) :: Nil
}
@@ -255,7 +255,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
filters: Seq[Expression],
buckets: Option[BitSet],
partitionColumns: StructType,
- partitions: Array[Partition],
+ partitions: Seq[Partition],
options: Map[String, String]): SparkPlan = {
val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation]
@@ -272,14 +272,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
(requiredColumns: Seq[Attribute], filters: Array[Filter]) => {
relation.bucketSpec match {
- case Some(spec) if relation.sqlContext.conf.bucketingEnabled() =>
+ case Some(spec) if relation.sqlContext.conf.bucketingEnabled =>
val requiredDataColumns =
requiredColumns.filterNot(c => partitionColumnNames.contains(c.name))
// Builds RDD[Row]s for each selected partition.
val perPartitionRows: Seq[(Int, RDD[InternalRow])] = partitions.flatMap {
- case Partition(partitionValues, dir) =>
- val files = relation.location.getStatus(dir)
+ case Partition(partitionValues, files) =>
val bucketed = files.groupBy { f =>
BucketingUtils
.getBucketId(f.getPath.getName)
@@ -327,14 +326,14 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
// Builds RDD[Row]s for each selected partition.
val perPartitionRows = partitions.map {
- case Partition(partitionValues, dir) =>
+ case Partition(partitionValues, files) =>
val dataRows = relation.fileFormat.buildInternalScan(
relation.sqlContext,
relation.dataSchema,
requiredDataColumns.map(_.name).toArray,
filters,
buckets,
- relation.location.getStatus(dir),
+ files,
confBroadcast,
options)
@@ -525,33 +524,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets)
}
- protected def prunePartitions(
- predicates: Seq[Expression],
- partitionSpec: PartitionSpec): Seq[Partition] = {
- val PartitionSpec(partitionColumns, partitions) = partitionSpec
- val partitionColumnNames = partitionColumns.map(_.name).toSet
- val partitionPruningPredicates = predicates.filter {
- _.references.map(_.name).toSet.subsetOf(partitionColumnNames)
- }
-
- if (partitionPruningPredicates.nonEmpty) {
- val predicate =
- partitionPruningPredicates
- .reduceOption(expressions.And)
- .getOrElse(Literal(true))
-
- val boundPredicate = InterpretedPredicate.create(predicate.transform {
- case a: AttributeReference =>
- val index = partitionColumns.indexWhere(a.name == _.name)
- BoundReference(index, partitionColumns(index).dataType, nullable = true)
- })
-
- partitions.filter { case Partition(values, _) => boundPredicate(values) }
- } else {
- partitions
- }
- }
-
// Based on Public API.
protected def pruneFilterProject(
relation: LogicalRelation,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
new file mode 100644
index 0000000000..e2cbbc34d9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.spark.{Partition, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.InternalRow
+
+/**
+ * A single file that should be read, along with partition column values that
+ * need to be prepended to each row. The reading should start at the first
+ * valid record found after `offset`.
+ */
+case class PartitionedFile(
+ partitionValues: InternalRow,
+ filePath: String,
+ start: Long,
+ length: Long)
+
+/**
+ * A collection of files that should be read as a single task possibly from multiple partitioned
+ * directories.
+ *
+ * IMPLEMENT ME: This is just a placeholder for a future implementation.
+ * TODO: This currently does not take locality information about the files into account.
+ */
+case class FilePartition(val index: Int, files: Seq[PartitionedFile]) extends Partition
+
+class FileScanRDD(
+ @transient val sqlContext: SQLContext,
+ readFunction: (PartitionedFile) => Iterator[InternalRow],
+ @transient val filePartitions: Seq[FilePartition])
+ extends RDD[InternalRow](sqlContext.sparkContext, Nil) {
+
+
+ override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
+ throw new NotImplementedError("Not Implemented Yet")
+ }
+
+ override protected def getPartitions: Array[Partition] = Array.empty
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
new file mode 100644
index 0000000000..ef95d5d289
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.Logging
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.planning.PhysicalOperation
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.{DataSourceScan, SparkPlan}
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.types._
+
+/**
+ * A strategy for planning scans over collections of files that might be partitioned or bucketed
+ * by user specified columns.
+ *
+ * At a high level planning occurs in several phases:
+ * - Split filters by when they need to be evaluated.
+ * - Prune the schema of the data requested based on any projections present. Today this pruning
+ * is only done on top level columns, but formats should support pruning of nested columns as
+ * well.
+ * - Construct a reader function by passing filters and the schema into the FileFormat.
+ * - Using an partition pruning predicates, enumerate the list of files that should be read.
+ * - Split the files into tasks and construct a FileScanRDD.
+ * - Add any projection or filters that must be evaluated after the scan.
+ *
+ * Files are assigned into tasks using the following algorithm:
+ * - If the table is bucketed, group files by bucket id into the correct number of partitions.
+ * - If the table is not bucketed or bucketing is turned off:
+ * - If any file is larger than the threshold, split it into pieces based on that threshold
+ * - Sort the files by decreasing file size.
+ * - Assign the ordered files to buckets using the following algorithm. If the current partition
+ * is under the threshold with the addition of the next file, add it. If not, open a new bucket
+ * and add it. Proceed to the next file.
+ */
+private[sql] object FileSourceStrategy extends Strategy with Logging {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case PhysicalOperation(projects, filters, l@LogicalRelation(files: HadoopFsRelation, _, _))
+ if files.fileFormat.toString == "TestFileFormat" =>
+ // Filters on this relation fall into four categories based on where we can use them to avoid
+ // reading unneeded data:
+ // - partition keys only - used to prune directories to read
+ // - bucket keys only - optionally used to prune files to read
+ // - keys stored in the data only - optionally used to skip groups of data in files
+ // - filters that need to be evaluated again after the scan
+ val filterSet = ExpressionSet(filters)
+
+ val partitionColumns =
+ AttributeSet(l.resolve(files.partitionSchema, files.sqlContext.analyzer.resolver))
+ val partitionKeyFilters =
+ ExpressionSet(filters.filter(_.references.subsetOf(partitionColumns)))
+ logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}")
+
+ val bucketColumns =
+ AttributeSet(
+ files.bucketSpec
+ .map(_.bucketColumnNames)
+ .getOrElse(Nil)
+ .map(l.resolveQuoted(_, files.sqlContext.conf.resolver)
+ .getOrElse(sys.error(""))))
+
+ // Partition keys are not available in the statistics of the files.
+ val dataFilters = filters.filter(_.references.intersect(partitionColumns).isEmpty)
+
+ // Predicates with both partition keys and attributes need to be evaluated after the scan.
+ val afterScanFilters = filterSet -- partitionKeyFilters
+ logInfo(s"Post-Scan Filters: ${afterScanFilters.mkString(",")}")
+
+ val selectedPartitions = files.location.listFiles(partitionKeyFilters.toSeq)
+
+ val filterAttributes = AttributeSet(afterScanFilters)
+ val requiredExpressions: Seq[NamedExpression] = filterAttributes.toSeq ++ projects
+ val requiredAttributes = AttributeSet(requiredExpressions).map(_.name).toSet
+
+ val prunedDataSchema =
+ StructType(
+ files.dataSchema.filter(f => requiredAttributes.contains(f.name)))
+ logInfo(s"Pruned Data Schema: ${prunedDataSchema.simpleString(5)}")
+
+ val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter)
+ logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}")
+
+ val readFile = files.fileFormat.buildReader(
+ sqlContext = files.sqlContext,
+ partitionSchema = files.partitionSchema,
+ dataSchema = prunedDataSchema,
+ filters = pushedDownFilters,
+ options = files.options)
+
+ val plannedPartitions = files.bucketSpec match {
+ case Some(bucketing) if files.sqlContext.conf.bucketingEnabled =>
+ logInfo(s"Planning with ${bucketing.numBuckets} buckets")
+ val bucketed =
+ selectedPartitions
+ .flatMap { p =>
+ p.files.map(f => PartitionedFile(p.values, f.getPath.toUri.toString, 0, f.getLen))
+ }.groupBy { f =>
+ BucketingUtils
+ .getBucketId(new Path(f.filePath).getName)
+ .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}"))
+ }
+
+ (0 until bucketing.numBuckets).map { bucketId =>
+ FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil))
+ }
+
+ case _ =>
+ val maxSplitBytes = files.sqlContext.conf.filesMaxPartitionBytes
+ logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes")
+
+ val splitFiles = selectedPartitions.flatMap { partition =>
+ partition.files.flatMap { file =>
+ assert(file.getLen != 0)
+ (0L to file.getLen by maxSplitBytes).map { offset =>
+ val remaining = file.getLen - offset
+ val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining
+ PartitionedFile(partition.values, file.getPath.toUri.toString, offset, size)
+ }
+ }
+ }.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse)
+
+ val partitions = new ArrayBuffer[FilePartition]
+ val currentFiles = new ArrayBuffer[PartitionedFile]
+ var currentSize = 0L
+
+ /** Add the given file to the current partition. */
+ def addFile(file: PartitionedFile): Unit = {
+ currentSize += file.length
+ currentFiles.append(file)
+ }
+
+ /** Close the current partition and move to the next. */
+ def closePartition(): Unit = {
+ if (currentFiles.nonEmpty) {
+ val newPartition =
+ FilePartition(
+ partitions.size,
+ currentFiles.toArray.toSeq) // Copy to a new Array.
+ partitions.append(newPartition)
+ }
+ currentFiles.clear()
+ currentSize = 0
+ }
+
+ // Assign files to partitions using "First Fit Decreasing" (FFD)
+ // TODO: consider adding a slop factor here?
+ splitFiles.foreach { file =>
+ if (currentSize + file.length > maxSplitBytes) {
+ closePartition()
+ addFile(file)
+ } else {
+ addFile(file)
+ }
+ }
+ closePartition()
+ partitions
+ }
+
+ val scan =
+ DataSourceScan(
+ l.output,
+ new FileScanRDD(
+ files.sqlContext,
+ readFile,
+ plannedPartitions),
+ files,
+ Map("format" -> files.fileFormat.toString))
+
+ val afterScanFilter = afterScanFilters.toSeq.reduceOption(expressions.And)
+ val withFilter = afterScanFilter.map(execution.Filter(_, scan)).getOrElse(scan)
+ val withProjections = if (projects.forall(_.isInstanceOf[AttributeReference])) {
+ withFilter
+ } else {
+ execution.Project(projects, withFilter)
+ }
+
+ withProjections :: Nil
+
+ case _ => Nil
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index 18a460fc85..3ac2ff494f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -32,17 +32,23 @@ import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
import org.apache.spark.sql.types._
-object Partition {
- def apply(values: InternalRow, path: String): Partition =
+object PartitionDirectory {
+ def apply(values: InternalRow, path: String): PartitionDirectory =
apply(values, new Path(path))
}
-private[sql] case class Partition(values: InternalRow, path: Path)
+/**
+ * Holds a directory in a partitioned collection of files as well as as the partition values
+ * in the form of a Row. Before scanning, the files at `path` need to be enumerated.
+ */
+private[sql] case class PartitionDirectory(values: InternalRow, path: Path)
-private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition])
+private[sql] case class PartitionSpec(
+ partitionColumns: StructType,
+ partitions: Seq[PartitionDirectory])
private[sql] object PartitionSpec {
- val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[Partition])
+ val emptySpec = PartitionSpec(StructType(Seq.empty[StructField]), Seq.empty[PartitionDirectory])
}
private[sql] object PartitioningUtils {
@@ -133,7 +139,7 @@ private[sql] object PartitioningUtils {
// Finally, we create `Partition`s based on paths and resolved partition values.
val partitions = resolvedPartitionValues.zip(pathsWithPartitionValues).map {
case (PartitionValues(_, literals), (path, _)) =>
- Partition(InternalRow.fromSeq(literals.map(_.value)), path)
+ PartitionDirectory(InternalRow.fromSeq(literals.map(_.value)), path)
}
PartitionSpec(StructType(fields), partitions)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
index 0e6b9855c7..c96a508cf1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
@@ -52,7 +52,7 @@ object CSVRelation extends Logging {
tokenizedRDD: RDD[Array[String]],
schema: StructType,
requiredColumns: Array[String],
- inputs: Array[FileStatus],
+ inputs: Seq[FileStatus],
sqlContext: SQLContext,
params: CSVOptions): RDD[Row] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
index 42c07c8a23..a5f94262ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala
@@ -103,7 +103,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
requiredColumns: Array[String],
filters: Array[Filter],
bucketSet: Option[BitSet],
- inputFiles: Array[FileStatus],
+ inputFiles: Seq[FileStatus],
broadcastedConf: Broadcast[SerializableConfiguration],
options: Map[String, String]): RDD[InternalRow] = {
// TODO: Filter before calling buildInternalScan.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
index 05b44d1a2a..3fa5ebf1bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
@@ -95,7 +95,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
requiredColumns: Array[String],
filters: Array[Filter],
bucketSet: Option[BitSet],
- inputFiles: Array[FileStatus],
+ inputFiles: Seq[FileStatus],
broadcastedConf: Broadcast[SerializableConfiguration],
options: Map[String, String]): RDD[InternalRow] = {
// TODO: Filter files for all formats before calling buildInternalScan.
@@ -115,7 +115,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
}
}
- private def createBaseRdd(sqlContext: SQLContext, inputPaths: Array[FileStatus]): RDD[String] = {
+ private def createBaseRdd(sqlContext: SQLContext, inputPaths: Seq[FileStatus]): RDD[String] = {
val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration)
val conf = job.getConfiguration
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
index f1060074d6..342034ca0f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala
@@ -274,7 +274,7 @@ private[sql] class DefaultSource extends FileFormat with DataSourceRegister with
requiredColumns: Array[String],
filters: Array[Filter],
bucketSet: Option[BitSet],
- allFiles: Array[FileStatus],
+ allFiles: Seq[FileStatus],
broadcastedConf: Broadcast[SerializableConfiguration],
options: Map[String, String]): RDD[InternalRow] = {
val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
index 2869a6a1ac..6af403dec5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala
@@ -94,7 +94,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
requiredColumns: Array[String],
filters: Array[Filter],
bucketSet: Option[BitSet],
- inputFiles: Array[FileStatus],
+ inputFiles: Seq[FileStatus],
broadcastedConf: Broadcast[SerializableConfiguration],
options: Map[String, String]): RDD[InternalRow] = {
verifySchema(dataSchema)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 59429d254e..cbdc37a2a1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -504,6 +504,11 @@ object SQLConf {
" method",
isPublic = false)
+ val FILES_MAX_PARTITION_BYTES = longConf("spark.sql.files.maxPartitionBytes",
+ defaultValue = Some(128 * 1024 * 1024), // parquet.block.size
+ doc = "The maximum number of bytes to pack into a single partition when reading files.",
+ isPublic = true)
+
val EXCHANGE_REUSE_ENABLED = booleanConf("spark.sql.exchange.reuse",
defaultValue = Some(true),
doc = "When true, the planner will try to find out duplicated exchanges and re-use them",
@@ -538,6 +543,8 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin
/** ************************ Spark SQL Params/Hints ******************* */
+ def filesMaxPartitionBytes: Long = getConf(FILES_MAX_PARTITION_BYTES)
+
def useCompression: Boolean = getConf(COMPRESS_CACHED)
def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION)
@@ -605,7 +612,7 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin
def parallelPartitionDiscoveryThreshold: Int =
getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD)
- def bucketingEnabled(): Boolean = getConf(SQLConf.BUCKETING_ENABLED)
+ def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED)
// Do not use a value larger than 4000 as the default value of this property.
// See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 601f944fb6..95ffc33011 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -30,11 +30,11 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.FileRelation
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.execution.streaming.{FileStreamSource, Sink, Source}
+import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.collection.BitSet
@@ -409,7 +409,7 @@ case class HadoopFsRelation(
def partitionSchemaOption: Option[StructType] =
if (partitionSchema.isEmpty) None else Some(partitionSchema)
- def partitionSpec: PartitionSpec = location.partitionSpec(partitionSchemaOption)
+ def partitionSpec: PartitionSpec = location.partitionSpec()
def refresh(): Unit = location.refresh()
@@ -454,19 +454,60 @@ trait FileFormat {
requiredColumns: Array[String],
filters: Array[Filter],
bucketSet: Option[BitSet],
- inputFiles: Array[FileStatus],
+ inputFiles: Seq[FileStatus],
broadcastedConf: Broadcast[SerializableConfiguration],
options: Map[String, String]): RDD[InternalRow]
+
+ /**
+ * Returns a function that can be used to read a single file in as an Iterator of InternalRow.
+ *
+ * @param partitionSchema The schema of the partition column row that will be present in each
+ * PartitionedFile. These columns should be prepended to the rows that
+ * are produced by the iterator.
+ * @param dataSchema The schema of the data that should be output for each row. This may be a
+ * subset of the columns that are present in the file if column pruning has
+ * occurred.
+ * @param filters A set of filters than can optionally be used to reduce the number of rows output
+ * @param options A set of string -> string configuration options.
+ * @return
+ */
+ def buildReader(
+ sqlContext: SQLContext,
+ partitionSchema: StructType,
+ dataSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = {
+ // TODO: Remove this default implementation when the other formats have been ported
+ // Until then we guard in [[FileSourceStrategy]] to only call this method on supported formats.
+ throw new UnsupportedOperationException(s"buildReader is not supported for $this")
+ }
}
/**
+ * A collection of data files from a partitioned relation, along with the partition values in the
+ * form of an [[InternalRow]].
+ */
+case class Partition(values: InternalRow, files: Seq[FileStatus])
+
+/**
* An interface for objects capable of enumerating the files that comprise a relation as well
* as the partitioning characteristics of those files.
*/
trait FileCatalog {
def paths: Seq[Path]
- def partitionSpec(schema: Option[StructType]): PartitionSpec
+ def partitionSpec(): PartitionSpec
+
+ /**
+ * Returns all valid files grouped into partitions when the data is partitioned. If the data is
+ * unpartitioned, this will return a single partition with not partition values.
+ *
+ * @param filters the filters used to prune which partitions are returned. These filters must
+ * only refer to partition columns and this method will only return files
+ * where these predicates are guaranteed to evaluate to `true`. Thus, these
+ * filters will not need to be evaluated again on the returned data.
+ */
+ def listFiles(filters: Seq[Expression]): Seq[Partition]
def allFiles(): Seq[FileStatus]
@@ -478,11 +519,17 @@ trait FileCatalog {
/**
* A file catalog that caches metadata gathered by scanning all the files present in `paths`
* recursively.
+ *
+ * @param parameters as set of options to control discovery
+ * @param paths a list of paths to scan
+ * @param partitionSchema an optional partition schema that will be use to provide types for the
+ * discovered partitions
*/
class HDFSFileCatalog(
val sqlContext: SQLContext,
val parameters: Map[String, String],
- val paths: Seq[Path])
+ val paths: Seq[Path],
+ val partitionSchema: Option[StructType])
extends FileCatalog with Logging {
private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration)
@@ -491,9 +538,9 @@ class HDFSFileCatalog(
var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]]
var cachedPartitionSpec: PartitionSpec = _
- def partitionSpec(schema: Option[StructType]): PartitionSpec = {
+ def partitionSpec(): PartitionSpec = {
if (cachedPartitionSpec == null) {
- cachedPartitionSpec = inferPartitioning(schema)
+ cachedPartitionSpec = inferPartitioning(partitionSchema)
}
cachedPartitionSpec
@@ -501,6 +548,53 @@ class HDFSFileCatalog(
refresh()
+ override def listFiles(filters: Seq[Expression]): Seq[Partition] = {
+ if (partitionSpec().partitionColumns.isEmpty) {
+ Partition(InternalRow.empty, allFiles()) :: Nil
+ } else {
+ prunePartitions(filters, partitionSpec()).map {
+ case PartitionDirectory(values, path) => Partition(values, getStatus(path))
+ }
+ }
+ }
+
+ protected def prunePartitions(
+ predicates: Seq[Expression],
+ partitionSpec: PartitionSpec): Seq[PartitionDirectory] = {
+ val PartitionSpec(partitionColumns, partitions) = partitionSpec
+ val partitionColumnNames = partitionColumns.map(_.name).toSet
+ val partitionPruningPredicates = predicates.filter {
+ _.references.map(_.name).toSet.subsetOf(partitionColumnNames)
+ }
+
+ if (partitionPruningPredicates.nonEmpty) {
+ val predicate =
+ partitionPruningPredicates
+ .reduceOption(expressions.And)
+ .getOrElse(Literal(true))
+
+ val boundPredicate = InterpretedPredicate.create(predicate.transform {
+ case a: AttributeReference =>
+ val index = partitionColumns.indexWhere(a.name == _.name)
+ BoundReference(index, partitionColumns(index).dataType, nullable = true)
+ })
+
+ val selected = partitions.filter {
+ case PartitionDirectory(values, _) => boundPredicate(values)
+ }
+ logInfo {
+ val total = partitions.length
+ val selectedSize = selected.length
+ val percentPruned = (1 - selectedSize.toDouble / total.toDouble) * 100
+ s"Selected $selectedSize partitions out of $total, pruned $percentPruned% partitions."
+ }
+
+ selected
+ } else {
+ partitions
+ }
+ }
+
def allFiles(): Seq[FileStatus] = leafFiles.values.toSeq
def getStatus(path: Path): Array[FileStatus] = leafDirToChildrenFiles(path)
@@ -560,7 +654,7 @@ class HDFSFileCatalog(
PartitionSpec(userProvidedSchema, spec.partitions.map { part =>
part.copy(values = castPartitionValuesToUserSchema(part.values))
})
- case None =>
+ case _ =>
PartitioningUtils.parsePartitions(
leafDirs,
PartitioningUtils.DEFAULT_PARTITION_NAME,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
new file mode 100644
index 0000000000..2f8129c5da
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
@@ -0,0 +1,345 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import java.io.{File, FilenameFilter}
+
+import org.apache.hadoop.fs.FileStatus
+import org.apache.hadoop.mapreduce.Job
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper}
+import org.apache.spark.sql.catalyst.util
+import org.apache.spark.sql.execution.{DataSourceScan, PhysicalRDD}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.util.{SerializableConfiguration, Utils}
+import org.apache.spark.util.collection.BitSet
+
+class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper {
+ import testImplicits._
+
+ test("unpartitioned table, single partition") {
+ val table =
+ createTable(
+ files = Seq(
+ "file1" -> 1,
+ "file2" -> 1,
+ "file3" -> 1,
+ "file4" -> 1,
+ "file5" -> 1,
+ "file6" -> 1,
+ "file7" -> 1,
+ "file8" -> 1,
+ "file9" -> 1,
+ "file10" -> 1))
+
+ checkScan(table.select('c1)) { partitions =>
+ // 10 one byte files should fit in a single partition with 10 files.
+ assert(partitions.size == 1, "when checking partitions")
+ assert(partitions.head.files.size == 10, "when checking partition 1")
+ // 1 byte files are too small to split so we should read the whole thing.
+ assert(partitions.head.files.head.start == 0)
+ assert(partitions.head.files.head.length == 1)
+ }
+
+ checkPartitionSchema(StructType(Nil))
+ checkDataSchema(StructType(Nil).add("c1", IntegerType))
+ }
+
+ test("unpartitioned table, multiple partitions") {
+ val table =
+ createTable(
+ files = Seq(
+ "file1" -> 5,
+ "file2" -> 5,
+ "file3" -> 5))
+
+ withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10") {
+ checkScan(table.select('c1)) { partitions =>
+ // 5 byte files should be laid out [(5, 5), (5)]
+ assert(partitions.size == 2, "when checking partitions")
+ assert(partitions(0).files.size == 2, "when checking partition 1")
+ assert(partitions(1).files.size == 1, "when checking partition 2")
+
+ // 5 byte files are too small to split so we should read the whole thing.
+ assert(partitions.head.files.head.start == 0)
+ assert(partitions.head.files.head.length == 5)
+ }
+
+ checkPartitionSchema(StructType(Nil))
+ checkDataSchema(StructType(Nil).add("c1", IntegerType))
+ }
+ }
+
+ test("Unpartitioned table, large file that gets split") {
+ val table =
+ createTable(
+ files = Seq(
+ "file1" -> 15,
+ "file2" -> 4))
+
+ withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10") {
+ checkScan(table.select('c1)) { partitions =>
+ // Files should be laid out [(0-5), (5-10, 4)]
+ assert(partitions.size == 2, "when checking partitions")
+ assert(partitions(0).files.size == 1, "when checking partition 1")
+ assert(partitions(1).files.size == 2, "when checking partition 2")
+
+ // Start by reading 10 bytes of the first file
+ assert(partitions.head.files.head.start == 0)
+ assert(partitions.head.files.head.length == 10)
+
+ // Second partition reads the remaining 5
+ assert(partitions(1).files.head.start == 10)
+ assert(partitions(1).files.head.length == 5)
+ }
+
+ checkPartitionSchema(StructType(Nil))
+ checkDataSchema(StructType(Nil).add("c1", IntegerType))
+ }
+ }
+
+ test("partitioned table") {
+ val table =
+ createTable(
+ files = Seq(
+ "p1=1/file1" -> 10,
+ "p1=2/file2" -> 10))
+
+ // Only one file should be read.
+ checkScan(table.where("p1 = 1")) { partitions =>
+ assert(partitions.size == 1, "when checking partitions")
+ assert(partitions.head.files.size == 1, "when files in partition 1")
+ }
+ // We don't need to reevaluate filters that are only on partitions.
+ checkDataFilters(Set.empty)
+
+ // Only one file should be read.
+ checkScan(table.where("p1 = 1 AND c1 = 1 AND (p1 + c1) = 1")) { partitions =>
+ assert(partitions.size == 1, "when checking partitions")
+ assert(partitions.head.files.size == 1, "when checking files in partition 1")
+ assert(partitions.head.files.head.partitionValues.getInt(0) == 1,
+ "when checking partition values")
+ }
+ // Only the filters that do not contain the partition column should be pushed down
+ checkDataFilters(Set(IsNotNull("c1"), EqualTo("c1", 1)))
+ }
+
+ test("partitioned table - after scan filters") {
+ val table =
+ createTable(
+ files = Seq(
+ "p1=1/file1" -> 10,
+ "p1=2/file2" -> 10))
+
+ val df = table.where("p1 = 1 AND (p1 + c1) = 2 AND c1 = 1")
+ // Filter on data only are advisory so we have to reevaluate.
+ assert(getPhysicalFilters(df) contains resolve(df, "c1 = 1"))
+ // Need to evalaute filters that are not pushed down.
+ assert(getPhysicalFilters(df) contains resolve(df, "(p1 + c1) = 2"))
+ // Don't reevaluate partition only filters.
+ assert(!(getPhysicalFilters(df) contains resolve(df, "p1 = 1")))
+ }
+
+ test("bucketed table") {
+ val table =
+ createTable(
+ files = Seq(
+ "p1=1/file1_0000" -> 1,
+ "p1=1/file2_0000" -> 1,
+ "p1=1/file3_0002" -> 1,
+ "p1=2/file4_0002" -> 1,
+ "p1=2/file5_0000" -> 1,
+ "p1=2/file6_0000" -> 1,
+ "p1=2/file7_0000" -> 1),
+ buckets = 3)
+
+ // No partition pruning
+ checkScan(table) { partitions =>
+ assert(partitions.size == 3)
+ assert(partitions(0).files.size == 5)
+ assert(partitions(1).files.size == 0)
+ assert(partitions(2).files.size == 2)
+ }
+
+ // With partition pruning
+ checkScan(table.where("p1=2")) { partitions =>
+ assert(partitions.size == 3)
+ assert(partitions(0).files.size == 3)
+ assert(partitions(1).files.size == 0)
+ assert(partitions(2).files.size == 1)
+ }
+ }
+
+ // Helpers for checking the arguments passed to the FileFormat.
+
+ protected val checkPartitionSchema =
+ checkArgument("partition schema", _.partitionSchema, _: StructType)
+ protected val checkDataSchema =
+ checkArgument("data schema", _.dataSchema, _: StructType)
+ protected val checkDataFilters =
+ checkArgument("data filters", _.filters.toSet, _: Set[Filter])
+
+ /** Helper for building checks on the arguments passed to the reader. */
+ protected def checkArgument[T](name: String, arg: LastArguments.type => T, expected: T): Unit = {
+ if (arg(LastArguments) != expected) {
+ fail(
+ s"""
+ |Wrong $name
+ |expected: $expected
+ |actual: ${arg(LastArguments)}
+ """.stripMargin)
+ }
+ }
+
+ /** Returns a resolved expression for `str` in the context of `df`. */
+ def resolve(df: DataFrame, str: String): Expression = {
+ df.select(expr(str)).queryExecution.analyzed.expressions.head.children.head
+ }
+
+ /** Returns a set with all the filters present in the physical plan. */
+ def getPhysicalFilters(df: DataFrame): ExpressionSet = {
+ ExpressionSet(
+ df.queryExecution.executedPlan.collect {
+ case execution.Filter(f, _) => splitConjunctivePredicates(f)
+ }.flatten)
+ }
+
+ /** Plans the query and calls the provided validation function with the planned partitioning. */
+ def checkScan(df: DataFrame)(func: Seq[FilePartition] => Unit): Unit = {
+ val fileScan = df.queryExecution.executedPlan.collect {
+ case DataSourceScan(_, scan: FileScanRDD, _, _) => scan
+ }.headOption.getOrElse {
+ fail(s"No FileScan in query\n${df.queryExecution}")
+ }
+
+ func(fileScan.filePartitions)
+ }
+
+ /**
+ * Constructs a new table given a list of file names and sizes expressed in bytes. The table
+ * is written out in a temporary directory and any nested directories in the files names
+ * are automatically created.
+ *
+ * When `buckets` is > 0 the returned [[DataFrame]] will have metadata specifying that number of
+ * buckets. However, it is the responsibility of the caller to assign files to each bucket
+ * by appending the bucket id to the file names.
+ */
+ def createTable(
+ files: Seq[(String, Int)],
+ buckets: Int = 0): DataFrame = {
+ val tempDir = Utils.createTempDir()
+ files.foreach {
+ case (name, size) =>
+ val file = new File(tempDir, name)
+ assert(file.getParentFile.exists() || file.getParentFile.mkdirs())
+ util.stringToFile(file, "*" * size)
+ }
+
+ val df = sqlContext.read
+ .format(classOf[TestFileFormat].getName)
+ .load(tempDir.getCanonicalPath)
+
+ if (buckets > 0) {
+ val bucketed = df.queryExecution.analyzed transform {
+ case l @ LogicalRelation(r: HadoopFsRelation, _, _) =>
+ l.copy(relation =
+ r.copy(bucketSpec = Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil))))
+ }
+ Dataset.newDataFrame(sqlContext, bucketed)
+ } else {
+ df
+ }
+ }
+}
+
+/** Holds the last arguments passed to [[TestFileFormat]]. */
+object LastArguments {
+ var partitionSchema: StructType = _
+ var dataSchema: StructType = _
+ var filters: Seq[Filter] = _
+ var options: Map[String, String] = _
+}
+
+/** A test [[FileFormat]] that records the arguments passed to buildReader, and returns nothing. */
+class TestFileFormat extends FileFormat {
+
+ override def toString: String = "TestFileFormat"
+
+ /**
+ * When possible, this method should return the schema of the given `files`. When the format
+ * does not support inference, or no valid files are given should return None. In these cases
+ * Spark will require that user specify the schema manually.
+ */
+ override def inferSchema(
+ sqlContext: SQLContext,
+ options: Map[String, String],
+ files: Seq[FileStatus]): Option[StructType] =
+ Some(
+ StructType(Nil)
+ .add("c1", IntegerType)
+ .add("c2", IntegerType))
+
+ /**
+ * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can
+ * be put here. For example, user defined output committer can be configured here
+ * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass.
+ */
+ override def prepareWrite(
+ sqlContext: SQLContext,
+ job: Job,
+ options: Map[String, String],
+ dataSchema: StructType): OutputWriterFactory = {
+ throw new NotImplementedError("JUST FOR TESTING")
+ }
+
+ override def buildInternalScan(
+ sqlContext: SQLContext,
+ dataSchema: StructType,
+ requiredColumns: Array[String],
+ filters: Array[Filter],
+ bucketSet: Option[BitSet],
+ inputFiles: Seq[FileStatus],
+ broadcastedConf: Broadcast[SerializableConfiguration],
+ options: Map[String, String]): RDD[InternalRow] = {
+ throw new NotImplementedError("JUST FOR TESTING")
+ }
+
+ override def buildReader(
+ sqlContext: SQLContext,
+ partitionSchema: StructType,
+ dataSchema: StructType,
+ filters: Seq[Filter],
+ options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = {
+
+ // Record the arguments so they can be checked in the test case.
+ LastArguments.partitionSchema = partitionSchema
+ LastArguments.dataSchema = dataSchema
+ LastArguments.filters = filters
+ LastArguments.options = options
+
+ (file: PartitionedFile) => { Iterator.empty }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
index 026191528e..f875b54cd6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
@@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Literal
-import org.apache.spark.sql.execution.datasources.{LogicalRelation, Partition, PartitioningUtils, PartitionSpec}
+import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionDirectory => Partition, PartitioningUtils, PartitionSpec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.HadoopFsRelation
import org.apache.spark.sql.test.SharedSQLContext