diff options
Diffstat (limited to 'core')
22 files changed, 459 insertions, 130 deletions
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index e2d2250982..bf3c3a6ceb 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -96,7 +96,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } /** Register a ShuffleDependency for cleanup when it is garbage collected. */ - def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) { + def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) { registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) } diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 2c31cc2021..c8c194a111 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -20,6 +20,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleHandle /** * :: DeveloperApi :: @@ -50,19 +51,24 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { * Represents a dependency on the output of a shuffle stage. * @param rdd the parent RDD * @param partitioner partitioner used to partition the shuffle output - * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to null, + * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None, * the default serializer, as specified by `spark.serializer` config option, will * be used. */ @DeveloperApi -class ShuffleDependency[K, V]( +class ShuffleDependency[K, V, C]( @transient rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, - val serializer: Serializer = null) + val serializer: Option[Serializer] = None, + val keyOrdering: Option[Ordering[K]] = None, + val aggregator: Option[Aggregator[K, V, C]] = None) extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { val shuffleId: Int = rdd.context.newShuffleId() + val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle( + shuffleId, rdd.partitions.size, this) + rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 720151a6b0..8dfa8cc4b5 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -34,6 +34,7 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.ConnectionManager import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -56,7 +57,7 @@ class SparkEnv ( val closureSerializer: Serializer, val cacheManager: CacheManager, val mapOutputTracker: MapOutputTracker, - val shuffleFetcher: ShuffleFetcher, + val shuffleManager: ShuffleManager, val broadcastManager: BroadcastManager, val blockManager: BlockManager, val connectionManager: ConnectionManager, @@ -80,7 +81,7 @@ class SparkEnv ( pythonWorkers.foreach { case(key, worker) => worker.stop() } httpFileServer.stop() mapOutputTracker.stop() - shuffleFetcher.stop() + shuffleManager.stop() broadcastManager.stop() blockManager.stop() blockManager.master.stop() @@ -163,13 +164,20 @@ object SparkEnv extends Logging { def instantiateClass[T](propertyName: String, defaultClassName: String): T = { val name = conf.get(propertyName, defaultClassName) val cls = Class.forName(name, true, Utils.getContextOrSparkClassLoader) - // First try with the constructor that takes SparkConf. If we can't find one, - // use a no-arg constructor instead. + // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just + // SparkConf, then one taking no arguments try { - cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T] + cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE) + .newInstance(conf, new java.lang.Boolean(isDriver)) + .asInstanceOf[T] } catch { case _: NoSuchMethodException => - cls.getConstructor().newInstance().asInstanceOf[T] + try { + cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T] + } catch { + case _: NoSuchMethodException => + cls.getConstructor().newInstance().asInstanceOf[T] + } } } @@ -219,9 +227,6 @@ object SparkEnv extends Logging { val cacheManager = new CacheManager(blockManager) - val shuffleFetcher = instantiateClass[ShuffleFetcher]( - "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher") - val httpFileServer = new HttpFileServer(securityManager) httpFileServer.initialize() conf.set("spark.fileserver.uri", httpFileServer.serverUri) @@ -242,6 +247,9 @@ object SparkEnv extends Logging { "." } + val shuffleManager = instantiateClass[ShuffleManager]( + "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager") + // Warn about deprecated spark.cache.class property if (conf.contains("spark.cache.class")) { logWarning("The spark.cache.class property is no longer being used! Specify storage " + @@ -255,7 +263,7 @@ object SparkEnv extends Logging { closureSerializer, cacheManager, mapOutputTracker, - shuffleFetcher, + shuffleManager, broadcastManager, blockManager, connectionManager, diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 9ff76892ae..5951865e56 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -27,6 +27,7 @@ import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap} import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.ShuffleHandle private[spark] sealed trait CoGroupSplitDep extends Serializable @@ -44,7 +45,7 @@ private[spark] case class NarrowCoGroupSplitDep( } } -private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep +private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]) extends Partition with Serializable { @@ -74,10 +75,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: private type CoGroupValue = (Any, Int) // Int is dependency number private type CoGroupCombiner = Seq[CoGroup] - private var serializer: Serializer = null + private var serializer: Option[Serializer] = None + /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): CoGroupedRDD[K] = { - this.serializer = serializer + this.serializer = Option(serializer) this } @@ -88,7 +90,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency[Any, Any](rdd, part, serializer) + new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer) } } } @@ -100,8 +102,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) => // Assume each RDD contributed a single dependency, and get it dependencies(j) match { - case s: ShuffleDependency[_, _] => - new ShuffleCoGroupSplitDep(s.shuffleId) + case s: ShuffleDependency[_, _, _] => + new ShuffleCoGroupSplitDep(s.shuffleHandle) case _ => new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) } @@ -126,11 +128,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]] rddIterators += ((it, depNum)) - case ShuffleCoGroupSplitDep(shuffleId) => + case ShuffleCoGroupSplitDep(handle) => // Read map outputs of shuffle - val fetcher = SparkEnv.get.shuffleFetcher - val ser = Serializer.getSerializer(serializer) - val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser) + val it = SparkEnv.get.shuffleManager + .getReader(handle, split.index, split.index + 1, context) + .read() rddIterators += ((it, depNum)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 802b0bdfb2..bb108ef163 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -42,10 +42,11 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( part: Partitioner) extends RDD[P](prev.context, Nil) { - private var serializer: Serializer = null + private var serializer: Option[Serializer] = None + /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = { - this.serializer = serializer + this.serializer = Option(serializer) this } @@ -60,9 +61,10 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( } override def compute(split: Partition, context: TaskContext): Iterator[P] = { - val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId - val ser = Serializer.getSerializer(serializer) - SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser) + val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, V]] + SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context) + .read() + .asInstanceOf[Iterator[P]] } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 9a09c05bbc..ed24ea22a6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -54,10 +54,11 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( part: Partitioner) extends RDD[(K, V)](rdd1.context, Nil) { - private var serializer: Serializer = null + private var serializer: Option[Serializer] = None + /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */ def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = { - this.serializer = serializer + this.serializer = Option(serializer) this } @@ -79,8 +80,8 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( // Each CoGroupPartition will depend on rdd1 and rdd2 array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) => dependencies(j) match { - case s: ShuffleDependency[_, _] => - new ShuffleCoGroupSplitDep(s.shuffleId) + case s: ShuffleDependency[_, _, _] => + new ShuffleCoGroupSplitDep(s.shuffleHandle) case _ => new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)) } @@ -93,7 +94,6 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = { val partition = p.asInstanceOf[CoGroupPartition] - val ser = Serializer.getSerializer(serializer) val map = new JHashMap[K, ArrayBuffer[V]] def getSeq(k: K): ArrayBuffer[V] = { val seq = map.get(k) @@ -109,9 +109,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( case NarrowCoGroupSplitDep(rdd, _, itsSplit) => rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op) - case ShuffleCoGroupSplitDep(shuffleId) => - val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index, - context, ser) + case ShuffleCoGroupSplitDep(handle) => + val iter = SparkEnv.get.shuffleManager + .getReader(handle, partition.index, partition.index + 1, context) + .read() iter.foreach(op) } // the first dep is rdd1; add all values to the map diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index e09a4221e8..3c85b5a2ae 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -190,7 +190,7 @@ class DAGScheduler( * The jobId value passed in will be used if the stage doesn't already exist with * a lower jobId (jobId always increases across jobs.) */ - private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], jobId: Int): Stage = { + private def getShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): Stage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => @@ -210,7 +210,7 @@ class DAGScheduler( private def newStage( rdd: RDD[_], numTasks: Int, - shuffleDep: Option[ShuffleDependency[_,_]], + shuffleDep: Option[ShuffleDependency[_, _, _]], jobId: Int, callSite: Option[String] = None) : Stage = @@ -233,7 +233,7 @@ class DAGScheduler( private def newOrUsedStage( rdd: RDD[_], numTasks: Int, - shuffleDep: ShuffleDependency[_,_], + shuffleDep: ShuffleDependency[_, _, _], jobId: Int, callSite: Option[String] = None) : Stage = @@ -269,7 +269,7 @@ class DAGScheduler( // we can't do it in its constructor because # of partitions is unknown for (dep <- r.dependencies) { dep match { - case shufDep: ShuffleDependency[_,_] => + case shufDep: ShuffleDependency[_, _, _] => parents += getShuffleMapStage(shufDep, jobId) case _ => visit(dep.rdd) @@ -290,7 +290,7 @@ class DAGScheduler( if (getCacheLocs(rdd).contains(Nil)) { for (dep <- rdd.dependencies) { dep match { - case shufDep: ShuffleDependency[_,_] => + case shufDep: ShuffleDependency[_, _, _] => val mapStage = getShuffleMapStage(shufDep, stage.jobId) if (!mapStage.isAvailable) { missing += mapStage @@ -1088,7 +1088,7 @@ class DAGScheduler( visitedRdds += rdd for (dep <- rdd.dependencies) { dep match { - case shufDep: ShuffleDependency[_,_] => + case shufDep: ShuffleDependency[_, _, _] => val mapStage = getShuffleMapStage(shufDep, stage.jobId) if (!mapStage.isAvailable) { visitedStages += mapStage diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index ed0f56f1ab..0098b5a59d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -29,6 +29,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.serializer.Serializer import org.apache.spark.storage._ +import org.apache.spark.shuffle.ShuffleWriter private[spark] object ShuffleMapTask { @@ -37,7 +38,7 @@ private[spark] object ShuffleMapTask { // expensive on the master node if it needs to launch thousands of tasks. private val serializedInfoCache = new HashMap[Int, Array[Byte]] - def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { + def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = { synchronized { val old = serializedInfoCache.get(stageId).orNull if (old != null) { @@ -56,12 +57,12 @@ private[spark] object ShuffleMapTask { } } - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = { + def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = { val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) val ser = SparkEnv.get.closureSerializer.newInstance() val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] + val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]] (rdd, dep) } @@ -99,7 +100,7 @@ private[spark] object ShuffleMapTask { private[spark] class ShuffleMapTask( stageId: Int, var rdd: RDD[_], - var dep: ShuffleDependency[_,_], + var dep: ShuffleDependency[_, _, _], _partitionId: Int, @transient private var locs: Seq[TaskLocation]) extends Task[MapStatus](stageId, _partitionId) @@ -141,66 +142,22 @@ private[spark] class ShuffleMapTask( } override def runTask(context: TaskContext): MapStatus = { - val numOutputSplits = dep.partitioner.numPartitions metrics = Some(context.taskMetrics) - - val blockManager = SparkEnv.get.blockManager - val shuffleBlockManager = blockManager.shuffleBlockManager - var shuffle: ShuffleWriterGroup = null - var success = false - + var writer: ShuffleWriter[Any, Any] = null try { - // Obtain all the block writers for shuffle blocks. - val ser = Serializer.getSerializer(dep.serializer) - shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser) - - // Write the map output to its associated buckets. + val manager = SparkEnv.get.shuffleManager + writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) for (elem <- rdd.iterator(split, context)) { - val pair = elem.asInstanceOf[Product2[Any, Any]] - val bucketId = dep.partitioner.getPartition(pair._1) - shuffle.writers(bucketId).write(pair) - } - - // Commit the writes. Get the size of each bucket block (total block size). - var totalBytes = 0L - var totalTime = 0L - val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter => - writer.commit() - writer.close() - val size = writer.fileSegment().length - totalBytes += size - totalTime += writer.timeWriting() - MapOutputTracker.compressSize(size) + writer.write(elem.asInstanceOf[Product2[Any, Any]]) } - - // Update shuffle metrics. - val shuffleMetrics = new ShuffleWriteMetrics - shuffleMetrics.shuffleBytesWritten = totalBytes - shuffleMetrics.shuffleWriteTime = totalTime - metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) - - success = true - new MapStatus(blockManager.blockManagerId, compressedSizes) - } catch { case e: Exception => - // If there is an exception from running the task, revert the partial writes - // and throw the exception upstream to Spark. - if (shuffle != null && shuffle.writers != null) { - for (writer <- shuffle.writers) { - writer.revertPartialWrites() - writer.close() + return writer.stop(success = true).get + } catch { + case e: Exception => + if (writer != null) { + writer.stop(success = false) } - } - throw e + throw e } finally { - // Release the writers back to the shuffle block manager. - if (shuffle != null && shuffle.writers != null) { - try { - shuffle.releaseWriters(success) - } catch { - case e: Exception => logError("Failed to release shuffle writers", e) - } - } - // Execute the callbacks on task completion. context.executeOnCompleteCallbacks() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 5c1fc30e4a..3bf9713f72 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -40,7 +40,7 @@ private[spark] class Stage( val id: Int, val rdd: RDD[_], val numTasks: Int, - val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage + val shuffleDep: Option[ShuffleDependency[_, _, _]], // Output shuffle if stage is a map stage val parents: List[Stage], val jobId: Int, callSite: Option[String]) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 99d305b36a..df59f444b7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -71,7 +71,8 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul val loader = Thread.currentThread.getContextClassLoader taskSetManager.abort("ClassNotFound with classloader: " + loader) case ex: Exception => - taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex)) + logError("Exception while getting task result", ex) + taskSetManager.abort("Exception while getting task result: %s".format(ex)) } } }) diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index ee26970a3d..f2f5cea469 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -52,6 +52,10 @@ object Serializer { def getSerializer(serializer: Serializer): Serializer = { if (serializer == null) SparkEnv.get.serializer else serializer } + + def getSerializer(serializer: Option[Serializer]): Serializer = { + serializer.getOrElse(SparkEnv.get.serializer) + } } diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala index a4f69b6b22..b36c457d6d 100644 --- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala @@ -15,22 +15,16 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.shuffle +import org.apache.spark.{ShuffleDependency, Aggregator, Partitioner} import org.apache.spark.serializer.Serializer -private[spark] abstract class ShuffleFetcher { - - /** - * Fetch the shuffle outputs for a given ShuffleDependency. - * @return An iterator over the elements of the fetched shuffle outputs. - */ - def fetch[T]( - shuffleId: Int, - reduceId: Int, - context: TaskContext, - serializer: Serializer = SparkEnv.get.serializer): Iterator[T] - - /** Stop the fetcher */ - def stop() {} -} +/** + * A basic ShuffleHandle implementation that just captures registerShuffle's parameters. + */ +private[spark] class BaseShuffleHandle[K, V, C]( + shuffleId: Int, + val numMaps: Int, + val dependency: ShuffleDependency[K, V, C]) + extends ShuffleHandle(shuffleId) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala new file mode 100644 index 0000000000..13c7115f88 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +/** + * An opaque handle to a shuffle, used by a ShuffleManager to pass information about it to tasks. + * + * @param shuffleId ID of the shuffle + */ +private[spark] abstract class ShuffleHandle(val shuffleId: Int) extends Serializable {} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala new file mode 100644 index 0000000000..9c859b8b4a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.{TaskContext, ShuffleDependency} + +/** + * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on both the + * driver and executors, based on the spark.shuffle.manager setting. The driver registers shuffles + * with it, and executors (or tasks running locally in the driver) can ask to read and write data. + * + * NOTE: this will be instantiated by SparkEnv so its constructor can take a SparkConf and + * boolean isDriver as parameters. + */ +private[spark] trait ShuffleManager { + /** + * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + */ + def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle + + /** Get a writer for a given partition. Called on executors by map tasks. */ + def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V] + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): ShuffleReader[K, C] + + /** Remove a shuffle's metadata from the ShuffleManager. */ + def unregisterShuffle(shuffleId: Int) + + /** Shut down this ShuffleManager. */ + def stop(): Unit +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala new file mode 100644 index 0000000000..b30e366d06 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +/** + * Obtained inside a reduce task to read combined records from the mappers. + */ +private[spark] trait ShuffleReader[K, C] { + /** Read the combined key-values for this reduce task */ + def read(): Iterator[Product2[K, C]] + + /** Close this reader */ + def stop(): Unit +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala new file mode 100644 index 0000000000..ead3ebd652 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.scheduler.MapStatus + +/** + * Obtained inside a map task to write out records to the shuffle system. + */ +private[spark] trait ShuffleWriter[K, V] { + /** Write a record to this task's output */ + def write(record: Product2[K, V]): Unit + + /** Close this writer, passing along whether the map completed */ + def stop(success: Boolean): Option[MapStatus] +} diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index a67392441e..b05b6ea345 100644 --- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.shuffle.hash import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -24,17 +24,16 @@ import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util.CompletionIterator +import org.apache.spark._ -private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { - - override def fetch[T]( +private[hash] object BlockStoreShuffleFetcher extends Logging { + def fetch[T]( shuffleId: Int, reduceId: Int, context: TaskContext, serializer: Serializer) : Iterator[T] = { - logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) val blockManager = SparkEnv.get.blockManager diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala new file mode 100644 index 0000000000..5b0940ecce --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import org.apache.spark._ +import org.apache.spark.shuffle._ + +/** + * A ShuffleManager using hashing, that creates one output file per reduce partition on each + * mapper (possibly reusing these across waves of tasks). + */ +class HashShuffleManager(conf: SparkConf) extends ShuffleManager { + /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ + override def registerShuffle[K, V, C]( + shuffleId: Int, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). + * Called on executors by reduce tasks. + */ + override def getReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext): ShuffleReader[K, C] = { + new HashShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) + } + + /** Get a writer for a given partition. Called on executors by map tasks. */ + override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) + : ShuffleWriter[K, V] = { + new HashShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) + } + + /** Remove a shuffle's metadata from the ShuffleManager. */ + override def unregisterShuffle(shuffleId: Int): Unit = {} + + /** Shut down this ShuffleManager. */ + override def stop(): Unit = {} +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala new file mode 100644 index 0000000000..f6a790309a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import org.apache.spark.serializer.Serializer +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.TaskContext + +class HashShuffleReader[K, C]( + handle: BaseShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + context: TaskContext) + extends ShuffleReader[K, C] +{ + require(endPartition == startPartition + 1, + "Hash shuffle currently only supports fetching one partition") + + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, + Serializer.getSerializer(handle.dependency.serializer)) + } + + /** Close this reader */ + override def stop(): Unit = ??? +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala new file mode 100644 index 0000000000..4c6749098c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter} +import org.apache.spark.{Logging, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark.storage.{BlockObjectWriter} +import org.apache.spark.serializer.Serializer +import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.scheduler.MapStatus + +class HashShuffleWriter[K, V]( + handle: BaseShuffleHandle[K, V, _], + mapId: Int, + context: TaskContext) + extends ShuffleWriter[K, V] with Logging { + + private val dep = handle.dependency + private val numOutputSplits = dep.partitioner.numPartitions + private val metrics = context.taskMetrics + private var stopping = false + + private val blockManager = SparkEnv.get.blockManager + private val shuffleBlockManager = blockManager.shuffleBlockManager + private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null)) + private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser) + + /** Write a record to this task's output */ + override def write(record: Product2[K, V]): Unit = { + val pair = record.asInstanceOf[Product2[Any, Any]] + val bucketId = dep.partitioner.getPartition(pair._1) + shuffle.writers(bucketId).write(pair) + } + + /** Close this writer, passing along whether the map completed */ + override def stop(success: Boolean): Option[MapStatus] = { + try { + if (stopping) { + return None + } + stopping = true + if (success) { + try { + return Some(commitWritesAndBuildStatus()) + } catch { + case e: Exception => + revertWrites() + throw e + } + } else { + revertWrites() + return None + } + } finally { + // Release the writers back to the shuffle block manager. + if (shuffle != null && shuffle.writers != null) { + try { + shuffle.releaseWriters(success) + } catch { + case e: Exception => logError("Failed to release shuffle writers", e) + } + } + } + } + + private def commitWritesAndBuildStatus(): MapStatus = { + // Commit the writes. Get the size of each bucket block (total block size). + var totalBytes = 0L + var totalTime = 0L + val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter => + writer.commit() + writer.close() + val size = writer.fileSegment().length + totalBytes += size + totalTime += writer.timeWriting() + MapOutputTracker.compressSize(size) + } + + // Update shuffle metrics. + val shuffleMetrics = new ShuffleWriteMetrics + shuffleMetrics.shuffleBytesWritten = totalBytes + shuffleMetrics.shuffleWriteTime = totalTime + metrics.shuffleWriteMetrics = Some(shuffleMetrics) + + new MapStatus(blockManager.blockManagerId, compressedSizes) + } + + private def revertWrites(): Unit = { + if (shuffle != null && shuffle.writers != null) { + for (writer <- shuffle.writers) { + writer.revertPartialWrites() + writer.close() + } + } + } +} diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index dc2db66df6..13b415cccb 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -201,7 +201,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo def newPairRDD = newRDD.map(_ -> 1) def newShuffleRDD = newPairRDD.reduceByKey(_ + _) def newBroadcast = sc.broadcast(1 to 100) - def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = { + def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = { def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { rdd.dependencies ++ rdd.dependencies.flatMap { dep => getAllDependencies(dep.rdd) @@ -211,8 +211,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo // Get all the shuffle dependencies val shuffleDeps = getAllDependencies(rdd) - .filter(_.isInstanceOf[ShuffleDependency[_, _]]) - .map(_.asInstanceOf[ShuffleDependency[_, _]]) + .filter(_.isInstanceOf[ShuffleDependency[_, _, _]]) + .map(_.asInstanceOf[ShuffleDependency[_, _, _]]) (rdd, shuffleDeps) } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 7b0607dd3e..47112ce66d 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -58,7 +58,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { // default Java serializer cannot handle the non serializable class. val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)]( b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf)) - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 10) @@ -97,7 +97,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) .setSerializer(new KryoSerializer(conf)) - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => @@ -122,7 +122,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { // NOTE: The default Java serializer should create zero-sized blocks val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10)) - val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => |