aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala221
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala24
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala41
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala22
4 files changed, 182 insertions, 126 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index 3b064a5bc4..7e12bbb212 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -239,48 +239,50 @@ private[sql] class DefaultWriterContainer(
extends BaseWriterContainer(relation, job, isAppend) {
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
- executorSideSetup(taskContext)
- val configuration = taskAttemptContext.getConfiguration
- configuration.set("spark.sql.sources.output.path", outputPath)
- var writer = newOutputWriter(getWorkPath)
- writer.initConverter(dataSchema)
-
- // If anything below fails, we should abort the task.
- try {
- Utils.tryWithSafeFinallyAndFailureCallbacks {
- while (iterator.hasNext) {
- val internalRow = iterator.next()
- writer.writeInternal(internalRow)
- }
- commitTask()
- }(catchBlock = abortTask())
- } catch {
- case t: Throwable =>
- throw new SparkException("Task failed while writing rows", t)
- }
+ if (iterator.hasNext) {
+ executorSideSetup(taskContext)
+ val configuration = taskAttemptContext.getConfiguration
+ configuration.set("spark.sql.sources.output.path", outputPath)
+ var writer = newOutputWriter(getWorkPath)
+ writer.initConverter(dataSchema)
- def commitTask(): Unit = {
+ // If anything below fails, we should abort the task.
try {
- if (writer != null) {
- writer.close()
- writer = null
- }
- super.commitTask()
+ Utils.tryWithSafeFinallyAndFailureCallbacks {
+ while (iterator.hasNext) {
+ val internalRow = iterator.next()
+ writer.writeInternal(internalRow)
+ }
+ commitTask()
+ }(catchBlock = abortTask())
} catch {
- case cause: Throwable =>
- // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and
- // will cause `abortTask()` to be invoked.
- throw new RuntimeException("Failed to commit task", cause)
+ case t: Throwable =>
+ throw new SparkException("Task failed while writing rows", t)
}
- }
- def abortTask(): Unit = {
- try {
- if (writer != null) {
- writer.close()
+ def commitTask(): Unit = {
+ try {
+ if (writer != null) {
+ writer.close()
+ writer = null
+ }
+ super.commitTask()
+ } catch {
+ case cause: Throwable =>
+ // This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and
+ // will cause `abortTask()` to be invoked.
+ throw new RuntimeException("Failed to commit task", cause)
+ }
+ }
+
+ def abortTask(): Unit = {
+ try {
+ if (writer != null) {
+ writer.close()
+ }
+ } finally {
+ super.abortTask()
}
- } finally {
- super.abortTask()
}
}
}
@@ -363,84 +365,87 @@ private[sql] class DynamicPartitionWriterContainer(
}
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
- executorSideSetup(taskContext)
-
- // We should first sort by partition columns, then bucket id, and finally sorting columns.
- val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns
- val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)
-
- val sortingKeySchema = StructType(sortingExpressions.map {
- case a: Attribute => StructField(a.name, a.dataType, a.nullable)
- // The sorting expressions are all `Attribute` except bucket id.
- case _ => StructField("bucketId", IntegerType, nullable = false)
- })
-
- // Returns the data columns to be written given an input row
- val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
-
- // Returns the partition path given a partition key.
- val getPartitionString =
- UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)
-
- // Sorts the data before write, so that we only need one writer at the same time.
- // TODO: inject a local sort operator in planning.
- val sorter = new UnsafeKVExternalSorter(
- sortingKeySchema,
- StructType.fromAttributes(dataColumns),
- SparkEnv.get.blockManager,
- SparkEnv.get.serializerManager,
- TaskContext.get().taskMemoryManager().pageSizeBytes)
-
- while (iterator.hasNext) {
- val currentRow = iterator.next()
- sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
- }
- logInfo(s"Sorting complete. Writing out partition files one at a time.")
-
- val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
- identity
- } else {
- UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
- case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
+ if (iterator.hasNext) {
+ executorSideSetup(taskContext)
+
+ // We should first sort by partition columns, then bucket id, and finally sorting columns.
+ val sortingExpressions: Seq[Expression] =
+ partitionColumns ++ bucketIdExpression ++ sortColumns
+ val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)
+
+ val sortingKeySchema = StructType(sortingExpressions.map {
+ case a: Attribute => StructField(a.name, a.dataType, a.nullable)
+ // The sorting expressions are all `Attribute` except bucket id.
+ case _ => StructField("bucketId", IntegerType, nullable = false)
})
- }
- val sortedIterator = sorter.sortedIterator()
+ // Returns the data columns to be written given an input row
+ val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
+
+ // Returns the partition path given a partition key.
+ val getPartitionString =
+ UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)
+
+ // Sorts the data before write, so that we only need one writer at the same time.
+ // TODO: inject a local sort operator in planning.
+ val sorter = new UnsafeKVExternalSorter(
+ sortingKeySchema,
+ StructType.fromAttributes(dataColumns),
+ SparkEnv.get.blockManager,
+ SparkEnv.get.serializerManager,
+ TaskContext.get().taskMemoryManager().pageSizeBytes)
+
+ while (iterator.hasNext) {
+ val currentRow = iterator.next()
+ sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
+ }
+ logInfo(s"Sorting complete. Writing out partition files one at a time.")
+
+ val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
+ identity
+ } else {
+ UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
+ case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
+ })
+ }
- // If anything below fails, we should abort the task.
- var currentWriter: OutputWriter = null
- try {
- Utils.tryWithSafeFinallyAndFailureCallbacks {
- var currentKey: UnsafeRow = null
- while (sortedIterator.next()) {
- val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
- if (currentKey != nextKey) {
- if (currentWriter != null) {
- currentWriter.close()
- currentWriter = null
- }
- currentKey = nextKey.copy()
- logDebug(s"Writing partition: $currentKey")
+ val sortedIterator = sorter.sortedIterator()
- currentWriter = newOutputWriter(currentKey, getPartitionString)
+ // If anything below fails, we should abort the task.
+ var currentWriter: OutputWriter = null
+ try {
+ Utils.tryWithSafeFinallyAndFailureCallbacks {
+ var currentKey: UnsafeRow = null
+ while (sortedIterator.next()) {
+ val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
+ if (currentKey != nextKey) {
+ if (currentWriter != null) {
+ currentWriter.close()
+ currentWriter = null
+ }
+ currentKey = nextKey.copy()
+ logDebug(s"Writing partition: $currentKey")
+
+ currentWriter = newOutputWriter(currentKey, getPartitionString)
+ }
+ currentWriter.writeInternal(sortedIterator.getValue)
+ }
+ if (currentWriter != null) {
+ currentWriter.close()
+ currentWriter = null
}
- currentWriter.writeInternal(sortedIterator.getValue)
- }
- if (currentWriter != null) {
- currentWriter.close()
- currentWriter = null
- }
- commitTask()
- }(catchBlock = {
- if (currentWriter != null) {
- currentWriter.close()
- }
- abortTask()
- })
- } catch {
- case t: Throwable =>
- throw new SparkException("Task failed while writing rows", t)
+ commitTask()
+ }(catchBlock = {
+ if (currentWriter != null) {
+ currentWriter.close()
+ }
+ abortTask()
+ })
+ } catch {
+ case t: Throwable =>
+ throw new SparkException("Task failed while writing rows", t)
+ }
}
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
index 794fe264ea..706fdbc260 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
@@ -178,19 +178,21 @@ private[hive] class SparkHiveWriterContainer(
// this function is executed on executor side
def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = {
- val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite()
- executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)
-
- iterator.foreach { row =>
- var i = 0
- while (i < fieldOIs.length) {
- outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i)))
- i += 1
+ if (iterator.hasNext) {
+ val (serializer, standardOI, fieldOIs, dataTypes, wrappers, outputData) = prepareForWrite()
+ executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)
+
+ iterator.foreach { row =>
+ var i = 0
+ while (i < fieldOIs.length) {
+ outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i)))
+ i += 1
+ }
+ writer.write(serializer.serialize(outputData, standardOI))
}
- writer.write(serializer.serialize(outputData, standardOI))
- }
- close()
+ close()
+ }
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 82d3e49f92..883cdac110 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -19,13 +19,13 @@ package org.apache.spark.sql.hive
import java.io.File
-import org.apache.hadoop.hive.conf.HiveConf
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkException
-import org.apache.spark.sql.{QueryTest, _}
+import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -118,10 +118,10 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
sql(
s"""
- |CREATE TABLE table_with_partition(c1 string)
- |PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string)
- |location '${tmpDir.toURI.toString}'
- """.stripMargin)
+ |CREATE TABLE table_with_partition(c1 string)
+ |PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string)
+ |location '${tmpDir.toURI.toString}'
+ """.stripMargin)
sql(
"""
|INSERT OVERWRITE TABLE table_with_partition
@@ -216,6 +216,35 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
sql("DROP TABLE hiveTableWithStructValue")
}
+ test("SPARK-10216: Avoid empty files during overwrite into Hive table with group by query") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ val testDataset = hiveContext.sparkContext.parallelize(
+ (1 to 2).map(i => TestData(i, i.toString))).toDF()
+ testDataset.createOrReplaceTempView("testDataset")
+
+ val tmpDir = Utils.createTempDir()
+ sql(
+ s"""
+ |CREATE TABLE table1(key int,value string)
+ |location '${tmpDir.toURI.toString}'
+ """.stripMargin)
+ sql(
+ """
+ |INSERT OVERWRITE TABLE table1
+ |SELECT count(key), value FROM testDataset GROUP BY value
+ """.stripMargin)
+
+ val overwrittenFiles = tmpDir.listFiles()
+ .filter(f => f.isFile && !f.getName.endsWith(".crc"))
+ .sortBy(_.getName)
+ val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 0)
+
+ assert(overwrittenFiles === overwrittenFilesWithoutEmpty)
+
+ sql("DROP TABLE table1")
+ }
+ }
+
test("Reject partitioning that does not match table") {
withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala
index f4d63334b6..78d2dc28d6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala
@@ -29,7 +29,7 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql._
import org.apache.spark.sql.execution.DataSourceScanExec
-import org.apache.spark.sql.execution.datasources.{FileScanRDD, HadoopFsRelation, LocalityTestFileSystem, LogicalRelation}
+import org.apache.spark.sql.execution.datasources.{FileScanRDD, LocalityTestFileSystem}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
@@ -879,6 +879,26 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
}
}
}
+
+ test("SPARK-10216: Avoid empty files during overwriting with group by query") {
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ withTempPath { path =>
+ val df = spark.range(0, 5)
+ val groupedDF = df.groupBy("id").count()
+ groupedDF.write
+ .format(dataSourceName)
+ .mode(SaveMode.Overwrite)
+ .save(path.getCanonicalPath)
+
+ val overwrittenFiles = path.listFiles()
+ .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_"))
+ .sortBy(_.getName)
+ val overwrittenFilesWithoutEmpty = overwrittenFiles.filter(_.length > 0)
+
+ assert(overwrittenFiles === overwrittenFilesWithoutEmpty)
+ }
+ }
+ }
}
// This class is used to test SPARK-8578. We should not use any custom output committer when