aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala63
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala10
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala94
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala26
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala61
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala161
13 files changed, 411 insertions, 73 deletions
diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
index fb8020585c..afd2250c93 100644
--- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
@@ -82,10 +82,25 @@ abstract class FileCommitProtocol {
*
* The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest
* are left to the commit protocol implementation to decide.
+ *
+ * Important: it is the caller's responsibility to add uniquely identifying content to "ext"
+ * if a task is going to write out multiple files to the same dir. The file commit protocol only
+ * guarantees that files written by different tasks will not conflict.
*/
def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String
/**
+ * Similar to newTaskTempFile(), but allows files to committed to an absolute output location.
+ * Depending on the implementation, there may be weaker guarantees around adding files this way.
+ *
+ * Important: it is the caller's responsibility to add uniquely identifying content to "ext"
+ * if a task is going to write out multiple files to the same dir. The file commit protocol only
+ * guarantees that files written by different tasks will not conflict.
+ */
+ def newTaskTempFileAbsPath(
+ taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String
+
+ /**
* Commits a task after the writes succeed. Must be called on the executors when running tasks.
*/
def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage
diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
index 6b0bcb8f90..b2d9b8d2a0 100644
--- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
@@ -17,7 +17,9 @@
package org.apache.spark.internal.io
-import java.util.Date
+import java.util.{Date, UUID}
+
+import scala.collection.mutable
import org.apache.hadoop.conf.Configurable
import org.apache.hadoop.fs.Path
@@ -42,6 +44,19 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
/** OutputCommitter from Hadoop is not serializable so marking it transient. */
@transient private var committer: OutputCommitter = _
+ /**
+ * Tracks files staged by this task for absolute output paths. These outputs are not managed by
+ * the Hadoop OutputCommitter, so we must move these to their final locations on job commit.
+ *
+ * The mapping is from the temp output path to the final desired output path of the file.
+ */
+ @transient private var addedAbsPathFiles: mutable.Map[String, String] = null
+
+ /**
+ * The staging directory for all files committed with absolute output paths.
+ */
+ private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId)
+
protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = {
val format = context.getOutputFormatClass.newInstance()
// If OutputFormat is Configurable, we should set conf to it.
@@ -54,11 +69,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
override def newTaskTempFile(
taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = {
- // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet
- // Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
- // the file name is fine and won't overflow.
- val split = taskContext.getTaskAttemptID.getTaskID.getId
- val filename = f"part-$split%05d-$jobId$ext"
+ val filename = getFilename(taskContext, ext)
val stagingDir: String = committer match {
// For FileOutputCommitter it has its own staging path called "work path".
@@ -73,6 +84,28 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
}
}
+ override def newTaskTempFileAbsPath(
+ taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = {
+ val filename = getFilename(taskContext, ext)
+ val absOutputPath = new Path(absoluteDir, filename).toString
+
+ // Include a UUID here to prevent file collisions for one task writing to different dirs.
+ // In principle we could include hash(absoluteDir) instead but this is simpler.
+ val tmpOutputPath = new Path(
+ absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString
+
+ addedAbsPathFiles(tmpOutputPath) = absOutputPath
+ tmpOutputPath
+ }
+
+ private def getFilename(taskContext: TaskAttemptContext, ext: String): String = {
+ // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet
+ // Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
+ // the file name is fine and won't overflow.
+ val split = taskContext.getTaskAttemptID.getTaskID.getId
+ f"part-$split%05d-$jobId$ext"
+ }
+
override def setupJob(jobContext: JobContext): Unit = {
// Setup IDs
val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0)
@@ -93,26 +126,42 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = {
committer.commitJob(jobContext)
+ val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]])
+ .foldLeft(Map[String, String]())(_ ++ _)
+ logDebug(s"Committing files staged for absolute locations $filesToMove")
+ val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration)
+ for ((src, dst) <- filesToMove) {
+ fs.rename(new Path(src), new Path(dst))
+ }
+ fs.delete(absPathStagingDir, true)
}
override def abortJob(jobContext: JobContext): Unit = {
committer.abortJob(jobContext, JobStatus.State.FAILED)
+ val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration)
+ fs.delete(absPathStagingDir, true)
}
override def setupTask(taskContext: TaskAttemptContext): Unit = {
committer = setupCommitter(taskContext)
committer.setupTask(taskContext)
+ addedAbsPathFiles = mutable.Map[String, String]()
}
override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = {
val attemptId = taskContext.getTaskAttemptID
SparkHadoopMapRedUtil.commitTask(
committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId)
- EmptyTaskCommitMessage
+ new TaskCommitMessage(addedAbsPathFiles.toMap)
}
override def abortTask(taskContext: TaskAttemptContext): Unit = {
committer.abortTask(taskContext)
+ // best effort cleanup of other staged files
+ for ((src, _) <- addedAbsPathFiles) {
+ val tmp = new Path(src)
+ tmp.getFileSystem(taskContext.getConfiguration).delete(tmp, false)
+ }
}
/** Whether we are using a direct output committer */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 2c4db0d2c3..3fa7bf1cdb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -172,24 +172,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
val tableIdent = visitTableIdentifier(ctx.tableIdentifier)
val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
- val dynamicPartitionKeys = partitionKeys.filter(_._2.isEmpty)
+ val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty)
if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) {
throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " +
"partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx)
}
val overwrite = ctx.OVERWRITE != null
- val overwritePartition =
- if (overwrite && partitionKeys.nonEmpty && dynamicPartitionKeys.isEmpty) {
- Some(partitionKeys.map(t => (t._1, t._2.get)))
- } else {
- None
- }
+ val staticPartitionKeys: Map[String, String] =
+ partitionKeys.filter(_._2.nonEmpty).map(t => (t._1, t._2.get))
InsertIntoTable(
UnresolvedRelation(tableIdent, None),
partitionKeys,
query,
- OverwriteOptions(overwrite, overwritePartition),
+ OverwriteOptions(overwrite, if (overwrite) staticPartitionKeys else Map.empty),
ctx.EXISTS != null)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index dcae7b026f..4dcc288553 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -349,13 +349,15 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
* Options for writing new data into a table.
*
* @param enabled whether to overwrite existing data in the table.
- * @param specificPartition only data in the specified partition will be overwritten.
+ * @param staticPartitionKeys if non-empty, specifies that we only want to overwrite partitions
+ * that match this partial partition spec. If empty, all partitions
+ * will be overwritten.
*/
case class OverwriteOptions(
enabled: Boolean,
- specificPartition: Option[CatalogTypes.TablePartitionSpec] = None) {
- if (specificPartition.isDefined) {
- assert(enabled, "Overwrite must be enabled when specifying a partition to overwrite.")
+ staticPartitionKeys: CatalogTypes.TablePartitionSpec = Map.empty) {
+ if (staticPartitionKeys.nonEmpty) {
+ assert(enabled, "Overwrite must be enabled when specifying specific partitions.")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 5f0f6ee479..9aae520ae6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -185,9 +185,9 @@ class PlanParserSuite extends PlanTest {
OverwriteOptions(
overwrite,
if (overwrite && partition.nonEmpty) {
- Some(partition.map(kv => (kv._1, kv._2.get)))
+ partition.map(kv => (kv._1, kv._2.get))
} else {
- None
+ Map.empty
}),
ifNotExists)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 5d663949df..65422f1495 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -417,15 +417,17 @@ case class DataSource(
// will be adjusted within InsertIntoHadoopFsRelation.
val plan =
InsertIntoHadoopFsRelationCommand(
- outputPath,
- columns,
- bucketSpec,
- format,
- _ => Unit, // No existing table needs to be refreshed.
- options,
- data.logicalPlan,
- mode,
- catalogTable)
+ outputPath = outputPath,
+ staticPartitionKeys = Map.empty,
+ customPartitionLocations = Map.empty,
+ partitionColumns = columns,
+ bucketSpec = bucketSpec,
+ fileFormat = format,
+ refreshFunction = _ => Unit, // No existing table needs to be refreshed.
+ options = options,
+ query = data.logicalPlan,
+ mode = mode,
+ catalogTable = catalogTable)
sparkSession.sessionState.executePlan(plan).toRdd
// Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it.
copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 739aeac877..4f19a2d00b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -24,10 +24,10 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation}
+import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTablePartition, SimpleCatalogRelation}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
-import org.apache.spark.sql.execution.command.{AlterTableAddPartitionCommand, DDLUtils, ExecutedCommandExec}
+import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -182,41 +182,53 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
"Cannot overwrite a path that is also being read from.")
}
- val overwritingSinglePartition =
- overwrite.specificPartition.isDefined &&
+ val partitionSchema = query.resolve(
+ t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver)
+ val partitionsTrackedByCatalog =
t.sparkSession.sessionState.conf.manageFilesourcePartitions &&
+ l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty &&
l.catalogTable.get.tracksPartitionsInCatalog
- val effectiveOutputPath = if (overwritingSinglePartition) {
- val partition = t.sparkSession.sessionState.catalog.getPartition(
- l.catalogTable.get.identifier, overwrite.specificPartition.get)
- new Path(partition.location)
- } else {
- outputPath
- }
-
- val effectivePartitionSchema = if (overwritingSinglePartition) {
- Nil
- } else {
- query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver)
+ var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil
+ var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty
+
+ // When partitions are tracked by the catalog, compute all custom partition locations that
+ // may be relevant to the insertion job.
+ if (partitionsTrackedByCatalog) {
+ val matchingPartitions = t.sparkSession.sessionState.catalog.listPartitions(
+ l.catalogTable.get.identifier, Some(overwrite.staticPartitionKeys))
+ initialMatchingPartitions = matchingPartitions.map(_.spec)
+ customPartitionLocations = getCustomPartitionLocations(
+ t.sparkSession, l.catalogTable.get, outputPath, matchingPartitions)
}
+ // Callback for updating metastore partition metadata after the insertion job completes.
+ // TODO(ekl) consider moving this into InsertIntoHadoopFsRelationCommand
def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = {
- if (l.catalogTable.isDefined && updatedPartitions.nonEmpty &&
- l.catalogTable.get.partitionColumnNames.nonEmpty &&
- l.catalogTable.get.tracksPartitionsInCatalog) {
- val metastoreUpdater = AlterTableAddPartitionCommand(
- l.catalogTable.get.identifier,
- updatedPartitions.map(p => (p, None)),
- ifNotExists = true)
- metastoreUpdater.run(t.sparkSession)
+ if (partitionsTrackedByCatalog) {
+ val newPartitions = updatedPartitions.toSet -- initialMatchingPartitions
+ if (newPartitions.nonEmpty) {
+ AlterTableAddPartitionCommand(
+ l.catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)),
+ ifNotExists = true).run(t.sparkSession)
+ }
+ if (overwrite.enabled) {
+ val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions
+ if (deletedPartitions.nonEmpty) {
+ AlterTableDropPartitionCommand(
+ l.catalogTable.get.identifier, deletedPartitions.toSeq,
+ ifExists = true, purge = true).run(t.sparkSession)
+ }
+ }
}
t.location.refresh()
}
val insertCmd = InsertIntoHadoopFsRelationCommand(
- effectiveOutputPath,
- effectivePartitionSchema,
+ outputPath,
+ if (overwrite.enabled) overwrite.staticPartitionKeys else Map.empty,
+ customPartitionLocations,
+ partitionSchema,
t.bucketSpec,
t.fileFormat,
refreshPartitionsCallback,
@@ -227,6 +239,34 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
insertCmd
}
+
+ /**
+ * Given a set of input partitions, returns those that have locations that differ from the
+ * Hive default (e.g. /k1=v1/k2=v2). These partitions were manually assigned locations by
+ * the user.
+ *
+ * @return a mapping from partition specs to their custom locations
+ */
+ private def getCustomPartitionLocations(
+ spark: SparkSession,
+ table: CatalogTable,
+ basePath: Path,
+ partitions: Seq[CatalogTablePartition]): Map[TablePartitionSpec, String] = {
+ val hadoopConf = spark.sessionState.newHadoopConf
+ val fs = basePath.getFileSystem(hadoopConf)
+ val qualifiedBasePath = basePath.makeQualified(fs.getUri, fs.getWorkingDirectory)
+ partitions.flatMap { p =>
+ val defaultLocation = qualifiedBasePath.suffix(
+ "/" + PartitioningUtils.getPathFragment(p.spec, table.partitionSchema)).toString
+ val catalogLocation = new Path(p.location).makeQualified(
+ fs.getUri, fs.getWorkingDirectory).toString
+ if (catalogLocation != defaultLocation) {
+ Some(p.spec -> catalogLocation)
+ } else {
+ None
+ }
+ }.toMap
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 69b3fa667e..4e4b0e48cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -47,6 +47,10 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/** A helper object for writing FileFormat data out to a location. */
object FileFormatWriter extends Logging {
+ /** Describes how output files should be placed in the filesystem. */
+ case class OutputSpec(
+ outputPath: String, customPartitionLocations: Map[TablePartitionSpec, String])
+
/** A shared job description for all the write tasks. */
private class WriteJobDescription(
val uuid: String, // prevent collision between different (appending) write jobs
@@ -56,7 +60,8 @@ object FileFormatWriter extends Logging {
val partitionColumns: Seq[Attribute],
val nonPartitionColumns: Seq[Attribute],
val bucketSpec: Option[BucketSpec],
- val path: String)
+ val path: String,
+ val customPartitionLocations: Map[TablePartitionSpec, String])
extends Serializable {
assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns),
@@ -83,7 +88,7 @@ object FileFormatWriter extends Logging {
plan: LogicalPlan,
fileFormat: FileFormat,
committer: FileCommitProtocol,
- outputPath: String,
+ outputSpec: OutputSpec,
hadoopConf: Configuration,
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
@@ -93,7 +98,7 @@ object FileFormatWriter extends Logging {
val job = Job.getInstance(hadoopConf)
job.setOutputKeyClass(classOf[Void])
job.setOutputValueClass(classOf[InternalRow])
- FileOutputFormat.setOutputPath(job, new Path(outputPath))
+ FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))
val partitionSet = AttributeSet(partitionColumns)
val dataColumns = plan.output.filterNot(partitionSet.contains)
@@ -111,7 +116,8 @@ object FileFormatWriter extends Logging {
partitionColumns = partitionColumns,
nonPartitionColumns = dataColumns,
bucketSpec = bucketSpec,
- path = outputPath)
+ path = outputSpec.outputPath,
+ customPartitionLocations = outputSpec.customPartitionLocations)
SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
// This call shouldn't be put into the `try` block below because it only initializes and
@@ -308,7 +314,17 @@ object FileFormatWriter extends Logging {
}
val ext = bucketId + description.outputWriterFactory.getFileExtension(taskAttemptContext)
- val path = committer.newTaskTempFile(taskAttemptContext, partDir, ext)
+ val customPath = partDir match {
+ case Some(dir) =>
+ description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
+ case _ =>
+ None
+ }
+ val path = if (customPath.isDefined) {
+ committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext)
+ } else {
+ committer.newTaskTempFile(taskAttemptContext, partDir, ext)
+ }
val newWriter = description.outputWriterFactory.newInstance(
path = path,
dataSchema = description.nonPartitionColumns.toStructType,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
index a0a8cb5024..28975e1546 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources
import java.io.IOException
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql._
@@ -32,19 +32,32 @@ import org.apache.spark.sql.execution.command.RunnableCommand
/**
* A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending.
* Writing to dynamic partitions is also supported.
+ *
+ * @param staticPartitionKeys partial partitioning spec for write. This defines the scope of
+ * partition overwrites: when the spec is empty, all partitions are
+ * overwritten. When it covers a prefix of the partition keys, only
+ * partitions matching the prefix are overwritten.
+ * @param customPartitionLocations mapping of partition specs to their custom locations. The
+ * caller should guarantee that exactly those table partitions
+ * falling under the specified static partition keys are contained
+ * in this map, and that no other partitions are.
*/
case class InsertIntoHadoopFsRelationCommand(
outputPath: Path,
+ staticPartitionKeys: TablePartitionSpec,
+ customPartitionLocations: Map[TablePartitionSpec, String],
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
fileFormat: FileFormat,
- refreshFunction: (Seq[TablePartitionSpec]) => Unit,
+ refreshFunction: Seq[TablePartitionSpec] => Unit,
options: Map[String, String],
@transient query: LogicalPlan,
mode: SaveMode,
catalogTable: Option[CatalogTable])
extends RunnableCommand {
+ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
+
override protected def innerChildren: Seq[LogicalPlan] = query :: Nil
override def run(sparkSession: SparkSession): Seq[Row] = {
@@ -66,10 +79,7 @@ case class InsertIntoHadoopFsRelationCommand(
case (SaveMode.ErrorIfExists, true) =>
throw new AnalysisException(s"path $qualifiedOutputPath already exists.")
case (SaveMode.Overwrite, true) =>
- if (!fs.delete(qualifiedOutputPath, true /* recursively */)) {
- throw new IOException(s"Unable to clear output " +
- s"directory $qualifiedOutputPath prior to writing to it")
- }
+ deleteMatchingPartitions(fs, qualifiedOutputPath)
true
case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) =>
true
@@ -93,7 +103,8 @@ case class InsertIntoHadoopFsRelationCommand(
plan = query,
fileFormat = fileFormat,
committer = committer,
- outputPath = qualifiedOutputPath.toString,
+ outputSpec = FileFormatWriter.OutputSpec(
+ qualifiedOutputPath.toString, customPartitionLocations),
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
@@ -105,4 +116,40 @@ case class InsertIntoHadoopFsRelationCommand(
Seq.empty[Row]
}
+
+ /**
+ * Deletes all partition files that match the specified static prefix. Partitions with custom
+ * locations are also cleared based on the custom locations map given to this class.
+ */
+ private def deleteMatchingPartitions(fs: FileSystem, qualifiedOutputPath: Path): Unit = {
+ val staticPartitionPrefix = if (staticPartitionKeys.nonEmpty) {
+ "/" + partitionColumns.flatMap { p =>
+ staticPartitionKeys.get(p.name) match {
+ case Some(value) =>
+ Some(escapePathName(p.name) + "=" + escapePathName(value))
+ case None =>
+ None
+ }
+ }.mkString("/")
+ } else {
+ ""
+ }
+ // first clear the path determined by the static partition keys (e.g. /table/foo=1)
+ val staticPrefixPath = qualifiedOutputPath.suffix(staticPartitionPrefix)
+ if (fs.exists(staticPrefixPath) && !fs.delete(staticPrefixPath, true /* recursively */)) {
+ throw new IOException(s"Unable to clear output " +
+ s"directory $staticPrefixPath prior to writing to it")
+ }
+ // now clear all custom partition locations (e.g. /custom/dir/where/foo=2/bar=4)
+ for ((spec, customLoc) <- customPartitionLocations) {
+ assert(
+ (staticPartitionKeys.toSet -- spec).isEmpty,
+ "Custom partition location did not match static partitioning keys")
+ val path = new Path(customLoc)
+ if (fs.exists(path) && !fs.delete(path, true)) {
+ throw new IOException(s"Unable to clear partition " +
+ s"directory $path prior to writing to it")
+ }
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index a28b04ca3f..bf9f318780 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -62,6 +62,7 @@ object PartitioningUtils {
}
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.DEFAULT_PARTITION_NAME
+ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.unescapePathName
/**
@@ -253,6 +254,15 @@ object PartitioningUtils {
}
/**
+ * This is the inverse of parsePathFragment().
+ */
+ def getPathFragment(spec: TablePartitionSpec, partitionSchema: StructType): String = {
+ partitionSchema.map { field =>
+ escapePathName(field.name) + "=" + escapePathName(spec(field.name))
+ }.mkString("/")
+ }
+
+ /**
* Normalize the column names in partition specification, w.r.t. the real partition column names
* and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a
* partition column named `month`, and it's case insensitive, we will normalize `monTh` to
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
index e849cafef4..f1c5f9ab50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
@@ -80,7 +80,7 @@ class FileStreamSink(
plan = data.logicalPlan,
fileFormat = fileFormat,
committer = committer,
- outputPath = path,
+ outputSpec = FileFormatWriter.OutputSpec(path, Map.empty),
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
bucketSpec = None,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala
index 1fe13fa162..92191c8b64 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala
@@ -96,6 +96,12 @@ class ManifestFileCommitProtocol(jobId: String, path: String)
file
}
+ override def newTaskTempFileAbsPath(
+ taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = {
+ throw new UnsupportedOperationException(
+ s"$this does not support adding files with an absolute path")
+ }
+
override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = {
if (addedFiles.nonEmpty) {
val fs = new Path(addedFiles.head).getFileSystem(taskContext.getConfiguration)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala
index ac435bf619..a1aa07456f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.util.Utils
class PartitionProviderCompatibilitySuite
extends QueryTest with TestHiveSingleton with SQLTestUtils {
@@ -135,7 +136,7 @@ class PartitionProviderCompatibilitySuite
}
}
- test("insert overwrite partition of legacy datasource table overwrites entire table") {
+ test("insert overwrite partition of legacy datasource table") {
withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") {
withTable("test") {
withTempDir { dir =>
@@ -144,9 +145,9 @@ class PartitionProviderCompatibilitySuite
"""insert overwrite table test
|partition (partCol=1)
|select * from range(100)""".stripMargin)
- assert(spark.sql("select * from test").count() == 100)
+ assert(spark.sql("select * from test").count() == 104)
- // Dynamic partitions case
+ // Overwriting entire table
spark.sql("insert overwrite table test select id, id from range(10)".stripMargin)
assert(spark.sql("select * from test").count() == 10)
}
@@ -186,4 +187,158 @@ class PartitionProviderCompatibilitySuite
}
}
}
+
+ /**
+ * Runs a test against a multi-level partitioned table, then validates that the custom locations
+ * were respected by the output writer.
+ *
+ * The initial partitioning structure is:
+ * /P1=0/P2=0 -- custom location a
+ * /P1=0/P2=1 -- custom location b
+ * /P1=1/P2=0 -- custom location c
+ * /P1=1/P2=1 -- default location
+ */
+ private def testCustomLocations(testFn: => Unit): Unit = {
+ val base = Utils.createTempDir(namePrefix = "base")
+ val a = Utils.createTempDir(namePrefix = "a")
+ val b = Utils.createTempDir(namePrefix = "b")
+ val c = Utils.createTempDir(namePrefix = "c")
+ try {
+ spark.sql(s"""
+ |create table test (id long, P1 int, P2 int)
+ |using parquet
+ |options (path "${base.getAbsolutePath}")
+ |partitioned by (P1, P2)""".stripMargin)
+ spark.sql(s"alter table test add partition (P1=0, P2=0) location '${a.getAbsolutePath}'")
+ spark.sql(s"alter table test add partition (P1=0, P2=1) location '${b.getAbsolutePath}'")
+ spark.sql(s"alter table test add partition (P1=1, P2=0) location '${c.getAbsolutePath}'")
+ spark.sql(s"alter table test add partition (P1=1, P2=1)")
+
+ testFn
+
+ // Now validate the partition custom locations were respected
+ val initialCount = spark.sql("select * from test").count()
+ val numA = spark.sql("select * from test where P1=0 and P2=0").count()
+ val numB = spark.sql("select * from test where P1=0 and P2=1").count()
+ val numC = spark.sql("select * from test where P1=1 and P2=0").count()
+ Utils.deleteRecursively(a)
+ spark.sql("refresh table test")
+ assert(spark.sql("select * from test where P1=0 and P2=0").count() == 0)
+ assert(spark.sql("select * from test").count() == initialCount - numA)
+ Utils.deleteRecursively(b)
+ spark.sql("refresh table test")
+ assert(spark.sql("select * from test where P1=0 and P2=1").count() == 0)
+ assert(spark.sql("select * from test").count() == initialCount - numA - numB)
+ Utils.deleteRecursively(c)
+ spark.sql("refresh table test")
+ assert(spark.sql("select * from test where P1=1 and P2=0").count() == 0)
+ assert(spark.sql("select * from test").count() == initialCount - numA - numB - numC)
+ } finally {
+ Utils.deleteRecursively(base)
+ Utils.deleteRecursively(a)
+ Utils.deleteRecursively(b)
+ Utils.deleteRecursively(c)
+ spark.sql("drop table test")
+ }
+ }
+
+ test("sanity check table setup") {
+ testCustomLocations {
+ assert(spark.sql("select * from test").count() == 0)
+ assert(spark.sql("show partitions test").count() == 4)
+ }
+ }
+
+ test("insert into partial dynamic partitions") {
+ testCustomLocations {
+ spark.sql("insert into test partition (P1=0, P2) select id, id from range(10)")
+ assert(spark.sql("select * from test").count() == 10)
+ assert(spark.sql("show partitions test").count() == 12)
+ spark.sql("insert into test partition (P1=0, P2) select id, id from range(10)")
+ assert(spark.sql("select * from test").count() == 20)
+ assert(spark.sql("show partitions test").count() == 12)
+ spark.sql("insert into test partition (P1=1, P2) select id, id from range(10)")
+ assert(spark.sql("select * from test").count() == 30)
+ assert(spark.sql("show partitions test").count() == 20)
+ spark.sql("insert into test partition (P1=2, P2) select id, id from range(10)")
+ assert(spark.sql("select * from test").count() == 40)
+ assert(spark.sql("show partitions test").count() == 30)
+ }
+ }
+
+ test("insert into fully dynamic partitions") {
+ testCustomLocations {
+ spark.sql("insert into test partition (P1, P2) select id, id, id from range(10)")
+ assert(spark.sql("select * from test").count() == 10)
+ assert(spark.sql("show partitions test").count() == 12)
+ spark.sql("insert into test partition (P1, P2) select id, id, id from range(10)")
+ assert(spark.sql("select * from test").count() == 20)
+ assert(spark.sql("show partitions test").count() == 12)
+ }
+ }
+
+ test("insert into static partition") {
+ testCustomLocations {
+ spark.sql("insert into test partition (P1=0, P2=0) select id from range(10)")
+ assert(spark.sql("select * from test").count() == 10)
+ assert(spark.sql("show partitions test").count() == 4)
+ spark.sql("insert into test partition (P1=0, P2=0) select id from range(10)")
+ assert(spark.sql("select * from test").count() == 20)
+ assert(spark.sql("show partitions test").count() == 4)
+ spark.sql("insert into test partition (P1=1, P2=1) select id from range(10)")
+ assert(spark.sql("select * from test").count() == 30)
+ assert(spark.sql("show partitions test").count() == 4)
+ }
+ }
+
+ test("overwrite partial dynamic partitions") {
+ testCustomLocations {
+ spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(10)")
+ assert(spark.sql("select * from test").count() == 10)
+ assert(spark.sql("show partitions test").count() == 12)
+ spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(5)")
+ assert(spark.sql("select * from test").count() == 5)
+ assert(spark.sql("show partitions test").count() == 7)
+ spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(1)")
+ assert(spark.sql("select * from test").count() == 1)
+ assert(spark.sql("show partitions test").count() == 3)
+ spark.sql("insert overwrite table test partition (P1=1, P2) select id, id from range(10)")
+ assert(spark.sql("select * from test").count() == 11)
+ assert(spark.sql("show partitions test").count() == 11)
+ spark.sql("insert overwrite table test partition (P1=1, P2) select id, id from range(1)")
+ assert(spark.sql("select * from test").count() == 2)
+ assert(spark.sql("show partitions test").count() == 2)
+ spark.sql("insert overwrite table test partition (P1=3, P2) select id, id from range(100)")
+ assert(spark.sql("select * from test").count() == 102)
+ assert(spark.sql("show partitions test").count() == 102)
+ }
+ }
+
+ test("overwrite fully dynamic partitions") {
+ testCustomLocations {
+ spark.sql("insert overwrite table test partition (P1, P2) select id, id, id from range(10)")
+ assert(spark.sql("select * from test").count() == 10)
+ assert(spark.sql("show partitions test").count() == 10)
+ spark.sql("insert overwrite table test partition (P1, P2) select id, id, id from range(5)")
+ assert(spark.sql("select * from test").count() == 5)
+ assert(spark.sql("show partitions test").count() == 5)
+ }
+ }
+
+ test("overwrite static partition") {
+ testCustomLocations {
+ spark.sql("insert overwrite table test partition (P1=0, P2=0) select id from range(10)")
+ assert(spark.sql("select * from test").count() == 10)
+ assert(spark.sql("show partitions test").count() == 4)
+ spark.sql("insert overwrite table test partition (P1=0, P2=0) select id from range(5)")
+ assert(spark.sql("select * from test").count() == 5)
+ assert(spark.sql("show partitions test").count() == 4)
+ spark.sql("insert overwrite table test partition (P1=1, P2=1) select id from range(5)")
+ assert(spark.sql("select * from test").count() == 10)
+ assert(spark.sql("show partitions test").count() == 4)
+ spark.sql("insert overwrite table test partition (P1=1, P2=2) select id from range(5)")
+ assert(spark.sql("select * from test").count() == 15)
+ assert(spark.sql("show partitions test").count() == 5)
+ }
+ }
}