aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@databricks.com>2014-07-30 18:07:59 -0700
committerReynold Xin <rxin@apache.org>2014-07-30 18:07:59 -0700
commite966284409f9355e1169960e73a2215617c8cb22 (patch)
tree2e2ad582ff8fa55a8d1cc747cf8a833c3de77dff
parentda501766834453c9ac7095c7e8c930151f87cf11 (diff)
downloadspark-e966284409f9355e1169960e73a2215617c8cb22.tar.gz
spark-e966284409f9355e1169960e73a2215617c8cb22.tar.bz2
spark-e966284409f9355e1169960e73a2215617c8cb22.zip
SPARK-2045 Sort-based shuffle
This adds a new ShuffleManager based on sorting, as described in https://issues.apache.org/jira/browse/SPARK-2045. The bulk of the code is in an ExternalSorter class that is similar to ExternalAppendOnlyMap, but sorts key-value pairs by partition ID and can be used to create a single sorted file with a map task's output. (Longer-term I think this can take on the remaining functionality in ExternalAppendOnlyMap and replace it so we don't have code duplication.) The main TODOs still left are: - [x] enabling ExternalSorter to merge across spilled files - [x] with an Ordering - [x] without an Ordering, using the keys' hash codes - [x] adding more tests (e.g. a version of our shuffle suite that runs on this) - [x] rebasing on top of the size-tracking refactoring in #1165 when that is merged - [x] disabling spilling if spark.shuffle.spill is set to false Despite this though, this seems to work pretty well (running successfully in cases where the hash shuffle would OOM, such as 1000 reduce tasks on executors with only 1G memory), and it seems to be comparable in speed or faster than hash-based shuffle (it will create much fewer files for the OS to keep track of). So I'm posting it to get some early feedback. After these TODOs are done, I'd also like to enable ExternalSorter to sort data within each partition by a key as well, which will allow us to use it to implement external spilling in reduce tasks in `sortByKey`. Author: Matei Zaharia <matei@databricks.com> Closes #1499 from mateiz/sort-based-shuffle and squashes the following commits: bd841f9 [Matei Zaharia] Various review comments d1c137fd [Matei Zaharia] Various review comments a611159 [Matei Zaharia] Compile fixes due to rebase 62c56c8 [Matei Zaharia] Fix ShuffledRDD sometimes not returning Tuple2s. f617432 [Matei Zaharia] Fix a failing test (seems to be due to change in SizeTracker logic) 9464d5f [Matei Zaharia] Simplify code and fix conflicts after latest rebase 0174149 [Matei Zaharia] Add cleanup behavior and cleanup tests for sort-based shuffle eb4ee0d [Matei Zaharia] Remove customizable element type in ShuffledRDD fa2e8db [Matei Zaharia] Allow nextBatchStream to be called after we're done looking at all streams a34b352 [Matei Zaharia] Fix tracking of indices within a partition in SpillReader, and add test 03e1006 [Matei Zaharia] Add a SortShuffleSuite that runs ShuffleSuite with sort-based shuffle 3c7ff1f [Matei Zaharia] Obey the spark.shuffle.spill setting in ExternalSorter ad65fbd [Matei Zaharia] Rebase on top of Aaron's Sorter change, and use Sorter in our buffer 44d2a93 [Matei Zaharia] Use estimateSize instead of atGrowThreshold to test collection sizes 5686f71 [Matei Zaharia] Optimize merging phase for in-memory only data: 5461cbb [Matei Zaharia] Review comments and more tests (e.g. tests with 1 element per partition) e9ad356 [Matei Zaharia] Update ContextCleanerSuite to make sure shuffle cleanup tests use hash shuffle (since they were written for it) c72362a [Matei Zaharia] Added bug fix and test for when iterators are empty de1fb40 [Matei Zaharia] Make trait SizeTrackingCollection private[spark] 4988d16 [Matei Zaharia] tweak c1b7572 [Matei Zaharia] Small optimization ba7db7f [Matei Zaharia] Handle null keys in hash-based comparator, and add tests for collisions ef4e397 [Matei Zaharia] Support for partial aggregation even without an Ordering 4b7a5ce [Matei Zaharia] More tests, and ability to sort data if a total ordering is given e1f84be [Matei Zaharia] Fix disk block manager test 5a40a1c [Matei Zaharia] More tests 614f1b4 [Matei Zaharia] Add spill metrics to map tasks cc52caf [Matei Zaharia] Add more error handling and tests for error cases bbf359d [Matei Zaharia] More work 3a56341 [Matei Zaharia] More partial work towards sort-based shuffle 7a0895d [Matei Zaharia] Some more partial work towards sort-based shuffle b615476 [Matei Zaharia] Scaffolding for sort-based shuffle
-rw-r--r--core/src/main/scala/org/apache/spark/Aggregator.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala80
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala165
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockId.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala38
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala36
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala662
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala86
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala34
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala186
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/SortShuffleSuite.scala34
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala25
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala566
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala25
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala2
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala2
-rw-r--r--project/SparkBuild.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala2
35 files changed, 1969 insertions, 159 deletions
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index ff0ca11749..79c9c451d2 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -56,18 +56,23 @@ case class Aggregator[K, V, C] (
} else {
val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
combiners.insertAll(iter)
- // TODO: Make this non optional in a future release
- Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled)
- Option(context).foreach(c => c.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled)
+ // Update task metrics if context is not null
+ // TODO: Make context non optional in a future release
+ Option(context).foreach { c =>
+ c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled
+ c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled
+ }
combiners.iterator
}
}
@deprecated("use combineCombinersByKey with TaskContext argument", "0.9.0")
- def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] =
+ def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]]) : Iterator[(K, C)] =
combineCombinersByKey(iter, null)
- def combineCombinersByKey(iter: Iterator[(K, C)], context: TaskContext) : Iterator[(K, C)] = {
+ def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext)
+ : Iterator[(K, C)] =
+ {
if (!externalSorting) {
val combiners = new AppendOnlyMap[K,C]
var kc: Product2[K, C] = null
@@ -85,9 +90,12 @@ case class Aggregator[K, V, C] (
val pair = iter.next()
combiners.insert(pair._1, pair._2)
}
- // TODO: Make this non optional in a future release
- Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled)
- Option(context).foreach(c => c.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled)
+ // Update task metrics if context is not null
+ // TODO: Make context non-optional in a future release
+ Option(context).foreach { c =>
+ c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled
+ c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled
+ }
combiners.iterator
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index fb4c86716b..b25f081761 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -289,7 +289,7 @@ class SparkContext(config: SparkConf) extends Logging {
value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} {
executorEnvs(envKey) = value
}
- Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v =>
+ Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v =>
executorEnvs("SPARK_PREPEND_CLASSES") = v
}
// The Mesos scheduler backend relies on this environment variable to set executor memory.
@@ -1203,10 +1203,10 @@ class SparkContext(config: SparkConf) extends Logging {
/**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
- * If <tt>checkSerializable</tt> is set, <tt>clean</tt> will also proactively
- * check to see if <tt>f</tt> is serializable and throw a <tt>SparkException</tt>
+ * If <tt>checkSerializable</tt> is set, <tt>clean</tt> will also proactively
+ * check to see if <tt>f</tt> is serializable and throw a <tt>SparkException</tt>
* if not.
- *
+ *
* @param f the closure to clean
* @param checkSerializable whether or not to immediately check <tt>f</tt> for serializability
* @throws <tt>SparkException<tt> if <tt>checkSerializable</tt> is set but <tt>f</tt> is not
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 31bf8dced2..47708cb2e7 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -122,7 +122,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
*/
def sample(withReplacement: Boolean, fraction: Double): JavaPairRDD[K, V] =
sample(withReplacement, fraction, Utils.random.nextLong)
-
+
/**
* Return a sampled subset of this RDD.
*/
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 6388ef82cc..fabb882cdd 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -17,10 +17,11 @@
package org.apache.spark.rdd
+import scala.language.existentials
+
import java.io.{IOException, ObjectOutputStream}
import scala.collection.mutable.ArrayBuffer
-import scala.language.existentials
import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
@@ -157,8 +158,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
for ((it, depNum) <- rddIterators) {
map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum))))
}
- context.taskMetrics.memoryBytesSpilled = map.memoryBytesSpilled
- context.taskMetrics.diskBytesSpilled = map.diskBytesSpilled
+ context.taskMetrics.memoryBytesSpilled += map.memoryBytesSpilled
+ context.taskMetrics.diskBytesSpilled += map.diskBytesSpilled
new InterruptibleIterator(context,
map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
index d85f962783..e98bad2026 100644
--- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
@@ -20,6 +20,7 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag
import org.apache.spark.{Logging, RangePartitioner}
+import org.apache.spark.annotation.DeveloperApi
/**
* Extra functions available on RDDs of (key, value) pairs where the key is sortable through
@@ -43,10 +44,10 @@ import org.apache.spark.{Logging, RangePartitioner}
*/
class OrderedRDDFunctions[K : Ordering : ClassTag,
V: ClassTag,
- P <: Product2[K, V] : ClassTag](
+ P <: Product2[K, V] : ClassTag] @DeveloperApi() (
self: RDD[P])
- extends Logging with Serializable {
-
+ extends Logging with Serializable
+{
private val ordering = implicitly[Ordering[K]]
/**
@@ -55,9 +56,12 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
* (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
* order of the keys).
*/
- def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = {
+ // TODO: this currently doesn't work on P other than Tuple2!
+ def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size)
+ : RDD[(K, V)] =
+ {
val part = new RangePartitioner(numPartitions, self, ascending)
- new ShuffledRDD[K, V, V, P](self, part)
+ new ShuffledRDD[K, V, V](self, part)
.setKeyOrdering(if (ascending) ordering else ordering.reverse)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 1af4e5f0b6..93af50c0a9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -90,7 +90,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
}, preservesPartitioning = true)
} else {
- new ShuffledRDD[K, V, C, (K, C)](self, partitioner)
+ new ShuffledRDD[K, V, C](self, partitioner)
.setSerializer(serializer)
.setAggregator(aggregator)
.setMapSideCombine(mapSideCombine)
@@ -425,7 +425,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
if (self.partitioner == Some(partitioner)) {
self
} else {
- new ShuffledRDD[K, V, V, (K, V)](self, partitioner)
+ new ShuffledRDD[K, V, V](self, partitioner)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 726b3f2bbe..74ac97091f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -332,7 +332,7 @@ abstract class RDD[T: ClassTag](
val distributePartition = (index: Int, items: Iterator[T]) => {
var position = (new Random(index)).nextInt(numPartitions)
items.map { t =>
- // Note that the hash code of the key will just be the key itself. The HashPartitioner
+ // Note that the hash code of the key will just be the key itself. The HashPartitioner
// will mod it with the number of total partitions.
position = position + 1
(position, t)
@@ -341,7 +341,7 @@ abstract class RDD[T: ClassTag](
// include a shuffle step so that our upstream tasks are still distributed
new CoalescedRDD(
- new ShuffledRDD[Int, T, T, (Int, T)](mapPartitionsWithIndex(distributePartition),
+ new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition),
new HashPartitioner(numPartitions)),
numPartitions).values
} else {
@@ -352,8 +352,8 @@ abstract class RDD[T: ClassTag](
/**
* Return a sampled subset of this RDD.
*/
- def sample(withReplacement: Boolean,
- fraction: Double,
+ def sample(withReplacement: Boolean,
+ fraction: Double,
seed: Long = Utils.random.nextLong): RDD[T] = {
require(fraction >= 0.0, "Negative fraction value: " + fraction)
if (withReplacement) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index bf02f68d0d..d9fe684725 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -37,11 +37,12 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
* @tparam V the value class.
* @tparam C the combiner class.
*/
+// TODO: Make this return RDD[Product2[K, C]] or have some way to configure mutable pairs
@DeveloperApi
-class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
+class ShuffledRDD[K, V, C](
@transient var prev: RDD[_ <: Product2[K, V]],
part: Partitioner)
- extends RDD[P](prev.context, Nil) {
+ extends RDD[(K, C)](prev.context, Nil) {
private var serializer: Option[Serializer] = None
@@ -52,25 +53,25 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
private var mapSideCombine: Boolean = false
/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
- def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = {
+ def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C] = {
this.serializer = Option(serializer)
this
}
/** Set key ordering for RDD's shuffle. */
- def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C, P] = {
+ def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C] = {
this.keyOrdering = Option(keyOrdering)
this
}
/** Set aggregator for RDD's shuffle. */
- def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C, P] = {
+ def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C] = {
this.aggregator = Option(aggregator)
this
}
/** Set mapSideCombine flag for RDD's shuffle. */
- def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C, P] = {
+ def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C] = {
this.mapSideCombine = mapSideCombine
this
}
@@ -85,11 +86,11 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i))
}
- override def compute(split: Partition, context: TaskContext): Iterator[P] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
- .asInstanceOf[Iterator[P]]
+ .asInstanceOf[Iterator[(K, C)]]
}
override def clearDependencies() {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
index 5b0940ecce..df98d18fa8 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
@@ -24,7 +24,7 @@ import org.apache.spark.shuffle._
* A ShuffleManager using hashing, that creates one output file per reduce partition on each
* mapper (possibly reusing these across waves of tasks).
*/
-class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
+private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
/* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */
override def registerShuffle[K, V, C](
shuffleId: Int,
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index c8059496a1..e32ad9c036 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -21,7 +21,7 @@ import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
-class HashShuffleReader[K, C](
+private[spark] class HashShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
endPartition: Int,
@@ -47,7 +47,8 @@ class HashShuffleReader[K, C](
} else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
throw new IllegalStateException("Aggregator is empty for map-side combine")
} else {
- iter
+ // Convert the Product2s to pairs since this is what downstream RDDs currently expect
+ iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
}
// Sort the output if there is a sort ordering defined.
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 9b78228519..1923f7c71a 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -24,7 +24,7 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
-class HashShuffleWriter[K, V](
+private[spark] class HashShuffleWriter[K, V](
handle: BaseShuffleHandle[K, V, _],
mapId: Int,
context: TaskContext)
@@ -33,6 +33,10 @@ class HashShuffleWriter[K, V](
private val dep = handle.dependency
private val numOutputSplits = dep.partitioner.numPartitions
private val metrics = context.taskMetrics
+
+ // Are we in the process of stopping? Because map tasks can call stop() with success = true
+ // and then call stop() with success = false if they get an exception, we want to make sure
+ // we don't try deleting files, etc twice.
private var stopping = false
private val blockManager = SparkEnv.get.blockManager
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
new file mode 100644
index 0000000000..6dcca47ea7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.io.{DataInputStream, FileInputStream}
+
+import org.apache.spark.shuffle._
+import org.apache.spark.{TaskContext, ShuffleDependency}
+import org.apache.spark.shuffle.hash.HashShuffleReader
+import org.apache.spark.storage.{DiskBlockManager, FileSegment, ShuffleBlockId}
+
+private[spark] class SortShuffleManager extends ShuffleManager {
+ /**
+ * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
+ */
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ new BaseShuffleHandle(shuffleId, numMaps, dependency)
+ }
+
+ /**
+ * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+ * Called on executors by reduce tasks.
+ */
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext): ShuffleReader[K, C] = {
+ // We currently use the same block store shuffle fetcher as the hash-based shuffle.
+ new HashShuffleReader(
+ handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
+ }
+
+ /** Get a writer for a given partition. Called on executors by map tasks. */
+ override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
+ : ShuffleWriter[K, V] = {
+ new SortShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
+ }
+
+ /** Remove a shuffle's metadata from the ShuffleManager. */
+ override def unregisterShuffle(shuffleId: Int): Unit = {}
+
+ /** Shut down this ShuffleManager. */
+ override def stop(): Unit = {}
+
+ /** Get the location of a block in a map output file. Uses the index file we create for it. */
+ def getBlockLocation(blockId: ShuffleBlockId, diskManager: DiskBlockManager): FileSegment = {
+ // The block is actually going to be a range of a single map output file for this map, so
+ // figure out the ID of the consolidated file, then the offset within that from our index
+ val consolidatedId = blockId.copy(reduceId = 0)
+ val indexFile = diskManager.getFile(consolidatedId.name + ".index")
+ val in = new DataInputStream(new FileInputStream(indexFile))
+ try {
+ in.skip(blockId.reduceId * 8)
+ val offset = in.readLong()
+ val nextOffset = in.readLong()
+ new FileSegment(diskManager.getFile(consolidatedId), offset, nextOffset - offset)
+ } finally {
+ in.close()
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
new file mode 100644
index 0000000000..42fcd07fa1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.io.{BufferedOutputStream, File, FileOutputStream, DataOutputStream}
+
+import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext}
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.{ShuffleWriter, BaseShuffleHandle}
+import org.apache.spark.storage.ShuffleBlockId
+import org.apache.spark.util.collection.ExternalSorter
+
+private[spark] class SortShuffleWriter[K, V, C](
+ handle: BaseShuffleHandle[K, V, C],
+ mapId: Int,
+ context: TaskContext)
+ extends ShuffleWriter[K, V] with Logging {
+
+ private val dep = handle.dependency
+ private val numPartitions = dep.partitioner.numPartitions
+
+ private val blockManager = SparkEnv.get.blockManager
+ private val ser = Serializer.getSerializer(dep.serializer.orNull)
+
+ private val conf = SparkEnv.get.conf
+ private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
+
+ private var sorter: ExternalSorter[K, V, _] = null
+ private var outputFile: File = null
+
+ // Are we in the process of stopping? Because map tasks can call stop() with success = true
+ // and then call stop() with success = false if they get an exception, we want to make sure
+ // we don't try deleting files, etc twice.
+ private var stopping = false
+
+ private var mapStatus: MapStatus = null
+
+ /** Write a bunch of records to this task's output */
+ override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
+ // Get an iterator with the elements for each partition ID
+ val partitions: Iterator[(Int, Iterator[Product2[K, _]])] = {
+ if (dep.mapSideCombine) {
+ if (!dep.aggregator.isDefined) {
+ throw new IllegalStateException("Aggregator is empty for map-side combine")
+ }
+ sorter = new ExternalSorter[K, V, C](
+ dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
+ sorter.write(records)
+ sorter.partitionedIterator
+ } else {
+ // In this case we pass neither an aggregator nor an ordering to the sorter, because we
+ // don't care whether the keys get sorted in each partition; that will be done on the
+ // reduce side if the operation being run is sortByKey.
+ sorter = new ExternalSorter[K, V, V](
+ None, Some(dep.partitioner), None, dep.serializer)
+ sorter.write(records)
+ sorter.partitionedIterator
+ }
+ }
+
+ // Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later
+ // serve different ranges of this file using an index file that we create at the end.
+ val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0)
+ outputFile = blockManager.diskBlockManager.getFile(blockId)
+
+ // Track location of each range in the output file
+ val offsets = new Array[Long](numPartitions + 1)
+ val lengths = new Array[Long](numPartitions)
+
+ // Statistics
+ var totalBytes = 0L
+ var totalTime = 0L
+
+ for ((id, elements) <- partitions) {
+ if (elements.hasNext) {
+ val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize)
+ for (elem <- elements) {
+ writer.write(elem)
+ }
+ writer.commit()
+ writer.close()
+ val segment = writer.fileSegment()
+ offsets(id + 1) = segment.offset + segment.length
+ lengths(id) = segment.length
+ totalTime += writer.timeWriting()
+ totalBytes += segment.length
+ } else {
+ // The partition is empty; don't create a new writer to avoid writing headers, etc
+ offsets(id + 1) = offsets(id)
+ }
+ }
+
+ val shuffleMetrics = new ShuffleWriteMetrics
+ shuffleMetrics.shuffleBytesWritten = totalBytes
+ shuffleMetrics.shuffleWriteTime = totalTime
+ context.taskMetrics.shuffleWriteMetrics = Some(shuffleMetrics)
+ context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
+ context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
+
+ // Write an index file with the offsets of each block, plus a final offset at the end for the
+ // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure
+ // out where each block begins and ends.
+
+ val diskBlockManager = blockManager.diskBlockManager
+ val indexFile = diskBlockManager.getFile(blockId.name + ".index")
+ val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
+ try {
+ var i = 0
+ while (i < numPartitions + 1) {
+ out.writeLong(offsets(i))
+ i += 1
+ }
+ } finally {
+ out.close()
+ }
+
+ // Register our map output with the ShuffleBlockManager, which handles cleaning it over time
+ blockManager.shuffleBlockManager.addCompletedMap(dep.shuffleId, mapId, numPartitions)
+
+ mapStatus = new MapStatus(blockManager.blockManagerId,
+ lengths.map(MapOutputTracker.compressSize))
+ }
+
+ /** Close this writer, passing along whether the map completed */
+ override def stop(success: Boolean): Option[MapStatus] = {
+ try {
+ if (stopping) {
+ return None
+ }
+ stopping = true
+ if (success) {
+ return Option(mapStatus)
+ } else {
+ // The map task failed, so delete our output file if we created one
+ if (outputFile != null) {
+ outputFile.delete()
+ }
+ return None
+ }
+ } finally {
+ // Clean up our sorter, which may have its own intermediate files
+ if (sorter != null) {
+ sorter.stop()
+ sorter = null
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 42ec181b00..c1756ac905 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -54,12 +54,16 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
}
@DeveloperApi
-case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int)
- extends BlockId {
+case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
}
@DeveloperApi
+case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
+ def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index"
+}
+
+@DeveloperApi
case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId {
def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field)
}
@@ -88,6 +92,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId {
object BlockId {
val RDD = "rdd_([0-9]+)_([0-9]+)".r
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
+ val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r
val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
val TASKRESULT = "taskresult_([0-9]+)".r
val STREAM = "input-([0-9]+)-([0-9]+)".r
@@ -99,6 +104,8 @@ object BlockId {
RDDBlockId(rddId.toInt, splitIndex.toInt)
case SHUFFLE(shuffleId, mapId, reduceId) =>
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
+ case SHUFFLE_INDEX(shuffleId, mapId, reduceId) =>
+ ShuffleIndexBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
case BROADCAST(broadcastId, field) =>
BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_"))
case TASKRESULT(taskId) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index 2e7ed7538e..4d66ccea21 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -21,10 +21,11 @@ import java.io.File
import java.text.SimpleDateFormat
import java.util.{Date, Random, UUID}
-import org.apache.spark.Logging
+import org.apache.spark.{SparkEnv, Logging}
import org.apache.spark.executor.ExecutorExitCode
import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
import org.apache.spark.util.Utils
+import org.apache.spark.shuffle.sort.SortShuffleManager
/**
* Creates and maintains the logical mapping between logical blocks and physical on-disk
@@ -34,11 +35,13 @@ import org.apache.spark.util.Utils
*
* @param rootDirs The directories to use for storing block files. Data will be hashed among these.
*/
-private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootDirs: String)
+private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, rootDirs: String)
extends PathResolver with Logging {
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- private val subDirsPerLocalDir = shuffleManager.conf.getInt("spark.diskStore.subDirectories", 64)
+
+ private val subDirsPerLocalDir =
+ shuffleBlockManager.conf.getInt("spark.diskStore.subDirectories", 64)
/* Create one local directory for each path mentioned in spark.local.dir; then, inside this
* directory, create multiple subdirectories that we will hash files into, in order to avoid
@@ -54,13 +57,19 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
addShutdownHook()
/**
- * Returns the physical file segment in which the given BlockId is located.
- * If the BlockId has been mapped to a specific FileSegment, that will be returned.
- * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
+ * Returns the physical file segment in which the given BlockId is located. If the BlockId has
+ * been mapped to a specific FileSegment by the shuffle layer, that will be returned.
+ * Otherwise, we assume the Block is mapped to the whole file identified by the BlockId.
*/
def getBlockLocation(blockId: BlockId): FileSegment = {
- if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) {
- shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])
+ val env = SparkEnv.get // NOTE: can be null in unit tests
+ if (blockId.isShuffle && env != null && env.shuffleManager.isInstanceOf[SortShuffleManager]) {
+ // For sort-based shuffle, let it figure out its blocks
+ val sortShuffleManager = env.shuffleManager.asInstanceOf[SortShuffleManager]
+ sortShuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId], this)
+ } else if (blockId.isShuffle && shuffleBlockManager.consolidateShuffleFiles) {
+ // For hash-based shuffle with consolidated files, ShuffleBlockManager takes care of this
+ shuffleBlockManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])
} else {
val file = getFile(blockId.name)
new FileSegment(file, 0, file.length())
@@ -99,13 +108,18 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
getBlockLocation(blockId).file.exists()
}
- /** List all the blocks currently stored on disk by the disk manager. */
- def getAllBlocks(): Seq[BlockId] = {
+ /** List all the files currently stored on disk by the disk manager. */
+ def getAllFiles(): Seq[File] = {
// Get all the files inside the array of array of directories
subDirs.flatten.filter(_ != null).flatMap { dir =>
- val files = dir.list()
+ val files = dir.listFiles()
if (files != null) files else Seq.empty
- }.map(BlockId.apply)
+ }
+ }
+
+ /** List all the blocks currently stored on disk by the disk manager. */
+ def getAllBlocks(): Seq[BlockId] = {
+ getAllFiles().map(f => BlockId(f.getName))
}
/** Produces a unique block id and File suitable for intermediate results. */
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index 35910e552f..7beb55c411 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -28,6 +28,7 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
+import org.apache.spark.shuffle.sort.SortShuffleManager
/** A group of writers for a ShuffleMapTask, one writer per reducer. */
private[spark] trait ShuffleWriterGroup {
@@ -58,6 +59,7 @@ private[spark] trait ShuffleWriterGroup {
* each block stored in each file. In order to find the location of a shuffle block, we search the
* files within a ShuffleFileGroups associated with the block's reducer.
*/
+// TODO: Factor this into a separate class for each ShuffleManager implementation
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
def conf = blockManager.conf
@@ -67,6 +69,10 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
val consolidateShuffleFiles =
conf.getBoolean("spark.shuffle.consolidateFiles", false)
+ // Are we using sort-based shuffle?
+ val sortBasedShuffle =
+ conf.get("spark.shuffle.manager", "") == classOf[SortShuffleManager].getName
+
private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
/**
@@ -91,6 +97,20 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
private val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf)
+ /**
+ * Register a completed map without getting a ShuffleWriterGroup. Used by sort-based shuffle
+ * because it just writes a single file by itself.
+ */
+ def addCompletedMap(shuffleId: Int, mapId: Int, numBuckets: Int): Unit = {
+ shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
+ val shuffleState = shuffleStates(shuffleId)
+ shuffleState.completedMapTasks.add(mapId)
+ }
+
+ /**
+ * Get a ShuffleWriterGroup for the given map task, which will register it as complete
+ * when the writers are closed successfully
+ */
def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
new ShuffleWriterGroup {
shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
@@ -182,7 +202,14 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = {
shuffleStates.get(shuffleId) match {
case Some(state) =>
- if (consolidateShuffleFiles) {
+ if (sortBasedShuffle) {
+ // There's a single block ID for each map, plus an index file for it
+ for (mapId <- state.completedMapTasks) {
+ val blockId = new ShuffleBlockId(shuffleId, mapId, 0)
+ blockManager.diskBlockManager.getFile(blockId).delete()
+ blockManager.diskBlockManager.getFile(blockId.name + ".index").delete()
+ }
+ } else if (consolidateShuffleFiles) {
for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
file.delete()
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 6f263c39d1..b34512ef9e 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -79,12 +79,16 @@ class ExternalAppendOnlyMap[K, V, C](
(Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
}
- // Number of pairs in the in-memory map
- private var numPairsInMemory = 0L
+ // Number of pairs inserted since last spill; note that we count them even if a value is merged
+ // with a previous key in case we're doing something like groupBy where the result grows
+ private var elementsRead = 0L
// Number of in-memory pairs inserted before tracking the map's shuffle memory usage
private val trackMemoryThreshold = 1000
+ // How much of the shared memory pool this collection has claimed
+ private var myMemoryThreshold = 0L
+
/**
* Size of object batches when reading/writing from serializers.
*
@@ -106,7 +110,6 @@ class ExternalAppendOnlyMap[K, V, C](
private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
private val keyComparator = new HashComparator[K]
private val ser = serializer.newInstance()
- private val threadId = Thread.currentThread().getId
/**
* Insert the given key and value into the map.
@@ -134,31 +137,35 @@ class ExternalAppendOnlyMap[K, V, C](
while (entries.hasNext) {
curEntry = entries.next()
- if (numPairsInMemory > trackMemoryThreshold && currentMap.atGrowThreshold) {
- val mapSize = currentMap.estimateSize()
+ if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
+ currentMap.estimateSize() >= myMemoryThreshold)
+ {
+ val currentSize = currentMap.estimateSize()
var shouldSpill = false
val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
// Atomically check whether there is sufficient memory in the global pool for
// this map to grow and, if possible, allocate the required amount
shuffleMemoryMap.synchronized {
+ val threadId = Thread.currentThread().getId
val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId)
val availableMemory = maxMemoryThreshold -
(shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L))
- // Assume map growth factor is 2x
- shouldSpill = availableMemory < mapSize * 2
+ // Try to allocate at least 2x more memory, otherwise spill
+ shouldSpill = availableMemory < currentSize * 2
if (!shouldSpill) {
- shuffleMemoryMap(threadId) = mapSize * 2
+ shuffleMemoryMap(threadId) = currentSize * 2
+ myMemoryThreshold = currentSize * 2
}
}
// Do not synchronize spills
if (shouldSpill) {
- spill(mapSize)
+ spill(currentSize)
}
}
currentMap.changeValue(curEntry._1, update)
- numPairsInMemory += 1
+ elementsRead += 1
}
}
@@ -178,9 +185,10 @@ class ExternalAppendOnlyMap[K, V, C](
/**
* Sort the existing contents of the in-memory map and spill them to a temporary file on disk.
*/
- private def spill(mapSize: Long) {
+ private def spill(mapSize: Long): Unit = {
spillCount += 1
- logWarning("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)"
+ val threadId = Thread.currentThread().getId
+ logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)"
.format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
val (blockId, file) = diskBlockManager.createTempBlock()
var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
@@ -227,7 +235,9 @@ class ExternalAppendOnlyMap[K, V, C](
shuffleMemoryMap.synchronized {
shuffleMemoryMap(Thread.currentThread().getId) = 0
}
- numPairsInMemory = 0
+ myMemoryThreshold = 0
+
+ elementsRead = 0
_memoryBytesSpilled += mapSize
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
new file mode 100644
index 0000000000..54c3310744
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -0,0 +1,662 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.io._
+import java.util.Comparator
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable
+
+import com.google.common.io.ByteStreams
+
+import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner}
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.storage.BlockId
+
+/**
+ * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
+ * pairs of type (K, C). Uses a Partitioner to first group the keys into partitions, and then
+ * optionally sorts keys within each partition using a custom Comparator. Can output a single
+ * partitioned file with a different byte range for each partition, suitable for shuffle fetches.
+ *
+ * If combining is disabled, the type C must equal V -- we'll cast the objects at the end.
+ *
+ * @param aggregator optional Aggregator with combine functions to use for merging data
+ * @param partitioner optional Partitioner; if given, sort by partition ID and then key
+ * @param ordering optional Ordering to sort keys within each partition; should be a total ordering
+ * @param serializer serializer to use when spilling to disk
+ *
+ * Note that if an Ordering is given, we'll always sort using it, so only provide it if you really
+ * want the output keys to be sorted. In a map task without map-side combine for example, you
+ * probably want to pass None as the ordering to avoid extra sorting. On the other hand, if you do
+ * want to do combining, having an Ordering is more efficient than not having it.
+ *
+ * At a high level, this class works as follows:
+ *
+ * - We repeatedly fill up buffers of in-memory data, using either a SizeTrackingAppendOnlyMap if
+ * we want to combine by key, or an simple SizeTrackingBuffer if we don't. Inside these buffers,
+ * we sort elements of type ((Int, K), C) where the Int is the partition ID. This is done to
+ * avoid calling the partitioner multiple times on the same key (e.g. for RangePartitioner).
+ *
+ * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first
+ * by partition ID and possibly second by key or by hash code of the key, if we want to do
+ * aggregation. For each file, we track how many objects were in each partition in memory, so we
+ * don't have to write out the partition ID for every element.
+ *
+ * - When the user requests an iterator, the spilled files are merged, along with any remaining
+ * in-memory data, using the same sort order defined above (unless both sorting and aggregation
+ * are disabled). If we need to aggregate by key, we either use a total ordering from the
+ * ordering parameter, or read the keys with the same hash code and compare them with each other
+ * for equality to merge values.
+ *
+ * - Users are expected to call stop() at the end to delete all the intermediate files.
+ */
+private[spark] class ExternalSorter[K, V, C](
+ aggregator: Option[Aggregator[K, V, C]] = None,
+ partitioner: Option[Partitioner] = None,
+ ordering: Option[Ordering[K]] = None,
+ serializer: Option[Serializer] = None) extends Logging {
+
+ private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
+ private val shouldPartition = numPartitions > 1
+
+ private val blockManager = SparkEnv.get.blockManager
+ private val diskBlockManager = blockManager.diskBlockManager
+ private val ser = Serializer.getSerializer(serializer)
+ private val serInstance = ser.newInstance()
+
+ private val conf = SparkEnv.get.conf
+ private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true)
+ private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
+
+ // Size of object batches when reading/writing from serializers.
+ //
+ // Objects are written in batches, with each batch using its own serialization stream. This
+ // cuts down on the size of reference-tracking maps constructed when deserializing a stream.
+ //
+ // NOTE: Setting this too low can cause excessive copying when serializing, since some serializers
+ // grow internal data structures by growing + copying every time the number of objects doubles.
+ private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)
+
+ private def getPartition(key: K): Int = {
+ if (shouldPartition) partitioner.get.getPartition(key) else 0
+ }
+
+ // Data structures to store in-memory objects before we spill. Depending on whether we have an
+ // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
+ // store them in an array buffer.
+ private var map = new SizeTrackingAppendOnlyMap[(Int, K), C]
+ private var buffer = new SizeTrackingPairBuffer[(Int, K), C]
+
+ // Number of pairs read from input since last spill; note that we count them even if a value is
+ // merged with a previous key in case we're doing something like groupBy where the result grows
+ private var elementsRead = 0L
+
+ // What threshold of elementsRead we start estimating map size at.
+ private val trackMemoryThreshold = 1000
+
+ // Spilling statistics
+ private var spillCount = 0
+ private var _memoryBytesSpilled = 0L
+ private var _diskBytesSpilled = 0L
+
+ // Collective memory threshold shared across all running tasks
+ private val maxMemoryThreshold = {
+ val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2)
+ val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8)
+ (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
+ }
+
+ // How much of the shared memory pool this collection has claimed
+ private var myMemoryThreshold = 0L
+
+ // A comparator for keys K that orders them within a partition to allow aggregation or sorting.
+ // Can be a partial ordering by hash code if a total ordering is not provided through by the
+ // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some
+ // non-equal keys also have this, so we need to do a later pass to find truly equal keys).
+ // Note that we ignore this if no aggregator and no ordering are given.
+ private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] {
+ override def compare(a: K, b: K): Int = {
+ val h1 = if (a == null) 0 else a.hashCode()
+ val h2 = if (b == null) 0 else b.hashCode()
+ h1 - h2
+ }
+ })
+
+ // A comparator for (Int, K) elements that orders them by partition and then possibly by key
+ private val partitionKeyComparator: Comparator[(Int, K)] = {
+ if (ordering.isDefined || aggregator.isDefined) {
+ // Sort by partition ID then key comparator
+ new Comparator[(Int, K)] {
+ override def compare(a: (Int, K), b: (Int, K)): Int = {
+ val partitionDiff = a._1 - b._1
+ if (partitionDiff != 0) {
+ partitionDiff
+ } else {
+ keyComparator.compare(a._2, b._2)
+ }
+ }
+ }
+ } else {
+ // Just sort it by partition ID
+ new Comparator[(Int, K)] {
+ override def compare(a: (Int, K), b: (Int, K)): Int = {
+ a._1 - b._1
+ }
+ }
+ }
+ }
+
+ // Information about a spilled file. Includes sizes in bytes of "batches" written by the
+ // serializer as we periodically reset its stream, as well as number of elements in each
+ // partition, used to efficiently keep track of partitions when merging.
+ private[this] case class SpilledFile(
+ file: File,
+ blockId: BlockId,
+ serializerBatchSizes: Array[Long],
+ elementsPerPartition: Array[Long])
+ private val spills = new ArrayBuffer[SpilledFile]
+
+ def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
+ // TODO: stop combining if we find that the reduction factor isn't high
+ val shouldCombine = aggregator.isDefined
+
+ if (shouldCombine) {
+ // Combine values in-memory first using our AppendOnlyMap
+ val mergeValue = aggregator.get.mergeValue
+ val createCombiner = aggregator.get.createCombiner
+ var kv: Product2[K, V] = null
+ val update = (hadValue: Boolean, oldValue: C) => {
+ if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
+ }
+ while (records.hasNext) {
+ elementsRead += 1
+ kv = records.next()
+ map.changeValue((getPartition(kv._1), kv._1), update)
+ maybeSpill(usingMap = true)
+ }
+ } else {
+ // Stick values into our buffer
+ while (records.hasNext) {
+ elementsRead += 1
+ val kv = records.next()
+ buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
+ maybeSpill(usingMap = false)
+ }
+ }
+ }
+
+ /**
+ * Spill the current in-memory collection to disk if needed.
+ *
+ * @param usingMap whether we're using a map or buffer as our current in-memory collection
+ */
+ private def maybeSpill(usingMap: Boolean): Unit = {
+ if (!spillingEnabled) {
+ return
+ }
+
+ val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
+
+ // TODO: factor this out of both here and ExternalAppendOnlyMap
+ if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
+ collection.estimateSize() >= myMemoryThreshold)
+ {
+ // TODO: This logic doesn't work if there are two external collections being used in the same
+ // task (e.g. to read shuffle output and write it out into another shuffle) [SPARK-2711]
+
+ val currentSize = collection.estimateSize()
+ var shouldSpill = false
+ val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
+
+ // Atomically check whether there is sufficient memory in the global pool for
+ // us to double our threshold
+ shuffleMemoryMap.synchronized {
+ val threadId = Thread.currentThread().getId
+ val previouslyClaimedMemory = shuffleMemoryMap.get(threadId)
+ val availableMemory = maxMemoryThreshold -
+ (shuffleMemoryMap.values.sum - previouslyClaimedMemory.getOrElse(0L))
+
+ // Try to allocate at least 2x more memory, otherwise spill
+ shouldSpill = availableMemory < currentSize * 2
+ if (!shouldSpill) {
+ shuffleMemoryMap(threadId) = currentSize * 2
+ myMemoryThreshold = currentSize * 2
+ }
+ }
+ // Do not hold lock during spills
+ if (shouldSpill) {
+ spill(currentSize, usingMap)
+ }
+ }
+ }
+
+ /**
+ * Spill the current in-memory collection to disk, adding a new file to spills, and clear it.
+ *
+ * @param usingMap whether we're using a map or buffer as our current in-memory collection
+ */
+ private def spill(memorySize: Long, usingMap: Boolean): Unit = {
+ val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
+ val memorySize = collection.estimateSize()
+
+ spillCount += 1
+ val threadId = Thread.currentThread().getId
+ logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)"
+ .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
+ val (blockId, file) = diskBlockManager.createTempBlock()
+ var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize)
+ var objectsWritten = 0 // Objects written since the last flush
+
+ // List of batch sizes (bytes) in the order they are written to disk
+ val batchSizes = new ArrayBuffer[Long]
+
+ // How many elements we have in each partition
+ val elementsPerPartition = new Array[Long](numPartitions)
+
+ // Flush the disk writer's contents to disk, and update relevant variables
+ def flush() = {
+ writer.commit()
+ val bytesWritten = writer.bytesWritten
+ batchSizes.append(bytesWritten)
+ _diskBytesSpilled += bytesWritten
+ objectsWritten = 0
+ }
+
+ try {
+ val it = collection.destructiveSortedIterator(partitionKeyComparator)
+ while (it.hasNext) {
+ val elem = it.next()
+ val partitionId = elem._1._1
+ val key = elem._1._2
+ val value = elem._2
+ writer.write(key)
+ writer.write(value)
+ elementsPerPartition(partitionId) += 1
+ objectsWritten += 1
+
+ if (objectsWritten == serializerBatchSize) {
+ flush()
+ writer.close()
+ writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize)
+ }
+ }
+ if (objectsWritten > 0) {
+ flush()
+ }
+ writer.close()
+ } catch {
+ case e: Exception =>
+ writer.close()
+ file.delete()
+ throw e
+ }
+
+ if (usingMap) {
+ map = new SizeTrackingAppendOnlyMap[(Int, K), C]
+ } else {
+ buffer = new SizeTrackingPairBuffer[(Int, K), C]
+ }
+
+ // Reset the amount of shuffle memory used by this map in the global pool
+ val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
+ shuffleMemoryMap.synchronized {
+ shuffleMemoryMap(Thread.currentThread().getId) = 0
+ }
+ myMemoryThreshold = 0
+
+ spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
+ _memoryBytesSpilled += memorySize
+ }
+
+ /**
+ * Merge a sequence of sorted files, giving an iterator over partitions and then over elements
+ * inside each partition. This can be used to either write out a new file or return data to
+ * the user.
+ *
+ * Returns an iterator over all the data written to this object, grouped by partition. For each
+ * partition we then have an iterator over its contents, and these are expected to be accessed
+ * in order (you can't "skip ahead" to one partition without reading the previous one).
+ * Guaranteed to return a key-value pair for each partition, in order of partition ID.
+ */
+ private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
+ : Iterator[(Int, Iterator[Product2[K, C]])] = {
+ val readers = spills.map(new SpillReader(_))
+ val inMemBuffered = inMemory.buffered
+ (0 until numPartitions).iterator.map { p =>
+ val inMemIterator = new IteratorForPartition(p, inMemBuffered)
+ val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
+ if (aggregator.isDefined) {
+ // Perform partial aggregation across partitions
+ (p, mergeWithAggregation(
+ iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
+ } else if (ordering.isDefined) {
+ // No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey);
+ // sort the elements without trying to merge them
+ (p, mergeSort(iterators, ordering.get))
+ } else {
+ (p, iterators.iterator.flatten)
+ }
+ }
+ }
+
+ /**
+ * Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys.
+ */
+ private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K])
+ : Iterator[Product2[K, C]] =
+ {
+ val bufferedIters = iterators.filter(_.hasNext).map(_.buffered)
+ type Iter = BufferedIterator[Product2[K, C]]
+ val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] {
+ // Use the reverse of comparator.compare because PriorityQueue dequeues the max
+ override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1)
+ })
+ heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true
+ new Iterator[Product2[K, C]] {
+ override def hasNext: Boolean = !heap.isEmpty
+
+ override def next(): Product2[K, C] = {
+ if (!hasNext) {
+ throw new NoSuchElementException
+ }
+ val firstBuf = heap.dequeue()
+ val firstPair = firstBuf.next()
+ if (firstBuf.hasNext) {
+ heap.enqueue(firstBuf)
+ }
+ firstPair
+ }
+ }
+ }
+
+ /**
+ * Merge a sequence of (K, C) iterators by aggregating values for each key, assuming that each
+ * iterator is sorted by key with a given comparator. If the comparator is not a total ordering
+ * (e.g. when we sort objects by hash code and different keys may compare as equal although
+ * they're not), we still merge them by doing equality tests for all keys that compare as equal.
+ */
+ private def mergeWithAggregation(
+ iterators: Seq[Iterator[Product2[K, C]]],
+ mergeCombiners: (C, C) => C,
+ comparator: Comparator[K],
+ totalOrder: Boolean)
+ : Iterator[Product2[K, C]] =
+ {
+ if (!totalOrder) {
+ // We only have a partial ordering, e.g. comparing the keys by hash code, which means that
+ // multiple distinct keys might be treated as equal by the ordering. To deal with this, we
+ // need to read all keys considered equal by the ordering at once and compare them.
+ new Iterator[Iterator[Product2[K, C]]] {
+ val sorted = mergeSort(iterators, comparator).buffered
+
+ // Buffers reused across elements to decrease memory allocation
+ val keys = new ArrayBuffer[K]
+ val combiners = new ArrayBuffer[C]
+
+ override def hasNext: Boolean = sorted.hasNext
+
+ override def next(): Iterator[Product2[K, C]] = {
+ if (!hasNext) {
+ throw new NoSuchElementException
+ }
+ keys.clear()
+ combiners.clear()
+ val firstPair = sorted.next()
+ keys += firstPair._1
+ combiners += firstPair._2
+ val key = firstPair._1
+ while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) {
+ val pair = sorted.next()
+ var i = 0
+ var foundKey = false
+ while (i < keys.size && !foundKey) {
+ if (keys(i) == pair._1) {
+ combiners(i) = mergeCombiners(combiners(i), pair._2)
+ foundKey = true
+ }
+ i += 1
+ }
+ if (!foundKey) {
+ keys += pair._1
+ combiners += pair._2
+ }
+ }
+
+ // Note that we return an iterator of elements since we could've had many keys marked
+ // equal by the partial order; we flatten this below to get a flat iterator of (K, C).
+ keys.iterator.zip(combiners.iterator)
+ }
+ }.flatMap(i => i)
+ } else {
+ // We have a total ordering, so the objects with the same key are sequential.
+ new Iterator[Product2[K, C]] {
+ val sorted = mergeSort(iterators, comparator).buffered
+
+ override def hasNext: Boolean = sorted.hasNext
+
+ override def next(): Product2[K, C] = {
+ if (!hasNext) {
+ throw new NoSuchElementException
+ }
+ val elem = sorted.next()
+ val k = elem._1
+ var c = elem._2
+ while (sorted.hasNext && sorted.head._1 == k) {
+ c = mergeCombiners(c, sorted.head._2)
+ }
+ (k, c)
+ }
+ }
+ }
+ }
+
+ /**
+ * An internal class for reading a spilled file partition by partition. Expects all the
+ * partitions to be requested in order.
+ */
+ private[this] class SpillReader(spill: SpilledFile) {
+ val fileStream = new FileInputStream(spill.file)
+ val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)
+
+ // Track which partition and which batch stream we're in. These will be the indices of
+ // the next element we will read. We'll also store the last partition read so that
+ // readNextPartition() can figure out what partition that was from.
+ var partitionId = 0
+ var indexInPartition = 0L
+ var batchStreamsRead = 0
+ var indexInBatch = 0
+ var lastPartitionId = 0
+
+ skipToNextPartition()
+
+ // An intermediate stream that reads from exactly one batch
+ // This guards against pre-fetching and other arbitrary behavior of higher level streams
+ var batchStream = nextBatchStream()
+ var compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
+ var deserStream = serInstance.deserializeStream(compressedStream)
+ var nextItem: (K, C) = null
+ var finished = false
+
+ /** Construct a stream that only reads from the next batch */
+ def nextBatchStream(): InputStream = {
+ if (batchStreamsRead < spill.serializerBatchSizes.length) {
+ batchStreamsRead += 1
+ ByteStreams.limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1))
+ } else {
+ // No more batches left; give an empty stream
+ bufferedStream
+ }
+ }
+
+ /**
+ * Update partitionId if we have reached the end of our current partition, possibly skipping
+ * empty partitions on the way.
+ */
+ private def skipToNextPartition() {
+ while (partitionId < numPartitions &&
+ indexInPartition == spill.elementsPerPartition(partitionId)) {
+ partitionId += 1
+ indexInPartition = 0L
+ }
+ }
+
+ /**
+ * Return the next (K, C) pair from the deserialization stream and update partitionId,
+ * indexInPartition, indexInBatch and such to match its location.
+ *
+ * If the current batch is drained, construct a stream for the next batch and read from it.
+ * If no more pairs are left, return null.
+ */
+ private def readNextItem(): (K, C) = {
+ if (finished) {
+ return null
+ }
+ val k = deserStream.readObject().asInstanceOf[K]
+ val c = deserStream.readObject().asInstanceOf[C]
+ lastPartitionId = partitionId
+ // Start reading the next batch if we're done with this one
+ indexInBatch += 1
+ if (indexInBatch == serializerBatchSize) {
+ batchStream = nextBatchStream()
+ compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
+ deserStream = serInstance.deserializeStream(compressedStream)
+ indexInBatch = 0
+ }
+ // Update the partition location of the element we're reading
+ indexInPartition += 1
+ skipToNextPartition()
+ // If we've finished reading the last partition, remember that we're done
+ if (partitionId == numPartitions) {
+ finished = true
+ deserStream.close()
+ }
+ (k, c)
+ }
+
+ var nextPartitionToRead = 0
+
+ def readNextPartition(): Iterator[Product2[K, C]] = new Iterator[Product2[K, C]] {
+ val myPartition = nextPartitionToRead
+ nextPartitionToRead += 1
+
+ override def hasNext: Boolean = {
+ if (nextItem == null) {
+ nextItem = readNextItem()
+ if (nextItem == null) {
+ return false
+ }
+ }
+ assert(lastPartitionId >= myPartition)
+ // Check that we're still in the right partition; note that readNextItem will have returned
+ // null at EOF above so we would've returned false there
+ lastPartitionId == myPartition
+ }
+
+ override def next(): Product2[K, C] = {
+ if (!hasNext) {
+ throw new NoSuchElementException
+ }
+ val item = nextItem
+ nextItem = null
+ item
+ }
+ }
+ }
+
+ /**
+ * Return an iterator over all the data written to this object, grouped by partition and
+ * aggregated by the requested aggregator. For each partition we then have an iterator over its
+ * contents, and these are expected to be accessed in order (you can't "skip ahead" to one
+ * partition without reading the previous one). Guaranteed to return a key-value pair for each
+ * partition, in order of partition ID.
+ *
+ * For now, we just merge all the spilled files in once pass, but this can be modified to
+ * support hierarchical merging.
+ */
+ def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
+ val usingMap = aggregator.isDefined
+ val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
+ if (spills.isEmpty) {
+ // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
+ // we don't even need to sort by anything other than partition ID
+ if (!ordering.isDefined) {
+ // The user isn't requested sorted keys, so only sort by partition ID, not key
+ val partitionComparator = new Comparator[(Int, K)] {
+ override def compare(a: (Int, K), b: (Int, K)): Int = {
+ a._1 - b._1
+ }
+ }
+ groupByPartition(collection.destructiveSortedIterator(partitionComparator))
+ } else {
+ // We do need to sort by both partition ID and key
+ groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator))
+ }
+ } else {
+ // General case: merge spilled and in-memory data
+ merge(spills, collection.destructiveSortedIterator(partitionKeyComparator))
+ }
+ }
+
+ /**
+ * Return an iterator over all the data written to this object, aggregated by our aggregator.
+ */
+ def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2)
+
+ def stop(): Unit = {
+ spills.foreach(s => s.file.delete())
+ spills.clear()
+ }
+
+ def memoryBytesSpilled: Long = _memoryBytesSpilled
+
+ def diskBytesSpilled: Long = _diskBytesSpilled
+
+ /**
+ * Given a stream of ((partition, key), combiner) pairs *assumed to be sorted by partition ID*,
+ * group together the pairs for each partition into a sub-iterator.
+ *
+ * @param data an iterator of elements, assumed to already be sorted by partition ID
+ */
+ private def groupByPartition(data: Iterator[((Int, K), C)])
+ : Iterator[(Int, Iterator[Product2[K, C]])] =
+ {
+ val buffered = data.buffered
+ (0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered)))
+ }
+
+ /**
+ * An iterator that reads only the elements for a given partition ID from an underlying buffered
+ * stream, assuming this partition is the next one to be read. Used to make it easier to return
+ * partitioned iterators from our in-memory collection.
+ */
+ private[this] class IteratorForPartition(partitionId: Int, data: BufferedIterator[((Int, K), C)])
+ extends Iterator[Product2[K, C]]
+ {
+ override def hasNext: Boolean = data.hasNext && data.head._1._1 == partitionId
+
+ override def next(): Product2[K, C] = {
+ if (!hasNext) {
+ throw new NoSuchElementException
+ }
+ val elem = data.next()
+ (elem._1._2, elem._2)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
index de61e1d17f..eb4de41386 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
@@ -20,8 +20,9 @@ package org.apache.spark.util.collection
/**
* An append-only map that keeps track of its estimated size in bytes.
*/
-private[spark] class SizeTrackingAppendOnlyMap[K, V] extends AppendOnlyMap[K, V] with SizeTracker {
-
+private[spark] class SizeTrackingAppendOnlyMap[K, V]
+ extends AppendOnlyMap[K, V] with SizeTracker with SizeTrackingPairCollection[K, V]
+{
override def update(key: K, value: V): Unit = {
super.update(key, value)
super.afterUpdate()
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala
new file mode 100644
index 0000000000..9e9c16c5a2
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.util.Comparator
+
+/**
+ * Append-only buffer of key-value pairs that keeps track of its estimated size in bytes.
+ */
+private[spark] class SizeTrackingPairBuffer[K, V](initialCapacity: Int = 64)
+ extends SizeTracker with SizeTrackingPairCollection[K, V]
+{
+ require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
+ require(initialCapacity >= 1, "Invalid initial capacity")
+
+ // Basic growable array data structure. We use a single array of AnyRef to hold both the keys
+ // and the values, so that we can sort them efficiently with KVArraySortDataFormat.
+ private var capacity = initialCapacity
+ private var curSize = 0
+ private var data = new Array[AnyRef](2 * initialCapacity)
+
+ /** Add an element into the buffer */
+ def insert(key: K, value: V): Unit = {
+ if (curSize == capacity) {
+ growArray()
+ }
+ data(2 * curSize) = key.asInstanceOf[AnyRef]
+ data(2 * curSize + 1) = value.asInstanceOf[AnyRef]
+ curSize += 1
+ afterUpdate()
+ }
+
+ /** Total number of elements in buffer */
+ override def size: Int = curSize
+
+ /** Iterate over the elements of the buffer */
+ override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
+ var pos = 0
+
+ override def hasNext: Boolean = pos < curSize
+
+ override def next(): (K, V) = {
+ if (!hasNext) {
+ throw new NoSuchElementException
+ }
+ val pair = (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
+ pos += 1
+ pair
+ }
+ }
+
+ /** Double the size of the array because we've reached capacity */
+ private def growArray(): Unit = {
+ if (capacity == (1 << 29)) {
+ // Doubling the capacity would create an array bigger than Int.MaxValue, so don't
+ throw new Exception("Can't grow buffer beyond 2^29 elements")
+ }
+ val newCapacity = capacity * 2
+ val newArray = new Array[AnyRef](2 * newCapacity)
+ System.arraycopy(data, 0, newArray, 0, 2 * capacity)
+ data = newArray
+ capacity = newCapacity
+ resetSamples()
+ }
+
+ /** Iterate through the data in a given order. For this class this is not really destructive. */
+ override def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = {
+ new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, curSize, keyComparator)
+ iterator
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala
new file mode 100644
index 0000000000..faa4e2b12d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.util.Comparator
+
+/**
+ * A common interface for our size-tracking collections of key-value pairs, which are used in
+ * external operations. These all support estimating the size and obtaining a memory-efficient
+ * sorted iterator.
+ */
+// TODO: should extend Iterable[Product2[K, V]] instead of (K, V)
+private[spark] trait SizeTrackingPairCollection[K, V] extends Iterable[(K, V)] {
+ /** Estimate the collection's current memory usage in bytes. */
+ def estimateSize(): Long
+
+ /** Iterate through the data in a given key order. This may destroy the underlying collection. */
+ def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)]
+}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index d1cb2d9d3a..a41914a1a9 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -99,7 +99,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
test("ShuffledRDD") {
testRDD(rdd => {
// Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
- new ShuffledRDD[Int, Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner)
+ new ShuffledRDD[Int, Int, Int](rdd.map(x => (x % 2, 1)), partitioner)
})
}
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index ad20f9b937..4bc4346c0a 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -19,9 +19,6 @@ package org.apache.spark
import java.lang.ref.WeakReference
-import org.apache.spark.broadcast.Broadcast
-
-import scala.collection.mutable
import scala.collection.mutable.{HashSet, SynchronizedSet}
import scala.language.existentials
import scala.language.postfixOps
@@ -34,15 +31,28 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.{BlockId, BroadcastBlockId, RDDBlockId, ShuffleBlockId}
-
-class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
-
+import org.apache.spark.storage._
+import org.apache.spark.shuffle.hash.HashShuffleManager
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.storage.BroadcastBlockId
+import org.apache.spark.storage.RDDBlockId
+import org.apache.spark.storage.ShuffleBlockId
+import org.apache.spark.storage.ShuffleIndexBlockId
+
+/**
+ * An abstract base class for context cleaner tests, which sets up a context with a config
+ * suitable for cleaner tests and provides some utility functions. Subclasses can use different
+ * config options, in particular, a different shuffle manager class
+ */
+abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[HashShuffleManager])
+ extends FunSuite with BeforeAndAfter with LocalSparkContext
+{
implicit val defaultTimeout = timeout(10000 millis)
val conf = new SparkConf()
.setMaster("local[2]")
.setAppName("ContextCleanerSuite")
.set("spark.cleaner.referenceTracking.blocking", "true")
+ .set("spark.shuffle.manager", shuffleManager.getName)
before {
sc = new SparkContext(conf)
@@ -55,6 +65,59 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
}
}
+ //------ Helper functions ------
+
+ protected def newRDD() = sc.makeRDD(1 to 10)
+ protected def newPairRDD() = newRDD().map(_ -> 1)
+ protected def newShuffleRDD() = newPairRDD().reduceByKey(_ + _)
+ protected def newBroadcast() = sc.broadcast(1 to 100)
+
+ protected def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
+ def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
+ rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
+ getAllDependencies(dep.rdd)
+ }
+ }
+ val rdd = newShuffleRDD()
+
+ // Get all the shuffle dependencies
+ val shuffleDeps = getAllDependencies(rdd)
+ .filter(_.isInstanceOf[ShuffleDependency[_, _, _]])
+ .map(_.asInstanceOf[ShuffleDependency[_, _, _]])
+ (rdd, shuffleDeps)
+ }
+
+ protected def randomRdd() = {
+ val rdd: RDD[_] = Random.nextInt(3) match {
+ case 0 => newRDD()
+ case 1 => newShuffleRDD()
+ case 2 => newPairRDD.join(newPairRDD())
+ }
+ if (Random.nextBoolean()) rdd.persist()
+ rdd.count()
+ rdd
+ }
+
+ /** Run GC and make sure it actually has run */
+ protected def runGC() {
+ val weakRef = new WeakReference(new Object())
+ val startTime = System.currentTimeMillis
+ System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
+ // Wait until a weak reference object has been GCed
+ while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
+ System.gc()
+ Thread.sleep(200)
+ }
+ }
+
+ protected def cleaner = sc.cleaner.get
+}
+
+
+/**
+ * Basic ContextCleanerSuite, which uses sort-based shuffle
+ */
+class ContextCleanerSuite extends ContextCleanerSuiteBase {
test("cleanup RDD") {
val rdd = newRDD().persist()
val collected = rdd.collect().toList
@@ -147,7 +210,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
val numRdds = 100
val numBroadcasts = 4 // Broadcasts are more costly
val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
- val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer
+ val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast()).toBuffer
val rddIds = sc.persistentRdds.keys.toSeq
val shuffleIds = 0 until sc.newShuffleId
val broadcastIds = broadcastBuffer.map(_.id)
@@ -180,12 +243,13 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
.setMaster("local-cluster[2, 1, 512]")
.setAppName("ContextCleanerSuite")
.set("spark.cleaner.referenceTracking.blocking", "true")
+ .set("spark.shuffle.manager", shuffleManager.getName)
sc = new SparkContext(conf2)
val numRdds = 10
val numBroadcasts = 4 // Broadcasts are more costly
val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
- val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer
+ val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast()).toBuffer
val rddIds = sc.persistentRdds.keys.toSeq
val shuffleIds = 0 until sc.newShuffleId
val broadcastIds = broadcastBuffer.map(_.id)
@@ -210,57 +274,82 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
case _ => false
}, askSlaves = true).isEmpty)
}
+}
- //------ Helper functions ------
- private def newRDD() = sc.makeRDD(1 to 10)
- private def newPairRDD() = newRDD().map(_ -> 1)
- private def newShuffleRDD() = newPairRDD().reduceByKey(_ + _)
- private def newBroadcast() = sc.broadcast(1 to 100)
+/**
+ * A copy of the shuffle tests for sort-based shuffle
+ */
+class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[SortShuffleManager]) {
+ test("cleanup shuffle") {
+ val (rdd, shuffleDeps) = newRDDWithShuffleDependencies()
+ val collected = rdd.collect().toList
+ val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId))
- private def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
- def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
- rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
- getAllDependencies(dep.rdd)
- }
- }
- val rdd = newShuffleRDD()
+ // Explicit cleanup
+ shuffleDeps.foreach(s => cleaner.doCleanupShuffle(s.shuffleId, blocking = true))
+ tester.assertCleanup()
- // Get all the shuffle dependencies
- val shuffleDeps = getAllDependencies(rdd)
- .filter(_.isInstanceOf[ShuffleDependency[_, _, _]])
- .map(_.asInstanceOf[ShuffleDependency[_, _, _]])
- (rdd, shuffleDeps)
+ // Verify that shuffles can be re-executed after cleaning up
+ assert(rdd.collect().toList.equals(collected))
}
- private def randomRdd() = {
- val rdd: RDD[_] = Random.nextInt(3) match {
- case 0 => newRDD()
- case 1 => newShuffleRDD()
- case 2 => newPairRDD.join(newPairRDD())
- }
- if (Random.nextBoolean()) rdd.persist()
+ test("automatically cleanup shuffle") {
+ var rdd = newShuffleRDD()
rdd.count()
- rdd
- }
- private def randomBroadcast() = {
- sc.broadcast(Random.nextInt(Int.MaxValue))
+ // Test that GC does not cause shuffle cleanup due to a strong reference
+ val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
+ runGC()
+ intercept[Exception] {
+ preGCTester.assertCleanup()(timeout(1000 millis))
+ }
+
+ // Test that GC causes shuffle cleanup after dereferencing the RDD
+ val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
+ rdd = null // Make RDD out of scope, so that corresponding shuffle goes out of scope
+ runGC()
+ postGCTester.assertCleanup()
}
- /** Run GC and make sure it actually has run */
- private def runGC() {
- val weakRef = new WeakReference(new Object())
- val startTime = System.currentTimeMillis
- System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
- // Wait until a weak reference object has been GCed
- while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
- System.gc()
- Thread.sleep(200)
+ test("automatically cleanup RDD + shuffle + broadcast in distributed mode") {
+ sc.stop()
+
+ val conf2 = new SparkConf()
+ .setMaster("local-cluster[2, 1, 512]")
+ .setAppName("ContextCleanerSuite")
+ .set("spark.cleaner.referenceTracking.blocking", "true")
+ .set("spark.shuffle.manager", shuffleManager.getName)
+ sc = new SparkContext(conf2)
+
+ val numRdds = 10
+ val numBroadcasts = 4 // Broadcasts are more costly
+ val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
+ val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast).toBuffer
+ val rddIds = sc.persistentRdds.keys.toSeq
+ val shuffleIds = 0 until sc.newShuffleId()
+ val broadcastIds = broadcastBuffer.map(_.id)
+
+ val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+ runGC()
+ intercept[Exception] {
+ preGCTester.assertCleanup()(timeout(1000 millis))
}
- }
- private def cleaner = sc.cleaner.get
+ // Test that GC triggers the cleanup of all variables after the dereferencing them
+ val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+ broadcastBuffer.clear()
+ rddBuffer.clear()
+ runGC()
+ postGCTester.assertCleanup()
+
+ // Make sure the broadcasted task closure no longer exists after GC.
+ val taskClosureBroadcastId = broadcastIds.max + 1
+ assert(sc.env.blockManager.master.getMatchingBlockIds({
+ case BroadcastBlockId(`taskClosureBroadcastId`, _) => true
+ case _ => false
+ }, askSlaves = true).isEmpty)
+ }
}
@@ -418,6 +507,7 @@ class CleanerTester(
private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = {
blockManager.master.getMatchingBlockIds( _ match {
case ShuffleBlockId(`shuffleId`, _, _) => true
+ case ShuffleIndexBlockId(`shuffleId`, _, _) => true
case _ => false
}, askSlaves = true)
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
index 47df00050c..d7b2d2e1e3 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
@@ -28,6 +28,6 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll {
}
override def afterAll() {
- System.setProperty("spark.shuffle.use.netty", "false")
+ System.clearProperty("spark.shuffle.use.netty")
}
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index eae67c7747..b13ddf96bc 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -58,8 +58,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD[Int,
NonJavaSerializableClass,
- NonJavaSerializableClass,
- (Int, NonJavaSerializableClass)](b, new HashPartitioner(NUM_BLOCKS))
+ NonJavaSerializableClass](b, new HashPartitioner(NUM_BLOCKS))
c.setSerializer(new KryoSerializer(conf))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
@@ -83,8 +82,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD[Int,
NonJavaSerializableClass,
- NonJavaSerializableClass,
- (Int, NonJavaSerializableClass)](b, new HashPartitioner(3))
+ NonJavaSerializableClass](b, new HashPartitioner(3))
c.setSerializer(new KryoSerializer(conf))
assert(c.count === 10)
}
@@ -100,7 +98,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
// NOTE: The default Java serializer doesn't create zero-sized blocks.
// So, use Kryo
- val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10))
+ val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10))
.setSerializer(new KryoSerializer(conf))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
@@ -126,7 +124,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
val b = a.map(x => (x, x*2))
// NOTE: The default Java serializer should create zero-sized blocks
- val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10))
+ val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 4)
@@ -141,19 +139,19 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
assert(nonEmptyBlocks.size <= 4)
}
- test("shuffle using mutable pairs") {
+ test("shuffle on mutable pairs") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
sc = new SparkContext("local-cluster[2,1,512]", "test")
def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1))
val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2)
- val results = new ShuffledRDD[Int, Int, Int, MutablePair[Int, Int]](pairs,
+ val results = new ShuffledRDD[Int, Int, Int](pairs,
new HashPartitioner(2)).collect()
- data.foreach { pair => results should contain (pair) }
+ data.foreach { pair => results should contain ((pair._1, pair._2)) }
}
- test("sorting using mutable pairs") {
+ test("sorting on mutable pairs") {
// This is not in SortingSuite because of the local cluster setup.
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
sc = new SparkContext("local-cluster[2,1,512]", "test")
@@ -162,10 +160,10 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2)
val results = new OrderedRDDFunctions[Int, Int, MutablePair[Int, Int]](pairs)
.sortByKey().collect()
- results(0) should be (p(1, 11))
- results(1) should be (p(2, 22))
- results(2) should be (p(3, 33))
- results(3) should be (p(100, 100))
+ results(0) should be ((1, 11))
+ results(1) should be ((2, 22))
+ results(2) should be ((3, 33))
+ results(3) should be ((100, 100))
}
test("cogroup using mutable pairs") {
diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
new file mode 100644
index 0000000000..5c02c00586
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.BeforeAndAfterAll
+
+class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
+
+ // This test suite should run all tests in ShuffleSuite with sort-based shuffle.
+
+ override def beforeAll() {
+ System.setProperty("spark.shuffle.manager",
+ "org.apache.spark.shuffle.sort.SortShuffleManager")
+ }
+
+ override def afterAll() {
+ System.clearProperty("spark.shuffle.manager")
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 4953d565ae..8966eedd80 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -270,7 +270,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
// we can optionally shuffle to keep the upstream parallel
val coalesced5 = data.coalesce(1, shuffle = true)
val isEquals = coalesced5.dependencies.head.rdd.dependencies.head.rdd.
- asInstanceOf[ShuffledRDD[_, _, _, _]] != null
+ asInstanceOf[ShuffledRDD[_, _, _]] != null
assert(isEquals)
// when shuffling, we can increase the number of partitions
@@ -730,9 +730,9 @@ class RDDSuite extends FunSuite with SharedSparkContext {
// Any ancestors before the shuffle are not considered
assert(ancestors4.size === 0)
- assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 0)
+ assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 0)
assert(ancestors5.size === 3)
- assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 1)
+ assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 1)
assert(ancestors5.count(_.isInstanceOf[MapPartitionsRDD[_, _]]) === 0)
assert(ancestors5.count(_.isInstanceOf[MappedValuesRDD[_, _, _]]) === 2)
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 0b7ad184a4..7de5df6e1c 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -208,11 +208,8 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
val resultA = rddA.reduceByKey(math.max).collect()
assert(resultA.length == 50000)
resultA.foreach { case(k, v) =>
- k match {
- case 0 => assert(v == 1)
- case 25000 => assert(v == 50001)
- case 49999 => assert(v == 99999)
- case _ =>
+ if (v != k * 2 + 1) {
+ fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}")
}
}
@@ -221,11 +218,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
val resultB = rddB.groupByKey().collect()
assert(resultB.length == 25000)
resultB.foreach { case(i, seq) =>
- i match {
- case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3))
- case 12500 => assert(seq.toSet == Set[Int](50000, 50001, 50002, 50003))
- case 24999 => assert(seq.toSet == Set[Int](99996, 99997, 99998, 99999))
- case _ =>
+ val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3)
+ if (seq.toSet != expected) {
+ fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}")
}
}
@@ -239,6 +234,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
case 0 =>
assert(seq1.toSet == Set[Int](0))
assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
+ case 1 =>
+ assert(seq1.toSet == Set[Int](1))
+ assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001))
case 5000 =>
assert(seq1.toSet == Set[Int](5000))
assert(seq2.toSet == Set[Int]())
@@ -369,10 +367,3 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
}
}
-
-/**
- * A dummy class that always returns the same hash code, to easily test hash collisions
- */
-case class FixedHashObject(v: Int, h: Int) extends Serializable {
- override def hashCode(): Int = h
-}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
new file mode 100644
index 0000000000..ddb5df4036
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -0,0 +1,566 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.FunSuite
+
+import org.apache.spark._
+import org.apache.spark.SparkContext._
+
+class ExternalSorterSuite extends FunSuite with LocalSparkContext {
+ test("empty data stream") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+ val ord = implicitly[Ordering[Int]]
+
+ // Both aggregator and ordering
+ val sorter = new ExternalSorter[Int, Int, Int](
+ Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
+ assert(sorter.iterator.toSeq === Seq())
+ sorter.stop()
+
+ // Only aggregator
+ val sorter2 = new ExternalSorter[Int, Int, Int](
+ Some(agg), Some(new HashPartitioner(3)), None, None)
+ assert(sorter2.iterator.toSeq === Seq())
+ sorter2.stop()
+
+ // Only ordering
+ val sorter3 = new ExternalSorter[Int, Int, Int](
+ None, Some(new HashPartitioner(3)), Some(ord), None)
+ assert(sorter3.iterator.toSeq === Seq())
+ sorter3.stop()
+
+ // Neither aggregator nor ordering
+ val sorter4 = new ExternalSorter[Int, Int, Int](
+ None, Some(new HashPartitioner(3)), None, None)
+ assert(sorter4.iterator.toSeq === Seq())
+ sorter4.stop()
+ }
+
+ test("few elements per partition") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+ val ord = implicitly[Ordering[Int]]
+ val elements = Set((1, 1), (2, 2), (5, 5))
+ val expected = Set(
+ (0, Set()), (1, Set((1, 1))), (2, Set((2, 2))), (3, Set()), (4, Set()),
+ (5, Set((5, 5))), (6, Set()))
+
+ // Both aggregator and ordering
+ val sorter = new ExternalSorter[Int, Int, Int](
+ Some(agg), Some(new HashPartitioner(7)), Some(ord), None)
+ sorter.write(elements.iterator)
+ assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
+ sorter.stop()
+
+ // Only aggregator
+ val sorter2 = new ExternalSorter[Int, Int, Int](
+ Some(agg), Some(new HashPartitioner(7)), None, None)
+ sorter2.write(elements.iterator)
+ assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
+ sorter2.stop()
+
+ // Only ordering
+ val sorter3 = new ExternalSorter[Int, Int, Int](
+ None, Some(new HashPartitioner(7)), Some(ord), None)
+ sorter3.write(elements.iterator)
+ assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
+ sorter3.stop()
+
+ // Neither aggregator nor ordering
+ val sorter4 = new ExternalSorter[Int, Int, Int](
+ None, Some(new HashPartitioner(7)), None, None)
+ sorter4.write(elements.iterator)
+ assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
+ sorter4.stop()
+ }
+
+ test("empty partitions with spilling") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+ val ord = implicitly[Ordering[Int]]
+ val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
+
+ val sorter = new ExternalSorter[Int, Int, Int](
+ None, Some(new HashPartitioner(7)), None, None)
+ sorter.write(elements)
+ assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled
+ val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
+ assert(iter.next() === (0, Nil))
+ assert(iter.next() === (1, List((1, 1))))
+ assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList))
+ assert(iter.next() === (3, Nil))
+ assert(iter.next() === (4, Nil))
+ assert(iter.next() === (5, List((5, 5))))
+ assert(iter.next() === (6, Nil))
+ sorter.stop()
+ }
+
+ test("spilling in local cluster") {
+ val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+
+ // reduceByKey - should spill ~8 times
+ val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i))
+ val resultA = rddA.reduceByKey(math.max).collect()
+ assert(resultA.length == 50000)
+ resultA.foreach { case(k, v) =>
+ if (v != k * 2 + 1) {
+ fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}")
+ }
+ }
+
+ // groupByKey - should spill ~17 times
+ val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i))
+ val resultB = rddB.groupByKey().collect()
+ assert(resultB.length == 25000)
+ resultB.foreach { case(i, seq) =>
+ val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3)
+ if (seq.toSet != expected) {
+ fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}")
+ }
+ }
+
+ // cogroup - should spill ~7 times
+ val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i))
+ val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i))
+ val resultC = rddC1.cogroup(rddC2).collect()
+ assert(resultC.length == 10000)
+ resultC.foreach { case(i, (seq1, seq2)) =>
+ i match {
+ case 0 =>
+ assert(seq1.toSet == Set[Int](0))
+ assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
+ case 1 =>
+ assert(seq1.toSet == Set[Int](1))
+ assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001))
+ case 5000 =>
+ assert(seq1.toSet == Set[Int](5000))
+ assert(seq2.toSet == Set[Int]())
+ case 9999 =>
+ assert(seq1.toSet == Set[Int](9999))
+ assert(seq2.toSet == Set[Int]())
+ case _ =>
+ }
+ }
+
+ // larger cogroup - should spill ~7 times
+ val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i))
+ val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i))
+ val resultD = rddD1.cogroup(rddD2).collect()
+ assert(resultD.length == 5000)
+ resultD.foreach { case(i, (seq1, seq2)) =>
+ val expected = Set(i * 2, i * 2 + 1)
+ if (seq1.toSet != expected) {
+ fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}")
+ }
+ if (seq2.toSet != expected) {
+ fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}")
+ }
+ }
+ }
+
+ test("spilling in local cluster with many reduce tasks") {
+ val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+
+ // reduceByKey - should spill ~4 times per executor
+ val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i))
+ val resultA = rddA.reduceByKey(math.max _, 100).collect()
+ assert(resultA.length == 50000)
+ resultA.foreach { case(k, v) =>
+ if (v != k * 2 + 1) {
+ fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}")
+ }
+ }
+
+ // groupByKey - should spill ~8 times per executor
+ val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i))
+ val resultB = rddB.groupByKey(100).collect()
+ assert(resultB.length == 25000)
+ resultB.foreach { case(i, seq) =>
+ val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3)
+ if (seq.toSet != expected) {
+ fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}")
+ }
+ }
+
+ // cogroup - should spill ~4 times per executor
+ val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i))
+ val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i))
+ val resultC = rddC1.cogroup(rddC2, 100).collect()
+ assert(resultC.length == 10000)
+ resultC.foreach { case(i, (seq1, seq2)) =>
+ i match {
+ case 0 =>
+ assert(seq1.toSet == Set[Int](0))
+ assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
+ case 1 =>
+ assert(seq1.toSet == Set[Int](1))
+ assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001))
+ case 5000 =>
+ assert(seq1.toSet == Set[Int](5000))
+ assert(seq2.toSet == Set[Int]())
+ case 9999 =>
+ assert(seq1.toSet == Set[Int](9999))
+ assert(seq2.toSet == Set[Int]())
+ case _ =>
+ }
+ }
+
+ // larger cogroup - should spill ~4 times per executor
+ val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i))
+ val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i))
+ val resultD = rddD1.cogroup(rddD2).collect()
+ assert(resultD.length == 5000)
+ resultD.foreach { case(i, (seq1, seq2)) =>
+ val expected = Set(i * 2, i * 2 + 1)
+ if (seq1.toSet != expected) {
+ fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}")
+ }
+ if (seq2.toSet != expected) {
+ fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}")
+ }
+ }
+ }
+
+ test("cleanup of intermediate files in sorter") {
+ val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+ val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+
+ val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
+ sorter.write((0 until 100000).iterator.map(i => (i, i)))
+ assert(diskBlockManager.getAllFiles().length > 0)
+ sorter.stop()
+ assert(diskBlockManager.getAllBlocks().length === 0)
+
+ val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
+ sorter2.write((0 until 100000).iterator.map(i => (i, i)))
+ assert(diskBlockManager.getAllFiles().length > 0)
+ assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet)
+ sorter2.stop()
+ assert(diskBlockManager.getAllBlocks().length === 0)
+ }
+
+ test("cleanup of intermediate files in sorter if there are errors") {
+ val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+ val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+
+ val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
+ intercept[SparkException] {
+ sorter.write((0 until 100000).iterator.map(i => {
+ if (i == 99990) {
+ throw new SparkException("Intentional failure")
+ }
+ (i, i)
+ }))
+ }
+ assert(diskBlockManager.getAllFiles().length > 0)
+ sorter.stop()
+ assert(diskBlockManager.getAllBlocks().length === 0)
+ }
+
+ test("cleanup of intermediate files in shuffle") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+ val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+
+ val data = sc.parallelize(0 until 100000, 2).map(i => (i, i))
+ assert(data.reduceByKey(_ + _).count() === 100000)
+
+ // After the shuffle, there should be only 4 files on disk: our two map output files and
+ // their index files. All other intermediate files should've been deleted.
+ assert(diskBlockManager.getAllFiles().length === 4)
+ }
+
+ test("cleanup of intermediate files in shuffle with errors") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+ val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+
+ val data = sc.parallelize(0 until 100000, 2).map(i => {
+ if (i == 99990) {
+ throw new Exception("Intentional failure")
+ }
+ (i, i)
+ })
+ intercept[SparkException] {
+ data.reduceByKey(_ + _).count()
+ }
+
+ // After the shuffle, there should be only 2 files on disk: the output of task 1 and its index.
+ // All other files (map 2's output and intermediate merge files) should've been deleted.
+ assert(diskBlockManager.getAllFiles().length === 2)
+ }
+
+ test("no partial aggregation or sorting") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
+ sorter.write((0 until 100000).iterator.map(i => (i / 4, i)))
+ val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
+ val expected = (0 until 3).map(p => {
+ (p, (0 until 100000).map(i => (i / 4, i)).filter(_._1 % 3 == p).toSet)
+ }).toSet
+ assert(results === expected)
+ }
+
+ test("partial aggregation without spill") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+ val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None)
+ sorter.write((0 until 100).iterator.map(i => (i / 2, i)))
+ val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
+ val expected = (0 until 3).map(p => {
+ (p, (0 until 50).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
+ }).toSet
+ assert(results === expected)
+ }
+
+ test("partial aggregation with spill, no ordering") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+ val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None)
+ sorter.write((0 until 100000).iterator.map(i => (i / 2, i)))
+ val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
+ val expected = (0 until 3).map(p => {
+ (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
+ }).toSet
+ assert(results === expected)
+ }
+
+ test("partial aggregation with spill, with ordering") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+ val ord = implicitly[Ordering[Int]]
+ val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
+ sorter.write((0 until 100000).iterator.map(i => (i / 2, i)))
+ val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
+ val expected = (0 until 3).map(p => {
+ (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
+ }).toSet
+ assert(results === expected)
+ }
+
+ test("sorting without aggregation, no spill") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val ord = implicitly[Ordering[Int]]
+ val sorter = new ExternalSorter[Int, Int, Int](
+ None, Some(new HashPartitioner(3)), Some(ord), None)
+ sorter.write((0 until 100).iterator.map(i => (i, i)))
+ val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq
+ val expected = (0 until 3).map(p => {
+ (p, (0 until 100).map(i => (i, i)).filter(_._1 % 3 == p).toSeq)
+ }).toSeq
+ assert(results === expected)
+ }
+
+ test("sorting without aggregation, with spill") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val ord = implicitly[Ordering[Int]]
+ val sorter = new ExternalSorter[Int, Int, Int](
+ None, Some(new HashPartitioner(3)), Some(ord), None)
+ sorter.write((0 until 100000).iterator.map(i => (i, i)))
+ val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq
+ val expected = (0 until 3).map(p => {
+ (p, (0 until 100000).map(i => (i, i)).filter(_._1 % 3 == p).toSeq)
+ }).toSeq
+ assert(results === expected)
+ }
+
+ test("spilling with hash collisions") {
+ val conf = new SparkConf(true)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+
+ def createCombiner(i: String) = ArrayBuffer[String](i)
+ def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i
+ def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) =
+ buffer1 ++= buffer2
+
+ val agg = new Aggregator[String, String, ArrayBuffer[String]](
+ createCombiner _, mergeValue _, mergeCombiners _)
+
+ val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
+ Some(agg), None, None, None)
+
+ val collisionPairs = Seq(
+ ("Aa", "BB"), // 2112
+ ("to", "v1"), // 3707
+ ("variants", "gelato"), // -1249574770
+ ("Teheran", "Siblings"), // 231609873
+ ("misused", "horsemints"), // 1069518484
+ ("isohel", "epistolaries"), // -1179291542
+ ("righto", "buzzards"), // -931102253
+ ("hierarch", "crinolines"), // -1732884796
+ ("inwork", "hypercatalexes"), // -1183663690
+ ("wainages", "presentencing"), // 240183619
+ ("trichothecenes", "locular"), // 339006536
+ ("pomatoes", "eructation") // 568647356
+ )
+
+ collisionPairs.foreach { case (w1, w2) =>
+ // String.hashCode is documented to use a specific algorithm, but check just in case
+ assert(w1.hashCode === w2.hashCode)
+ }
+
+ val toInsert = (1 to 100000).iterator.map(_.toString).map(s => (s, s)) ++
+ collisionPairs.iterator ++ collisionPairs.iterator.map(_.swap)
+
+ sorter.write(toInsert)
+
+ // A map of collision pairs in both directions
+ val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap
+
+ // Avoid map.size or map.iterator.length because this destructively sorts the underlying map
+ var count = 0
+
+ val it = sorter.iterator
+ while (it.hasNext) {
+ val kv = it.next()
+ val expectedValue = ArrayBuffer[String](collisionPairsMap.getOrElse(kv._1, kv._1))
+ assert(kv._2.equals(expectedValue))
+ count += 1
+ }
+ assert(count === 100000 + collisionPairs.size * 2)
+ }
+
+ test("spilling with many hash collisions") {
+ val conf = new SparkConf(true)
+ conf.set("spark.shuffle.memoryFraction", "0.0001")
+ sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+
+ val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _)
+ val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None)
+
+ // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes
+ // problems if the map fails to group together the objects with the same code (SPARK-2043).
+ val toInsert = for (i <- 1 to 10; j <- 1 to 10000) yield (FixedHashObject(j, j % 2), 1)
+ sorter.write(toInsert.iterator)
+
+ val it = sorter.iterator
+ var count = 0
+ while (it.hasNext) {
+ val kv = it.next()
+ assert(kv._2 === 10)
+ count += 1
+ }
+ assert(count === 10000)
+ }
+
+ test("spilling with hash collisions using the Int.MaxValue key") {
+ val conf = new SparkConf(true)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+
+ def createCombiner(i: Int) = ArrayBuffer[Int](i)
+ def mergeValue(buffer: ArrayBuffer[Int], i: Int) = buffer += i
+ def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]) = buf1 ++= buf2
+
+ val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners)
+ val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None)
+
+ sorter.write((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue)))
+
+ val it = sorter.iterator
+ while (it.hasNext) {
+ // Should not throw NoSuchElementException
+ it.next()
+ }
+ }
+
+ test("spilling with null keys and values") {
+ val conf = new SparkConf(true)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+
+ def createCombiner(i: String) = ArrayBuffer[String](i)
+ def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i
+ def mergeCombiners(buf1: ArrayBuffer[String], buf2: ArrayBuffer[String]) = buf1 ++= buf2
+
+ val agg = new Aggregator[String, String, ArrayBuffer[String]](
+ createCombiner, mergeValue, mergeCombiners)
+
+ val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
+ Some(agg), None, None, None)
+
+ sorter.write((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator(
+ (null.asInstanceOf[String], "1"),
+ ("1", null.asInstanceOf[String]),
+ (null.asInstanceOf[String], null.asInstanceOf[String])
+ ))
+
+ val it = sorter.iterator
+ while (it.hasNext) {
+ // Should not throw NullPointerException
+ it.next()
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala b/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala
new file mode 100644
index 0000000000..c787b5f066
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+/**
+ * A dummy class that always returns the same hash code, to easily test hash collisions
+ */
+case class FixedHashObject(v: Int, h: Int) extends Serializable {
+ override def hashCode(): Int = h
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
index 5318b8da64..714f3b81c9 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
@@ -28,7 +28,7 @@ import org.apache.spark.rdd.{ShuffledRDD, RDD}
private[graphx]
class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) {
def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = {
- val rdd = new ShuffledRDD[VertexId, VD, VD, (VertexId, VD)](self, partitioner)
+ val rdd = new ShuffledRDD[VertexId, VD, VD](self, partitioner)
// Set a custom serializer if the data is of int or double type.
if (classTag[VD] == ClassTag.Int) {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
index a565d3b28b..b27485953f 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
@@ -33,7 +33,7 @@ private[graphx]
class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) {
/** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */
def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = {
- new ShuffledRDD[VertexId, Int, Int, RoutingTableMessage](
+ new ShuffledRDD[VertexId, Int, Int](
self, partitioner).setSerializer(new RoutingTableMessageSerializer)
}
}
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 672343fbbe..a8bbd55861 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -295,6 +295,7 @@ object Unidoc {
.map(_.filterNot(_.getCanonicalPath.contains("akka")))
.map(_.filterNot(_.getCanonicalPath.contains("deploy")))
.map(_.filterNot(_.getCanonicalPath.contains("network")))
+ .map(_.filterNot(_.getCanonicalPath.contains("shuffle")))
.map(_.filterNot(_.getCanonicalPath.contains("executor")))
.map(_.filterNot(_.getCanonicalPath.contains("python")))
.map(_.filterNot(_.getCanonicalPath.contains("collection")))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 392a7f3be3..30712f03ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -49,7 +49,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
iter.map(r => mutablePair.update(hashExpressions(r), r))
}
val part = new HashPartitioner(numPartitions)
- val shuffled = new ShuffledRDD[Row, Row, Row, MutablePair[Row, Row]](rdd, part)
+ val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._2)
@@ -62,7 +62,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
iter.map(row => mutablePair.update(row, null))
}
val part = new RangePartitioner(numPartitions, rdd, ascending = true)
- val shuffled = new ShuffledRDD[Row, Null, Null, MutablePair[Row, Null]](rdd, part)
+ val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._1)
@@ -73,7 +73,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
iter.map(r => mutablePair.update(null, r))
}
val partitioner = new HashPartitioner(1)
- val shuffled = new ShuffledRDD[Null, Row, Row, MutablePair[Null, Row]](rdd, partitioner)
+ val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._2)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 174eda8f1a..0027f3cf1f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -148,7 +148,7 @@ case class Limit(limit: Int, child: SparkPlan)
iter.take(limit).map(row => mutablePair.update(false, row))
}
val part = new HashPartitioner(1)
- val shuffled = new ShuffledRDD[Boolean, Row, Row, MutablePair[Boolean, Row]](rdd, part)
+ val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.mapPartitions(_.take(limit).map(_._2))
}