aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/Aggregator.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala80
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala165
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockId.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala38
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala36
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala662
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala86
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala34
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala186
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/SortShuffleSuite.scala34
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala25
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala566
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala25
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala2
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala2
-rw-r--r--project/SparkBuild.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala2
35 files changed, 1969 insertions, 159 deletions
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index ff0ca11749..79c9c451d2 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -56,18 +56,23 @@ case class Aggregator[K, V, C] (
} else {
val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
combiners.insertAll(iter)
- // TODO: Make this non optional in a future release
- Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled)
- Option(context).foreach(c => c.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled)
+ // Update task metrics if context is not null
+ // TODO: Make context non optional in a future release
+ Option(context).foreach { c =>
+ c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled
+ c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled
+ }
combiners.iterator
}
}
@deprecated("use combineCombinersByKey with TaskContext argument", "0.9.0")
- def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] =
+ def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]]) : Iterator[(K, C)] =
combineCombinersByKey(iter, null)
- def combineCombinersByKey(iter: Iterator[(K, C)], context: TaskContext) : Iterator[(K, C)] = {
+ def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext)
+ : Iterator[(K, C)] =
+ {
if (!externalSorting) {
val combiners = new AppendOnlyMap[K,C]
var kc: Product2[K, C] = null
@@ -85,9 +90,12 @@ case class Aggregator[K, V, C] (
val pair = iter.next()
combiners.insert(pair._1, pair._2)
}
- // TODO: Make this non optional in a future release
- Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled)
- Option(context).foreach(c => c.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled)
+ // Update task metrics if context is not null
+ // TODO: Make context non-optional in a future release
+ Option(context).foreach { c =>
+ c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled
+ c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled
+ }
combiners.iterator
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index fb4c86716b..b25f081761 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -289,7 +289,7 @@ class SparkContext(config: SparkConf) extends Logging {
value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} {
executorEnvs(envKey) = value
}
- Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v =>
+ Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v =>
executorEnvs("SPARK_PREPEND_CLASSES") = v
}
// The Mesos scheduler backend relies on this environment variable to set executor memory.
@@ -1203,10 +1203,10 @@ class SparkContext(config: SparkConf) extends Logging {
/**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
- * If <tt>checkSerializable</tt> is set, <tt>clean</tt> will also proactively
- * check to see if <tt>f</tt> is serializable and throw a <tt>SparkException</tt>
+ * If <tt>checkSerializable</tt> is set, <tt>clean</tt> will also proactively
+ * check to see if <tt>f</tt> is serializable and throw a <tt>SparkException</tt>
* if not.
- *
+ *
* @param f the closure to clean
* @param checkSerializable whether or not to immediately check <tt>f</tt> for serializability
* @throws <tt>SparkException<tt> if <tt>checkSerializable</tt> is set but <tt>f</tt> is not
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 31bf8dced2..47708cb2e7 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -122,7 +122,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
*/
def sample(withReplacement: Boolean, fraction: Double): JavaPairRDD[K, V] =
sample(withReplacement, fraction, Utils.random.nextLong)
-
+
/**
* Return a sampled subset of this RDD.
*/
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 6388ef82cc..fabb882cdd 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -17,10 +17,11 @@
package org.apache.spark.rdd
+import scala.language.existentials
+
import java.io.{IOException, ObjectOutputStream}
import scala.collection.mutable.ArrayBuffer
-import scala.language.existentials
import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
@@ -157,8 +158,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
for ((it, depNum) <- rddIterators) {
map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum))))
}
- context.taskMetrics.memoryBytesSpilled = map.memoryBytesSpilled
- context.taskMetrics.diskBytesSpilled = map.diskBytesSpilled
+ context.taskMetrics.memoryBytesSpilled += map.memoryBytesSpilled
+ context.taskMetrics.diskBytesSpilled += map.diskBytesSpilled
new InterruptibleIterator(context,
map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
index d85f962783..e98bad2026 100644
--- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
@@ -20,6 +20,7 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag
import org.apache.spark.{Logging, RangePartitioner}
+import org.apache.spark.annotation.DeveloperApi
/**
* Extra functions available on RDDs of (key, value) pairs where the key is sortable through
@@ -43,10 +44,10 @@ import org.apache.spark.{Logging, RangePartitioner}
*/
class OrderedRDDFunctions[K : Ordering : ClassTag,
V: ClassTag,
- P <: Product2[K, V] : ClassTag](
+ P <: Product2[K, V] : ClassTag] @DeveloperApi() (
self: RDD[P])
- extends Logging with Serializable {
-
+ extends Logging with Serializable
+{
private val ordering = implicitly[Ordering[K]]
/**
@@ -55,9 +56,12 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
* (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
* order of the keys).
*/
- def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = {
+ // TODO: this currently doesn't work on P other than Tuple2!
+ def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size)
+ : RDD[(K, V)] =
+ {
val part = new RangePartitioner(numPartitions, self, ascending)
- new ShuffledRDD[K, V, V, P](self, part)
+ new ShuffledRDD[K, V, V](self, part)
.setKeyOrdering(if (ascending) ordering else ordering.reverse)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 1af4e5f0b6..93af50c0a9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -90,7 +90,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
}, preservesPartitioning = true)
} else {
- new ShuffledRDD[K, V, C, (K, C)](self, partitioner)
+ new ShuffledRDD[K, V, C](self, partitioner)
.setSerializer(serializer)
.setAggregator(aggregator)
.setMapSideCombine(mapSideCombine)
@@ -425,7 +425,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
if (self.partitioner == Some(partitioner)) {
self
} else {
- new ShuffledRDD[K, V, V, (K, V)](self, partitioner)
+ new ShuffledRDD[K, V, V](self, partitioner)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 726b3f2bbe..74ac97091f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -332,7 +332,7 @@ abstract class RDD[T: ClassTag](
val distributePartition = (index: Int, items: Iterator[T]) => {
var position = (new Random(index)).nextInt(numPartitions)
items.map { t =>
- // Note that the hash code of the key will just be the key itself. The HashPartitioner
+ // Note that the hash code of the key will just be the key itself. The HashPartitioner
// will mod it with the number of total partitions.
position = position + 1
(position, t)
@@ -341,7 +341,7 @@ abstract class RDD[T: ClassTag](
// include a shuffle step so that our upstream tasks are still distributed
new CoalescedRDD(
- new ShuffledRDD[Int, T, T, (Int, T)](mapPartitionsWithIndex(distributePartition),
+ new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition),
new HashPartitioner(numPartitions)),
numPartitions).values
} else {
@@ -352,8 +352,8 @@ abstract class RDD[T: ClassTag](
/**
* Return a sampled subset of this RDD.
*/
- def sample(withReplacement: Boolean,
- fraction: Double,
+ def sample(withReplacement: Boolean,
+ fraction: Double,
seed: Long = Utils.random.nextLong): RDD[T] = {
require(fraction >= 0.0, "Negative fraction value: " + fraction)
if (withReplacement) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index bf02f68d0d..d9fe684725 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -37,11 +37,12 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
* @tparam V the value class.
* @tparam C the combiner class.
*/
+// TODO: Make this return RDD[Product2[K, C]] or have some way to configure mutable pairs
@DeveloperApi
-class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
+class ShuffledRDD[K, V, C](
@transient var prev: RDD[_ <: Product2[K, V]],
part: Partitioner)
- extends RDD[P](prev.context, Nil) {
+ extends RDD[(K, C)](prev.context, Nil) {
private var serializer: Option[Serializer] = None
@@ -52,25 +53,25 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
private var mapSideCombine: Boolean = false
/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
- def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = {
+ def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C] = {
this.serializer = Option(serializer)
this
}
/** Set key ordering for RDD's shuffle. */
- def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C, P] = {
+ def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C] = {
this.keyOrdering = Option(keyOrdering)
this
}
/** Set aggregator for RDD's shuffle. */
- def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C, P] = {
+ def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C] = {
this.aggregator = Option(aggregator)
this
}
/** Set mapSideCombine flag for RDD's shuffle. */
- def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C, P] = {
+ def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C] = {
this.mapSideCombine = mapSideCombine
this
}
@@ -85,11 +86,11 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i))
}
- override def compute(split: Partition, context: TaskContext): Iterator[P] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
- .asInstanceOf[Iterator[P]]
+ .asInstanceOf[Iterator[(K, C)]]
}
override def clearDependencies() {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
index 5b0940ecce..df98d18fa8 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
@@ -24,7 +24,7 @@ import org.apache.spark.shuffle._
* A ShuffleManager using hashing, that creates one output file per reduce partition on each
* mapper (possibly reusing these across waves of tasks).
*/
-class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
+private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
/* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */
override def registerShuffle[K, V, C](
shuffleId: Int,
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index c8059496a1..e32ad9c036 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -21,7 +21,7 @@ import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
-class HashShuffleReader[K, C](
+private[spark] class HashShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
endPartition: Int,
@@ -47,7 +47,8 @@ class HashShuffleReader[K, C](
} else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
throw new IllegalStateException("Aggregator is empty for map-side combine")
} else {
- iter
+ // Convert the Product2s to pairs since this is what downstream RDDs currently expect
+ iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
}
// Sort the output if there is a sort ordering defined.
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 9b78228519..1923f7c71a 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -24,7 +24,7 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
-class HashShuffleWriter[K, V](
+private[spark] class HashShuffleWriter[K, V](
handle: BaseShuffleHandle[K, V, _],
mapId: Int,
context: TaskContext)
@@ -33,6 +33,10 @@ class HashShuffleWriter[K, V](
private val dep = handle.dependency
private val numOutputSplits = dep.partitioner.numPartitions
private val metrics = context.taskMetrics
+
+ // Are we in the process of stopping? Because map tasks can call stop() with success = true
+ // and then call stop() with success = false if they get an exception, we want to make sure
+ // we don't try deleting files, etc twice.
private var stopping = false
private val blockManager = SparkEnv.get.blockManager
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
new file mode 100644
index 0000000000..6dcca47ea7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.io.{DataInputStream, FileInputStream}
+
+import org.apache.spark.shuffle._
+import org.apache.spark.{TaskContext, ShuffleDependency}
+import org.apache.spark.shuffle.hash.HashShuffleReader
+import org.apache.spark.storage.{DiskBlockManager, FileSegment, ShuffleBlockId}
+
+private[spark] class SortShuffleManager extends ShuffleManager {
+ /**
+ * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
+ */
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ new BaseShuffleHandle(shuffleId, numMaps, dependency)
+ }
+
+ /**
+ * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+ * Called on executors by reduce tasks.
+ */
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext): ShuffleReader[K, C] = {
+ // We currently use the same block store shuffle fetcher as the hash-based shuffle.
+ new HashShuffleReader(
+ handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
+ }
+
+ /** Get a writer for a given partition. Called on executors by map tasks. */
+ override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
+ : ShuffleWriter[K, V] = {
+ new SortShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
+ }
+
+ /** Remove a shuffle's metadata from the ShuffleManager. */
+ override def unregisterShuffle(shuffleId: Int): Unit = {}
+
+ /** Shut down this ShuffleManager. */
+ override def stop(): Unit = {}
+
+ /** Get the location of a block in a map output file. Uses the index file we create for it. */
+ def getBlockLocation(blockId: ShuffleBlockId, diskManager: DiskBlockManager): FileSegment = {
+ // The block is actually going to be a range of a single map output file for this map, so
+ // figure out the ID of the consolidated file, then the offset within that from our index
+ val consolidatedId = blockId.copy(reduceId = 0)
+ val indexFile = diskManager.getFile(consolidatedId.name + ".index")
+ val in = new DataInputStream(new FileInputStream(indexFile))
+ try {
+ in.skip(blockId.reduceId * 8)
+ val offset = in.readLong()
+ val nextOffset = in.readLong()
+ new FileSegment(diskManager.getFile(consolidatedId), offset, nextOffset - offset)
+ } finally {
+ in.close()
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
new file mode 100644
index 0000000000..42fcd07fa1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.io.{BufferedOutputStream, File, FileOutputStream, DataOutputStream}
+
+import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext}
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.{ShuffleWriter, BaseShuffleHandle}
+import org.apache.spark.storage.ShuffleBlockId
+import org.apache.spark.util.collection.ExternalSorter
+
+private[spark] class SortShuffleWriter[K, V, C](
+ handle: BaseShuffleHandle[K, V, C],
+ mapId: Int,
+ context: TaskContext)
+ extends ShuffleWriter[K, V] with Logging {
+
+ private val dep = handle.dependency
+ private val numPartitions = dep.partitioner.numPartitions
+
+ private val blockManager = SparkEnv.get.blockManager
+ private val ser = Serializer.getSerializer(dep.serializer.orNull)
+
+ private val conf = SparkEnv.get.conf
+ private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
+
+ private var sorter: ExternalSorter[K, V, _] = null
+ private var outputFile: File = null
+
+ // Are we in the process of stopping? Because map tasks can call stop() with success = true
+ // and then call stop() with success = false if they get an exception, we want to make sure
+ // we don't try deleting files, etc twice.
+ private var stopping = false
+
+ private var mapStatus: MapStatus = null
+
+ /** Write a bunch of records to this task's output */
+ override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
+ // Get an iterator with the elements for each partition ID
+ val partitions: Iterator[(Int, Iterator[Product2[K, _]])] = {
+ if (dep.mapSideCombine) {
+ if (!dep.aggregator.isDefined) {
+ throw new IllegalStateException("Aggregator is empty for map-side combine")
+ }
+ sorter = new ExternalSorter[K, V, C](
+ dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
+ sorter.write(records)
+ sorter.partitionedIterator
+ } else {
+ // In this case we pass neither an aggregator nor an ordering to the sorter, because we
+ // don't care whether the keys get sorted in each partition; that will be done on the
+ // reduce side if the operation being run is sortByKey.
+ sorter = new ExternalSorter[K, V, V](
+ None, Some(dep.partitioner), None, dep.serializer)
+ sorter.write(records)
+ sorter.partitionedIterator
+ }
+ }
+
+ // Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later
+ // serve different ranges of this file using an index file that we create at the end.
+ val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0)
+ outputFile = blockManager.diskBlockManager.getFile(blockId)
+
+ // Track location of each range in the output file
+ val offsets = new Array[Long](numPartitions + 1)
+ val lengths = new Array[Long](numPartitions)
+
+ // Statistics
+ var totalBytes = 0L
+ var totalTime = 0L
+
+ for ((id, elements) <- partitions) {
+ if (elements.hasNext) {
+ val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize)
+ for (elem <- elements) {
+ writer.write(elem)
+ }
+ writer.commit()
+ writer.close()
+ val segment = writer.fileSegment()
+ offsets(id + 1) = segment.offset + segment.length
+ lengths(id) = segment.length
+ totalTime += writer.timeWriting()
+ totalBytes += segment.length
+ } else {
+ // The partition is empty; don't create a new writer to avoid writing headers, etc
+ offsets(id + 1) = offsets(id)
+ }
+ }
+
+ val shuffleMetrics = new ShuffleWriteMetrics
+ shuffleMetrics.shuffleBytesWritten = totalBytes
+ shuffleMetrics.shuffleWriteTime = totalTime
+ context.taskMetrics.shuffleWriteMetrics = Some(shuffleMetrics)
+ context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
+ context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
+
+ // Write an index file with the offsets of each block, plus a final offset at the end for the
+ // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure
+ // out where each block begins and ends.
+
+ val diskBlockManager = blockManager.diskBlockManager
+ val indexFile = diskBlockManager.getFile(blockId.name + ".index")
+ val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
+ try {
+ var i = 0
+ while (i < numPartitions + 1) {
+ out.writeLong(offsets(i))
+ i += 1
+ }
+ } finally {
+ out.close()
+ }
+
+ // Register our map output with the ShuffleBlockManager, which handles cleaning it over time
+ blockManager.shuffleBlockManager.addCompletedMap(dep.shuffleId, mapId, numPartitions)
+
+ mapStatus = new MapStatus(blockManager.blockManagerId,
+ lengths.map(MapOutputTracker.compressSize))
+ }
+
+ /** Close this writer, passing along whether the map completed */
+ override def stop(success: Boolean): Option[MapStatus] = {
+ try {
+ if (stopping) {
+ return None
+ }
+ stopping = true
+ if (success) {
+ return Option(mapStatus)
+ } else {
+ // The map task failed, so delete our output file if we created one
+ if (outputFile != null) {
+ outputFile.delete()
+ }
+ return None
+ }
+ } finally {
+ // Clean up our sorter, which may have its own intermediate files
+ if (sorter != null) {
+ sorter.stop()
+ sorter = null
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 42ec181b00..c1756ac905 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -54,12 +54,16 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
}
@DeveloperApi
-case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int)
- extends BlockId {
+case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
}
@DeveloperApi
+case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
+ def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index"
+}
+
+@DeveloperApi
case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId {
def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field)
}
@@ -88,6 +92,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId {
object BlockId {
val RDD = "rdd_([0-9]+)_([0-9]+)".r
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
+ val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r
val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
val TASKRESULT = "taskresult_([0-9]+)".r
val STREAM = "input-([0-9]+)-([0-9]+)".r
@@ -99,6 +104,8 @@ object BlockId {
RDDBlockId(rddId.toInt, splitIndex.toInt)
case SHUFFLE(shuffleId, mapId, reduceId) =>
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
+ case SHUFFLE_INDEX(shuffleId, mapId, reduceId) =>
+ ShuffleIndexBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
case BROADCAST(broadcastId, field) =>
BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_"))
case TASKRESULT(taskId) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index 2e7ed7538e..4d66ccea21 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -21,10 +21,11 @@ import java.io.File
import java.text.SimpleDateFormat
import java.util.{Date, Random, UUID}
-import org.apache.spark.Logging
+import org.apache.spark.{SparkEnv, Logging}
import org.apache.spark.executor.ExecutorExitCode
import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
import org.apache.spark.util.Utils
+import org.apache.spark.shuffle.sort.SortShuffleManager
/**
* Creates and maintains the logical mapping between logical blocks and physical on-disk
@@ -34,11 +35,13 @@ import org.apache.spark.util.Utils
*
* @param rootDirs The directories to use for storing block files. Data will be hashed among these.
*/
-private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootDirs: String)
+private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, rootDirs: String)
extends PathResolver with Logging {
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- private val subDirsPerLocalDir = shuffleManager.conf.getInt("spark.diskStore.subDirectories", 64)
+
+ private val subDirsPerLocalDir =
+ shuffleBlockManager.conf.getInt("spark.diskStore.subDirectories", 64)
/* Create one local directory for each path mentioned in spark.local.dir; then, inside this
* directory, create multiple subdirectories that we will hash files into, in order to avoid
@@ -54,13 +57,19 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
addShutdownHook()
/**
- * Returns the physical file segment in which the given BlockId is located.
- * If the BlockId has been mapped to a specific FileSegment, that will be returned.
- * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
+ * Returns the physical file segment in which the given BlockId is located. If the BlockId has
+ * been mapped to a specific FileSegment by the shuffle layer, that will be returned.
+ * Otherwise, we assume the Block is mapped to the whole file identified by the BlockId.
*/
def getBlockLocation(blockId: BlockId): FileSegment = {
- if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) {
- shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])
+ val env = SparkEnv.get // NOTE: can be null in unit tests
+ if (blockId.isShuffle && env != null && env.shuffleManager.isInstanceOf[SortShuffleManager]) {
+ // For sort-based shuffle, let it figure out its blocks
+ val sortShuffleManager = env.shuffleManager.asInstanceOf[SortShuffleManager]
+ sortShuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId], this)
+ } else if (blockId.isShuffle && shuffleBlockManager.consolidateShuffleFiles) {
+ // For hash-based shuffle with consolidated files, ShuffleBlockManager takes care of this
+ shuffleBlockManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])
} else {
val file = getFile(blockId.name)
new FileSegment(file, 0, file.length())
@@ -99,13 +108,18 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
getBlockLocation(blockId).file.exists()
}
- /** List all the blocks currently stored on disk by the disk manager. */
- def getAllBlocks(): Seq[BlockId] = {
+ /** List all the files currently stored on disk by the disk manager. */
+ def getAllFiles(): Seq[File] = {
// Get all the files inside the array of array of directories
subDirs.flatten.filter(_ != null).flatMap { dir =>
- val files = dir.list()
+ val files = dir.listFiles()
if (files != null) files else Seq.empty
- }.map(BlockId.apply)
+ }
+ }
+
+ /** List all the blocks currently stored on disk by the disk manager. */
+ def getAllBlocks(): Seq[BlockId] = {
+ getAllFiles().map(f => BlockId(f.getName))
}
/** Produces a unique block id and File suitable for intermediate results. */
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index 35910e552f..7beb55c411 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -28,6 +28,7 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
+import org.apache.spark.shuffle.sort.SortShuffleManager
/** A group of writers for a ShuffleMapTask, one writer per reducer. */
private[spark] trait ShuffleWriterGroup {
@@ -58,6 +59,7 @@ private[spark] trait ShuffleWriterGroup {
* each block stored in each file. In order to find the location of a shuffle block, we search the
* files within a ShuffleFileGroups associated with the block's reducer.
*/
+// TODO: Factor this into a separate class for each ShuffleManager implementation
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
def conf = blockManager.conf
@@ -67,6 +69,10 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
val consolidateShuffleFiles =
conf.getBoolean("spark.shuffle.consolidateFiles", false)
+ // Are we using sort-based shuffle?
+ val sortBasedShuffle =
+ conf.get("spark.shuffle.manager", "") == classOf[SortShuffleManager].getName
+
private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
/**
@@ -91,6 +97,20 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
private val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf)
+ /**
+ * Register a completed map without getting a ShuffleWriterGroup. Used by sort-based shuffle
+ * because it just writes a single file by itself.
+ */
+ def addCompletedMap(shuffleId: Int, mapId: Int, numBuckets: Int): Unit = {
+ shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
+ val shuffleState = shuffleStates(shuffleId)
+ shuffleState.completedMapTasks.add(mapId)
+ }
+
+ /**
+ * Get a ShuffleWriterGroup for the given map task, which will register it as complete
+ * when the writers are closed successfully
+ */
def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
new ShuffleWriterGroup {
shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
@@ -182,7 +202,14 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = {
shuffleStates.get(shuffleId) match {
case Some(state) =>
- if (consolidateShuffleFiles) {
+ if (sortBasedShuffle) {
+ // There's a single block ID for each map, plus an index file for it
+ for (mapId <- state.completedMapTasks) {
+ val blockId = new ShuffleBlockId(shuffleId, mapId, 0)
+ blockManager.diskBlockManager.getFile(blockId).delete()
+ blockManager.diskBlockManager.getFile(blockId.name + ".index").delete()
+ }
+ } else if (consolidateShuffleFiles) {
for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
file.delete()
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 6f263c39d1..b34512ef9e 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -79,12 +79,16 @@ class ExternalAppendOnlyMap[K, V, C](
(Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
}
- // Number of pairs in the in-memory map
- private var numPairsInMemory = 0L
+ // Number of pairs inserted since last spill; note that we count them even if a value is merged
+ // with a previous key in case we're doing something like groupBy where the result grows
+ private var elementsRead = 0L
// Number of in-memory pairs inserted before tracking the map's shuffle memory usage
private val trackMemoryThreshold = 1000
+ // How much of the shared memory pool this collection has claimed
+ private var myMemoryThreshold = 0L
+
/**
* Size of object batches when reading/writing from serializers.
*
@@ -106,7 +110,6 @@ class ExternalAppendOnlyMap[K, V, C](
private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
private val keyComparator = new HashComparator[K]
private val ser = serializer.newInstance()
- private val threadId = Thread.currentThread().getId
/**
* Insert the given key and value into the map.
@@ -134,31 +137,35 @@ class ExternalAppendOnlyMap[K, V, C](
while (entries.hasNext) {
curEntry = entries.next()
- if (numPairsInMemory > trackMemoryThreshold && currentMap.atGrowThreshold) {
- val mapSize = currentMap.estimateSize()
+ if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
+ currentMap.estimateSize() >= myMemoryThreshold)
+ {
+ val currentSize = currentMap.estimateSize()
var shouldSpill = false
val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
// Atomically check whether there is sufficient memory in the global pool for
// this map to grow and, if possible, allocate the required amount
shuffleMemoryMap.synchronized {
+ val threadId = Thread.currentThread().getId
val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId)
val availableMemory = maxMemoryThreshold -
(shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L))
- // Assume map growth factor is 2x
- shouldSpill = availableMemory < mapSize * 2
+ // Try to allocate at least 2x more memory, otherwise spill
+ shouldSpill = availableMemory < currentSize * 2
if (!shouldSpill) {
- shuffleMemoryMap(threadId) = mapSize * 2
+ shuffleMemoryMap(threadId) = currentSize * 2
+ myMemoryThreshold = currentSize * 2
}
}
// Do not synchronize spills
if (shouldSpill) {
- spill(mapSize)
+ spill(currentSize)
}
}
currentMap.changeValue(curEntry._1, update)
- numPairsInMemory += 1
+ elementsRead += 1
}
}
@@ -178,9 +185,10 @@ class ExternalAppendOnlyMap[K, V, C](
/**
* Sort the existing contents of the in-memory map and spill them to a temporary file on disk.
*/
- private def spill(mapSize: Long) {
+ private def spill(mapSize: Long): Unit = {
spillCount += 1
- logWarning("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)"
+ val threadId = Thread.currentThread().getId
+ logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)"
.format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
val (blockId, file) = diskBlockManager.createTempBlock()
var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
@@ -227,7 +235,9 @@ class ExternalAppendOnlyMap[K, V, C](
shuffleMemoryMap.synchronized {
shuffleMemoryMap(Thread.currentThread().getId) = 0
}
- numPairsInMemory = 0
+ myMemoryThreshold = 0
+
+ elementsRead = 0
_memoryBytesSpilled += mapSize
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
new file mode 100644
index 0000000000..54c3310744
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -0,0 +1,662 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.io._
+import java.util.Comparator
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable
+
+import com.google.common.io.ByteStreams
+
+import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner}
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.storage.BlockId
+
+/**
+ * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
+ * pairs of type (K, C). Uses a Partitioner to first group the keys into partitions, and then
+ * optionally sorts keys within each partition using a custom Comparator. Can output a single
+ * partitioned file with a different byte range for each partition, suitable for shuffle fetches.
+ *
+ * If combining is disabled, the type C must equal V -- we'll cast the objects at the end.
+ *
+ * @param aggregator optional Aggregator with combine functions to use for merging data
+ * @param partitioner optional Partitioner; if given, sort by partition ID and then key
+ * @param ordering optional Ordering to sort keys within each partition; should be a total ordering
+ * @param serializer serializer to use when spilling to disk
+ *
+ * Note that if an Ordering is given, we'll always sort using it, so only provide it if you really
+ * want the output keys to be sorted. In a map task without map-side combine for example, you
+ * probably want to pass None as the ordering to avoid extra sorting. On the other hand, if you do
+ * want to do combining, having an Ordering is more efficient than not having it.
+ *
+ * At a high level, this class works as follows:
+ *
+ * - We repeatedly fill up buffers of in-memory data, using either a SizeTrackingAppendOnlyMap if
+ * we want to combine by key, or an simple SizeTrackingBuffer if we don't. Inside these buffers,
+ * we sort elements of type ((Int, K), C) where the Int is the partition ID. This is done to
+ * avoid calling the partitioner multiple times on the same key (e.g. for RangePartitioner).
+ *
+ * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first
+ * by partition ID and possibly second by key or by hash code of the key, if we want to do
+ * aggregation. For each file, we track how many objects were in each partition in memory, so we
+ * don't have to write out the partition ID for every element.
+ *
+ * - When the user requests an iterator, the spilled files are merged, along with any remaining
+ * in-memory data, using the same sort order defined above (unless both sorting and aggregation
+ * are disabled). If we need to aggregate by key, we either use a total ordering from the
+ * ordering parameter, or read the keys with the same hash code and compare them with each other
+ * for equality to merge values.
+ *
+ * - Users are expected to call stop() at the end to delete all the intermediate files.
+ */
+private[spark] class ExternalSorter[K, V, C](
+ aggregator: Option[Aggregator[K, V, C]] = None,
+ partitioner: Option[Partitioner] = None,
+ ordering: Option[Ordering[K]] = None,
+ serializer: Option[Serializer] = None) extends Logging {
+
+ private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
+ private val shouldPartition = numPartitions > 1
+
+ private val blockManager = SparkEnv.get.blockManager
+ private val diskBlockManager = blockManager.diskBlockManager
+ private val ser = Serializer.getSerializer(serializer)
+ private val serInstance = ser.newInstance()
+
+ private val conf = SparkEnv.get.conf
+ private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true)
+ private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
+
+ // Size of object batches when reading/writing from serializers.
+ //
+ // Objects are written in batches, with each batch using its own serialization stream. This
+ // cuts down on the size of reference-tracking maps constructed when deserializing a stream.
+ //
+ // NOTE: Setting this too low can cause excessive copying when serializing, since some serializers
+ // grow internal data structures by growing + copying every time the number of objects doubles.
+ private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)
+
+ private def getPartition(key: K): Int = {
+ if (shouldPartition) partitioner.get.getPartition(key) else 0
+ }
+
+ // Data structures to store in-memory objects before we spill. Depending on whether we have an
+ // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
+ // store them in an array buffer.
+ private var map = new SizeTrackingAppendOnlyMap[(Int, K), C]
+ private var buffer = new SizeTrackingPairBuffer[(Int, K), C]
+
+ // Number of pairs read from input since last spill; note that we count them even if a value is
+ // merged with a previous key in case we're doing something like groupBy where the result grows
+ private var elementsRead = 0L
+
+ // What threshold of elementsRead we start estimating map size at.
+ private val trackMemoryThreshold = 1000
+
+ // Spilling statistics
+ private var spillCount = 0
+ private var _memoryBytesSpilled = 0L
+ private var _diskBytesSpilled = 0L
+
+ // Collective memory threshold shared across all running tasks
+ private val maxMemoryThreshold = {
+ val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2)
+ val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8)
+ (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
+ }
+
+ // How much of the shared memory pool this collection has claimed
+ private var myMemoryThreshold = 0L
+
+ // A comparator for keys K that orders them within a partition to allow aggregation or sorting.
+ // Can be a partial ordering by hash code if a total ordering is not provided through by the
+ // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some
+ // non-equal keys also have this, so we need to do a later pass to find truly equal keys).
+ // Note that we ignore this if no aggregator and no ordering are given.
+ private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] {
+ override def compare(a: K, b: K): Int = {
+ val h1 = if (a == null) 0 else a.hashCode()
+ val h2 = if (b == null) 0 else b.hashCode()
+ h1 - h2
+ }
+ })
+
+ // A comparator for (Int, K) elements that orders them by partition and then possibly by key
+ private val partitionKeyComparator: Comparator[(Int, K)] = {
+ if (ordering.isDefined || aggregator.isDefined) {
+ // Sort by partition ID then key comparator
+ new Comparator[(Int, K)] {
+ override def compare(a: (Int, K), b: (Int, K)): Int = {
+ val partitionDiff = a._1 - b._1
+ if (partitionDiff != 0) {
+ partitionDiff
+ } else {
+ keyComparator.compare(a._2, b._2)
+ }
+ }
+ }
+ } else {
+ // Just sort it by partition ID
+ new Comparator[(Int, K)] {
+ override def compare(a: (Int, K), b: (Int, K)): Int = {
+ a._1 - b._1
+ }
+ }
+ }
+ }
+
+ // Information about a spilled file. Includes sizes in bytes of "batches" written by the
+ // serializer as we periodically reset its stream, as well as number of elements in each
+ // partition, used to efficiently keep track of partitions when merging.
+ private[this] case class SpilledFile(
+ file: File,
+ blockId: BlockId,
+ serializerBatchSizes: Array[Long],
+ elementsPerPartition: Array[Long])
+ private val spills = new ArrayBuffer[SpilledFile]
+
+ def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
+ // TODO: stop combining if we find that the reduction factor isn't high
+ val shouldCombine = aggregator.isDefined
+
+ if (shouldCombine) {
+ // Combine values in-memory first using our AppendOnlyMap
+ val mergeValue = aggregator.get.mergeValue
+ val createCombiner = aggregator.get.createCombiner
+ var kv: Product2[K, V] = null
+ val update = (hadValue: Boolean, oldValue: C) => {
+ if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
+ }
+ while (records.hasNext) {
+ elementsRead += 1
+ kv = records.next()
+ map.changeValue((getPartition(kv._1), kv._1), update)
+ maybeSpill(usingMap = true)
+ }
+ } else {
+ // Stick values into our buffer
+ while (records.hasNext) {
+ elementsRead += 1
+ val kv = records.next()
+ buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
+ maybeSpill(usingMap = false)
+ }
+ }
+ }
+
+ /**
+ * Spill the current in-memory collection to disk if needed.
+ *
+ * @param usingMap whether we're using a map or buffer as our current in-memory collection
+ */
+ private def maybeSpill(usingMap: Boolean): Unit = {
+ if (!spillingEnabled) {
+ return
+ }
+
+ val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
+
+ // TODO: factor this out of both here and ExternalAppendOnlyMap
+ if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
+ collection.estimateSize() >= myMemoryThreshold)
+ {
+ // TODO: This logic doesn't work if there are two external collections being used in the same
+ // task (e.g. to read shuffle output and write it out into another shuffle) [SPARK-2711]
+
+ val currentSize = collection.estimateSize()
+ var shouldSpill = false
+ val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
+
+ // Atomically check whether there is sufficient memory in the global pool for
+ // us to double our threshold
+ shuffleMemoryMap.synchronized {
+ val threadId = Thread.currentThread().getId
+ val previouslyClaimedMemory = shuffleMemoryMap.get(threadId)
+ val availableMemory = maxMemoryThreshold -
+ (shuffleMemoryMap.values.sum - previouslyClaimedMemory.getOrElse(0L))
+
+ // Try to allocate at least 2x more memory, otherwise spill
+ shouldSpill = availableMemory < currentSize * 2
+ if (!shouldSpill) {
+ shuffleMemoryMap(threadId) = currentSize * 2
+ myMemoryThreshold = currentSize * 2
+ }
+ }
+ // Do not hold lock during spills
+ if (shouldSpill) {
+ spill(currentSize, usingMap)
+ }
+ }
+ }
+
+ /**
+ * Spill the current in-memory collection to disk, adding a new file to spills, and clear it.
+ *
+ * @param usingMap whether we're using a map or buffer as our current in-memory collection
+ */
+ private def spill(memorySize: Long, usingMap: Boolean): Unit = {
+ val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
+ val memorySize = collection.estimateSize()
+
+ spillCount += 1
+ val threadId = Thread.currentThread().getId
+ logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)"
+ .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
+ val (blockId, file) = diskBlockManager.createTempBlock()
+ var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize)
+ var objectsWritten = 0 // Objects written since the last flush
+
+ // List of batch sizes (bytes) in the order they are written to disk
+ val batchSizes = new ArrayBuffer[Long]
+
+ // How many elements we have in each partition
+ val elementsPerPartition = new Array[Long](numPartitions)
+
+ // Flush the disk writer's contents to disk, and update relevant variables
+ def flush() = {
+ writer.commit()
+ val bytesWritten = writer.bytesWritten
+ batchSizes.append(bytesWritten)
+ _diskBytesSpilled += bytesWritten
+ objectsWritten = 0
+ }
+
+ try {
+ val it = collection.destructiveSortedIterator(partitionKeyComparator)
+ while (it.hasNext) {
+ val elem = it.next()
+ val partitionId = elem._1._1
+ val key = elem._1._2
+ val value = elem._2
+ writer.write(key)
+ writer.write(value)
+ elementsPerPartition(partitionId) += 1
+ objectsWritten += 1
+
+ if (objectsWritten == serializerBatchSize) {
+ flush()
+ writer.close()
+ writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize)
+ }
+ }
+ if (objectsWritten > 0) {
+ flush()
+ }
+ writer.close()
+ } catch {
+ case e: Exception =>
+ writer.close()
+ file.delete()
+ throw e
+ }
+
+ if (usingMap) {
+ map = new SizeTrackingAppendOnlyMap[(Int, K), C]
+ } else {
+ buffer = new SizeTrackingPairBuffer[(Int, K), C]
+ }
+
+ // Reset the amount of shuffle memory used by this map in the global pool
+ val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
+ shuffleMemoryMap.synchronized {
+ shuffleMemoryMap(Thread.currentThread().getId) = 0
+ }
+ myMemoryThreshold = 0
+
+ spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
+ _memoryBytesSpilled += memorySize
+ }
+
+ /**
+ * Merge a sequence of sorted files, giving an iterator over partitions and then over elements
+ * inside each partition. This can be used to either write out a new file or return data to
+ * the user.
+ *
+ * Returns an iterator over all the data written to this object, grouped by partition. For each
+ * partition we then have an iterator over its contents, and these are expected to be accessed
+ * in order (you can't "skip ahead" to one partition without reading the previous one).
+ * Guaranteed to return a key-value pair for each partition, in order of partition ID.
+ */
+ private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
+ : Iterator[(Int, Iterator[Product2[K, C]])] = {
+ val readers = spills.map(new SpillReader(_))
+ val inMemBuffered = inMemory.buffered
+ (0 until numPartitions).iterator.map { p =>
+ val inMemIterator = new IteratorForPartition(p, inMemBuffered)
+ val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
+ if (aggregator.isDefined) {
+ // Perform partial aggregation across partitions
+ (p, mergeWithAggregation(
+ iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
+ } else if (ordering.isDefined) {
+ // No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey);
+ // sort the elements without trying to merge them
+ (p, mergeSort(iterators, ordering.get))
+ } else {
+ (p, iterators.iterator.flatten)
+ }
+ }
+ }
+
+ /**
+ * Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys.
+ */
+ private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K])
+ : Iterator[Product2[K, C]] =
+ {
+ val bufferedIters = iterators.filter(_.hasNext).map(_.buffered)
+ type Iter = BufferedIterator[Product2[K, C]]
+ val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] {
+ // Use the reverse of comparator.compare because PriorityQueue dequeues the max
+ override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1)
+ })
+ heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true
+ new Iterator[Product2[K, C]] {
+ override def hasNext: Boolean = !heap.isEmpty
+
+ override def next(): Product2[K, C] = {
+ if (!hasNext) {
+ throw new NoSuchElementException
+ }
+ val firstBuf = heap.dequeue()
+ val firstPair = firstBuf.next()
+ if (firstBuf.hasNext) {
+ heap.enqueue(firstBuf)
+ }
+ firstPair
+ }
+ }
+ }
+
+ /**
+ * Merge a sequence of (K, C) iterators by aggregating values for each key, assuming that each
+ * iterator is sorted by key with a given comparator. If the comparator is not a total ordering
+ * (e.g. when we sort objects by hash code and different keys may compare as equal although
+ * they're not), we still merge them by doing equality tests for all keys that compare as equal.
+ */
+ private def mergeWithAggregation(
+ iterators: Seq[Iterator[Product2[K, C]]],
+ mergeCombiners: (C, C) => C,
+ comparator: Comparator[K],
+ totalOrder: Boolean)
+ : Iterator[Product2[K, C]] =
+ {
+ if (!totalOrder) {
+ // We only have a partial ordering, e.g. comparing the keys by hash code, which means that
+ // multiple distinct keys might be treated as equal by the ordering. To deal with this, we
+ // need to read all keys considered equal by the ordering at once and compare them.
+ new Iterator[Iterator[Product2[K, C]]] {
+ val sorted = mergeSort(iterators, comparator).buffered
+
+ // Buffers reused across elements to decrease memory allocation
+ val keys = new ArrayBuffer[K]
+ val combiners = new ArrayBuffer[C]
+
+ override def hasNext: Boolean = sorted.hasNext
+
+ override def next(): Iterator[Product2[K, C]] = {
+ if (!hasNext) {
+ throw new NoSuchElementException
+ }
+ keys.clear()
+ combiners.clear()
+ val firstPair = sorted.next()
+ keys += firstPair._1
+ combiners += firstPair._2
+ val key = firstPair._1
+ while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) {
+ val pair = sorted.next()
+ var i = 0
+ var foundKey = false
+ while (i < keys.size && !foundKey) {
+ if (keys(i) == pair._1) {
+ combiners(i) = mergeCombiners(combiners(i), pair._2)
+ foundKey = true
+ }
+ i += 1
+ }
+ if (!foundKey) {
+ keys += pair._1
+ combiners += pair._2
+ }
+ }
+
+ // Note that we return an iterator of elements since we could've had many keys marked
+ // equal by the partial order; we flatten this below to get a flat iterator of (K, C).
+ keys.iterator.zip(combiners.iterator)
+ }
+ }.flatMap(i => i)
+ } else {
+ // We have a total ordering, so the objects with the same key are sequential.
+ new Iterator[Product2[K, C]] {
+ val sorted = mergeSort(iterators, comparator).buffered
+
+ override def hasNext: Boolean = sorted.hasNext
+
+ override def next(): Product2[K, C] = {
+ if (!hasNext) {
+ throw new NoSuchElementException
+ }
+ val elem = sorted.next()
+ val k = elem._1
+ var c = elem._2
+ while (sorted.hasNext && sorted.head._1 == k) {
+ c = mergeCombiners(c, sorted.head._2)
+ }
+ (k, c)
+ }
+ }
+ }
+ }
+
+ /**
+ * An internal class for reading a spilled file partition by partition. Expects all the
+ * partitions to be requested in order.
+ */
+ private[this] class SpillReader(spill: SpilledFile) {
+ val fileStream = new FileInputStream(spill.file)
+ val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)
+
+ // Track which partition and which batch stream we're in. These will be the indices of
+ // the next element we will read. We'll also store the last partition read so that
+ // readNextPartition() can figure out what partition that was from.
+ var partitionId = 0
+ var indexInPartition = 0L
+ var batchStreamsRead = 0
+ var indexInBatch = 0
+ var lastPartitionId = 0
+
+ skipToNextPartition()
+
+ // An intermediate stream that reads from exactly one batch
+ // This guards against pre-fetching and other arbitrary behavior of higher level streams
+ var batchStream = nextBatchStream()
+ var compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
+ var deserStream = serInstance.deserializeStream(compressedStream)
+ var nextItem: (K, C) = null
+ var finished = false
+
+ /** Construct a stream that only reads from the next batch */
+ def nextBatchStream(): InputStream = {
+ if (batchStreamsRead < spill.serializerBatchSizes.length) {
+ batchStreamsRead += 1
+ ByteStreams.limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1))
+ } else {
+ // No more batches left; give an empty stream
+ bufferedStream
+ }
+ }
+
+ /**
+ * Update partitionId if we have reached the end of our current partition, possibly skipping
+ * empty partitions on the way.
+ */
+ private def skipToNextPartition() {
+ while (partitionId < numPartitions &&
+ indexInPartition == spill.elementsPerPartition(partitionId)) {
+ partitionId += 1
+ indexInPartition = 0L
+ }
+ }
+
+ /**
+ * Return the next (K, C) pair from the deserialization stream and update partitionId,
+ * indexInPartition, indexInBatch and such to match its location.
+ *
+ * If the current batch is drained, construct a stream for the next batch and read from it.
+ * If no more pairs are left, return null.
+ */
+ private def readNextItem(): (K, C) = {
+ if (finished) {
+ return null
+ }
+ val k = deserStream.readObject().asInstanceOf[K]
+ val c = deserStream.readObject().asInstanceOf[C]
+ lastPartitionId = partitionId
+ // Start reading the next batch if we're done with this one
+ indexInBatch += 1
+ if (indexInBatch == serializerBatchSize) {
+ batchStream = nextBatchStream()
+ compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
+ deserStream = serInstance.deserializeStream(compressedStream)
+ indexInBatch = 0
+ }
+ // Update the partition location of the element we're reading
+ indexInPartition += 1
+ skipToNextPartition()
+ // If we've finished reading the last partition, remember that we're done
+ if (partitionId == numPartitions) {
+ finished = true
+ deserStream.close()
+ }
+ (k, c)
+ }
+
+ var nextPartitionToRead = 0
+
+ def readNextPartition(): Iterator[Product2[K, C]] = new Iterator[Product2[K, C]] {
+ val myPartition = nextPartitionToRead
+ nextPartitionToRead += 1
+
+ override def hasNext: Boolean = {
+ if (nextItem == null) {
+ nextItem = readNextItem()
+ if (nextItem == null) {
+ return false
+ }
+ }
+ assert(lastPartitionId >= myPartition)
+ // Check that we're still in the right partition; note that readNextItem will have returned
+ // null at EOF above so we would've returned false there
+ lastPartitionId == myPartition
+ }
+
+ override def next(): Product2[K, C] = {
+ if (!hasNext) {
+ throw new NoSuchElementException
+ }
+ val item = nextItem
+ nextItem = null
+ item
+ }
+ }
+ }
+
+ /**
+ * Return an iterator over all the data written to this object, grouped by partition and
+ * aggregated by the requested aggregator. For each partition we then have an iterator over its
+ * contents, and these are expected to be accessed in order (you can't "skip ahead" to one
+ * partition without reading the previous one). Guaranteed to return a key-value pair for each
+ * partition, in order of partition ID.
+ *
+ * For now, we just merge all the spilled files in once pass, but this can be modified to
+ * support hierarchical merging.
+ */
+ def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
+ val usingMap = aggregator.isDefined
+ val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
+ if (spills.isEmpty) {
+ // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
+ // we don't even need to sort by anything other than partition ID
+ if (!ordering.isDefined) {
+ // The user isn't requested sorted keys, so only sort by partition ID, not key
+ val partitionComparator = new Comparator[(Int, K)] {
+ override def compare(a: (Int, K), b: (Int, K)): Int = {
+ a._1 - b._1
+ }
+ }
+ groupByPartition(collection.destructiveSortedIterator(partitionComparator))
+ } else {
+ // We do need to sort by both partition ID and key
+ groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator))
+ }
+ } else {
+ // General case: merge spilled and in-memory data
+ merge(spills, collection.destructiveSortedIterator(partitionKeyComparator))
+ }
+ }
+
+ /**
+ * Return an iterator over all the data written to this object, aggregated by our aggregator.
+ */
+ def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2)
+
+ def stop(): Unit = {
+ spills.foreach(s => s.file.delete())
+ spills.clear()
+ }
+
+ def memoryBytesSpilled: Long = _memoryBytesSpilled
+
+ def diskBytesSpilled: Long = _diskBytesSpilled
+
+ /**
+ * Given a stream of ((partition, key), combiner) pairs *assumed to be sorted by partition ID*,
+ * group together the pairs for each partition into a sub-iterator.
+ *
+ * @param data an iterator of elements, assumed to already be sorted by partition ID
+ */
+ private def groupByPartition(data: Iterator[((Int, K), C)])
+ : Iterator[(Int, Iterator[Product2[K, C]])] =
+ {
+ val buffered = data.buffered
+ (0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered)))
+ }
+
+ /**
+ * An iterator that reads only the elements for a given partition ID from an underlying buffered
+ * stream, assuming this partition is the next one to be read. Used to make it easier to return
+ * partitioned iterators from our in-memory collection.
+ */
+ private[this] class IteratorForPartition(partitionId: Int, data: BufferedIterator[((Int, K), C)])
+ extends Iterator[Product2[K, C]]
+ {
+ override def hasNext: Boolean = data.hasNext && data.head._1._1 == partitionId
+
+ override def next(): Product2[K, C] = {
+ if (!hasNext) {
+ throw new NoSuchElementException
+ }
+ val elem = data.next()
+ (elem._1._2, elem._2)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
index de61e1d17f..eb4de41386 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
@@ -20,8 +20,9 @@ package org.apache.spark.util.collection
/**
* An append-only map that keeps track of its estimated size in bytes.
*/
-private[spark] class SizeTrackingAppendOnlyMap[K, V] extends AppendOnlyMap[K, V] with SizeTracker {
-
+private[spark] class SizeTrackingAppendOnlyMap[K, V]
+ extends AppendOnlyMap[K, V] with SizeTracker with SizeTrackingPairCollection[K, V]
+{
override def update(key: K, value: V): Unit = {
super.update(key, value)
super.afterUpdate()
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala
new file mode 100644
index 0000000000..9e9c16c5a2
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.util.Comparator
+
+/**
+ * Append-only buffer of key-value pairs that keeps track of its estimated size in bytes.
+ */
+private[spark] class SizeTrackingPairBuffer[K, V](initialCapacity: Int = 64)
+ extends SizeTracker with SizeTrackingPairCollection[K, V]
+{
+ require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
+ require(initialCapacity >= 1, "Invalid initial capacity")
+
+ // Basic growable array data structure. We use a single array of AnyRef to hold both the keys
+ // and the values, so that we can sort them efficiently with KVArraySortDataFormat.
+ private var capacity = initialCapacity
+ private var curSize = 0
+ private var data = new Array[AnyRef](2 * initialCapacity)
+
+ /** Add an element into the buffer */
+ def insert(key: K, value: V): Unit = {
+ if (curSize == capacity) {
+ growArray()
+ }
+ data(2 * curSize) = key.asInstanceOf[AnyRef]
+ data(2 * curSize + 1) = value.asInstanceOf[AnyRef]
+ curSize += 1
+ afterUpdate()
+ }
+
+ /** Total number of elements in buffer */
+ override def size: Int = curSize
+
+ /** Iterate over the elements of the buffer */
+ override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
+ var pos = 0
+
+ override def hasNext: Boolean = pos < curSize
+
+ override def next(): (K, V) = {
+ if (!hasNext) {
+ throw new NoSuchElementException
+ }
+ val pair = (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
+ pos += 1
+ pair
+ }
+ }
+
+ /** Double the size of the array because we've reached capacity */
+ private def growArray(): Unit = {
+ if (capacity == (1 << 29)) {
+ // Doubling the capacity would create an array bigger than Int.MaxValue, so don't
+ throw new Exception("Can't grow buffer beyond 2^29 elements")
+ }
+ val newCapacity = capacity * 2
+ val newArray = new Array[AnyRef](2 * newCapacity)
+ System.arraycopy(data, 0, newArray, 0, 2 * capacity)
+ data = newArray
+ capacity = newCapacity
+ resetSamples()
+ }
+
+ /** Iterate through the data in a given order. For this class this is not really destructive. */
+ override def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = {
+ new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, curSize, keyComparator)
+ iterator
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala
new file mode 100644
index 0000000000..faa4e2b12d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.util.Comparator
+
+/**
+ * A common interface for our size-tracking collections of key-value pairs, which are used in
+ * external operations. These all support estimating the size and obtaining a memory-efficient
+ * sorted iterator.
+ */
+// TODO: should extend Iterable[Product2[K, V]] instead of (K, V)
+private[spark] trait SizeTrackingPairCollection[K, V] extends Iterable[(K, V)] {
+ /** Estimate the collection's current memory usage in bytes. */
+ def estimateSize(): Long
+
+ /** Iterate through the data in a given key order. This may destroy the underlying collection. */
+ def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)]
+}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index d1cb2d9d3a..a41914a1a9 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -99,7 +99,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
test("ShuffledRDD") {
testRDD(rdd => {
// Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
- new ShuffledRDD[Int, Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner)
+ new ShuffledRDD[Int, Int, Int](rdd.map(x => (x % 2, 1)), partitioner)
})
}
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index ad20f9b937..4bc4346c0a 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -19,9 +19,6 @@ package org.apache.spark
import java.lang.ref.WeakReference
-import org.apache.spark.broadcast.Broadcast
-
-import scala.collection.mutable
import scala.collection.mutable.{HashSet, SynchronizedSet}
import scala.language.existentials
import scala.language.postfixOps
@@ -34,15 +31,28 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.{BlockId, BroadcastBlockId, RDDBlockId, ShuffleBlockId}
-
-class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
-
+import org.apache.spark.storage._
+import org.apache.spark.shuffle.hash.HashShuffleManager
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.storage.BroadcastBlockId
+import org.apache.spark.storage.RDDBlockId
+import org.apache.spark.storage.ShuffleBlockId
+import org.apache.spark.storage.ShuffleIndexBlockId
+
+/**
+ * An abstract base class for context cleaner tests, which sets up a context with a config
+ * suitable for cleaner tests and provides some utility functions. Subclasses can use different
+ * config options, in particular, a different shuffle manager class
+ */
+abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[HashShuffleManager])
+ extends FunSuite with BeforeAndAfter with LocalSparkContext
+{
implicit val defaultTimeout = timeout(10000 millis)
val conf = new SparkConf()
.setMaster("local[2]")
.setAppName("ContextCleanerSuite")
.set("spark.cleaner.referenceTracking.blocking", "true")
+ .set("spark.shuffle.manager", shuffleManager.getName)
before {
sc = new SparkContext(conf)
@@ -55,6 +65,59 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
}
}
+ //------ Helper functions ------
+
+ protected def newRDD() = sc.makeRDD(1 to 10)
+ protected def newPairRDD() = newRDD().map(_ -> 1)
+ protected def newShuffleRDD() = newPairRDD().reduceByKey(_ + _)
+ protected def newBroadcast() = sc.broadcast(1 to 100)
+
+ protected def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
+ def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
+ rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
+ getAllDependencies(dep.rdd)
+ }
+ }
+ val rdd = newShuffleRDD()
+
+ // Get all the shuffle dependencies
+ val shuffleDeps = getAllDependencies(rdd)
+ .filter(_.isInstanceOf[ShuffleDependency[_, _, _]])
+ .map(_.asInstanceOf[ShuffleDependency[_, _, _]])
+ (rdd, shuffleDeps)
+ }
+
+ protected def randomRdd() = {
+ val rdd: RDD[_] = Random.nextInt(3) match {
+ case 0 => newRDD()
+ case 1 => newShuffleRDD()
+ case 2 => newPairRDD.join(newPairRDD())
+ }
+ if (Random.nextBoolean()) rdd.persist()
+ rdd.count()
+ rdd
+ }
+
+ /** Run GC and make sure it actually has run */
+ protected def runGC() {
+ val weakRef = new WeakReference(new Object())
+ val startTime = System.currentTimeMillis
+ System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
+ // Wait until a weak reference object has been GCed
+ while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
+ System.gc()
+ Thread.sleep(200)
+ }
+ }
+
+ protected def cleaner = sc.cleaner.get
+}
+
+
+/**
+ * Basic ContextCleanerSuite, which uses sort-based shuffle
+ */
+class ContextCleanerSuite extends ContextCleanerSuiteBase {
test("cleanup RDD") {
val rdd = newRDD().persist()
val collected = rdd.collect().toList
@@ -147,7 +210,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
val numRdds = 100
val numBroadcasts = 4 // Broadcasts are more costly
val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
- val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer
+ val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast()).toBuffer
val rddIds = sc.persistentRdds.keys.toSeq
val shuffleIds = 0 until sc.newShuffleId
val broadcastIds = broadcastBuffer.map(_.id)
@@ -180,12 +243,13 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
.setMaster("local-cluster[2, 1, 512]")
.setAppName("ContextCleanerSuite")
.set("spark.cleaner.referenceTracking.blocking", "true")
+ .set("spark.shuffle.manager", shuffleManager.getName)
sc = new SparkContext(conf2)
val numRdds = 10
val numBroadcasts = 4 // Broadcasts are more costly
val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
- val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer
+ val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast()).toBuffer
val rddIds = sc.persistentRdds.keys.toSeq
val shuffleIds = 0 until sc.newShuffleId
val broadcastIds = broadcastBuffer.map(_.id)
@@ -210,57 +274,82 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
case _ => false
}, askSlaves = true).isEmpty)
}
+}
- //------ Helper functions ------
- private def newRDD() = sc.makeRDD(1 to 10)
- private def newPairRDD() = newRDD().map(_ -> 1)
- private def newShuffleRDD() = newPairRDD().reduceByKey(_ + _)
- private def newBroadcast() = sc.broadcast(1 to 100)
+/**
+ * A copy of the shuffle tests for sort-based shuffle
+ */
+class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[SortShuffleManager]) {
+ test("cleanup shuffle") {
+ val (rdd, shuffleDeps) = newRDDWithShuffleDependencies()
+ val collected = rdd.collect().toList
+ val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId))
- private def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
- def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
- rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
- getAllDependencies(dep.rdd)
- }
- }
- val rdd = newShuffleRDD()
+ // Explicit cleanup
+ shuffleDeps.foreach(s => cleaner.doCleanupShuffle(s.shuffleId, blocking = true))
+ tester.assertCleanup()
- // Get all the shuffle dependencies
- val shuffleDeps = getAllDependencies(rdd)
- .filter(_.isInstanceOf[ShuffleDependency[_, _, _]])
- .map(_.asInstanceOf[ShuffleDependency[_, _, _]])
- (rdd, shuffleDeps)
+ // Verify that shuffles can be re-executed after cleaning up
+ assert(rdd.collect().toList.equals(collected))
}
- private def randomRdd() = {
- val rdd: RDD[_] = Random.nextInt(3) match {
- case 0 => newRDD()
- case 1 => newShuffleRDD()
- case 2 => newPairRDD.join(newPairRDD())
- }
- if (Random.nextBoolean()) rdd.persist()
+ test("automatically cleanup shuffle") {
+ var rdd = newShuffleRDD()
rdd.count()
- rdd
- }
- private def randomBroadcast() = {
- sc.broadcast(Random.nextInt(Int.MaxValue))
+ // Test that GC does not cause shuffle cleanup due to a strong reference
+ val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
+ runGC()
+ intercept[Exception] {
+ preGCTester.assertCleanup()(timeout(1000 millis))
+ }
+
+ // Test that GC causes shuffle cleanup after dereferencing the RDD
+ val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
+ rdd = null // Make RDD out of scope, so that corresponding shuffle goes out of scope
+ runGC()
+ postGCTester.assertCleanup()
}
- /** Run GC and make sure it actually has run */
- private def runGC() {
- val weakRef = new WeakReference(new Object())
- val startTime = System.currentTimeMillis
- System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
- // Wait until a weak reference object has been GCed
- while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
- System.gc()
- Thread.sleep(200)
+ test("automatically cleanup RDD + shuffle + broadcast in distributed mode") {
+ sc.stop()
+
+ val conf2 = new SparkConf()
+ .setMaster("local-cluster[2, 1, 512]")
+ .setAppName("ContextCleanerSuite")
+ .set("spark.cleaner.referenceTracking.blocking", "true")
+ .set("spark.shuffle.manager", shuffleManager.getName)
+ sc = new SparkContext(conf2)
+
+ val numRdds = 10
+ val numBroadcasts = 4 // Broadcasts are more costly
+ val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
+ val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast).toBuffer
+ val rddIds = sc.persistentRdds.keys.toSeq
+ val shuffleIds = 0 until sc.newShuffleId()
+ val broadcastIds = broadcastBuffer.map(_.id)
+
+ val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+ runGC()
+ intercept[Exception] {
+ preGCTester.assertCleanup()(timeout(1000 millis))
}
- }
- private def cleaner = sc.cleaner.get
+ // Test that GC triggers the cleanup of all variables after the dereferencing them
+ val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+ broadcastBuffer.clear()
+ rddBuffer.clear()
+ runGC()
+ postGCTester.assertCleanup()
+
+ // Make sure the broadcasted task closure no longer exists after GC.
+ val taskClosureBroadcastId = broadcastIds.max + 1
+ assert(sc.env.blockManager.master.getMatchingBlockIds({
+ case BroadcastBlockId(`taskClosureBroadcastId`, _) => true
+ case _ => false
+ }, askSlaves = true).isEmpty)
+ }
}
@@ -418,6 +507,7 @@ class CleanerTester(
private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = {
blockManager.master.getMatchingBlockIds( _ match {
case ShuffleBlockId(`shuffleId`, _, _) => true
+ case ShuffleIndexBlockId(`shuffleId`, _, _) => true
case _ => false
}, askSlaves = true)
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
index 47df00050c..d7b2d2e1e3 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
@@ -28,6 +28,6 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll {
}
override def afterAll() {
- System.setProperty("spark.shuffle.use.netty", "false")
+ System.clearProperty("spark.shuffle.use.netty")
}
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index eae67c7747..b13ddf96bc 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -58,8 +58,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD[Int,
NonJavaSerializableClass,
- NonJavaSerializableClass,
- (Int, NonJavaSerializableClass)](b, new HashPartitioner(NUM_BLOCKS))
+ NonJavaSerializableClass](b, new HashPartitioner(NUM_BLOCKS))
c.setSerializer(new KryoSerializer(conf))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
@@ -83,8 +82,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD[Int,
NonJavaSerializableClass,
- NonJavaSerializableClass,
- (Int, NonJavaSerializableClass)](b, new HashPartitioner(3))
+ NonJavaSerializableClass](b, new HashPartitioner(3))
c.setSerializer(new KryoSerializer(conf))
assert(c.count === 10)
}
@@ -100,7 +98,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
// NOTE: The default Java serializer doesn't create zero-sized blocks.
// So, use Kryo
- val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10))
+ val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10))
.setSerializer(new KryoSerializer(conf))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
@@ -126,7 +124,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
val b = a.map(x => (x, x*2))
// NOTE: The default Java serializer should create zero-sized blocks
- val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10))
+ val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 4)
@@ -141,19 +139,19 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
assert(nonEmptyBlocks.size <= 4)
}
- test("shuffle using mutable pairs") {
+ test("shuffle on mutable pairs") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
sc = new SparkContext("local-cluster[2,1,512]", "test")
def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1))
val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2)
- val results = new ShuffledRDD[Int, Int, Int, MutablePair[Int, Int]](pairs,
+ val results = new ShuffledRDD[Int, Int, Int](pairs,
new HashPartitioner(2)).collect()
- data.foreach { pair => results should contain (pair) }
+ data.foreach { pair => results should contain ((pair._1, pair._2)) }
}
- test("sorting using mutable pairs") {
+ test("sorting on mutable pairs") {
// This is not in SortingSuite because of the local cluster setup.
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
sc = new SparkContext("local-cluster[2,1,512]", "test")
@@ -162,10 +160,10 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2)
val results = new OrderedRDDFunctions[Int, Int, MutablePair[Int, Int]](pairs)
.sortByKey().collect()
- results(0) should be (p(1, 11))
- results(1) should be (p(2, 22))
- results(2) should be (p(3, 33))
- results(3) should be (p(100, 100))
+ results(0) should be ((1, 11))
+ results(1) should be ((2, 22))
+ results(2) should be ((3, 33))
+ results(3) should be ((100, 100))
}
test("cogroup using mutable pairs") {
diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
new file mode 100644
index 0000000000..5c02c00586
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.BeforeAndAfterAll
+
+class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
+
+ // This test suite should run all tests in ShuffleSuite with sort-based shuffle.
+
+ override def beforeAll() {
+ System.setProperty("spark.shuffle.manager",
+ "org.apache.spark.shuffle.sort.SortShuffleManager")
+ }
+
+ override def afterAll() {
+ System.clearProperty("spark.shuffle.manager")
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 4953d565ae..8966eedd80 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -270,7 +270,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
// we can optionally shuffle to keep the upstream parallel
val coalesced5 = data.coalesce(1, shuffle = true)
val isEquals = coalesced5.dependencies.head.rdd.dependencies.head.rdd.
- asInstanceOf[ShuffledRDD[_, _, _, _]] != null
+ asInstanceOf[ShuffledRDD[_, _, _]] != null
assert(isEquals)
// when shuffling, we can increase the number of partitions
@@ -730,9 +730,9 @@ class RDDSuite extends FunSuite with SharedSparkContext {
// Any ancestors before the shuffle are not considered
assert(ancestors4.size === 0)
- assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 0)
+ assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 0)
assert(ancestors5.size === 3)
- assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 1)
+ assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 1)
assert(ancestors5.count(_.isInstanceOf[MapPartitionsRDD[_, _]]) === 0)
assert(ancestors5.count(_.isInstanceOf[MappedValuesRDD[_, _, _]]) === 2)
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 0b7ad184a4..7de5df6e1c 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -208,11 +208,8 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
val resultA = rddA.reduceByKey(math.max).collect()
assert(resultA.length == 50000)
resultA.foreach { case(k, v) =>
- k match {
- case 0 => assert(v == 1)
- case 25000 => assert(v == 50001)
- case 49999 => assert(v == 99999)
- case _ =>
+ if (v != k * 2 + 1) {
+ fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}")
}
}
@@ -221,11 +218,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
val resultB = rddB.groupByKey().collect()
assert(resultB.length == 25000)
resultB.foreach { case(i, seq) =>
- i match {
- case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3))
- case 12500 => assert(seq.toSet == Set[Int](50000, 50001, 50002, 50003))
- case 24999 => assert(seq.toSet == Set[Int](99996, 99997, 99998, 99999))
- case _ =>
+ val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3)
+ if (seq.toSet != expected) {
+ fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}")
}
}
@@ -239,6 +234,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
case 0 =>
assert(seq1.toSet == Set[Int](0))
assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
+ case 1 =>
+ assert(seq1.toSet == Set[Int](1))
+ assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001))
case 5000 =>
assert(seq1.toSet == Set[Int](5000))
assert(seq2.toSet == Set[Int]())
@@ -369,10 +367,3 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
}
}
-
-/**
- * A dummy class that always returns the same hash code, to easily test hash collisions
- */
-case class FixedHashObject(v: Int, h: Int) extends Serializable {
- override def hashCode(): Int = h
-}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
new file mode 100644
index 0000000000..ddb5df4036
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -0,0 +1,566 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.FunSuite
+
+import org.apache.spark._
+import org.apache.spark.SparkContext._
+
+class ExternalSorterSuite extends FunSuite with LocalSparkContext {
+ test("empty data stream") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+ val ord = implicitly[Ordering[Int]]
+
+ // Both aggregator and ordering
+ val sorter = new ExternalSorter[Int, Int, Int](
+ Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
+ assert(sorter.iterator.toSeq === Seq())
+ sorter.stop()
+
+ // Only aggregator
+ val sorter2 = new ExternalSorter[Int, Int, Int](
+ Some(agg), Some(new HashPartitioner(3)), None, None)
+ assert(sorter2.iterator.toSeq === Seq())
+ sorter2.stop()
+
+ // Only ordering
+ val sorter3 = new ExternalSorter[Int, Int, Int](
+ None, Some(new HashPartitioner(3)), Some(ord), None)
+ assert(sorter3.iterator.toSeq === Seq())
+ sorter3.stop()
+
+ // Neither aggregator nor ordering
+ val sorter4 = new ExternalSorter[Int, Int, Int](
+ None, Some(new HashPartitioner(3)), None, None)
+ assert(sorter4.iterator.toSeq === Seq())
+ sorter4.stop()
+ }
+
+ test("few elements per partition") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+ val ord = implicitly[Ordering[Int]]
+ val elements = Set((1, 1), (2, 2), (5, 5))
+ val expected = Set(
+ (0, Set()), (1, Set((1, 1))), (2, Set((2, 2))), (3, Set()), (4, Set()),
+ (5, Set((5, 5))), (6, Set()))
+
+ // Both aggregator and ordering
+ val sorter = new ExternalSorter[Int, Int, Int](
+ Some(agg), Some(new HashPartitioner(7)), Some(ord), None)
+ sorter.write(elements.iterator)
+ assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
+ sorter.stop()
+
+ // Only aggregator
+ val sorter2 = new ExternalSorter[Int, Int, Int](
+ Some(agg), Some(new HashPartitioner(7)), None, None)
+ sorter2.write(elements.iterator)
+ assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
+ sorter2.stop()
+
+ // Only ordering
+ val sorter3 = new ExternalSorter[Int, Int, Int](
+ None, Some(new HashPartitioner(7)), Some(ord), None)
+ sorter3.write(elements.iterator)
+ assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
+ sorter3.stop()
+
+ // Neither aggregator nor ordering
+ val sorter4 = new ExternalSorter[Int, Int, Int](
+ None, Some(new HashPartitioner(7)), None, None)
+ sorter4.write(elements.iterator)
+ assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
+ sorter4.stop()
+ }
+
+ test("empty partitions with spilling") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+ val ord = implicitly[Ordering[Int]]
+ val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
+
+ val sorter = new ExternalSorter[Int, Int, Int](
+ None, Some(new HashPartitioner(7)), None, None)
+ sorter.write(elements)
+ assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled
+ val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
+ assert(iter.next() === (0, Nil))
+ assert(iter.next() === (1, List((1, 1))))
+ assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList))
+ assert(iter.next() === (3, Nil))
+ assert(iter.next() === (4, Nil))
+ assert(iter.next() === (5, List((5, 5))))
+ assert(iter.next() === (6, Nil))
+ sorter.stop()
+ }
+
+ test("spilling in local cluster") {
+ val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+
+ // reduceByKey - should spill ~8 times
+ val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i))
+ val resultA = rddA.reduceByKey(math.max).collect()
+ assert(resultA.length == 50000)
+ resultA.foreach { case(k, v) =>
+ if (v != k * 2 + 1) {
+ fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}")
+ }
+ }
+
+ // groupByKey - should spill ~17 times
+ val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i))
+ val resultB = rddB.groupByKey().collect()
+ assert(resultB.length == 25000)
+ resultB.foreach { case(i, seq) =>
+ val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3)
+ if (seq.toSet != expected) {
+ fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}")
+ }
+ }
+
+ // cogroup - should spill ~7 times
+ val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i))
+ val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i))
+ val resultC = rddC1.cogroup(rddC2).collect()
+ assert(resultC.length == 10000)
+ resultC.foreach { case(i, (seq1, seq2)) =>
+ i match {
+ case 0 =>
+ assert(seq1.toSet == Set[Int](0))
+ assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
+ case 1 =>
+ assert(seq1.toSet == Set[Int](1))
+ assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001))
+ case 5000 =>
+ assert(seq1.toSet == Set[Int](5000))
+ assert(seq2.toSet == Set[Int]())
+ case 9999 =>
+ assert(seq1.toSet == Set[Int](9999))
+ assert(seq2.toSet == Set[Int]())
+ case _ =>
+ }
+ }
+
+ // larger cogroup - should spill ~7 times
+ val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i))
+ val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i))
+ val resultD = rddD1.cogroup(rddD2).collect()
+ assert(resultD.length == 5000)
+ resultD.foreach { case(i, (seq1, seq2)) =>
+ val expected = Set(i * 2, i * 2 + 1)
+ if (seq1.toSet != expected) {
+ fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}")
+ }
+ if (seq2.toSet != expected) {
+ fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}")
+ }
+ }
+ }
+
+ test("spilling in local cluster with many reduce tasks") {
+ val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+
+ // reduceByKey - should spill ~4 times per executor
+ val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i))
+ val resultA = rddA.reduceByKey(math.max _, 100).collect()
+ assert(resultA.length == 50000)
+ resultA.foreach { case(k, v) =>
+ if (v != k * 2 + 1) {
+ fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}")
+ }
+ }
+
+ // groupByKey - should spill ~8 times per executor
+ val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i))
+ val resultB = rddB.groupByKey(100).collect()
+ assert(resultB.length == 25000)
+ resultB.foreach { case(i, seq) =>
+ val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3)
+ if (seq.toSet != expected) {
+ fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}")
+ }
+ }
+
+ // cogroup - should spill ~4 times per executor
+ val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i))
+ val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i))
+ val resultC = rddC1.cogroup(rddC2, 100).collect()
+ assert(resultC.length == 10000)
+ resultC.foreach { case(i, (seq1, seq2)) =>
+ i match {
+ case 0 =>
+ assert(seq1.toSet == Set[Int](0))
+ assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
+ case 1 =>
+ assert(seq1.toSet == Set[Int](1))
+ assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001))
+ case 5000 =>
+ assert(seq1.toSet == Set[Int](5000))
+ assert(seq2.toSet == Set[Int]())
+ case 9999 =>
+ assert(seq1.toSet == Set[Int](9999))
+ assert(seq2.toSet == Set[Int]())
+ case _ =>
+ }
+ }
+
+ // larger cogroup - should spill ~4 times per executor
+ val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i))
+ val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i))
+ val resultD = rddD1.cogroup(rddD2).collect()
+ assert(resultD.length == 5000)
+ resultD.foreach { case(i, (seq1, seq2)) =>
+ val expected = Set(i * 2, i * 2 + 1)
+ if (seq1.toSet != expected) {
+ fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}")
+ }
+ if (seq2.toSet != expected) {
+ fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}")
+ }
+ }
+ }
+
+ test("cleanup of intermediate files in sorter") {
+ val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+ val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+
+ val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
+ sorter.write((0 until 100000).iterator.map(i => (i, i)))
+ assert(diskBlockManager.getAllFiles().length > 0)
+ sorter.stop()
+ assert(diskBlockManager.getAllBlocks().length === 0)
+
+ val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
+ sorter2.write((0 until 100000).iterator.map(i => (i, i)))
+ assert(diskBlockManager.getAllFiles().length > 0)
+ assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet)
+ sorter2.stop()
+ assert(diskBlockManager.getAllBlocks().length === 0)
+ }
+
+ test("cleanup of intermediate files in sorter if there are errors") {
+ val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+ val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+
+ val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
+ intercept[SparkException] {
+ sorter.write((0 until 100000).iterator.map(i => {
+ if (i == 99990) {
+ throw new SparkException("Intentional failure")
+ }
+ (i, i)
+ }))
+ }
+ assert(diskBlockManager.getAllFiles().length > 0)
+ sorter.stop()
+ assert(diskBlockManager.getAllBlocks().length === 0)
+ }
+
+ test("cleanup of intermediate files in shuffle") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+ val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+
+ val data = sc.parallelize(0 until 100000, 2).map(i => (i, i))
+ assert(data.reduceByKey(_ + _).count() === 100000)
+
+ // After the shuffle, there should be only 4 files on disk: our two map output files and
+ // their index files. All other intermediate files should've been deleted.
+ assert(diskBlockManager.getAllFiles().length === 4)
+ }
+
+ test("cleanup of intermediate files in shuffle with errors") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+ val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+
+ val data = sc.parallelize(0 until 100000, 2).map(i => {
+ if (i == 99990) {
+ throw new Exception("Intentional failure")
+ }
+ (i, i)
+ })
+ intercept[SparkException] {
+ data.reduceByKey(_ + _).count()
+ }
+
+ // After the shuffle, there should be only 2 files on disk: the output of task 1 and its index.
+ // All other files (map 2's output and intermediate merge files) should've been deleted.
+ assert(diskBlockManager.getAllFiles().length === 2)
+ }
+
+ test("no partial aggregation or sorting") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
+ sorter.write((0 until 100000).iterator.map(i => (i / 4, i)))
+ val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
+ val expected = (0 until 3).map(p => {
+ (p, (0 until 100000).map(i => (i / 4, i)).filter(_._1 % 3 == p).toSet)
+ }).toSet
+ assert(results === expected)
+ }
+
+ test("partial aggregation without spill") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+ val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None)
+ sorter.write((0 until 100).iterator.map(i => (i / 2, i)))
+ val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
+ val expected = (0 until 3).map(p => {
+ (p, (0 until 50).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
+ }).toSet
+ assert(results === expected)
+ }
+
+ test("partial aggregation with spill, no ordering") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+ val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None)
+ sorter.write((0 until 100000).iterator.map(i => (i / 2, i)))
+ val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
+ val expected = (0 until 3).map(p => {
+ (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
+ }).toSet
+ assert(results === expected)
+ }
+
+ test("partial aggregation with spill, with ordering") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+ val ord = implicitly[Ordering[Int]]
+ val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
+ sorter.write((0 until 100000).iterator.map(i => (i / 2, i)))
+ val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
+ val expected = (0 until 3).map(p => {
+ (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
+ }).toSet
+ assert(results === expected)
+ }
+
+ test("sorting without aggregation, no spill") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val ord = implicitly[Ordering[Int]]
+ val sorter = new ExternalSorter[Int, Int, Int](
+ None, Some(new HashPartitioner(3)), Some(ord), None)
+ sorter.write((0 until 100).iterator.map(i => (i, i)))
+ val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq
+ val expected = (0 until 3).map(p => {
+ (p, (0 until 100).map(i => (i, i)).filter(_._1 % 3 == p).toSeq)
+ }).toSeq
+ assert(results === expected)
+ }
+
+ test("sorting without aggregation, with spill") {
+ val conf = new SparkConf(false)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+ sc = new SparkContext("local", "test", conf)
+
+ val ord = implicitly[Ordering[Int]]
+ val sorter = new ExternalSorter[Int, Int, Int](
+ None, Some(new HashPartitioner(3)), Some(ord), None)
+ sorter.write((0 until 100000).iterator.map(i => (i, i)))
+ val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq
+ val expected = (0 until 3).map(p => {
+ (p, (0 until 100000).map(i => (i, i)).filter(_._1 % 3 == p).toSeq)
+ }).toSeq
+ assert(results === expected)
+ }
+
+ test("spilling with hash collisions") {
+ val conf = new SparkConf(true)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+
+ def createCombiner(i: String) = ArrayBuffer[String](i)
+ def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i
+ def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) =
+ buffer1 ++= buffer2
+
+ val agg = new Aggregator[String, String, ArrayBuffer[String]](
+ createCombiner _, mergeValue _, mergeCombiners _)
+
+ val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
+ Some(agg), None, None, None)
+
+ val collisionPairs = Seq(
+ ("Aa", "BB"), // 2112
+ ("to", "v1"), // 3707
+ ("variants", "gelato"), // -1249574770
+ ("Teheran", "Siblings"), // 231609873
+ ("misused", "horsemints"), // 1069518484
+ ("isohel", "epistolaries"), // -1179291542
+ ("righto", "buzzards"), // -931102253
+ ("hierarch", "crinolines"), // -1732884796
+ ("inwork", "hypercatalexes"), // -1183663690
+ ("wainages", "presentencing"), // 240183619
+ ("trichothecenes", "locular"), // 339006536
+ ("pomatoes", "eructation") // 568647356
+ )
+
+ collisionPairs.foreach { case (w1, w2) =>
+ // String.hashCode is documented to use a specific algorithm, but check just in case
+ assert(w1.hashCode === w2.hashCode)
+ }
+
+ val toInsert = (1 to 100000).iterator.map(_.toString).map(s => (s, s)) ++
+ collisionPairs.iterator ++ collisionPairs.iterator.map(_.swap)
+
+ sorter.write(toInsert)
+
+ // A map of collision pairs in both directions
+ val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap
+
+ // Avoid map.size or map.iterator.length because this destructively sorts the underlying map
+ var count = 0
+
+ val it = sorter.iterator
+ while (it.hasNext) {
+ val kv = it.next()
+ val expectedValue = ArrayBuffer[String](collisionPairsMap.getOrElse(kv._1, kv._1))
+ assert(kv._2.equals(expectedValue))
+ count += 1
+ }
+ assert(count === 100000 + collisionPairs.size * 2)
+ }
+
+ test("spilling with many hash collisions") {
+ val conf = new SparkConf(true)
+ conf.set("spark.shuffle.memoryFraction", "0.0001")
+ sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+
+ val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _)
+ val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None)
+
+ // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes
+ // problems if the map fails to group together the objects with the same code (SPARK-2043).
+ val toInsert = for (i <- 1 to 10; j <- 1 to 10000) yield (FixedHashObject(j, j % 2), 1)
+ sorter.write(toInsert.iterator)
+
+ val it = sorter.iterator
+ var count = 0
+ while (it.hasNext) {
+ val kv = it.next()
+ assert(kv._2 === 10)
+ count += 1
+ }
+ assert(count === 10000)
+ }
+
+ test("spilling with hash collisions using the Int.MaxValue key") {
+ val conf = new SparkConf(true)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+
+ def createCombiner(i: Int) = ArrayBuffer[Int](i)
+ def mergeValue(buffer: ArrayBuffer[Int], i: Int) = buffer += i
+ def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]) = buf1 ++= buf2
+
+ val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners)
+ val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None)
+
+ sorter.write((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue)))
+
+ val it = sorter.iterator
+ while (it.hasNext) {
+ // Should not throw NoSuchElementException
+ it.next()
+ }
+ }
+
+ test("spilling with null keys and values") {
+ val conf = new SparkConf(true)
+ conf.set("spark.shuffle.memoryFraction", "0.001")
+ sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+
+ def createCombiner(i: String) = ArrayBuffer[String](i)
+ def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i
+ def mergeCombiners(buf1: ArrayBuffer[String], buf2: ArrayBuffer[String]) = buf1 ++= buf2
+
+ val agg = new Aggregator[String, String, ArrayBuffer[String]](
+ createCombiner, mergeValue, mergeCombiners)
+
+ val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
+ Some(agg), None, None, None)
+
+ sorter.write((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator(
+ (null.asInstanceOf[String], "1"),
+ ("1", null.asInstanceOf[String]),
+ (null.asInstanceOf[String], null.asInstanceOf[String])
+ ))
+
+ val it = sorter.iterator
+ while (it.hasNext) {
+ // Should not throw NullPointerException
+ it.next()
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala b/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala
new file mode 100644
index 0000000000..c787b5f066
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+/**
+ * A dummy class that always returns the same hash code, to easily test hash collisions
+ */
+case class FixedHashObject(v: Int, h: Int) extends Serializable {
+ override def hashCode(): Int = h
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
index 5318b8da64..714f3b81c9 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
@@ -28,7 +28,7 @@ import org.apache.spark.rdd.{ShuffledRDD, RDD}
private[graphx]
class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) {
def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = {
- val rdd = new ShuffledRDD[VertexId, VD, VD, (VertexId, VD)](self, partitioner)
+ val rdd = new ShuffledRDD[VertexId, VD, VD](self, partitioner)
// Set a custom serializer if the data is of int or double type.
if (classTag[VD] == ClassTag.Int) {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
index a565d3b28b..b27485953f 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
@@ -33,7 +33,7 @@ private[graphx]
class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) {
/** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */
def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = {
- new ShuffledRDD[VertexId, Int, Int, RoutingTableMessage](
+ new ShuffledRDD[VertexId, Int, Int](
self, partitioner).setSerializer(new RoutingTableMessageSerializer)
}
}
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 672343fbbe..a8bbd55861 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -295,6 +295,7 @@ object Unidoc {
.map(_.filterNot(_.getCanonicalPath.contains("akka")))
.map(_.filterNot(_.getCanonicalPath.contains("deploy")))
.map(_.filterNot(_.getCanonicalPath.contains("network")))
+ .map(_.filterNot(_.getCanonicalPath.contains("shuffle")))
.map(_.filterNot(_.getCanonicalPath.contains("executor")))
.map(_.filterNot(_.getCanonicalPath.contains("python")))
.map(_.filterNot(_.getCanonicalPath.contains("collection")))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 392a7f3be3..30712f03ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -49,7 +49,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
iter.map(r => mutablePair.update(hashExpressions(r), r))
}
val part = new HashPartitioner(numPartitions)
- val shuffled = new ShuffledRDD[Row, Row, Row, MutablePair[Row, Row]](rdd, part)
+ val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._2)
@@ -62,7 +62,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
iter.map(row => mutablePair.update(row, null))
}
val part = new RangePartitioner(numPartitions, rdd, ascending = true)
- val shuffled = new ShuffledRDD[Row, Null, Null, MutablePair[Row, Null]](rdd, part)
+ val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._1)
@@ -73,7 +73,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
iter.map(r => mutablePair.update(null, r))
}
val partitioner = new HashPartitioner(1)
- val shuffled = new ShuffledRDD[Null, Row, Row, MutablePair[Null, Row]](rdd, partitioner)
+ val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.map(_._2)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 174eda8f1a..0027f3cf1f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -148,7 +148,7 @@ case class Limit(limit: Int, child: SparkPlan)
iter.take(limit).map(row => mutablePair.update(false, row))
}
val part = new HashPartitioner(1)
- val shuffled = new ShuffledRDD[Boolean, Row, Row, MutablePair[Boolean, Row]](rdd, part)
+ val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part)
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
shuffled.mapPartitions(_.take(limit).map(_._2))
}