aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMatei Zaharia <matei@databricks.com>2014-06-11 20:45:29 -0700
committerReynold Xin <rxin@apache.org>2014-06-11 20:45:29 -0700
commit508fd371d6dbb826fd8a00787d347235b549e189 (patch)
tree2cc77c43ab7b9f30a9b1afc6f55d9438779e46ab /core
parentd9203350b06a9c737421ba244162b1365402c01b (diff)
downloadspark-508fd371d6dbb826fd8a00787d347235b549e189.tar.gz
spark-508fd371d6dbb826fd8a00787d347235b549e189.tar.bz2
spark-508fd371d6dbb826fd8a00787d347235b549e189.zip
[SPARK-2044] Pluggable interface for shuffles
This is a first cut at moving shuffle logic behind a pluggable interface, as described at https://issues.apache.org/jira/browse/SPARK-2044, to let us more easily experiment with new shuffle implementations. It moves the existing shuffle code to a class HashShuffleManager behind a general ShuffleManager interface. Two things are still missing to make this complete: * MapOutputTracker needs to be hidden behind the ShuffleManager interface; this will also require adding methods to ShuffleManager that will let the DAGScheduler interact with it as it does with the MapOutputTracker today * The code to do map-sides and reduce-side combine in ShuffledRDD, PairRDDFunctions, etc needs to be moved into the ShuffleManager's readers and writers However, some of these may also be done later after we merge the current interface. Author: Matei Zaharia <matei@databricks.com> Closes #1009 from mateiz/pluggable-shuffle and squashes the following commits: 7a09862 [Matei Zaharia] review comments be33d3f [Matei Zaharia] review comments 1513d4e [Matei Zaharia] Add ASF header ac56831 [Matei Zaharia] Bug fix and better error message 4f681ba [Matei Zaharia] Move write part of ShuffleMapTask to ShuffleManager f6f011d [Matei Zaharia] Move hash shuffle reader behind ShuffleManager interface 55c7717 [Matei Zaharia] Changed RDD code to use ShuffleReader 75cc044 [Matei Zaharia] Partial work to move hash shuffle in
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/ContextCleaner.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/Dependency.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala73
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/Serializer.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala (renamed from core/src/main/scala/org/apache/spark/ShuffleFetcher.scala)26
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala57
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala31
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala (renamed from core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala)9
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala42
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala111
-rw-r--r--core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala6
22 files changed, 459 insertions, 130 deletions
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index e2d2250982..bf3c3a6ceb 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -96,7 +96,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}
/** Register a ShuffleDependency for cleanup when it is garbage collected. */
- def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) {
+ def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) {
registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
}
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index 2c31cc2021..c8c194a111 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -20,6 +20,7 @@ package org.apache.spark
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.ShuffleHandle
/**
* :: DeveloperApi ::
@@ -50,19 +51,24 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
* Represents a dependency on the output of a shuffle stage.
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
- * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to null,
+ * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
* the default serializer, as specified by `spark.serializer` config option, will
* be used.
*/
@DeveloperApi
-class ShuffleDependency[K, V](
+class ShuffleDependency[K, V, C](
@transient rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
- val serializer: Serializer = null)
+ val serializer: Option[Serializer] = None,
+ val keyOrdering: Option[Ordering[K]] = None,
+ val aggregator: Option[Aggregator[K, V, C]] = None)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
val shuffleId: Int = rdd.context.newShuffleId()
+ val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
+ shuffleId, rdd.partitions.size, this)
+
rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 720151a6b0..8dfa8cc4b5 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -34,6 +34,7 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.ConnectionManager
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage._
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -56,7 +57,7 @@ class SparkEnv (
val closureSerializer: Serializer,
val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker,
- val shuffleFetcher: ShuffleFetcher,
+ val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
@@ -80,7 +81,7 @@ class SparkEnv (
pythonWorkers.foreach { case(key, worker) => worker.stop() }
httpFileServer.stop()
mapOutputTracker.stop()
- shuffleFetcher.stop()
+ shuffleManager.stop()
broadcastManager.stop()
blockManager.stop()
blockManager.master.stop()
@@ -163,13 +164,20 @@ object SparkEnv extends Logging {
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
val name = conf.get(propertyName, defaultClassName)
val cls = Class.forName(name, true, Utils.getContextOrSparkClassLoader)
- // First try with the constructor that takes SparkConf. If we can't find one,
- // use a no-arg constructor instead.
+ // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
+ // SparkConf, then one taking no arguments
try {
- cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
+ cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
+ .newInstance(conf, new java.lang.Boolean(isDriver))
+ .asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
- cls.getConstructor().newInstance().asInstanceOf[T]
+ try {
+ cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
+ } catch {
+ case _: NoSuchMethodException =>
+ cls.getConstructor().newInstance().asInstanceOf[T]
+ }
}
}
@@ -219,9 +227,6 @@ object SparkEnv extends Logging {
val cacheManager = new CacheManager(blockManager)
- val shuffleFetcher = instantiateClass[ShuffleFetcher](
- "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
-
val httpFileServer = new HttpFileServer(securityManager)
httpFileServer.initialize()
conf.set("spark.fileserver.uri", httpFileServer.serverUri)
@@ -242,6 +247,9 @@ object SparkEnv extends Logging {
"."
}
+ val shuffleManager = instantiateClass[ShuffleManager](
+ "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager")
+
// Warn about deprecated spark.cache.class property
if (conf.contains("spark.cache.class")) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
@@ -255,7 +263,7 @@ object SparkEnv extends Logging {
closureSerializer,
cacheManager,
mapOutputTracker,
- shuffleFetcher,
+ shuffleManager,
broadcastManager,
blockManager,
connectionManager,
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 9ff76892ae..5951865e56 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -27,6 +27,7 @@ import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap}
import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.ShuffleHandle
private[spark] sealed trait CoGroupSplitDep extends Serializable
@@ -44,7 +45,7 @@ private[spark] case class NarrowCoGroupSplitDep(
}
}
-private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
+private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep
private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
extends Partition with Serializable {
@@ -74,10 +75,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
private type CoGroupValue = (Any, Int) // Int is dependency number
private type CoGroupCombiner = Seq[CoGroup]
- private var serializer: Serializer = null
+ private var serializer: Option[Serializer] = None
+ /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): CoGroupedRDD[K] = {
- this.serializer = serializer
+ this.serializer = Option(serializer)
this
}
@@ -88,7 +90,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
- new ShuffleDependency[Any, Any](rdd, part, serializer)
+ new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer)
}
}
}
@@ -100,8 +102,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) =>
// Assume each RDD contributed a single dependency, and get it
dependencies(j) match {
- case s: ShuffleDependency[_, _] =>
- new ShuffleCoGroupSplitDep(s.shuffleId)
+ case s: ShuffleDependency[_, _, _] =>
+ new ShuffleCoGroupSplitDep(s.shuffleHandle)
case _ =>
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
}
@@ -126,11 +128,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
rddIterators += ((it, depNum))
- case ShuffleCoGroupSplitDep(shuffleId) =>
+ case ShuffleCoGroupSplitDep(handle) =>
// Read map outputs of shuffle
- val fetcher = SparkEnv.get.shuffleFetcher
- val ser = Serializer.getSerializer(serializer)
- val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser)
+ val it = SparkEnv.get.shuffleManager
+ .getReader(handle, split.index, split.index + 1, context)
+ .read()
rddIterators += ((it, depNum))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index 802b0bdfb2..bb108ef163 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -42,10 +42,11 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
part: Partitioner)
extends RDD[P](prev.context, Nil) {
- private var serializer: Serializer = null
+ private var serializer: Option[Serializer] = None
+ /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = {
- this.serializer = serializer
+ this.serializer = Option(serializer)
this
}
@@ -60,9 +61,10 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
}
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
- val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- val ser = Serializer.getSerializer(serializer)
- SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser)
+ val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, V]]
+ SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
+ .read()
+ .asInstanceOf[Iterator[P]]
}
override def clearDependencies() {
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index 9a09c05bbc..ed24ea22a6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -54,10 +54,11 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
part: Partitioner)
extends RDD[(K, V)](rdd1.context, Nil) {
- private var serializer: Serializer = null
+ private var serializer: Option[Serializer] = None
+ /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = {
- this.serializer = serializer
+ this.serializer = Option(serializer)
this
}
@@ -79,8 +80,8 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
// Each CoGroupPartition will depend on rdd1 and rdd2
array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) =>
dependencies(j) match {
- case s: ShuffleDependency[_, _] =>
- new ShuffleCoGroupSplitDep(s.shuffleId)
+ case s: ShuffleDependency[_, _, _] =>
+ new ShuffleCoGroupSplitDep(s.shuffleHandle)
case _ =>
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
}
@@ -93,7 +94,6 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
- val ser = Serializer.getSerializer(serializer)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
@@ -109,9 +109,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
- case ShuffleCoGroupSplitDep(shuffleId) =>
- val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
- context, ser)
+ case ShuffleCoGroupSplitDep(handle) =>
+ val iter = SparkEnv.get.shuffleManager
+ .getReader(handle, partition.index, partition.index + 1, context)
+ .read()
iter.foreach(op)
}
// the first dep is rdd1; add all values to the map
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index e09a4221e8..3c85b5a2ae 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -190,7 +190,7 @@ class DAGScheduler(
* The jobId value passed in will be used if the stage doesn't already exist with
* a lower jobId (jobId always increases across jobs.)
*/
- private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], jobId: Int): Stage = {
+ private def getShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): Stage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
@@ -210,7 +210,7 @@ class DAGScheduler(
private def newStage(
rdd: RDD[_],
numTasks: Int,
- shuffleDep: Option[ShuffleDependency[_,_]],
+ shuffleDep: Option[ShuffleDependency[_, _, _]],
jobId: Int,
callSite: Option[String] = None)
: Stage =
@@ -233,7 +233,7 @@ class DAGScheduler(
private def newOrUsedStage(
rdd: RDD[_],
numTasks: Int,
- shuffleDep: ShuffleDependency[_,_],
+ shuffleDep: ShuffleDependency[_, _, _],
jobId: Int,
callSite: Option[String] = None)
: Stage =
@@ -269,7 +269,7 @@ class DAGScheduler(
// we can't do it in its constructor because # of partitions is unknown
for (dep <- r.dependencies) {
dep match {
- case shufDep: ShuffleDependency[_,_] =>
+ case shufDep: ShuffleDependency[_, _, _] =>
parents += getShuffleMapStage(shufDep, jobId)
case _ =>
visit(dep.rdd)
@@ -290,7 +290,7 @@ class DAGScheduler(
if (getCacheLocs(rdd).contains(Nil)) {
for (dep <- rdd.dependencies) {
dep match {
- case shufDep: ShuffleDependency[_,_] =>
+ case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getShuffleMapStage(shufDep, stage.jobId)
if (!mapStage.isAvailable) {
missing += mapStage
@@ -1088,7 +1088,7 @@ class DAGScheduler(
visitedRdds += rdd
for (dep <- rdd.dependencies) {
dep match {
- case shufDep: ShuffleDependency[_,_] =>
+ case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getShuffleMapStage(shufDep, stage.jobId)
if (!mapStage.isAvailable) {
visitedStages += mapStage
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index ed0f56f1ab..0098b5a59d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -29,6 +29,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.rdd.{RDD, RDDCheckpointData}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
+import org.apache.spark.shuffle.ShuffleWriter
private[spark] object ShuffleMapTask {
@@ -37,7 +38,7 @@ private[spark] object ShuffleMapTask {
// expensive on the master node if it needs to launch thousands of tasks.
private val serializedInfoCache = new HashMap[Int, Array[Byte]]
- def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
+ def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = {
synchronized {
val old = serializedInfoCache.get(stageId).orNull
if (old != null) {
@@ -56,12 +57,12 @@ private[spark] object ShuffleMapTask {
}
}
- def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = {
+ def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = {
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
- val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
+ val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]]
(rdd, dep)
}
@@ -99,7 +100,7 @@ private[spark] object ShuffleMapTask {
private[spark] class ShuffleMapTask(
stageId: Int,
var rdd: RDD[_],
- var dep: ShuffleDependency[_,_],
+ var dep: ShuffleDependency[_, _, _],
_partitionId: Int,
@transient private var locs: Seq[TaskLocation])
extends Task[MapStatus](stageId, _partitionId)
@@ -141,66 +142,22 @@ private[spark] class ShuffleMapTask(
}
override def runTask(context: TaskContext): MapStatus = {
- val numOutputSplits = dep.partitioner.numPartitions
metrics = Some(context.taskMetrics)
-
- val blockManager = SparkEnv.get.blockManager
- val shuffleBlockManager = blockManager.shuffleBlockManager
- var shuffle: ShuffleWriterGroup = null
- var success = false
-
+ var writer: ShuffleWriter[Any, Any] = null
try {
- // Obtain all the block writers for shuffle blocks.
- val ser = Serializer.getSerializer(dep.serializer)
- shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
-
- // Write the map output to its associated buckets.
+ val manager = SparkEnv.get.shuffleManager
+ writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
for (elem <- rdd.iterator(split, context)) {
- val pair = elem.asInstanceOf[Product2[Any, Any]]
- val bucketId = dep.partitioner.getPartition(pair._1)
- shuffle.writers(bucketId).write(pair)
- }
-
- // Commit the writes. Get the size of each bucket block (total block size).
- var totalBytes = 0L
- var totalTime = 0L
- val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
- writer.commit()
- writer.close()
- val size = writer.fileSegment().length
- totalBytes += size
- totalTime += writer.timeWriting()
- MapOutputTracker.compressSize(size)
+ writer.write(elem.asInstanceOf[Product2[Any, Any]])
}
-
- // Update shuffle metrics.
- val shuffleMetrics = new ShuffleWriteMetrics
- shuffleMetrics.shuffleBytesWritten = totalBytes
- shuffleMetrics.shuffleWriteTime = totalTime
- metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
-
- success = true
- new MapStatus(blockManager.blockManagerId, compressedSizes)
- } catch { case e: Exception =>
- // If there is an exception from running the task, revert the partial writes
- // and throw the exception upstream to Spark.
- if (shuffle != null && shuffle.writers != null) {
- for (writer <- shuffle.writers) {
- writer.revertPartialWrites()
- writer.close()
+ return writer.stop(success = true).get
+ } catch {
+ case e: Exception =>
+ if (writer != null) {
+ writer.stop(success = false)
}
- }
- throw e
+ throw e
} finally {
- // Release the writers back to the shuffle block manager.
- if (shuffle != null && shuffle.writers != null) {
- try {
- shuffle.releaseWriters(success)
- } catch {
- case e: Exception => logError("Failed to release shuffle writers", e)
- }
- }
- // Execute the callbacks on task completion.
context.executeOnCompleteCallbacks()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 5c1fc30e4a..3bf9713f72 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -40,7 +40,7 @@ private[spark] class Stage(
val id: Int,
val rdd: RDD[_],
val numTasks: Int,
- val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
+ val shuffleDep: Option[ShuffleDependency[_, _, _]], // Output shuffle if stage is a map stage
val parents: List[Stage],
val jobId: Int,
callSite: Option[String])
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 99d305b36a..df59f444b7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -71,7 +71,8 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
val loader = Thread.currentThread.getContextClassLoader
taskSetManager.abort("ClassNotFound with classloader: " + loader)
case ex: Exception =>
- taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex))
+ logError("Exception while getting task result", ex)
+ taskSetManager.abort("Exception while getting task result: %s".format(ex))
}
}
})
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index ee26970a3d..f2f5cea469 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -52,6 +52,10 @@ object Serializer {
def getSerializer(serializer: Serializer): Serializer = {
if (serializer == null) SparkEnv.get.serializer else serializer
}
+
+ def getSerializer(serializer: Option[Serializer]): Serializer = {
+ serializer.getOrElse(SparkEnv.get.serializer)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala
index a4f69b6b22..b36c457d6d 100644
--- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala
@@ -15,22 +15,16 @@
* limitations under the License.
*/
-package org.apache.spark
+package org.apache.spark.shuffle
+import org.apache.spark.{ShuffleDependency, Aggregator, Partitioner}
import org.apache.spark.serializer.Serializer
-private[spark] abstract class ShuffleFetcher {
-
- /**
- * Fetch the shuffle outputs for a given ShuffleDependency.
- * @return An iterator over the elements of the fetched shuffle outputs.
- */
- def fetch[T](
- shuffleId: Int,
- reduceId: Int,
- context: TaskContext,
- serializer: Serializer = SparkEnv.get.serializer): Iterator[T]
-
- /** Stop the fetcher */
- def stop() {}
-}
+/**
+ * A basic ShuffleHandle implementation that just captures registerShuffle's parameters.
+ */
+private[spark] class BaseShuffleHandle[K, V, C](
+ shuffleId: Int,
+ val numMaps: Int,
+ val dependency: ShuffleDependency[K, V, C])
+ extends ShuffleHandle(shuffleId)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala
new file mode 100644
index 0000000000..13c7115f88
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+/**
+ * An opaque handle to a shuffle, used by a ShuffleManager to pass information about it to tasks.
+ *
+ * @param shuffleId ID of the shuffle
+ */
+private[spark] abstract class ShuffleHandle(val shuffleId: Int) extends Serializable {}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
new file mode 100644
index 0000000000..9c859b8b4a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import org.apache.spark.{TaskContext, ShuffleDependency}
+
+/**
+ * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on both the
+ * driver and executors, based on the spark.shuffle.manager setting. The driver registers shuffles
+ * with it, and executors (or tasks running locally in the driver) can ask to read and write data.
+ *
+ * NOTE: this will be instantiated by SparkEnv so its constructor can take a SparkConf and
+ * boolean isDriver as parameters.
+ */
+private[spark] trait ShuffleManager {
+ /**
+ * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
+ */
+ def registerShuffle[K, V, C](
+ shuffleId: Int,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle
+
+ /** Get a writer for a given partition. Called on executors by map tasks. */
+ def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V]
+
+ /**
+ * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+ * Called on executors by reduce tasks.
+ */
+ def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext): ShuffleReader[K, C]
+
+ /** Remove a shuffle's metadata from the ShuffleManager. */
+ def unregisterShuffle(shuffleId: Int)
+
+ /** Shut down this ShuffleManager. */
+ def stop(): Unit
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
new file mode 100644
index 0000000000..b30e366d06
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+/**
+ * Obtained inside a reduce task to read combined records from the mappers.
+ */
+private[spark] trait ShuffleReader[K, C] {
+ /** Read the combined key-values for this reduce task */
+ def read(): Iterator[Product2[K, C]]
+
+ /** Close this reader */
+ def stop(): Unit
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
new file mode 100644
index 0000000000..ead3ebd652
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import org.apache.spark.scheduler.MapStatus
+
+/**
+ * Obtained inside a map task to write out records to the shuffle system.
+ */
+private[spark] trait ShuffleWriter[K, V] {
+ /** Write a record to this task's output */
+ def write(record: Product2[K, V]): Unit
+
+ /** Close this writer, passing along whether the map completed */
+ def stop(success: Boolean): Option[MapStatus]
+}
diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index a67392441e..b05b6ea345 100644
--- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark
+package org.apache.spark.shuffle.hash
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@@ -24,17 +24,16 @@ import org.apache.spark.executor.ShuffleReadMetrics
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator
+import org.apache.spark._
-private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
-
- override def fetch[T](
+private[hash] object BlockStoreShuffleFetcher extends Logging {
+ def fetch[T](
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer)
: Iterator[T] =
{
-
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
new file mode 100644
index 0000000000..5b0940ecce
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.hash
+
+import org.apache.spark._
+import org.apache.spark.shuffle._
+
+/**
+ * A ShuffleManager using hashing, that creates one output file per reduce partition on each
+ * mapper (possibly reusing these across waves of tasks).
+ */
+class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
+ /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ new BaseShuffleHandle(shuffleId, numMaps, dependency)
+ }
+
+ /**
+ * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+ * Called on executors by reduce tasks.
+ */
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext): ShuffleReader[K, C] = {
+ new HashShuffleReader(
+ handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
+ }
+
+ /** Get a writer for a given partition. Called on executors by map tasks. */
+ override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
+ : ShuffleWriter[K, V] = {
+ new HashShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
+ }
+
+ /** Remove a shuffle's metadata from the ShuffleManager. */
+ override def unregisterShuffle(shuffleId: Int): Unit = {}
+
+ /** Shut down this ShuffleManager. */
+ override def stop(): Unit = {}
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
new file mode 100644
index 0000000000..f6a790309a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.hash
+
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
+import org.apache.spark.TaskContext
+
+class HashShuffleReader[K, C](
+ handle: BaseShuffleHandle[K, _, C],
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext)
+ extends ShuffleReader[K, C]
+{
+ require(endPartition == startPartition + 1,
+ "Hash shuffle currently only supports fetching one partition")
+
+ /** Read the combined key-values for this reduce task */
+ override def read(): Iterator[Product2[K, C]] = {
+ BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context,
+ Serializer.getSerializer(handle.dependency.serializer))
+ }
+
+ /** Close this reader */
+ override def stop(): Unit = ???
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
new file mode 100644
index 0000000000..4c6749098c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.hash
+
+import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter}
+import org.apache.spark.{Logging, MapOutputTracker, SparkEnv, TaskContext}
+import org.apache.spark.storage.{BlockObjectWriter}
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.scheduler.MapStatus
+
+class HashShuffleWriter[K, V](
+ handle: BaseShuffleHandle[K, V, _],
+ mapId: Int,
+ context: TaskContext)
+ extends ShuffleWriter[K, V] with Logging {
+
+ private val dep = handle.dependency
+ private val numOutputSplits = dep.partitioner.numPartitions
+ private val metrics = context.taskMetrics
+ private var stopping = false
+
+ private val blockManager = SparkEnv.get.blockManager
+ private val shuffleBlockManager = blockManager.shuffleBlockManager
+ private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
+ private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser)
+
+ /** Write a record to this task's output */
+ override def write(record: Product2[K, V]): Unit = {
+ val pair = record.asInstanceOf[Product2[Any, Any]]
+ val bucketId = dep.partitioner.getPartition(pair._1)
+ shuffle.writers(bucketId).write(pair)
+ }
+
+ /** Close this writer, passing along whether the map completed */
+ override def stop(success: Boolean): Option[MapStatus] = {
+ try {
+ if (stopping) {
+ return None
+ }
+ stopping = true
+ if (success) {
+ try {
+ return Some(commitWritesAndBuildStatus())
+ } catch {
+ case e: Exception =>
+ revertWrites()
+ throw e
+ }
+ } else {
+ revertWrites()
+ return None
+ }
+ } finally {
+ // Release the writers back to the shuffle block manager.
+ if (shuffle != null && shuffle.writers != null) {
+ try {
+ shuffle.releaseWriters(success)
+ } catch {
+ case e: Exception => logError("Failed to release shuffle writers", e)
+ }
+ }
+ }
+ }
+
+ private def commitWritesAndBuildStatus(): MapStatus = {
+ // Commit the writes. Get the size of each bucket block (total block size).
+ var totalBytes = 0L
+ var totalTime = 0L
+ val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter =>
+ writer.commit()
+ writer.close()
+ val size = writer.fileSegment().length
+ totalBytes += size
+ totalTime += writer.timeWriting()
+ MapOutputTracker.compressSize(size)
+ }
+
+ // Update shuffle metrics.
+ val shuffleMetrics = new ShuffleWriteMetrics
+ shuffleMetrics.shuffleBytesWritten = totalBytes
+ shuffleMetrics.shuffleWriteTime = totalTime
+ metrics.shuffleWriteMetrics = Some(shuffleMetrics)
+
+ new MapStatus(blockManager.blockManagerId, compressedSizes)
+ }
+
+ private def revertWrites(): Unit = {
+ if (shuffle != null && shuffle.writers != null) {
+ for (writer <- shuffle.writers) {
+ writer.revertPartialWrites()
+ writer.close()
+ }
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index dc2db66df6..13b415cccb 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -201,7 +201,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
def newPairRDD = newRDD.map(_ -> 1)
def newShuffleRDD = newPairRDD.reduceByKey(_ + _)
def newBroadcast = sc.broadcast(1 to 100)
- def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = {
+ def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
getAllDependencies(dep.rdd)
@@ -211,8 +211,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
// Get all the shuffle dependencies
val shuffleDeps = getAllDependencies(rdd)
- .filter(_.isInstanceOf[ShuffleDependency[_, _]])
- .map(_.asInstanceOf[ShuffleDependency[_, _]])
+ .filter(_.isInstanceOf[ShuffleDependency[_, _, _]])
+ .map(_.asInstanceOf[ShuffleDependency[_, _, _]])
(rdd, shuffleDeps)
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 7b0607dd3e..47112ce66d 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -58,7 +58,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf))
- val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 10)
@@ -97,7 +97,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
.setSerializer(new KryoSerializer(conf))
- val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
@@ -122,7 +122,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
// NOTE: The default Java serializer should create zero-sized blocks
val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
- val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>