aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala22
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala10
2 files changed, 20 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index e32ad9c036..7c9dc8e5f8 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -20,6 +20,7 @@ package org.apache.spark.shuffle.hash
import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
+import org.apache.spark.util.collection.ExternalSorter
private[spark] class HashShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
@@ -35,8 +36,8 @@ private[spark] class HashShuffleReader[K, C](
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
- val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context,
- Serializer.getSerializer(dep.serializer))
+ val ser = Serializer.getSerializer(dep.serializer)
+ val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
@@ -54,16 +55,13 @@ private[spark] class HashShuffleReader[K, C](
// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
- // Define a Comparator for the whole record based on the key Ordering.
- val cmp = new Ordering[Product2[K, C]] {
- override def compare(o1: Product2[K, C], o2: Product2[K, C]): Int = {
- keyOrd.compare(o1._1, o2._1)
- }
- }
- val sortBuffer: Array[Product2[K, C]] = aggregatedIter.toArray
- // TODO: do external sort.
- scala.util.Sorting.quickSort(sortBuffer)(cmp)
- sortBuffer.iterator
+ // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
+ // the ExternalSorter won't spill to disk.
+ val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
+ sorter.write(aggregatedIter)
+ context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
+ context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
+ sorter.iterator
case None =>
aggregatedIter
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index ddb5df4036..65a71e5a83 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -190,6 +190,11 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}")
}
}
+
+ // sortByKey - should spill ~17 times
+ val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i))
+ val resultE = rddE.sortByKey().collect().toSeq
+ assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq)
}
test("spilling in local cluster with many reduce tasks") {
@@ -256,6 +261,11 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext {
fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}")
}
}
+
+ // sortByKey - should spill ~8 times per executor
+ val rddE = sc.parallelize(0 until 100000).map(i => (i/4, i))
+ val resultE = rddE.sortByKey().collect().toSeq
+ assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq)
}
test("cleanup of intermediate files in sorter") {