aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala23
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala16
3 files changed, 41 insertions, 10 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index c617ff5c51..15bda1c9cc 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -205,6 +205,13 @@ private[spark] class ExternalSorter[K, V, C](
map.changeValue((getPartition(kv._1), kv._1), update)
maybeSpillCollection(usingMap = true)
}
+ } else if (bypassMergeSort) {
+ // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies
+ if (records.hasNext) {
+ spillToPartitionFiles(records.map { kv =>
+ ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
+ })
+ }
} else {
// Stick values into our buffer
while (records.hasNext) {
@@ -336,6 +343,10 @@ private[spark] class ExternalSorter[K, V, C](
* @param collection whichever collection we're using (map or buffer)
*/
private def spillToPartitionFiles(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
+ spillToPartitionFiles(collection.iterator)
+ }
+
+ private def spillToPartitionFiles(iterator: Iterator[((Int, K), C)]): Unit = {
assert(bypassMergeSort)
// Create our file writers if we haven't done so yet
@@ -350,9 +361,9 @@ private[spark] class ExternalSorter[K, V, C](
}
}
- val it = collection.iterator // No need to sort stuff, just write each element out
- while (it.hasNext) {
- val elem = it.next()
+ // No need to sort stuff, just write each element out
+ while (iterator.hasNext) {
+ val elem = iterator.next()
val partitionId = elem._1._1
val key = elem._1._2
val value = elem._2
@@ -748,6 +759,12 @@ private[spark] class ExternalSorter[K, V, C](
context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
context.taskMetrics.diskBytesSpilled += diskBytesSpilled
+ context.taskMetrics.shuffleWriteMetrics.filter(_ => bypassMergeSort).foreach { m =>
+ if (curWriteMetrics != null) {
+ m.shuffleBytesWritten += curWriteMetrics.shuffleBytesWritten
+ m.shuffleWriteTime += curWriteMetrics.shuffleWriteTime
+ }
+ }
lengths
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index cda942e15a..85e5f9ab44 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -95,14 +95,14 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
- // 10 partitions from 4 keys
- val NUM_BLOCKS = 10
+ // 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys
+ val NUM_BLOCKS = 201
val a = sc.parallelize(1 to 4, NUM_BLOCKS)
val b = a.map(x => (x, x*2))
// NOTE: The default Java serializer doesn't create zero-sized blocks.
// So, use Kryo
- val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10))
+ val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(NUM_BLOCKS))
.setSerializer(new KryoSerializer(conf))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
@@ -122,13 +122,13 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
- // 10 partitions from 4 keys
- val NUM_BLOCKS = 10
+ // 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys
+ val NUM_BLOCKS = 201
val a = sc.parallelize(1 to 4, NUM_BLOCKS)
val b = a.map(x => (x, x*2))
// NOTE: The default Java serializer should create zero-sized blocks
- val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10))
+ val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(NUM_BLOCKS))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 4)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index cff7a01269..d7c811ca89 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -41,11 +41,21 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
/** We must copy rows when sort based shuffle is on */
protected def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
+ private val bypassMergeThreshold =
+ child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+
override def execute() = attachTree(this , "execute") {
newPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
// TODO: Eliminate redundant expressions in grouping key and value.
- val rdd = if (sortBasedShuffleOn) {
+ // This is a workaround for SPARK-4479. When:
+ // 1. sort based shuffle is on, and
+ // 2. the partition number is under the merge threshold, and
+ // 3. no ordering is required
+ // we can avoid the defensive copies to improve performance. In the long run, we probably
+ // want to include information in shuffle dependencies to indicate whether elements in the
+ // source RDD should be copied.
+ val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) {
child.execute().mapPartitions { iter =>
val hashExpressions = newMutableProjection(expressions, child.output)()
iter.map(r => (hashExpressions(r).copy(), r.copy()))
@@ -82,6 +92,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
shuffled.map(_._1)
case SinglePartition =>
+ // SPARK-4479: Can't turn off defensive copy as what we do for `HashPartitioning`, since
+ // operators like `TakeOrdered` may require an ordering within the partition, and currently
+ // `SinglePartition` doesn't include ordering information.
+ // TODO Add `SingleOrderedPartition` for operators like `TakeOrdered`
val rdd = if (sortBasedShuffleOn) {
child.execute().mapPartitions { iter => iter.map(r => (null, r.copy())) }
} else {