diff options
13 files changed, 411 insertions, 73 deletions
diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index fb8020585c..afd2250c93 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -82,10 +82,25 @@ abstract class FileCommitProtocol { * * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest * are left to the commit protocol implementation to decide. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "ext" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. */ def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String /** + * Similar to newTaskTempFile(), but allows files to committed to an absolute output location. + * Depending on the implementation, there may be weaker guarantees around adding files this way. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "ext" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. + */ + def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String + + /** * Commits a task after the writes succeed. Must be called on the executors when running tasks. */ def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 6b0bcb8f90..b2d9b8d2a0 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -17,7 +17,9 @@ package org.apache.spark.internal.io -import java.util.Date +import java.util.{Date, UUID} + +import scala.collection.mutable import org.apache.hadoop.conf.Configurable import org.apache.hadoop.fs.Path @@ -42,6 +44,19 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) /** OutputCommitter from Hadoop is not serializable so marking it transient. */ @transient private var committer: OutputCommitter = _ + /** + * Tracks files staged by this task for absolute output paths. These outputs are not managed by + * the Hadoop OutputCommitter, so we must move these to their final locations on job commit. + * + * The mapping is from the temp output path to the final desired output path of the file. + */ + @transient private var addedAbsPathFiles: mutable.Map[String, String] = null + + /** + * The staging directory for all files committed with absolute output paths. + */ + private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.newInstance() // If OutputFormat is Configurable, we should set conf to it. @@ -54,11 +69,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) override def newTaskTempFile( taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { - // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet - // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, - // the file name is fine and won't overflow. - val split = taskContext.getTaskAttemptID.getTaskID.getId - val filename = f"part-$split%05d-$jobId$ext" + val filename = getFilename(taskContext, ext) val stagingDir: String = committer match { // For FileOutputCommitter it has its own staging path called "work path". @@ -73,6 +84,28 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) } } + override def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { + val filename = getFilename(taskContext, ext) + val absOutputPath = new Path(absoluteDir, filename).toString + + // Include a UUID here to prevent file collisions for one task writing to different dirs. + // In principle we could include hash(absoluteDir) instead but this is simpler. + val tmpOutputPath = new Path( + absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString + + addedAbsPathFiles(tmpOutputPath) = absOutputPath + tmpOutputPath + } + + private def getFilename(taskContext: TaskAttemptContext, ext: String): String = { + // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet + // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, + // the file name is fine and won't overflow. + val split = taskContext.getTaskAttemptID.getTaskID.getId + f"part-$split%05d-$jobId$ext" + } + override def setupJob(jobContext: JobContext): Unit = { // Setup IDs val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0) @@ -93,26 +126,42 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { committer.commitJob(jobContext) + val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) + .foldLeft(Map[String, String]())(_ ++ _) + logDebug(s"Committing files staged for absolute locations $filesToMove") + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + for ((src, dst) <- filesToMove) { + fs.rename(new Path(src), new Path(dst)) + } + fs.delete(absPathStagingDir, true) } override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(absPathStagingDir, true) } override def setupTask(taskContext: TaskAttemptContext): Unit = { committer = setupCommitter(taskContext) committer.setupTask(taskContext) + addedAbsPathFiles = mutable.Map[String, String]() } override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { val attemptId = taskContext.getTaskAttemptID SparkHadoopMapRedUtil.commitTask( committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) - EmptyTaskCommitMessage + new TaskCommitMessage(addedAbsPathFiles.toMap) } override def abortTask(taskContext: TaskAttemptContext): Unit = { committer.abortTask(taskContext) + // best effort cleanup of other staged files + for ((src, _) <- addedAbsPathFiles) { + val tmp = new Path(src) + tmp.getFileSystem(taskContext.getConfiguration).delete(tmp, false) + } } /** Whether we are using a direct output committer */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 2c4db0d2c3..3fa7bf1cdb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -172,24 +172,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val tableIdent = visitTableIdentifier(ctx.tableIdentifier) val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) - val dynamicPartitionKeys = partitionKeys.filter(_._2.isEmpty) + val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty) if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) { throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " + "partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx) } val overwrite = ctx.OVERWRITE != null - val overwritePartition = - if (overwrite && partitionKeys.nonEmpty && dynamicPartitionKeys.isEmpty) { - Some(partitionKeys.map(t => (t._1, t._2.get))) - } else { - None - } + val staticPartitionKeys: Map[String, String] = + partitionKeys.filter(_._2.nonEmpty).map(t => (t._1, t._2.get)) InsertIntoTable( UnresolvedRelation(tableIdent, None), partitionKeys, query, - OverwriteOptions(overwrite, overwritePartition), + OverwriteOptions(overwrite, if (overwrite) staticPartitionKeys else Map.empty), ctx.EXISTS != null) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index dcae7b026f..4dcc288553 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -349,13 +349,15 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { * Options for writing new data into a table. * * @param enabled whether to overwrite existing data in the table. - * @param specificPartition only data in the specified partition will be overwritten. + * @param staticPartitionKeys if non-empty, specifies that we only want to overwrite partitions + * that match this partial partition spec. If empty, all partitions + * will be overwritten. */ case class OverwriteOptions( enabled: Boolean, - specificPartition: Option[CatalogTypes.TablePartitionSpec] = None) { - if (specificPartition.isDefined) { - assert(enabled, "Overwrite must be enabled when specifying a partition to overwrite.") + staticPartitionKeys: CatalogTypes.TablePartitionSpec = Map.empty) { + if (staticPartitionKeys.nonEmpty) { + assert(enabled, "Overwrite must be enabled when specifying specific partitions.") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 5f0f6ee479..9aae520ae6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -185,9 +185,9 @@ class PlanParserSuite extends PlanTest { OverwriteOptions( overwrite, if (overwrite && partition.nonEmpty) { - Some(partition.map(kv => (kv._1, kv._2.get))) + partition.map(kv => (kv._1, kv._2.get)) } else { - None + Map.empty }), ifNotExists) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 5d663949df..65422f1495 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -417,15 +417,17 @@ case class DataSource( // will be adjusted within InsertIntoHadoopFsRelation. val plan = InsertIntoHadoopFsRelationCommand( - outputPath, - columns, - bucketSpec, - format, - _ => Unit, // No existing table needs to be refreshed. - options, - data.logicalPlan, - mode, - catalogTable) + outputPath = outputPath, + staticPartitionKeys = Map.empty, + customPartitionLocations = Map.empty, + partitionColumns = columns, + bucketSpec = bucketSpec, + fileFormat = format, + refreshFunction = _ => Unit, // No existing table needs to be refreshed. + options = options, + query = data.logicalPlan, + mode = mode, + catalogTable = catalogTable) sparkSession.sessionState.executePlan(plan).toRdd // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it. copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 739aeac877..4f19a2d00b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -24,10 +24,10 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTablePartition, SimpleCatalogRelation} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} -import org.apache.spark.sql.execution.command.{AlterTableAddPartitionCommand, DDLUtils, ExecutedCommandExec} +import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -182,41 +182,53 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { "Cannot overwrite a path that is also being read from.") } - val overwritingSinglePartition = - overwrite.specificPartition.isDefined && + val partitionSchema = query.resolve( + t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver) + val partitionsTrackedByCatalog = t.sparkSession.sessionState.conf.manageFilesourcePartitions && + l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty && l.catalogTable.get.tracksPartitionsInCatalog - val effectiveOutputPath = if (overwritingSinglePartition) { - val partition = t.sparkSession.sessionState.catalog.getPartition( - l.catalogTable.get.identifier, overwrite.specificPartition.get) - new Path(partition.location) - } else { - outputPath - } - - val effectivePartitionSchema = if (overwritingSinglePartition) { - Nil - } else { - query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver) + var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil + var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty + + // When partitions are tracked by the catalog, compute all custom partition locations that + // may be relevant to the insertion job. + if (partitionsTrackedByCatalog) { + val matchingPartitions = t.sparkSession.sessionState.catalog.listPartitions( + l.catalogTable.get.identifier, Some(overwrite.staticPartitionKeys)) + initialMatchingPartitions = matchingPartitions.map(_.spec) + customPartitionLocations = getCustomPartitionLocations( + t.sparkSession, l.catalogTable.get, outputPath, matchingPartitions) } + // Callback for updating metastore partition metadata after the insertion job completes. + // TODO(ekl) consider moving this into InsertIntoHadoopFsRelationCommand def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = { - if (l.catalogTable.isDefined && updatedPartitions.nonEmpty && - l.catalogTable.get.partitionColumnNames.nonEmpty && - l.catalogTable.get.tracksPartitionsInCatalog) { - val metastoreUpdater = AlterTableAddPartitionCommand( - l.catalogTable.get.identifier, - updatedPartitions.map(p => (p, None)), - ifNotExists = true) - metastoreUpdater.run(t.sparkSession) + if (partitionsTrackedByCatalog) { + val newPartitions = updatedPartitions.toSet -- initialMatchingPartitions + if (newPartitions.nonEmpty) { + AlterTableAddPartitionCommand( + l.catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), + ifNotExists = true).run(t.sparkSession) + } + if (overwrite.enabled) { + val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions + if (deletedPartitions.nonEmpty) { + AlterTableDropPartitionCommand( + l.catalogTable.get.identifier, deletedPartitions.toSeq, + ifExists = true, purge = true).run(t.sparkSession) + } + } } t.location.refresh() } val insertCmd = InsertIntoHadoopFsRelationCommand( - effectiveOutputPath, - effectivePartitionSchema, + outputPath, + if (overwrite.enabled) overwrite.staticPartitionKeys else Map.empty, + customPartitionLocations, + partitionSchema, t.bucketSpec, t.fileFormat, refreshPartitionsCallback, @@ -227,6 +239,34 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { insertCmd } + + /** + * Given a set of input partitions, returns those that have locations that differ from the + * Hive default (e.g. /k1=v1/k2=v2). These partitions were manually assigned locations by + * the user. + * + * @return a mapping from partition specs to their custom locations + */ + private def getCustomPartitionLocations( + spark: SparkSession, + table: CatalogTable, + basePath: Path, + partitions: Seq[CatalogTablePartition]): Map[TablePartitionSpec, String] = { + val hadoopConf = spark.sessionState.newHadoopConf + val fs = basePath.getFileSystem(hadoopConf) + val qualifiedBasePath = basePath.makeQualified(fs.getUri, fs.getWorkingDirectory) + partitions.flatMap { p => + val defaultLocation = qualifiedBasePath.suffix( + "/" + PartitioningUtils.getPathFragment(p.spec, table.partitionSchema)).toString + val catalogLocation = new Path(p.location).makeQualified( + fs.getUri, fs.getWorkingDirectory).toString + if (catalogLocation != defaultLocation) { + Some(p.spec -> catalogLocation) + } else { + None + } + }.toMap + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 69b3fa667e..4e4b0e48cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -47,6 +47,10 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** A helper object for writing FileFormat data out to a location. */ object FileFormatWriter extends Logging { + /** Describes how output files should be placed in the filesystem. */ + case class OutputSpec( + outputPath: String, customPartitionLocations: Map[TablePartitionSpec, String]) + /** A shared job description for all the write tasks. */ private class WriteJobDescription( val uuid: String, // prevent collision between different (appending) write jobs @@ -56,7 +60,8 @@ object FileFormatWriter extends Logging { val partitionColumns: Seq[Attribute], val nonPartitionColumns: Seq[Attribute], val bucketSpec: Option[BucketSpec], - val path: String) + val path: String, + val customPartitionLocations: Map[TablePartitionSpec, String]) extends Serializable { assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns), @@ -83,7 +88,7 @@ object FileFormatWriter extends Logging { plan: LogicalPlan, fileFormat: FileFormat, committer: FileCommitProtocol, - outputPath: String, + outputSpec: OutputSpec, hadoopConf: Configuration, partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], @@ -93,7 +98,7 @@ object FileFormatWriter extends Logging { val job = Job.getInstance(hadoopConf) job.setOutputKeyClass(classOf[Void]) job.setOutputValueClass(classOf[InternalRow]) - FileOutputFormat.setOutputPath(job, new Path(outputPath)) + FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) val partitionSet = AttributeSet(partitionColumns) val dataColumns = plan.output.filterNot(partitionSet.contains) @@ -111,7 +116,8 @@ object FileFormatWriter extends Logging { partitionColumns = partitionColumns, nonPartitionColumns = dataColumns, bucketSpec = bucketSpec, - path = outputPath) + path = outputSpec.outputPath, + customPartitionLocations = outputSpec.customPartitionLocations) SQLExecution.withNewExecutionId(sparkSession, queryExecution) { // This call shouldn't be put into the `try` block below because it only initializes and @@ -308,7 +314,17 @@ object FileFormatWriter extends Logging { } val ext = bucketId + description.outputWriterFactory.getFileExtension(taskAttemptContext) - val path = committer.newTaskTempFile(taskAttemptContext, partDir, ext) + val customPath = partDir match { + case Some(dir) => + description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) + case _ => + None + } + val path = if (customPath.isDefined) { + committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) + } else { + committer.newTaskTempFile(taskAttemptContext, partDir, ext) + } val newWriter = description.outputWriterFactory.newInstance( path = path, dataSchema = description.nonPartitionColumns.toStructType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index a0a8cb5024..28975e1546 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.io.IOException -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql._ @@ -32,19 +32,32 @@ import org.apache.spark.sql.execution.command.RunnableCommand /** * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. * Writing to dynamic partitions is also supported. + * + * @param staticPartitionKeys partial partitioning spec for write. This defines the scope of + * partition overwrites: when the spec is empty, all partitions are + * overwritten. When it covers a prefix of the partition keys, only + * partitions matching the prefix are overwritten. + * @param customPartitionLocations mapping of partition specs to their custom locations. The + * caller should guarantee that exactly those table partitions + * falling under the specified static partition keys are contained + * in this map, and that no other partitions are. */ case class InsertIntoHadoopFsRelationCommand( outputPath: Path, + staticPartitionKeys: TablePartitionSpec, + customPartitionLocations: Map[TablePartitionSpec, String], partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], fileFormat: FileFormat, - refreshFunction: (Seq[TablePartitionSpec]) => Unit, + refreshFunction: Seq[TablePartitionSpec] => Unit, options: Map[String, String], @transient query: LogicalPlan, mode: SaveMode, catalogTable: Option[CatalogTable]) extends RunnableCommand { + import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName + override protected def innerChildren: Seq[LogicalPlan] = query :: Nil override def run(sparkSession: SparkSession): Seq[Row] = { @@ -66,10 +79,7 @@ case class InsertIntoHadoopFsRelationCommand( case (SaveMode.ErrorIfExists, true) => throw new AnalysisException(s"path $qualifiedOutputPath already exists.") case (SaveMode.Overwrite, true) => - if (!fs.delete(qualifiedOutputPath, true /* recursively */)) { - throw new IOException(s"Unable to clear output " + - s"directory $qualifiedOutputPath prior to writing to it") - } + deleteMatchingPartitions(fs, qualifiedOutputPath) true case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => true @@ -93,7 +103,8 @@ case class InsertIntoHadoopFsRelationCommand( plan = query, fileFormat = fileFormat, committer = committer, - outputPath = qualifiedOutputPath.toString, + outputSpec = FileFormatWriter.OutputSpec( + qualifiedOutputPath.toString, customPartitionLocations), hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = bucketSpec, @@ -105,4 +116,40 @@ case class InsertIntoHadoopFsRelationCommand( Seq.empty[Row] } + + /** + * Deletes all partition files that match the specified static prefix. Partitions with custom + * locations are also cleared based on the custom locations map given to this class. + */ + private def deleteMatchingPartitions(fs: FileSystem, qualifiedOutputPath: Path): Unit = { + val staticPartitionPrefix = if (staticPartitionKeys.nonEmpty) { + "/" + partitionColumns.flatMap { p => + staticPartitionKeys.get(p.name) match { + case Some(value) => + Some(escapePathName(p.name) + "=" + escapePathName(value)) + case None => + None + } + }.mkString("/") + } else { + "" + } + // first clear the path determined by the static partition keys (e.g. /table/foo=1) + val staticPrefixPath = qualifiedOutputPath.suffix(staticPartitionPrefix) + if (fs.exists(staticPrefixPath) && !fs.delete(staticPrefixPath, true /* recursively */)) { + throw new IOException(s"Unable to clear output " + + s"directory $staticPrefixPath prior to writing to it") + } + // now clear all custom partition locations (e.g. /custom/dir/where/foo=2/bar=4) + for ((spec, customLoc) <- customPartitionLocations) { + assert( + (staticPartitionKeys.toSet -- spec).isEmpty, + "Custom partition location did not match static partitioning keys") + val path = new Path(customLoc) + if (fs.exists(path) && !fs.delete(path, true)) { + throw new IOException(s"Unable to clear partition " + + s"directory $path prior to writing to it") + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index a28b04ca3f..bf9f318780 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -62,6 +62,7 @@ object PartitioningUtils { } import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.DEFAULT_PARTITION_NAME + import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.unescapePathName /** @@ -253,6 +254,15 @@ object PartitioningUtils { } /** + * This is the inverse of parsePathFragment(). + */ + def getPathFragment(spec: TablePartitionSpec, partitionSchema: StructType): String = { + partitionSchema.map { field => + escapePathName(field.name) + "=" + escapePathName(spec(field.name)) + }.mkString("/") + } + + /** * Normalize the column names in partition specification, w.r.t. the real partition column names * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a * partition column named `month`, and it's case insensitive, we will normalize `monTh` to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index e849cafef4..f1c5f9ab50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -80,7 +80,7 @@ class FileStreamSink( plan = data.logicalPlan, fileFormat = fileFormat, committer = committer, - outputPath = path, + outputSpec = FileFormatWriter.OutputSpec(path, Map.empty), hadoopConf = hadoopConf, partitionColumns = partitionColumns, bucketSpec = None, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala index 1fe13fa162..92191c8b64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala @@ -96,6 +96,12 @@ class ManifestFileCommitProtocol(jobId: String, path: String) file } + override def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { + throw new UnsupportedOperationException( + s"$this does not support adding files with an absolute path") + } + override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { if (addedFiles.nonEmpty) { val fs = new Path(addedFiles.head).getFileSystem(taskContext.getConfiguration) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala index ac435bf619..a1aa07456f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils class PartitionProviderCompatibilitySuite extends QueryTest with TestHiveSingleton with SQLTestUtils { @@ -135,7 +136,7 @@ class PartitionProviderCompatibilitySuite } } - test("insert overwrite partition of legacy datasource table overwrites entire table") { + test("insert overwrite partition of legacy datasource table") { withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") { withTable("test") { withTempDir { dir => @@ -144,9 +145,9 @@ class PartitionProviderCompatibilitySuite """insert overwrite table test |partition (partCol=1) |select * from range(100)""".stripMargin) - assert(spark.sql("select * from test").count() == 100) + assert(spark.sql("select * from test").count() == 104) - // Dynamic partitions case + // Overwriting entire table spark.sql("insert overwrite table test select id, id from range(10)".stripMargin) assert(spark.sql("select * from test").count() == 10) } @@ -186,4 +187,158 @@ class PartitionProviderCompatibilitySuite } } } + + /** + * Runs a test against a multi-level partitioned table, then validates that the custom locations + * were respected by the output writer. + * + * The initial partitioning structure is: + * /P1=0/P2=0 -- custom location a + * /P1=0/P2=1 -- custom location b + * /P1=1/P2=0 -- custom location c + * /P1=1/P2=1 -- default location + */ + private def testCustomLocations(testFn: => Unit): Unit = { + val base = Utils.createTempDir(namePrefix = "base") + val a = Utils.createTempDir(namePrefix = "a") + val b = Utils.createTempDir(namePrefix = "b") + val c = Utils.createTempDir(namePrefix = "c") + try { + spark.sql(s""" + |create table test (id long, P1 int, P2 int) + |using parquet + |options (path "${base.getAbsolutePath}") + |partitioned by (P1, P2)""".stripMargin) + spark.sql(s"alter table test add partition (P1=0, P2=0) location '${a.getAbsolutePath}'") + spark.sql(s"alter table test add partition (P1=0, P2=1) location '${b.getAbsolutePath}'") + spark.sql(s"alter table test add partition (P1=1, P2=0) location '${c.getAbsolutePath}'") + spark.sql(s"alter table test add partition (P1=1, P2=1)") + + testFn + + // Now validate the partition custom locations were respected + val initialCount = spark.sql("select * from test").count() + val numA = spark.sql("select * from test where P1=0 and P2=0").count() + val numB = spark.sql("select * from test where P1=0 and P2=1").count() + val numC = spark.sql("select * from test where P1=1 and P2=0").count() + Utils.deleteRecursively(a) + spark.sql("refresh table test") + assert(spark.sql("select * from test where P1=0 and P2=0").count() == 0) + assert(spark.sql("select * from test").count() == initialCount - numA) + Utils.deleteRecursively(b) + spark.sql("refresh table test") + assert(spark.sql("select * from test where P1=0 and P2=1").count() == 0) + assert(spark.sql("select * from test").count() == initialCount - numA - numB) + Utils.deleteRecursively(c) + spark.sql("refresh table test") + assert(spark.sql("select * from test where P1=1 and P2=0").count() == 0) + assert(spark.sql("select * from test").count() == initialCount - numA - numB - numC) + } finally { + Utils.deleteRecursively(base) + Utils.deleteRecursively(a) + Utils.deleteRecursively(b) + Utils.deleteRecursively(c) + spark.sql("drop table test") + } + } + + test("sanity check table setup") { + testCustomLocations { + assert(spark.sql("select * from test").count() == 0) + assert(spark.sql("show partitions test").count() == 4) + } + } + + test("insert into partial dynamic partitions") { + testCustomLocations { + spark.sql("insert into test partition (P1=0, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 12) + spark.sql("insert into test partition (P1=0, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 20) + assert(spark.sql("show partitions test").count() == 12) + spark.sql("insert into test partition (P1=1, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 30) + assert(spark.sql("show partitions test").count() == 20) + spark.sql("insert into test partition (P1=2, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 40) + assert(spark.sql("show partitions test").count() == 30) + } + } + + test("insert into fully dynamic partitions") { + testCustomLocations { + spark.sql("insert into test partition (P1, P2) select id, id, id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 12) + spark.sql("insert into test partition (P1, P2) select id, id, id from range(10)") + assert(spark.sql("select * from test").count() == 20) + assert(spark.sql("show partitions test").count() == 12) + } + } + + test("insert into static partition") { + testCustomLocations { + spark.sql("insert into test partition (P1=0, P2=0) select id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert into test partition (P1=0, P2=0) select id from range(10)") + assert(spark.sql("select * from test").count() == 20) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert into test partition (P1=1, P2=1) select id from range(10)") + assert(spark.sql("select * from test").count() == 30) + assert(spark.sql("show partitions test").count() == 4) + } + } + + test("overwrite partial dynamic partitions") { + testCustomLocations { + spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 12) + spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(5)") + assert(spark.sql("select * from test").count() == 5) + assert(spark.sql("show partitions test").count() == 7) + spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(1)") + assert(spark.sql("select * from test").count() == 1) + assert(spark.sql("show partitions test").count() == 3) + spark.sql("insert overwrite table test partition (P1=1, P2) select id, id from range(10)") + assert(spark.sql("select * from test").count() == 11) + assert(spark.sql("show partitions test").count() == 11) + spark.sql("insert overwrite table test partition (P1=1, P2) select id, id from range(1)") + assert(spark.sql("select * from test").count() == 2) + assert(spark.sql("show partitions test").count() == 2) + spark.sql("insert overwrite table test partition (P1=3, P2) select id, id from range(100)") + assert(spark.sql("select * from test").count() == 102) + assert(spark.sql("show partitions test").count() == 102) + } + } + + test("overwrite fully dynamic partitions") { + testCustomLocations { + spark.sql("insert overwrite table test partition (P1, P2) select id, id, id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 10) + spark.sql("insert overwrite table test partition (P1, P2) select id, id, id from range(5)") + assert(spark.sql("select * from test").count() == 5) + assert(spark.sql("show partitions test").count() == 5) + } + } + + test("overwrite static partition") { + testCustomLocations { + spark.sql("insert overwrite table test partition (P1=0, P2=0) select id from range(10)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert overwrite table test partition (P1=0, P2=0) select id from range(5)") + assert(spark.sql("select * from test").count() == 5) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert overwrite table test partition (P1=1, P2=1) select id from range(5)") + assert(spark.sql("select * from test").count() == 10) + assert(spark.sql("show partitions test").count() == 4) + spark.sql("insert overwrite table test partition (P1=1, P2=2) select id from range(5)") + assert(spark.sql("select * from test").count() == 15) + assert(spark.sql("show partitions test").count() == 5) + } + } } |