aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2016-08-03 11:19:55 -0700
committerDavies Liu <davies.liu@gmail.com>2016-08-03 11:19:55 -0700
commite6f226c5670d9f332b49ca40ff7b86b81a218d1b (patch)
tree4b0c3899e1026fcf2f48a55bbb08f5dd26c8a4a3
parentb55f34370f695de355b72c1518b5f2a45c324af0 (diff)
downloadspark-e6f226c5670d9f332b49ca40ff7b86b81a218d1b.tar.gz
spark-e6f226c5670d9f332b49ca40ff7b86b81a218d1b.tar.bz2
spark-e6f226c5670d9f332b49ca40ff7b86b81a218d1b.zip
[SPARK-16596] [SQL] Refactor DataSourceScanExec to do partition discovery at execution instead of planning time
## What changes were proposed in this pull request? Partition discovery is rather expensive, so we should do it at execution time instead of during physical planning. Right now there is not much benefit since ListingFileCatalog will read scan for all partitions at planning time anyways, but this can be optimized in the future. Also, there might be more information for partition pruning not available at planning time. This PR moves a lot of the file scan logic from planning to execution time. All file scan operations are handled by `FileSourceScanExec`, which handles both batched and non-batched file scans. This requires some duplication with `RowDataSourceScanExec`, but is probably worth it so that `FileSourceScanExec` does not need to depend on an input RDD. TODO: In another pr, move DataSourceScanExec to it's own file. ## How was this patch tested? Existing tests (it might be worth adding a test that catalog.listFiles() is delayed until execution, but this can be delayed until there is an actual benefit to doing so). Author: Eric Liang <ekl@databricks.com> Closes #14241 from ericl/refactor.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala395
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala200
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala4
9 files changed, 356 insertions, 291 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index cf34f4b30d..becf6945a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -300,7 +300,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
*/
lazy val allAttributes: AttributeSeq = children.flatMap(_.output)
- private def cleanExpression(e: Expression): Expression = e match {
+ protected def cleanExpression(e: Expression): Expression = e match {
case a: Alias =>
// As the root of the expression, Alias will always take an arbitrary exprId, we need
// to erase that for equality testing.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 491c2742ca..79d9114ff3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -17,21 +17,25 @@
package org.apache.spark.sql.execution
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.commons.lang3.StringUtils
+import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, Path}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, Encoder, Row, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
+import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
-import org.apache.spark.sql.execution.datasources.HadoopFsRelation
+import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources.BaseRelation
+import org.apache.spark.sql.sources.{BaseRelation, Filter}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.Utils
@@ -186,20 +190,13 @@ private[sql] case class RDDScanExec(
}
}
-private[sql] trait DataSourceScanExec extends LeafExecNode {
- val rdd: RDD[InternalRow]
+private[sql] trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
val relation: BaseRelation
val metastoreTableIdentifier: Option[TableIdentifier]
override val nodeName: String = {
s"Scan $relation ${metastoreTableIdentifier.map(_.unquotedString).getOrElse("")}"
}
-
- // Ignore rdd when checking results
- override def sameResult(plan: SparkPlan): Boolean = plan match {
- case other: DataSourceScanExec => relation == other.relation && metadata == other.metadata
- case _ => false
- }
}
/** Physical plan node for scanning data from a relation. */
@@ -210,7 +207,7 @@ private[sql] case class RowDataSourceScanExec(
override val outputPartitioning: Partitioning,
override val metadata: Map[String, String],
override val metastoreTableIdentifier: Option[TableIdentifier])
- extends DataSourceScanExec with CodegenSupport {
+ extends DataSourceScanExec {
private[sql] override lazy val metrics =
Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -275,27 +272,125 @@ private[sql] case class RowDataSourceScanExec(
|}
""".stripMargin
}
+
+ // Ignore rdd when checking results
+ override def sameResult(plan: SparkPlan): Boolean = plan match {
+ case other: RowDataSourceScanExec => relation == other.relation && metadata == other.metadata
+ case _ => false
+ }
}
-/** Physical plan node for scanning data from a batched relation. */
-private[sql] case class BatchedDataSourceScanExec(
+/**
+ * Physical plan node for scanning data from HadoopFsRelations.
+ *
+ * @param relation The file-based relation to scan.
+ * @param output Output attributes of the scan.
+ * @param outputSchema Output schema of the scan.
+ * @param partitionFilters Predicates to use for partition pruning.
+ * @param dataFilters Data source filters to use for filtering data within partitions.
+ * @param metastoreTableIdentifier
+ */
+private[sql] case class FileSourceScanExec(
+ @transient relation: HadoopFsRelation,
output: Seq[Attribute],
- rdd: RDD[InternalRow],
- @transient relation: BaseRelation,
- override val outputPartitioning: Partitioning,
- override val metadata: Map[String, String],
+ outputSchema: StructType,
+ partitionFilters: Seq[Expression],
+ dataFilters: Seq[Filter],
override val metastoreTableIdentifier: Option[TableIdentifier])
- extends DataSourceScanExec with CodegenSupport {
+ extends DataSourceScanExec {
+
+ val supportsBatch = relation.fileFormat.supportBatch(
+ relation.sparkSession, StructType.fromAttributes(output))
+
+ val needsUnsafeRowConversion = if (relation.fileFormat.isInstanceOf[ParquetSource]) {
+ SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled
+ } else {
+ false
+ }
+
+ override val outputPartitioning: Partitioning = {
+ val bucketSpec = if (relation.sparkSession.sessionState.conf.bucketingEnabled) {
+ relation.bucketSpec
+ } else {
+ None
+ }
+ bucketSpec.map { spec =>
+ val numBuckets = spec.numBuckets
+ val bucketColumns = spec.bucketColumnNames.flatMap { n =>
+ output.find(_.name == n)
+ }
+ if (bucketColumns.size == spec.bucketColumnNames.size) {
+ HashPartitioning(bucketColumns, numBuckets)
+ } else {
+ UnknownPartitioning(0)
+ }
+ }.getOrElse {
+ UnknownPartitioning(0)
+ }
+ }
+
+ // These metadata values make scan plans uniquely identifiable for equality checking.
+ override val metadata: Map[String, String] = Map(
+ "Format" -> relation.fileFormat.toString,
+ "ReadSchema" -> outputSchema.catalogString,
+ "Batched" -> supportsBatch.toString,
+ "PartitionFilters" -> partitionFilters.mkString("[", ", ", "]"),
+ DataSourceScanExec.PUSHED_FILTERS -> dataFilters.mkString("[", ", ", "]"),
+ DataSourceScanExec.INPUT_PATHS -> relation.location.paths.mkString(", "))
+
+ private lazy val inputRDD: RDD[InternalRow] = {
+ val selectedPartitions = relation.location.listFiles(partitionFilters)
+
+ val readFile: (PartitionedFile) => Iterator[InternalRow] =
+ relation.fileFormat.buildReaderWithPartitionValues(
+ sparkSession = relation.sparkSession,
+ dataSchema = relation.dataSchema,
+ partitionSchema = relation.partitionSchema,
+ requiredSchema = outputSchema,
+ filters = dataFilters,
+ options = relation.options,
+ hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options))
+
+ relation.bucketSpec match {
+ case Some(bucketing) if relation.sparkSession.sessionState.conf.bucketingEnabled =>
+ createBucketedReadRDD(bucketing, readFile, selectedPartitions, relation)
+ case _ =>
+ createNonBucketedReadRDD(readFile, selectedPartitions, relation)
+ }
+ }
+
+ override def inputRDDs(): Seq[RDD[InternalRow]] = {
+ inputRDD :: Nil
+ }
private[sql] override lazy val metrics =
Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
protected override def doExecute(): RDD[InternalRow] = {
- // in the case of fallback, this batched scan should never fail because of:
- // 1) only primitive types are supported
- // 2) the number of columns should be smaller than spark.sql.codegen.maxFields
- WholeStageCodegenExec(this).execute()
+ if (supportsBatch) {
+ // in the case of fallback, this batched scan should never fail because of:
+ // 1) only primitive types are supported
+ // 2) the number of columns should be smaller than spark.sql.codegen.maxFields
+ WholeStageCodegenExec(this).execute()
+ } else {
+ val unsafeRows = {
+ val scan = inputRDD
+ if (needsUnsafeRowConversion) {
+ scan.mapPartitionsInternal { iter =>
+ val proj = UnsafeProjection.create(schema)
+ iter.map(proj)
+ }
+ } else {
+ scan
+ }
+ }
+ val numOutputRows = longMetric("numOutputRows")
+ unsafeRows.map { r =>
+ numOutputRows += 1
+ r
+ }
+ }
}
override def simpleString: String = {
@@ -303,34 +398,38 @@ private[sql] case class BatchedDataSourceScanExec(
key + ": " + StringUtils.abbreviate(value, 100)
}
val metadataStr = Utils.truncatedString(metadataEntries, " ", ", ", "")
- s"Batched$nodeName${Utils.truncatedString(output, "[", ",", "]")}$metadataStr"
+ s"File$nodeName${Utils.truncatedString(output, "[", ",", "]")}$metadataStr"
}
- override def inputRDDs(): Seq[RDD[InternalRow]] = {
- rdd :: Nil
- }
-
- private def genCodeColumnVector(ctx: CodegenContext, columnVar: String, ordinal: String,
- dataType: DataType, nullable: Boolean): ExprCode = {
- val javaType = ctx.javaType(dataType)
- val value = ctx.getValue(columnVar, dataType, ordinal)
- val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" }
- val valueVar = ctx.freshName("value")
- val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
- val code = s"${ctx.registerComment(str)}\n" + (if (nullable) {
- s"""
- boolean ${isNullVar} = ${columnVar}.isNullAt($ordinal);
- $javaType ${valueVar} = ${isNullVar} ? ${ctx.defaultValue(dataType)} : ($value);
- """
- } else {
- s"$javaType ${valueVar} = $value;"
- }).trim
- ExprCode(code, isNullVar, valueVar)
+ override protected def doProduce(ctx: CodegenContext): String = {
+ if (supportsBatch) {
+ return doProduceVectorized(ctx)
+ }
+ val numOutputRows = metricTerm(ctx, "numOutputRows")
+ // PhysicalRDD always just has one input
+ val input = ctx.freshName("input")
+ ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
+ val exprRows = output.zipWithIndex.map{ case (a, i) =>
+ new BoundReference(i, a.dataType, a.nullable)
+ }
+ val row = ctx.freshName("row")
+ ctx.INPUT_ROW = row
+ ctx.currentVars = null
+ val columnsRowInput = exprRows.map(_.genCode(ctx))
+ val inputRow = if (needsUnsafeRowConversion) null else row
+ s"""
+ |while ($input.hasNext()) {
+ | InternalRow $row = (InternalRow) $input.next();
+ | $numOutputRows.add(1);
+ | ${consume(ctx, columnsRowInput, inputRow).trim}
+ | if (shouldStop()) return;
+ |}
+ """.stripMargin
}
// Support codegen so that we can avoid the UnsafeRow conversion in all cases. Codegen
// never requires UnsafeRow as input.
- override protected def doProduce(ctx: CodegenContext): String = {
+ private def doProduceVectorized(ctx: CodegenContext): String = {
val input = ctx.freshName("input")
// PhysicalRDD always just has one input
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
@@ -391,48 +490,190 @@ private[sql] case class BatchedDataSourceScanExec(
|$scanTimeTotalNs = 0;
""".stripMargin
}
-}
-private[sql] object DataSourceScanExec {
- // Metadata keys
- val INPUT_PATHS = "InputPaths"
- val PUSHED_FILTERS = "PushedFilters"
+ private def genCodeColumnVector(ctx: CodegenContext, columnVar: String, ordinal: String,
+ dataType: DataType, nullable: Boolean): ExprCode = {
+ val javaType = ctx.javaType(dataType)
+ val value = ctx.getValue(columnVar, dataType, ordinal)
+ val isNullVar = if (nullable) { ctx.freshName("isNull") } else { "false" }
+ val valueVar = ctx.freshName("value")
+ val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
+ val code = s"${ctx.registerComment(str)}\n" + (if (nullable) {
+ s"""
+ boolean ${isNullVar} = ${columnVar}.isNullAt($ordinal);
+ $javaType ${valueVar} = ${isNullVar} ? ${ctx.defaultValue(dataType)} : ($value);
+ """
+ } else {
+ s"$javaType ${valueVar} = $value;"
+ }).trim
+ ExprCode(code, isNullVar, valueVar)
+ }
- def create(
- output: Seq[Attribute],
- rdd: RDD[InternalRow],
- relation: BaseRelation,
- metadata: Map[String, String] = Map.empty,
- metastoreTableIdentifier: Option[TableIdentifier] = None): DataSourceScanExec = {
- val outputPartitioning = {
- val bucketSpec = relation match {
- // TODO: this should be closer to bucket planning.
- case r: HadoopFsRelation
- if r.sparkSession.sessionState.conf.bucketingEnabled => r.bucketSpec
- case _ => None
+ /**
+ * Create an RDD for bucketed reads.
+ * The non-bucketed variant of this function is [[createNonBucketedReadRDD]].
+ *
+ * The algorithm is pretty simple: each RDD partition being returned should include all the files
+ * with the same bucket id from all the given Hive partitions.
+ *
+ * @param bucketSpec the bucketing spec.
+ * @param readFile a function to read each (part of a) file.
+ * @param selectedPartitions Hive-style partition that are part of the read.
+ * @param fsRelation [[HadoopFsRelation]] associated with the read.
+ */
+ private def createBucketedReadRDD(
+ bucketSpec: BucketSpec,
+ readFile: (PartitionedFile) => Iterator[InternalRow],
+ selectedPartitions: Seq[Partition],
+ fsRelation: HadoopFsRelation): RDD[InternalRow] = {
+ logInfo(s"Planning with ${bucketSpec.numBuckets} buckets")
+ val bucketed =
+ selectedPartitions.flatMap { p =>
+ p.files.map { f =>
+ val hosts = getBlockHosts(getBlockLocations(f), 0, f.getLen)
+ PartitionedFile(p.values, f.getPath.toUri.toString, 0, f.getLen, hosts)
+ }
+ }.groupBy { f =>
+ BucketingUtils
+ .getBucketId(new Path(f.filePath).getName)
+ .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}"))
}
- bucketSpec.map { spec =>
- val numBuckets = spec.numBuckets
- val bucketColumns = spec.bucketColumnNames.flatMap { n => output.find(_.name == n) }
- if (bucketColumns.size == spec.bucketColumnNames.size) {
- HashPartitioning(bucketColumns, numBuckets)
+ val filePartitions = Seq.tabulate(bucketSpec.numBuckets) { bucketId =>
+ FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil))
+ }
+
+ new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions)
+ }
+
+ /**
+ * Create an RDD for non-bucketed reads.
+ * The bucketed variant of this function is [[createBucketedReadRDD]].
+ *
+ * @param readFile a function to read each (part of a) file.
+ * @param selectedPartitions Hive-style partition that are part of the read.
+ * @param fsRelation [[HadoopFsRelation]] associated with the read.
+ */
+ private def createNonBucketedReadRDD(
+ readFile: (PartitionedFile) => Iterator[InternalRow],
+ selectedPartitions: Seq[Partition],
+ fsRelation: HadoopFsRelation): RDD[InternalRow] = {
+ val defaultMaxSplitBytes =
+ fsRelation.sparkSession.sessionState.conf.filesMaxPartitionBytes
+ val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes
+ val defaultParallelism = fsRelation.sparkSession.sparkContext.defaultParallelism
+ val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum
+ val bytesPerCore = totalBytes / defaultParallelism
+
+ val maxSplitBytes = Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore))
+ logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " +
+ s"open cost is considered as scanning $openCostInBytes bytes.")
+
+ val splitFiles = selectedPartitions.flatMap { partition =>
+ partition.files.flatMap { file =>
+ val blockLocations = getBlockLocations(file)
+ if (fsRelation.fileFormat.isSplitable(
+ fsRelation.sparkSession, fsRelation.options, file.getPath)) {
+ (0L until file.getLen by maxSplitBytes).map { offset =>
+ val remaining = file.getLen - offset
+ val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining
+ val hosts = getBlockHosts(blockLocations, offset, size)
+ PartitionedFile(
+ partition.values, file.getPath.toUri.toString, offset, size, hosts)
+ }
} else {
- UnknownPartitioning(0)
+ val hosts = getBlockHosts(blockLocations, 0, file.getLen)
+ Seq(PartitionedFile(
+ partition.values, file.getPath.toUri.toString, 0, file.getLen, hosts))
}
- }.getOrElse {
- UnknownPartitioning(0)
}
+ }.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse)
+
+ val partitions = new ArrayBuffer[FilePartition]
+ val currentFiles = new ArrayBuffer[PartitionedFile]
+ var currentSize = 0L
+
+ /** Close the current partition and move to the next. */
+ def closePartition(): Unit = {
+ if (currentFiles.nonEmpty) {
+ val newPartition =
+ FilePartition(
+ partitions.size,
+ currentFiles.toArray.toSeq) // Copy to a new Array.
+ partitions.append(newPartition)
+ }
+ currentFiles.clear()
+ currentSize = 0
}
- relation match {
- case r: HadoopFsRelation
- if r.fileFormat.supportBatch(r.sparkSession, StructType.fromAttributes(output)) =>
- BatchedDataSourceScanExec(
- output, rdd, relation, outputPartitioning, metadata, metastoreTableIdentifier)
- case _ =>
- RowDataSourceScanExec(
- output, rdd, relation, outputPartitioning, metadata, metastoreTableIdentifier)
+ // Assign files to partitions using "First Fit Decreasing" (FFD)
+ // TODO: consider adding a slop factor here?
+ splitFiles.foreach { file =>
+ if (currentSize + file.length > maxSplitBytes) {
+ closePartition()
+ }
+ // Add the given file to the current partition.
+ currentSize += file.length + openCostInBytes
+ currentFiles.append(file)
+ }
+ closePartition()
+
+ new FileScanRDD(fsRelation.sparkSession, readFile, partitions)
+ }
+
+ private def getBlockLocations(file: FileStatus): Array[BlockLocation] = file match {
+ case f: LocatedFileStatus => f.getBlockLocations
+ case f => Array.empty[BlockLocation]
+ }
+
+ // Given locations of all blocks of a single file, `blockLocations`, and an `(offset, length)`
+ // pair that represents a segment of the same file, find out the block that contains the largest
+ // fraction the segment, and returns location hosts of that block. If no such block can be found,
+ // returns an empty array.
+ private def getBlockHosts(
+ blockLocations: Array[BlockLocation], offset: Long, length: Long): Array[String] = {
+ val candidates = blockLocations.map {
+ // The fragment starts from a position within this block
+ case b if b.getOffset <= offset && offset < b.getOffset + b.getLength =>
+ b.getHosts -> (b.getOffset + b.getLength - offset).min(length)
+
+ // The fragment ends at a position within this block
+ case b if offset <= b.getOffset && offset + length < b.getLength =>
+ b.getHosts -> (offset + length - b.getOffset).min(length)
+
+ // The fragment fully contains this block
+ case b if offset <= b.getOffset && b.getOffset + b.getLength <= offset + length =>
+ b.getHosts -> b.getLength
+
+ // The fragment doesn't intersect with this block
+ case b =>
+ b.getHosts -> 0L
+ }.filter { case (hosts, size) =>
+ size > 0L
+ }
+
+ if (candidates.isEmpty) {
+ Array.empty[String]
+ } else {
+ val (hosts, _) = candidates.maxBy { case (_, size) => size }
+ hosts
}
}
+
+ override def sameResult(plan: SparkPlan): Boolean = plan match {
+ case other: FileSourceScanExec =>
+ val thisPredicates = partitionFilters.map(cleanExpression)
+ val otherPredicates = other.partitionFilters.map(cleanExpression)
+ val result = relation == other.relation && metadata == other.metadata &&
+ thisPredicates.length == otherPredicates.length &&
+ thisPredicates.zip(otherPredicates).forall(p => p._1.semanticEquals(p._2))
+ result
+ case _ => false
+ }
+}
+
+private[sql] object DataSourceScanExec {
+ // Metadata keys
+ val INPUT_PATHS = "InputPaths"
+ val PUSHED_FILTERS = "PushedFilters"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index ca03b26e85..52b1677d7c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -31,10 +31,10 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
-import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.DataSourceScanExec.PUSHED_FILTERS
-import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.{CreateDataSourceTableUtils, DDLUtils, ExecutedCommandExec}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@@ -268,8 +268,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
(a, _) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray))) :: Nil
case l @ LogicalRelation(baseRelation: TableScan, _, _) =>
- execution.DataSourceScanExec.create(
- l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil
+ RowDataSourceScanExec(
+ l.output,
+ toCatalystRDD(l, baseRelation.buildScan()),
+ baseRelation,
+ UnknownPartitioning(0),
+ Map.empty,
+ None) :: Nil
case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _),
part, query, overwrite, false) if part.isEmpty =>
@@ -375,20 +380,20 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
// Don't request columns that are only referenced by pushed filters.
.filterNot(handledSet.contains)
- val scan = execution.DataSourceScanExec.create(
+ val scan = RowDataSourceScanExec(
projects.map(_.toAttribute),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
- relation.relation, metadata, relation.metastoreTableIdentifier)
+ relation.relation, UnknownPartitioning(0), metadata, relation.metastoreTableIdentifier)
filterCondition.map(execution.FilterExec(_, scan)).getOrElse(scan)
} else {
// Don't request columns that are only referenced by pushed filters.
val requestedColumns =
(projectSet ++ filterSet -- handledSet).map(relation.attributeMap).toSeq
- val scan = execution.DataSourceScanExec.create(
+ val scan = RowDataSourceScanExec(
requestedColumns,
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
- relation.relation, metadata, relation.metastoreTableIdentifier)
+ relation.relation, UnknownPartitioning(0), metadata, relation.metastoreTableIdentifier)
execution.ProjectExec(
projects, filterCondition.map(execution.FilterExec(_, scan)).getOrElse(scan))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index 67491302a9..3ac09d99c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -17,10 +17,6 @@
package org.apache.spark.sql.execution.datasources
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, Path}
-
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
@@ -29,8 +25,8 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.DataSourceScanExec
-import org.apache.spark.sql.execution.DataSourceScanExec.{INPUT_PATHS, PUSHED_FILTERS}
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
+import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.SparkPlan
/**
@@ -96,8 +92,6 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
val afterScanFilters = filterSet -- partitionKeyFilters
logInfo(s"Post-Scan Filters: ${afterScanFilters.mkString(",")}")
- val selectedPartitions = fsRelation.location.listFiles(partitionKeyFilters.toSeq)
-
val filterAttributes = AttributeSet(afterScanFilters)
val requiredExpressions: Seq[NamedExpression] = filterAttributes.toSeq ++ projects
val requiredAttributes = AttributeSet(requiredExpressions)
@@ -106,44 +100,21 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
dataColumns
.filter(requiredAttributes.contains)
.filterNot(partitionColumns.contains)
- val prunedDataSchema = readDataColumns.toStructType
- logInfo(s"Pruned Data Schema: ${prunedDataSchema.simpleString(5)}")
+ val outputSchema = readDataColumns.toStructType
+ logInfo(s"Output Data Schema: ${outputSchema.simpleString(5)}")
val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter)
logInfo(s"Pushed Filters: ${pushedDownFilters.mkString(",")}")
- val readFile: (PartitionedFile) => Iterator[InternalRow] =
- fsRelation.fileFormat.buildReaderWithPartitionValues(
- sparkSession = fsRelation.sparkSession,
- dataSchema = fsRelation.dataSchema,
- partitionSchema = fsRelation.partitionSchema,
- requiredSchema = prunedDataSchema,
- filters = pushedDownFilters,
- options = fsRelation.options,
- hadoopConf =
- fsRelation.sparkSession.sessionState.newHadoopConfWithOptions(fsRelation.options))
-
- val rdd = fsRelation.bucketSpec match {
- case Some(bucketing) if fsRelation.sparkSession.sessionState.conf.bucketingEnabled =>
- createBucketedReadRDD(bucketing, readFile, selectedPartitions, fsRelation)
- case _ =>
- createNonBucketedReadRDD(readFile, selectedPartitions, fsRelation)
- }
-
- // These metadata values make scan plans uniquely identifiable for equality checking.
- val meta = Map(
- "PartitionFilters" -> partitionKeyFilters.mkString("[", ", ", "]"),
- "Format" -> fsRelation.fileFormat.toString,
- "ReadSchema" -> prunedDataSchema.simpleString,
- PUSHED_FILTERS -> pushedDownFilters.mkString("[", ", ", "]"),
- INPUT_PATHS -> fsRelation.location.paths.mkString(", "))
+ val outputAttributes = readDataColumns ++ partitionColumns
val scan =
- DataSourceScanExec.create(
- readDataColumns ++ partitionColumns,
- rdd,
+ new FileSourceScanExec(
fsRelation,
- meta,
+ outputAttributes,
+ outputSchema,
+ partitionKeyFilters.toSeq,
+ pushedDownFilters,
table)
val afterScanFilter = afterScanFilters.toSeq.reduceOption(expressions.And)
@@ -158,155 +129,4 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
case _ => Nil
}
-
- /**
- * Create an RDD for bucketed reads.
- * The non-bucketed variant of this function is [[createNonBucketedReadRDD]].
- *
- * The algorithm is pretty simple: each RDD partition being returned should include all the files
- * with the same bucket id from all the given Hive partitions.
- *
- * @param bucketSpec the bucketing spec.
- * @param readFile a function to read each (part of a) file.
- * @param selectedPartitions Hive-style partition that are part of the read.
- * @param fsRelation [[HadoopFsRelation]] associated with the read.
- */
- private def createBucketedReadRDD(
- bucketSpec: BucketSpec,
- readFile: (PartitionedFile) => Iterator[InternalRow],
- selectedPartitions: Seq[Partition],
- fsRelation: HadoopFsRelation): RDD[InternalRow] = {
- logInfo(s"Planning with ${bucketSpec.numBuckets} buckets")
- val bucketed =
- selectedPartitions.flatMap { p =>
- p.files.map { f =>
- val hosts = getBlockHosts(getBlockLocations(f), 0, f.getLen)
- PartitionedFile(p.values, f.getPath.toUri.toString, 0, f.getLen, hosts)
- }
- }.groupBy { f =>
- BucketingUtils
- .getBucketId(new Path(f.filePath).getName)
- .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}"))
- }
-
- val filePartitions = Seq.tabulate(bucketSpec.numBuckets) { bucketId =>
- FilePartition(bucketId, bucketed.getOrElse(bucketId, Nil))
- }
-
- new FileScanRDD(fsRelation.sparkSession, readFile, filePartitions)
- }
-
- /**
- * Create an RDD for non-bucketed reads.
- * The bucketed variant of this function is [[createBucketedReadRDD]].
- *
- * @param readFile a function to read each (part of a) file.
- * @param selectedPartitions Hive-style partition that are part of the read.
- * @param fsRelation [[HadoopFsRelation]] associated with the read.
- */
- private def createNonBucketedReadRDD(
- readFile: (PartitionedFile) => Iterator[InternalRow],
- selectedPartitions: Seq[Partition],
- fsRelation: HadoopFsRelation): RDD[InternalRow] = {
- val defaultMaxSplitBytes =
- fsRelation.sparkSession.sessionState.conf.filesMaxPartitionBytes
- val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes
- val defaultParallelism = fsRelation.sparkSession.sparkContext.defaultParallelism
- val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum
- val bytesPerCore = totalBytes / defaultParallelism
-
- val maxSplitBytes = Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore))
- logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " +
- s"open cost is considered as scanning $openCostInBytes bytes.")
-
- val splitFiles = selectedPartitions.flatMap { partition =>
- partition.files.flatMap { file =>
- val blockLocations = getBlockLocations(file)
- if (fsRelation.fileFormat.isSplitable(
- fsRelation.sparkSession, fsRelation.options, file.getPath)) {
- (0L until file.getLen by maxSplitBytes).map { offset =>
- val remaining = file.getLen - offset
- val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining
- val hosts = getBlockHosts(blockLocations, offset, size)
- PartitionedFile(
- partition.values, file.getPath.toUri.toString, offset, size, hosts)
- }
- } else {
- val hosts = getBlockHosts(blockLocations, 0, file.getLen)
- Seq(PartitionedFile(
- partition.values, file.getPath.toUri.toString, 0, file.getLen, hosts))
- }
- }
- }.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse)
-
- val partitions = new ArrayBuffer[FilePartition]
- val currentFiles = new ArrayBuffer[PartitionedFile]
- var currentSize = 0L
-
- /** Close the current partition and move to the next. */
- def closePartition(): Unit = {
- if (currentFiles.nonEmpty) {
- val newPartition =
- FilePartition(
- partitions.size,
- currentFiles.toArray.toSeq) // Copy to a new Array.
- partitions.append(newPartition)
- }
- currentFiles.clear()
- currentSize = 0
- }
-
- // Assign files to partitions using "First Fit Decreasing" (FFD)
- // TODO: consider adding a slop factor here?
- splitFiles.foreach { file =>
- if (currentSize + file.length > maxSplitBytes) {
- closePartition()
- }
- // Add the given file to the current partition.
- currentSize += file.length + openCostInBytes
- currentFiles.append(file)
- }
- closePartition()
-
- new FileScanRDD(fsRelation.sparkSession, readFile, partitions)
- }
-
- private def getBlockLocations(file: FileStatus): Array[BlockLocation] = file match {
- case f: LocatedFileStatus => f.getBlockLocations
- case f => Array.empty[BlockLocation]
- }
-
- // Given locations of all blocks of a single file, `blockLocations`, and an `(offset, length)`
- // pair that represents a segment of the same file, find out the block that contains the largest
- // fraction the segment, and returns location hosts of that block. If no such block can be found,
- // returns an empty array.
- private def getBlockHosts(
- blockLocations: Array[BlockLocation], offset: Long, length: Long): Array[String] = {
- val candidates = blockLocations.map {
- // The fragment starts from a position within this block
- case b if b.getOffset <= offset && offset < b.getOffset + b.getLength =>
- b.getHosts -> (b.getOffset + b.getLength - offset).min(length)
-
- // The fragment ends at a position within this block
- case b if offset <= b.getOffset && offset + length < b.getLength =>
- b.getHosts -> (offset + length - b.getOffset).min(length)
-
- // The fragment fully contains this block
- case b if offset <= b.getOffset && b.getOffset + b.getLength <= offset + length =>
- b.getHosts -> b.getLength
-
- // The fragment doesn't intersect with this block
- case b =>
- b.getHosts -> 0L
- }.filter { case (hosts, size) =>
- size > 0L
- }
-
- if (candidates.isEmpty) {
- Array.empty[String]
- } else {
- val (hosts, _) = candidates.maxBy { case (_, size) => size }
- hosts
- }
- }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
index 18246500f7..09fd750180 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
@@ -24,7 +24,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path, RawLocalFileSystem}
import org.apache.hadoop.mapreduce.Job
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
@@ -518,8 +518,8 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
def getFileScanRDD(df: DataFrame): FileScanRDD = {
df.queryExecution.executedPlan.collect {
- case scan: DataSourceScanExec if scan.rdd.isInstanceOf[FileScanRDD] =>
- scan.rdd.asInstanceOf[FileScanRDD]
+ case scan: DataSourceScanExec if scan.inputRDDs().head.isInstanceOf[FileScanRDD] =>
+ scan.inputRDDs().head.asInstanceOf[FileScanRDD]
}.headOption.getOrElse {
fail(s"No FileScan in query\n${df.queryExecution}")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index 7e83bcbb6e..9dd8d9f804 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -25,7 +25,7 @@ import org.apache.parquet.hadoop.ParquetOutputFormat
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
-import org.apache.spark.sql.execution.BatchedDataSourceScanExec
+import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT, SingleElement}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
@@ -624,16 +624,15 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext
// donot return batch, because whole stage codegen is disabled for wide table (>200 columns)
val df2 = spark.read.parquet(path)
- assert(df2.queryExecution.sparkPlan.find(_.isInstanceOf[BatchedDataSourceScanExec]).isEmpty,
- "Should not return batch")
+ val fileScan2 = df2.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
+ assert(!fileScan2.asInstanceOf[FileSourceScanExec].supportsBatch)
checkAnswer(df2, df)
// return batch
val columns = Seq.tabulate(9) {i => s"c$i"}
val df3 = df2.selectExpr(columns : _*)
- assert(
- df3.queryExecution.sparkPlan.find(_.isInstanceOf[BatchedDataSourceScanExec]).isDefined,
- "Should return batch")
+ val fileScan3 = df3.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
+ assert(fileScan3.asInstanceOf[FileSourceScanExec].supportsBatch)
checkAnswer(df3, df.selectExpr(columns : _*))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
index 9d0a2b3d5b..19c89f5c41 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
@@ -198,8 +198,8 @@ class FileStreamSinkSuite extends StreamTest {
/** Check some condition on the partitions of the FileScanRDD generated by a DF */
def checkFileScanPartitions(df: DataFrame)(func: Seq[FilePartition] => Unit): Unit = {
val getFileScanRDD = df.queryExecution.executedPlan.collect {
- case scan: DataSourceScanExec if scan.rdd.isInstanceOf[FileScanRDD] =>
- scan.rdd.asInstanceOf[FileScanRDD]
+ case scan: DataSourceScanExec if scan.inputRDDs().head.isInstanceOf[FileScanRDD] =>
+ scan.inputRDDs().head.asInstanceOf[FileScanRDD]
}.headOption.getOrElse {
fail(s"No FileScan in query\n${df.queryExecution}")
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 8d161a3c46..ca2ec9f6a5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -358,11 +358,11 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
df1.write.parquet(tableDir.getAbsolutePath)
val agged = spark.table("bucketed_table").groupBy("i").count()
- val error = intercept[RuntimeException] {
+ val error = intercept[Exception] {
agged.count()
}
- assert(error.toString contains "Invalid bucket file")
+ assert(error.getCause().toString contains "Invalid bucket file")
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala
index 047b08c4cc..27bb9676e9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala
@@ -862,8 +862,8 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
.load(path)
val Some(fileScanRDD) = df2.queryExecution.executedPlan.collectFirst {
- case scan: DataSourceScanExec if scan.rdd.isInstanceOf[FileScanRDD] =>
- scan.rdd.asInstanceOf[FileScanRDD]
+ case scan: DataSourceScanExec if scan.inputRDDs().head.isInstanceOf[FileScanRDD] =>
+ scan.inputRDDs().head.asInstanceOf[FileScanRDD]
}
val partitions = fileScanRDD.partitions