diff options
author | Matei Zaharia <matei@databricks.com> | 2014-06-11 20:45:29 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-06-11 20:45:29 -0700 |
commit | 508fd371d6dbb826fd8a00787d347235b549e189 (patch) | |
tree | 2cc77c43ab7b9f30a9b1afc6f55d9438779e46ab /core | |
parent | d9203350b06a9c737421ba244162b1365402c01b (diff) | |
download | spark-508fd371d6dbb826fd8a00787d347235b549e189.tar.gz spark-508fd371d6dbb826fd8a00787d347235b549e189.tar.bz2 spark-508fd371d6dbb826fd8a00787d347235b549e189.zip |
[SPARK-2044] Pluggable interface for shuffles
This is a first cut at moving shuffle logic behind a pluggable interface, as described at https://issues.apache.org/jira/browse/SPARK-2044, to let us more easily experiment with new shuffle implementations. It moves the existing shuffle code to a class HashShuffleManager behind a general ShuffleManager interface.
Two things are still missing to make this complete:
* MapOutputTracker needs to be hidden behind the ShuffleManager interface; this will also require adding methods to ShuffleManager that will let the DAGScheduler interact with it as it does with the MapOutputTracker today
* The code to do map-sides and reduce-side combine in ShuffledRDD, PairRDDFunctions, etc needs to be moved into the ShuffleManager's readers and writers
However, some of these may also be done later after we merge the current interface.
Author: Matei Zaharia <matei@databricks.com>
Closes #1009 from mateiz/pluggable-shuffle and squashes the following commits:
7a09862 [Matei Zaharia] review comments
be33d3f [Matei Zaharia] review comments
1513d4e [Matei Zaharia] Add ASF header
ac56831 [Matei Zaharia] Bug fix and better error message
4f681ba [Matei Zaharia] Move write part of ShuffleMapTask to ShuffleManager
f6f011d [Matei Zaharia] Move hash shuffle reader behind ShuffleManager interface
55c7717 [Matei Zaharia] Changed RDD code to use ShuffleReader
75cc044 [Matei Zaharia] Partial work to move hash shuffle in
Diffstat (limited to 'core')
22 files changed, 459 insertions, 130 deletions
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index e2d2250982..bf3c3a6ceb 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -96,7 +96,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Register a ShuffleDependency for cleanup when it is garbage collected. */ - def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) { + def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) { registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) } diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 2c31cc2021..c8c194a111 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -20,6 +20,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleHandle /** * :: DeveloperApi :: @@ -50,19 +51,24 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { * Represents a dependency on the output of a shuffle stage. * @param rdd the parent RDD * @param partitioner partitioner used to partition the shuffle output - * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to null, + * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None, * the default serializer, as specified by `spark.serializer` config option, will * be used. */ @DeveloperApi -class ShuffleDependency[K, V]( +class ShuffleDependency[K, V, C]( @transient rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, - val serializer: Serializer = null) + val serializer: Option[Serializer] = None, + val keyOrdering: Option[Ordering[K]] = None, + val aggregator: Option[Aggregator[K, V, C]] = None) extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { val shuffleId: Int = rdd.context.newShuffleId() + val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle( + shuffleId, rdd.partitions.size, this) + rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 720151a6b0..8dfa8cc4b5 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -34,6 +34,7 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.ConnectionManager import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -56,7 +57,7 @@ class SparkEnv ( val closureSerializer: Serializer, val cacheManager: CacheManager, val mapOutputTracker: MapOutputTracker, - val shuffleFetcher: ShuffleFetcher, + val shuffleManager: ShuffleManager, val broadcastManager: BroadcastManager, val blockManager: BlockManager, val connectionManager: ConnectionManager, @@ -80,7 +81,7 @@ class SparkEnv ( pythonWorkers.foreach { case(key, worker) => worker.stop() } httpFileServer.stop() mapOutputTracker.stop() - shuffleFetcher.stop() + shuffleManager.stop() broadcastManager.stop() blockManager.stop() blockManager.master.stop() @@ -163,13 +164,20 @@ object SparkEnv extends Logging { def instantiateClass[T](propertyName: String, defaultClassName: String): T = { val name = conf.get(propertyName, defaultClassName) val cls = Class.forName(name, true, Utils.getContextOrSparkClassLoader) - // First try with the constructor that takes SparkConf. If we can't find one, - // use a no-arg constructor instead. + // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just + // SparkConf, then one taking no arguments try { - cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T] + cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE) + .newInstance(conf, new java.lang.Boolean(isDriver)) + .asInstanceOf[T] } catch { case _: NoSuchMethodException => - cls.getConstructor().newInstance().asInstanceOf[T] + try { + cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T] + } catch { + case _: NoSuchMethodException => + cls.getConstructor().newInstance().asInstanceOf[T] + } } } @@ -219,9 +227,6 @@ object SparkEnv extends Logging { val cacheManager = new CacheManager(blockManager) - val shuffleFetcher = instantiateClass[ShuffleFetcher]( - "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher") - val httpFileServer = new HttpFileServer(securityManager) httpFileServer.initialize() conf.set("spark.fileserver.uri", httpFileServer.serverUri) @@ -242,6 +247,9 @@ object SparkEnv extends Logging { "." } + val shuffleManager = instantiateClass[ShuffleManager]( + "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") + // Warn about deprecated spark.cache.class property if (conf.contains("spark.cache.class")) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + @@ -255,7 +263,7 @@ object SparkEnv extends Logging { closureSerializer, cacheManager, mapOutputTracker, - shuffleFetcher, + shuffleManager, broadcastManager, blockManager, connectionManager, diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 9ff76892ae..5951865e56 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -27,6 +27,7 @@ import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap} import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleHandle private[spark] sealed trait CoGroupSplitDep extends Serializable @@ -44,7 +45,7 @@ private[spark] case class NarrowCoGroupSplitDep( } } -private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep +private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) extends Partition with Serializable { @@ -74,10 +75,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: private type CoGroupValue = (Any, Int) // Int is dependency number private type CoGroupCombiner = Seq[CoGroup] - private var serializer: Serializer = null + private var serializer: Option[Serializer] = None + /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): CoGroupedRDD[K] = { - this.serializer = serializer + this.serializer = Option(serializer) this } @@ -88,7 +90,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency[Any, Any](rdd, part, serializer) + new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer) } } } @@ -100,8 +102,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) => // Assume each RDD contributed a single dependency, and get it dependencies(j) match { - case s: ShuffleDependency[_, _] => - new ShuffleCoGroupSplitDep(s.shuffleId) + case s: ShuffleDependency[_, _, _] => + new ShuffleCoGroupSplitDep(s.shuffleHandle) case _ => new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) } @@ -126,11 +128,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]] rddIterators += ((it, depNum)) - case ShuffleCoGroupSplitDep(shuffleId) => + case ShuffleCoGroupSplitDep(handle) => // Read map outputs of shuffle - val fetcher = SparkEnv.get.shuffleFetcher - val ser = Serializer.getSerializer(serializer) - val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser) + val it = SparkEnv.get.shuffleManager + .getReader(handle, split.index, split.index + 1, context) + .read() rddIterators += ((it, depNum)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 802b0bdfb2..bb108ef163 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -42,10 +42,11 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( part: Partitioner) extends RDD[P](prev.context, Nil) { - private var serializer: Serializer = null + private var serializer: Option[Serializer] = None + /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = { - this.serializer = serializer + this.serializer = Option(serializer) this } @@ -60,9 +61,10 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( } override def compute(split: Partition, context: TaskContext): Iterator[P] = { - val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId - val ser = Serializer.getSerializer(serializer) - SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser) + val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, V]] + SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) + .read() + .asInstanceOf[Iterator[P]] } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 9a09c05bbc..ed24ea22a6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -54,10 +54,11 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) { - private var serializer: Serializer = null + private var serializer: Option[Serializer] = None + /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = { - this.serializer = serializer + this.serializer = Option(serializer) this } @@ -79,8 +80,8 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( // Each CoGroupPartition will depend on rdd1 and rdd2 array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => dependencies(j) match { - case s: ShuffleDependency[_, _] => - new ShuffleCoGroupSplitDep(s.shuffleId) + case s: ShuffleDependency[_, _, _] => + new ShuffleCoGroupSplitDep(s.shuffleHandle) case _ => new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) } @@ -93,7 +94,6 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = { val partition = p.asInstanceOf[CoGroupPartition] - val ser = Serializer.getSerializer(serializer) val map = new JHashMap[K, ArrayBuffer[V]] def getSeq(k: K): ArrayBuffer[V] = { val seq = map.get(k) @@ -109,9 +109,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( case NarrowCoGroupSplitDep(rdd, _, itsSplit) => rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op) - case ShuffleCoGroupSplitDep(shuffleId) => - val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index, - context, ser) + case ShuffleCoGroupSplitDep(handle) => + val iter = SparkEnv.get.shuffleManager + .getReader(handle, partition.index, partition.index + 1, context) + .read() iter.foreach(op) } // the first dep is rdd1; add all values to the map diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index e09a4221e8..3c85b5a2ae 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -190,7 +190,7 @@ class DAGScheduler( * The jobId value passed in will be used if the stage doesn't already exist with * a lower jobId (jobId always increases across jobs.) */ - private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], jobId: Int): Stage = { + private def getShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): Stage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => @@ -210,7 +210,7 @@ class DAGScheduler( private def newStage( rdd: RDD[_], numTasks: Int, - shuffleDep: Option[ShuffleDependency[_,_]], + shuffleDep: Option[ShuffleDependency[_, _, _]], jobId: Int, callSite: Option[String] = None) : Stage = @@ -233,7 +233,7 @@ class DAGScheduler( private def newOrUsedStage( rdd: RDD[_], numTasks: Int, - shuffleDep: ShuffleDependency[_,_], + shuffleDep: ShuffleDependency[_, _, _], jobId: Int, callSite: Option[String] = None) : Stage = @@ -269,7 +269,7 @@ class DAGScheduler( // we can't do it in its constructor because # of partitions is unknown for (dep <- r.dependencies) { dep match { - case shufDep: ShuffleDependency[_,_] => + case shufDep: ShuffleDependency[_, _, _] => parents += getShuffleMapStage(shufDep, jobId) case _ => visit(dep.rdd) @@ -290,7 +290,7 @@ class DAGScheduler( if (getCacheLocs(rdd).contains(Nil)) { for (dep <- rdd.dependencies) { dep match { - case shufDep: ShuffleDependency[_,_] => + case shufDep: ShuffleDependency[_, _, _] => val mapStage = getShuffleMapStage(shufDep, stage.jobId) if (!mapStage.isAvailable) { missing += mapStage @@ -1088,7 +1088,7 @@ class DAGScheduler( visitedRdds += rdd for (dep <- rdd.dependencies) { dep match { - case shufDep: ShuffleDependency[_,_] => + case shufDep: ShuffleDependency[_, _, _] => val mapStage = getShuffleMapStage(shufDep, stage.jobId) if (!mapStage.isAvailable) { visitedStages += mapStage diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index ed0f56f1ab..0098b5a59d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -29,6 +29,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.serializer.Serializer import org.apache.spark.storage._ +import org.apache.spark.shuffle.ShuffleWriter private[spark] object ShuffleMapTask { @@ -37,7 +38,7 @@ private[spark] object ShuffleMapTask { // expensive on the master node if it needs to launch thousands of tasks. private val serializedInfoCache = new HashMap[Int, Array[Byte]] - def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { + def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = { synchronized { val old = serializedInfoCache.get(stageId).orNull if (old != null) { @@ -56,12 +57,12 @@ private[spark] object ShuffleMapTask { } } - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = { + def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = { val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) val ser = SparkEnv.get.closureSerializer.newInstance() val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] + val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]] (rdd, dep) } @@ -99,7 +100,7 @@ private[spark] object ShuffleMapTask { private[spark] class ShuffleMapTask( stageId: Int, var rdd: RDD[_], - var dep: ShuffleDependency[_,_], + var dep: ShuffleDependency[_, _, _], _partitionId: Int, @transient private var locs: Seq[TaskLocation]) extends Task[MapStatus](stageId, _partitionId) @@ -141,66 +142,22 @@ private[spark] class ShuffleMapTask( } override def runTask(context: TaskContext): MapStatus = { - val numOutputSplits = dep.partitioner.numPartitions metrics = Some(context.taskMetrics) - - val blockManager = SparkEnv.get.blockManager - val shuffleBlockManager = blockManager.shuffleBlockManager - var shuffle: ShuffleWriterGroup = null - var success = false - + var writer: ShuffleWriter[Any, Any] = null try { - // Obtain all the block writers for shuffle blocks. - val ser = Serializer.getSerializer(dep.serializer) - shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser) - - // Write the map output to its associated buckets. + val manager = SparkEnv.get.shuffleManager + writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) for (elem <- rdd.iterator(split, context)) { - val pair = elem.asInstanceOf[Product2[Any, Any]] - val bucketId = dep.partitioner.getPartition(pair._1) - shuffle.writers(bucketId).write(pair) - } - - // Commit the writes. Get the size of each bucket block (total block size). - var totalBytes = 0L - var totalTime = 0L - val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter => - writer.commit() - writer.close() - val size = writer.fileSegment().length - totalBytes += size - totalTime += writer.timeWriting() - MapOutputTracker.compressSize(size) + writer.write(elem.asInstanceOf[Product2[Any, Any]]) } - - // Update shuffle metrics. - val shuffleMetrics = new ShuffleWriteMetrics - shuffleMetrics.shuffleBytesWritten = totalBytes - shuffleMetrics.shuffleWriteTime = totalTime - metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) - - success = true - new MapStatus(blockManager.blockManagerId, compressedSizes) - } catch { case e: Exception => - // If there is an exception from running the task, revert the partial writes - // and throw the exception upstream to Spark. - if (shuffle != null && shuffle.writers != null) { - for (writer <- shuffle.writers) { - writer.revertPartialWrites() - writer.close() + return writer.stop(success = true).get + } catch { + case e: Exception => + if (writer != null) { + writer.stop(success = false) } - } - throw e + throw e } finally { - // Release the writers back to the shuffle block manager. - if (shuffle != null && shuffle.writers != null) { - try { - shuffle.releaseWriters(success) - } catch { - case e: Exception => logError("Failed to release shuffle writers", e) - } - } - // Execute the callbacks on task completion. context.executeOnCompleteCallbacks() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 5c1fc30e4a..3bf9713f72 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -40,7 +40,7 @@ private[spark] class Stage( val id: Int, val rdd: RDD[_], val numTasks: Int, - val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage + val shuffleDep: Option[ShuffleDependency[_, _, _]], // Output shuffle if stage is a map stage val parents: List[Stage], val jobId: Int, callSite: Option[String]) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 99d305b36a..df59f444b7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -71,7 +71,8 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul val loader = Thread.currentThread.getContextClassLoader taskSetManager.abort("ClassNotFound with classloader: " + loader) case ex: Exception => - taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex)) + logError("Exception while getting task result", ex) + taskSetManager.abort("Exception while getting task result: %s".format(ex)) } } }) diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index ee26970a3d..f2f5cea469 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -52,6 +52,10 @@ object Serializer { def getSerializer(serializer: Serializer): Serializer = { if (serializer == null) SparkEnv.get.serializer else serializer } + + def getSerializer(serializer: Option[Serializer]): Serializer = { + serializer.getOrElse(SparkEnv.get.serializer) + } } diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala index a4f69b6b22..b36c457d6d 100644 --- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala @@ -15,22 +15,16 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.shuffle +import org.apache.spark.{ShuffleDependency, Aggregator, Partitioner} import org.apache.spark.serializer.Serializer -private[spark] abstract class ShuffleFetcher { - - /** - * Fetch the shuffle outputs for a given ShuffleDependency. - * @return An iterator over the elements of the fetched shuffle outputs. - */ - def fetch[T]( - shuffleId: Int, - reduceId: Int, - context: TaskContext, - serializer: Serializer = SparkEnv.get.serializer): Iterator[T] - - /** Stop the fetcher */ - def stop() {} -} +/** + * A basic ShuffleHandle implementation that just captures registerShuffle's parameters. + */ +private[spark] class BaseShuffleHandle[K, V, C]( + shuffleId: Int, + val numMaps: Int, + val dependency: ShuffleDependency[K, V, C]) + extends ShuffleHandle(shuffleId) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala new file mode 100644 index 0000000000..13c7115f88 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +/** + * An opaque handle to a shuffle, used by a ShuffleManager to pass information about it to tasks. + * + * @param shuffleId ID of the shuffle + */ +private[spark] abstract class ShuffleHandle(val shuffleId: Int) extends Serializable {} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala new file mode 100644 index 0000000000..9c859b8b4a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.{TaskContext, ShuffleDependency} + +/** + * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on both the + * driver and executors, based on the spark.shuffle.manager setting. The driver registers shuffles + * with it, and executors (or tasks running locally in the driver) can ask to read and write data. + * + * NOTE: this will be instantiated by SparkEnv so its constructor can take a SparkConf and + * boolean isDriver as parameters. + */ +private[spark] trait ShuffleManager { + /** + * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + */ + def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle + + /** Get a writer for a given partition. Called on executors by map tasks. */ + def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V] + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): ShuffleReader[K, C] + + /** Remove a shuffle's metadata from the ShuffleManager. */ + def unregisterShuffle(shuffleId: Int) + + /** Shut down this ShuffleManager. */ + def stop(): Unit +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala new file mode 100644 index 0000000000..b30e366d06 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +/** + * Obtained inside a reduce task to read combined records from the mappers. + */ +private[spark] trait ShuffleReader[K, C] { + /** Read the combined key-values for this reduce task */ + def read(): Iterator[Product2[K, C]] + + /** Close this reader */ + def stop(): Unit +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala new file mode 100644 index 0000000000..ead3ebd652 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.scheduler.MapStatus + +/** + * Obtained inside a map task to write out records to the shuffle system. + */ +private[spark] trait ShuffleWriter[K, V] { + /** Write a record to this task's output */ + def write(record: Product2[K, V]): Unit + + /** Close this writer, passing along whether the map completed */ + def stop(success: Boolean): Option[MapStatus] +} diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index a67392441e..b05b6ea345 100644 --- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.shuffle.hash import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -24,17 +24,16 @@ import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util.CompletionIterator +import org.apache.spark._ -private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { - - override def fetch[T]( +private[hash] object BlockStoreShuffleFetcher extends Logging { + def fetch[T]( shuffleId: Int, reduceId: Int, context: TaskContext, serializer: Serializer) : Iterator[T] = { - logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) val blockManager = SparkEnv.get.blockManager diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala new file mode 100644 index 0000000000..5b0940ecce --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import org.apache.spark._ +import org.apache.spark.shuffle._ + +/** + * A ShuffleManager using hashing, that creates one output file per reduce partition on each + * mapper (possibly reusing these across waves of tasks). + */ +class HashShuffleManager(conf: SparkConf) extends ShuffleManager { + /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): ShuffleReader[K, C] = { + new HashShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) + : ShuffleWriter[K, V] = { + new HashShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Unit = {} + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = {} +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala new file mode 100644 index 0000000000..f6a790309a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.TaskContext + +class HashShuffleReader[K, C]( + handle: BaseShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + context: TaskContext) + extends ShuffleReader[K, C] +{ + require(endPartition == startPartition + 1, + "Hash shuffle currently only supports fetching one partition") + + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, + Serializer.getSerializer(handle.dependency.serializer)) + } + + /** Close this reader */ + override def stop(): Unit = ??? +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala new file mode 100644 index 0000000000..4c6749098c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter} +import org.apache.spark.{Logging, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark.storage.{BlockObjectWriter} +import org.apache.spark.serializer.Serializer +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.scheduler.MapStatus + +class HashShuffleWriter[K, V]( + handle: BaseShuffleHandle[K, V, _], + mapId: Int, + context: TaskContext) + extends ShuffleWriter[K, V] with Logging { + + private val dep = handle.dependency + private val numOutputSplits = dep.partitioner.numPartitions + private val metrics = context.taskMetrics + private var stopping = false + + private val blockManager = SparkEnv.get.blockManager + private val shuffleBlockManager = blockManager.shuffleBlockManager + private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null)) + private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser) + + /** Write a record to this task's output */ + override def write(record: Product2[K, V]): Unit = { + val pair = record.asInstanceOf[Product2[Any, Any]] + val bucketId = dep.partitioner.getPartition(pair._1) + shuffle.writers(bucketId).write(pair) + } + + /** Close this writer, passing along whether the map completed */ + override def stop(success: Boolean): Option[MapStatus] = { + try { + if (stopping) { + return None + } + stopping = true + if (success) { + try { + return Some(commitWritesAndBuildStatus()) + } catch { + case e: Exception => + revertWrites() + throw e + } + } else { + revertWrites() + return None + } + } finally { + // Release the writers back to the shuffle block manager. + if (shuffle != null && shuffle.writers != null) { + try { + shuffle.releaseWriters(success) + } catch { + case e: Exception => logError("Failed to release shuffle writers", e) + } + } + } + } + + private def commitWritesAndBuildStatus(): MapStatus = { + // Commit the writes. Get the size of each bucket block (total block size). + var totalBytes = 0L + var totalTime = 0L + val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter => + writer.commit() + writer.close() + val size = writer.fileSegment().length + totalBytes += size + totalTime += writer.timeWriting() + MapOutputTracker.compressSize(size) + } + + // Update shuffle metrics. + val shuffleMetrics = new ShuffleWriteMetrics + shuffleMetrics.shuffleBytesWritten = totalBytes + shuffleMetrics.shuffleWriteTime = totalTime + metrics.shuffleWriteMetrics = Some(shuffleMetrics) + + new MapStatus(blockManager.blockManagerId, compressedSizes) + } + + private def revertWrites(): Unit = { + if (shuffle != null && shuffle.writers != null) { + for (writer <- shuffle.writers) { + writer.revertPartialWrites() + writer.close() + } + } + } +} diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index dc2db66df6..13b415cccb 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -201,7 +201,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo def newPairRDD = newRDD.map(_ -> 1) def newShuffleRDD = newPairRDD.reduceByKey(_ + _) def newBroadcast = sc.broadcast(1 to 100) - def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = { + def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = { def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { rdd.dependencies ++ rdd.dependencies.flatMap { dep => getAllDependencies(dep.rdd) @@ -211,8 +211,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo // Get all the shuffle dependencies val shuffleDeps = getAllDependencies(rdd) - .filter(_.isInstanceOf[ShuffleDependency[_, _]]) - .map(_.asInstanceOf[ShuffleDependency[_, _]]) + .filter(_.isInstanceOf[ShuffleDependency[_, _, _]]) + .map(_.asInstanceOf[ShuffleDependency[_, _, _]]) (rdd, shuffleDeps) } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 7b0607dd3e..47112ce66d 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -58,7 +58,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { // default Java serializer cannot handle the non serializable class. val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf)) - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 10) @@ -97,7 +97,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) .setSerializer(new KryoSerializer(conf)) - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => @@ -122,7 +122,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { // NOTE: The default Java serializer should create zero-sized blocks val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => |