aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockId.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala15
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala24
6 files changed, 61 insertions, 11 deletions
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index a83a3f468a..8df5ec6bde 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -83,9 +83,14 @@ case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId {
def name = "input-" + streamId + "-" + uniqueId
}
-/** Id associated with temporary data managed as blocks. Not serializable. */
-private[spark] case class TempBlockId(id: UUID) extends BlockId {
- def name = "temp_" + id
+/** Id associated with temporary local data managed as blocks. Not serializable. */
+private[spark] case class TempLocalBlockId(id: UUID) extends BlockId {
+ def name = "temp_local_" + id
+}
+
+/** Id associated with temporary shuffle data managed as blocks. Not serializable. */
+private[spark] case class TempShuffleBlockId(id: UUID) extends BlockId {
+ def name = "temp_shuffle_" + id
}
// Intended only for testing purposes
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 0ce2a3f631..4cc9792365 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -1071,7 +1071,8 @@ private[spark] class BlockManager(
case _: ShuffleBlockId => compressShuffle
case _: BroadcastBlockId => compressBroadcast
case _: RDDBlockId => compressRdds
- case _: TempBlockId => compressShuffleSpill
+ case _: TempLocalBlockId => compressShuffleSpill
+ case _: TempShuffleBlockId => compressShuffle
case _ => false
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index a715594f19..6633a1db57 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -98,11 +98,20 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
getAllFiles().map(f => BlockId(f.getName))
}
- /** Produces a unique block id and File suitable for intermediate results. */
- def createTempBlock(): (TempBlockId, File) = {
- var blockId = new TempBlockId(UUID.randomUUID())
+ /** Produces a unique block id and File suitable for storing local intermediate results. */
+ def createTempLocalBlock(): (TempLocalBlockId, File) = {
+ var blockId = new TempLocalBlockId(UUID.randomUUID())
while (getFile(blockId).exists()) {
- blockId = new TempBlockId(UUID.randomUUID())
+ blockId = new TempLocalBlockId(UUID.randomUUID())
+ }
+ (blockId, getFile(blockId))
+ }
+
+ /** Produces a unique block id and File suitable for storing shuffled intermediate results. */
+ def createTempShuffleBlock(): (TempShuffleBlockId, File) = {
+ var blockId = new TempShuffleBlockId(UUID.randomUUID())
+ while (getFile(blockId).exists()) {
+ blockId = new TempShuffleBlockId(UUID.randomUUID())
}
(blockId, getFile(blockId))
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 0c088da46a..26fa0cb6d7 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -153,7 +153,7 @@ class ExternalAppendOnlyMap[K, V, C](
* Sort the existing contents of the in-memory map and spill them to a temporary file on disk.
*/
override protected[this] def spill(collection: SizeTracker): Unit = {
- val (blockId, file) = diskBlockManager.createTempBlock()
+ val (blockId, file) = diskBlockManager.createTempLocalBlock()
curWriteMetrics = new ShuffleWriteMetrics()
var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize,
curWriteMetrics)
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index d1b06d14ac..c1ce13683b 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -38,6 +38,11 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId}
*
* If combining is disabled, the type C must equal V -- we'll cast the objects at the end.
*
+ * Note: Although ExternalSorter is a fairly generic sorter, some of its configuration is tied
+ * to its use in sort-based shuffle (for example, its block compression is controlled by
+ * `spark.shuffle.compress`). We may need to revisit this if ExternalSorter is used in other
+ * non-shuffle contexts where we might want to use different configuration settings.
+ *
* @param aggregator optional Aggregator with combine functions to use for merging data
* @param partitioner optional Partitioner; if given, sort by partition ID and then key
* @param ordering optional Ordering to sort keys within each partition; should be a total ordering
@@ -259,7 +264,10 @@ private[spark] class ExternalSorter[K, V, C](
private def spillToMergeableFile(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
assert(!bypassMergeSort)
- val (blockId, file) = diskBlockManager.createTempBlock()
+ // Because these files may be read during shuffle, their compression must be controlled by
+ // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
+ // createTempShuffleBlock here; see SPARK-3426 for more context.
+ val (blockId, file) = diskBlockManager.createTempShuffleBlock()
curWriteMetrics = new ShuffleWriteMetrics()
var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
var objectsWritten = 0 // Objects written since the last flush
@@ -338,7 +346,10 @@ private[spark] class ExternalSorter[K, V, C](
if (partitionWriters == null) {
curWriteMetrics = new ShuffleWriteMetrics()
partitionWriters = Array.fill(numPartitions) {
- val (blockId, file) = diskBlockManager.createTempBlock()
+ // Because these files may be read during shuffle, their compression must be controlled by
+ // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
+ // createTempShuffleBlock here; see SPARK-3426 for more context.
+ val (blockId, file) = diskBlockManager.createTempShuffleBlock()
blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open()
}
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 15aa4d8380..2bdd84ce69 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -242,6 +242,30 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
assert(thrown.getClass === classOf[SparkException])
assert(thrown.getMessage.toLowerCase.contains("serializable"))
}
+
+ test("shuffle with different compression settings (SPARK-3426)") {
+ for (
+ shuffleSpillCompress <- Set(true, false);
+ shuffleCompress <- Set(true, false)
+ ) {
+ val conf = new SparkConf()
+ .setAppName("test")
+ .setMaster("local")
+ .set("spark.shuffle.spill.compress", shuffleSpillCompress.toString)
+ .set("spark.shuffle.compress", shuffleCompress.toString)
+ .set("spark.shuffle.memoryFraction", "0.001")
+ resetSparkContext()
+ sc = new SparkContext(conf)
+ try {
+ sc.parallelize(0 until 100000).map(i => (i / 4, i)).groupByKey().collect()
+ } catch {
+ case e: Exception =>
+ val errMsg = s"Failed with spark.shuffle.spill.compress=$shuffleSpillCompress," +
+ s" spark.shuffle.compress=$shuffleCompress"
+ throw new Exception(errMsg, e)
+ }
+ }
+ }
}
object ShuffleSuite {