aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/ContextCleaner.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/Dependency.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala73
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/Serializer.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala (renamed from core/src/main/scala/org/apache/spark/ShuffleFetcher.scala)26
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala57
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala31
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala (renamed from core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala)9
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala42
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala111
-rw-r--r--core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala6
22 files changed, 459 insertions, 130 deletions
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index e2d2250982..bf3c3a6ceb 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -96,7 +96,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}
/** Register a ShuffleDependency for cleanup when it is garbage collected. */
- def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) {
+ def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) {
registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
}
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index 2c31cc2021..c8c194a111 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -20,6 +20,7 @@ package org.apache.spark
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.ShuffleHandle
/**
* :: DeveloperApi ::
@@ -50,19 +51,24 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
* Represents a dependency on the output of a shuffle stage.
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
- * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to null,
+ * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
* the default serializer, as specified by `spark.serializer` config option, will
* be used.
*/
@DeveloperApi
-class ShuffleDependency[K, V](
+class ShuffleDependency[K, V, C](
@transient rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
- val serializer: Serializer = null)
+ val serializer: Option[Serializer] = None,
+ val keyOrdering: Option[Ordering[K]] = None,
+ val aggregator: Option[Aggregator[K, V, C]] = None)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
val shuffleId: Int = rdd.context.newShuffleId()
+ val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
+ shuffleId, rdd.partitions.size, this)
+
rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 720151a6b0..8dfa8cc4b5 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -34,6 +34,7 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.ConnectionManager
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage._
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -56,7 +57,7 @@ class SparkEnv (
val closureSerializer: Serializer,
val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker,
- val shuffleFetcher: ShuffleFetcher,
+ val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
@@ -80,7 +81,7 @@ class SparkEnv (
pythonWorkers.foreach { case(key, worker) => worker.stop() }
httpFileServer.stop()
mapOutputTracker.stop()
- shuffleFetcher.stop()
+ shuffleManager.stop()
broadcastManager.stop()
blockManager.stop()
blockManager.master.stop()
@@ -163,13 +164,20 @@ object SparkEnv extends Logging {
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
val name = conf.get(propertyName, defaultClassName)
val cls = Class.forName(name, true, Utils.getContextOrSparkClassLoader)
- // First try with the constructor that takes SparkConf. If we can't find one,
- // use a no-arg constructor instead.
+ // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
+ // SparkConf, then one taking no arguments
try {
- cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
+ cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
+ .newInstance(conf, new java.lang.Boolean(isDriver))
+ .asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
- cls.getConstructor().newInstance().asInstanceOf[T]
+ try {
+ cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
+ } catch {
+ case _: NoSuchMethodException =>
+ cls.getConstructor().newInstance().asInstanceOf[T]
+ }
}
}
@@ -219,9 +227,6 @@ object SparkEnv extends Logging {
val cacheManager = new CacheManager(blockManager)
- val shuffleFetcher = instantiateClass[ShuffleFetcher](
- "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
-
val httpFileServer = new HttpFileServer(securityManager)
httpFileServer.initialize()
conf.set("spark.fileserver.uri", httpFileServer.serverUri)
@@ -242,6 +247,9 @@ object SparkEnv extends Logging {
"."
}
+ val shuffleManager = instantiateClass[ShuffleManager](
+ "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager")
+
// Warn about deprecated spark.cache.class property
if (conf.contains("spark.cache.class")) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
@@ -255,7 +263,7 @@ object SparkEnv extends Logging {
closureSerializer,
cacheManager,
mapOutputTracker,
- shuffleFetcher,
+ shuffleManager,
broadcastManager,
blockManager,
connectionManager,
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 9ff76892ae..5951865e56 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -27,6 +27,7 @@ import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap}
import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.ShuffleHandle
private[spark] sealed trait CoGroupSplitDep extends Serializable
@@ -44,7 +45,7 @@ private[spark] case class NarrowCoGroupSplitDep(
}
}
-private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
+private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep
private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
extends Partition with Serializable {
@@ -74,10 +75,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
private type CoGroupValue = (Any, Int) // Int is dependency number
private type CoGroupCombiner = Seq[CoGroup]
- private var serializer: Serializer = null
+ private var serializer: Option[Serializer] = None
+ /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): CoGroupedRDD[K] = {
- this.serializer = serializer
+ this.serializer = Option(serializer)
this
}
@@ -88,7 +90,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
- new ShuffleDependency[Any, Any](rdd, part, serializer)
+ new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer)
}
}
}
@@ -100,8 +102,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) =>
// Assume each RDD contributed a single dependency, and get it
dependencies(j) match {
- case s: ShuffleDependency[_, _] =>
- new ShuffleCoGroupSplitDep(s.shuffleId)
+ case s: ShuffleDependency[_, _, _] =>
+ new ShuffleCoGroupSplitDep(s.shuffleHandle)
case _ =>
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
}
@@ -126,11 +128,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
rddIterators += ((it, depNum))
- case ShuffleCoGroupSplitDep(shuffleId) =>
+ case ShuffleCoGroupSplitDep(handle) =>
// Read map outputs of shuffle
- val fetcher = SparkEnv.get.shuffleFetcher
- val ser = Serializer.getSerializer(serializer)
- val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser)
+ val it = SparkEnv.get.shuffleManager
+ .getReader(handle, split.index, split.index + 1, context)
+ .read()
rddIterators += ((it, depNum))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index 802b0bdfb2..bb108ef163 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -42,10 +42,11 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
part: Partitioner)
extends RDD[P](prev.context, Nil) {
- private var serializer: Serializer = null
+ private var serializer: Option[Serializer] = None
+ /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = {
- this.serializer = serializer
+ this.serializer = Option(serializer)
this
}
@@ -60,9 +61,10 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
}
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
- val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- val ser = Serializer.getSerializer(serializer)
- SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser)
+ val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, V]]
+ SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
+ .read()
+ .asInstanceOf[Iterator[P]]
}
override def clearDependencies() {
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index 9a09c05bbc..ed24ea22a6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -54,10 +54,11 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
part: Partitioner)
extends RDD[(K, V)](rdd1.context, Nil) {
- private var serializer: Serializer = null
+ private var serializer: Option[Serializer] = None
+ /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = {
- this.serializer = serializer
+ this.serializer = Option(serializer)
this
}
@@ -79,8 +80,8 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
// Each CoGroupPartition will depend on rdd1 and rdd2
array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) =>
dependencies(j) match {
- case s: ShuffleDependency[_, _] =>
- new ShuffleCoGroupSplitDep(s.shuffleId)
+ case s: ShuffleDependency[_, _, _] =>
+ new ShuffleCoGroupSplitDep(s.shuffleHandle)
case _ =>
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
}
@@ -93,7 +94,6 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
- val ser = Serializer.getSerializer(serializer)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
@@ -109,9 +109,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
- case ShuffleCoGroupSplitDep(shuffleId) =>
- val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
- context, ser)
+ case ShuffleCoGroupSplitDep(handle) =>
+ val iter = SparkEnv.get.shuffleManager
+ .getReader(handle, partition.index, partition.index + 1, context)
+ .read()
iter.foreach(op)
}
// the first dep is rdd1; add all values to the map
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index e09a4221e8..3c85b5a2ae 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -190,7 +190,7 @@ class DAGScheduler(
* The jobId value passed in will be used if the stage doesn't already exist with
* a lower jobId (jobId always increases across jobs.)
*/
- private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], jobId: Int): Stage = {
+ private def getShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): Stage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
@@ -210,7 +210,7 @@ class DAGScheduler(
private def newStage(
rdd: RDD[_],
numTasks: Int,
- shuffleDep: Option[ShuffleDependency[_,_]],
+ shuffleDep: Option[ShuffleDependency[_, _, _]],
jobId: Int,
callSite: Option[String] = None)
: Stage =
@@ -233,7 +233,7 @@ class DAGScheduler(
private def newOrUsedStage(
rdd: RDD[_],
numTasks: Int,
- shuffleDep: ShuffleDependency[_,_],
+ shuffleDep: ShuffleDependency[_, _, _],
jobId: Int,
callSite: Option[String] = None)
: Stage =
@@ -269,7 +269,7 @@ class DAGScheduler(
// we can't do it in its constructor because # of partitions is unknown
for (dep <- r.dependencies) {
dep match {
- case shufDep: ShuffleDependency[_,_] =>
+ case shufDep: ShuffleDependency[_, _, _] =>
parents += getShuffleMapStage(shufDep, jobId)
case _ =>
visit(dep.rdd)
@@ -290,7 +290,7 @@ class DAGScheduler(
if (getCacheLocs(rdd).contains(Nil)) {
for (dep <- rdd.dependencies) {
dep match {
- case shufDep: ShuffleDependency[_,_] =>
+ case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getShuffleMapStage(shufDep, stage.jobId)
if (!mapStage.isAvailable) {
missing += mapStage
@@ -1088,7 +1088,7 @@ class DAGScheduler(
visitedRdds += rdd
for (dep <- rdd.dependencies) {
dep match {
- case shufDep: ShuffleDependency[_,_] =>
+ case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getShuffleMapStage(shufDep, stage.jobId)
if (!mapStage.isAvailable) {
visitedStages += mapStage
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index ed0f56f1ab..0098b5a59d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -29,6 +29,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.rdd.{RDD, RDDCheckpointData}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
+import org.apache.spark.shuffle.ShuffleWriter
private[spark] object ShuffleMapTask {
@@ -37,7 +38,7 @@ private[spark] object ShuffleMapTask {
// expensive on the master node if it needs to launch thousands of tasks.
private val serializedInfoCache = new HashMap[Int, Array[Byte]]
- def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
+ def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = {
synchronized {
val old = serializedInfoCache.get(stageId).orNull
if (old != null) {
@@ -56,12 +57,12 @@ private[spark] object ShuffleMapTask {
}
}
- def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = {
+ def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = {
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
- val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
+ val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]]
(rdd, dep)
}
@@ -99,7 +100,7 @@ private[spark] object ShuffleMapTask {
private[spark] class ShuffleMapTask(
stageId: Int,
var rdd: RDD[_],
- var dep: ShuffleDependency[_,_],
+ var dep: ShuffleDependency[_, _, _],
_partitionId: Int,
@transient private var locs: Seq[TaskLocation])
extends Task[MapStatus](stageId, _partitionId)
@@ -141,66 +142,22 @@ private[spark] class ShuffleMapTask(
}
override def runTask(context: TaskContext): MapStatus = {
- val numOutputSplits = dep.partitioner.numPartitions
metrics = Some(context.taskMetrics)
-
- val blockManager = SparkEnv.get.blockManager
- val shuffleBlockManager = blockManager.shuffleBlockManager
- var shuffle: ShuffleWriterGroup = null
- var success = false
-
+ var writer: ShuffleWriter[Any, Any] = null
try {
- // Obtain all the block writers for shuffle blocks.
- val ser = Serializer.getSerializer(dep.serializer)
- shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
-
- // Write the map output to its associated buckets.
+ val manager = SparkEnv.get.shuffleManager
+ writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
for (elem <- rdd.iterator(split, context)) {
- val pair = elem.asInstanceOf[Product2[Any, Any]]
- val bucketId = dep.partitioner.getPartition(pair._1)
- shuffle.writers(bucketId).write(pair)
- }
-
- // Commit the writes. Get the size of each bucket block (total block size).
- var totalBytes = 0L
- var totalTime = 0L
- val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
- writer.commit()
- writer.close()
- val size = writer.fileSegment().length
- totalBytes += size
- totalTime += writer.timeWriting()
- MapOutputTracker.compressSize(size)
+ writer.write(elem.asInstanceOf[Product2[Any, Any]])
}
-
- // Update shuffle metrics.
- val shuffleMetrics = new ShuffleWriteMetrics
- shuffleMetrics.shuffleBytesWritten = totalBytes
- shuffleMetrics.shuffleWriteTime = totalTime
- metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
-
- success = true
- new MapStatus(blockManager.blockManagerId, compressedSizes)
- } catch { case e: Exception =>
- // If there is an exception from running the task, revert the partial writes
- // and throw the exception upstream to Spark.
- if (shuffle != null && shuffle.writers != null) {
- for (writer <- shuffle.writers) {
- writer.revertPartialWrites()
- writer.close()
+ return writer.stop(success = true).get
+ } catch {
+ case e: Exception =>
+ if (writer != null) {
+ writer.stop(success = false)
}
- }
- throw e
+ throw e
} finally {
- // Release the writers back to the shuffle block manager.
- if (shuffle != null && shuffle.writers != null) {
- try {
- shuffle.releaseWriters(success)
- } catch {
- case e: Exception => logError("Failed to release shuffle writers", e)
- }
- }
- // Execute the callbacks on task completion.
context.executeOnCompleteCallbacks()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 5c1fc30e4a..3bf9713f72 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -40,7 +40,7 @@ private[spark] class Stage(
val id: Int,
val rdd: RDD[_],
val numTasks: Int,
- val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
+ val shuffleDep: Option[ShuffleDependency[_, _, _]], // Output shuffle if stage is a map stage
val parents: List[Stage],
val jobId: Int,
callSite: Option[String])
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 99d305b36a..df59f444b7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -71,7 +71,8 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
val loader = Thread.currentThread.getContextClassLoader
taskSetManager.abort("ClassNotFound with classloader: " + loader)
case ex: Exception =>
- taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex))
+ logError("Exception while getting task result", ex)
+ taskSetManager.abort("Exception while getting task result: %s".format(ex))
}
}
})
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index ee26970a3d..f2f5cea469 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -52,6 +52,10 @@ object Serializer {
def getSerializer(serializer: Serializer): Serializer = {
if (serializer == null) SparkEnv.get.serializer else serializer
}
+
+ def getSerializer(serializer: Option[Serializer]): Serializer = {
+ serializer.getOrElse(SparkEnv.get.serializer)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala
index a4f69b6b22..b36c457d6d 100644
--- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala
@@ -15,22 +15,16 @@
* limitations under the License.
*/
-package org.apache.spark
+package org.apache.spark.shuffle
+import org.apache.spark.{ShuffleDependency, Aggregator, Partitioner}
import org.apache.spark.serializer.Serializer
-private[spark] abstract class ShuffleFetcher {
-
- /**
- * Fetch the shuffle outputs for a given ShuffleDependency.
- * @return An iterator over the elements of the fetched shuffle outputs.
- */
- def fetch[T](
- shuffleId: Int,
- reduceId: Int,
- context: TaskContext,
- serializer: Serializer = SparkEnv.get.serializer): Iterator[T]
-
- /** Stop the fetcher */
- def stop() {}
-}
+/**
+ * A basic ShuffleHandle implementation that just captures registerShuffle's parameters.
+ */
+private[spark] class BaseShuffleHandle[K, V, C](
+ shuffleId: Int,
+ val numMaps: Int,
+ val dependency: ShuffleDependency[K, V, C])
+ extends ShuffleHandle(shuffleId)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala
new file mode 100644
index 0000000000..13c7115f88
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+/**
+ * An opaque handle to a shuffle, used by a ShuffleManager to pass information about it to tasks.
+ *
+ * @param shuffleId ID of the shuffle
+ */
+private[spark] abstract class ShuffleHandle(val shuffleId: Int) extends Serializable {}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
new file mode 100644
index 0000000000..9c859b8b4a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import org.apache.spark.{TaskContext, ShuffleDependency}
+
+/**
+ * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on both the
+ * driver and executors, based on the spark.shuffle.manager setting. The driver registers shuffles
+ * with it, and executors (or tasks running locally in the driver) can ask to read and write data.
+ *
+ * NOTE: this will be instantiated by SparkEnv so its constructor can take a SparkConf and
+ * boolean isDriver as parameters.
+ */
+private[spark] trait ShuffleManager {
+ /**
+ * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
+ */
+ def registerShuffle[K, V, C](
+ shuffleId: Int,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle
+
+ /** Get a writer for a given partition. Called on executors by map tasks. */
+ def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V]
+
+ /**
+ * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+ * Called on executors by reduce tasks.
+ */
+ def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext): ShuffleReader[K, C]
+
+ /** Remove a shuffle's metadata from the ShuffleManager. */
+ def unregisterShuffle(shuffleId: Int)
+
+ /** Shut down this ShuffleManager. */
+ def stop(): Unit
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
new file mode 100644
index 0000000000..b30e366d06
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+/**
+ * Obtained inside a reduce task to read combined records from the mappers.
+ */
+private[spark] trait ShuffleReader[K, C] {
+ /** Read the combined key-values for this reduce task */
+ def read(): Iterator[Product2[K, C]]
+
+ /** Close this reader */
+ def stop(): Unit
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
new file mode 100644
index 0000000000..ead3ebd652
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import org.apache.spark.scheduler.MapStatus
+
+/**
+ * Obtained inside a map task to write out records to the shuffle system.
+ */
+private[spark] trait ShuffleWriter[K, V] {
+ /** Write a record to this task's output */
+ def write(record: Product2[K, V]): Unit
+
+ /** Close this writer, passing along whether the map completed */
+ def stop(success: Boolean): Option[MapStatus]
+}
diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index a67392441e..b05b6ea345 100644
--- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark
+package org.apache.spark.shuffle.hash
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@@ -24,17 +24,16 @@ import org.apache.spark.executor.ShuffleReadMetrics
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator
+import org.apache.spark._
-private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
-
- override def fetch[T](
+private[hash] object BlockStoreShuffleFetcher extends Logging {
+ def fetch[T](
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer)
: Iterator[T] =
{
-
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
new file mode 100644
index 0000000000..5b0940ecce
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.hash
+
+import org.apache.spark._
+import org.apache.spark.shuffle._
+
+/**
+ * A ShuffleManager using hashing, that creates one output file per reduce partition on each
+ * mapper (possibly reusing these across waves of tasks).
+ */
+class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
+ /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ new BaseShuffleHandle(shuffleId, numMaps, dependency)
+ }
+
+ /**
+ * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+ * Called on executors by reduce tasks.
+ */
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext): ShuffleReader[K, C] = {
+ new HashShuffleReader(
+ handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
+ }
+
+ /** Get a writer for a given partition. Called on executors by map tasks. */
+ override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
+ : ShuffleWriter[K, V] = {
+ new HashShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
+ }
+
+ /** Remove a shuffle's metadata from the ShuffleManager. */
+ override def unregisterShuffle(shuffleId: Int): Unit = {}
+
+ /** Shut down this ShuffleManager. */
+ override def stop(): Unit = {}
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
new file mode 100644
index 0000000000..f6a790309a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.hash
+
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
+import org.apache.spark.TaskContext
+
+class HashShuffleReader[K, C](
+ handle: BaseShuffleHandle[K, _, C],
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext)
+ extends ShuffleReader[K, C]
+{
+ require(endPartition == startPartition + 1,
+ "Hash shuffle currently only supports fetching one partition")
+
+ /** Read the combined key-values for this reduce task */
+ override def read(): Iterator[Product2[K, C]] = {
+ BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context,
+ Serializer.getSerializer(handle.dependency.serializer))
+ }
+
+ /** Close this reader */
+ override def stop(): Unit = ???
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
new file mode 100644
index 0000000000..4c6749098c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.hash
+
+import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter}
+import org.apache.spark.{Logging, MapOutputTracker, SparkEnv, TaskContext}
+import org.apache.spark.storage.{BlockObjectWriter}
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.scheduler.MapStatus
+
+class HashShuffleWriter[K, V](
+ handle: BaseShuffleHandle[K, V, _],
+ mapId: Int,
+ context: TaskContext)
+ extends ShuffleWriter[K, V] with Logging {
+
+ private val dep = handle.dependency
+ private val numOutputSplits = dep.partitioner.numPartitions
+ private val metrics = context.taskMetrics
+ private var stopping = false
+
+ private val blockManager = SparkEnv.get.blockManager
+ private val shuffleBlockManager = blockManager.shuffleBlockManager
+ private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
+ private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser)
+
+ /** Write a record to this task's output */
+ override def write(record: Product2[K, V]): Unit = {
+ val pair = record.asInstanceOf[Product2[Any, Any]]
+ val bucketId = dep.partitioner.getPartition(pair._1)
+ shuffle.writers(bucketId).write(pair)
+ }
+
+ /** Close this writer, passing along whether the map completed */
+ override def stop(success: Boolean): Option[MapStatus] = {
+ try {
+ if (stopping) {
+ return None
+ }
+ stopping = true
+ if (success) {
+ try {
+ return Some(commitWritesAndBuildStatus())
+ } catch {
+ case e: Exception =>
+ revertWrites()
+ throw e
+ }
+ } else {
+ revertWrites()
+ return None
+ }
+ } finally {
+ // Release the writers back to the shuffle block manager.
+ if (shuffle != null && shuffle.writers != null) {
+ try {
+ shuffle.releaseWriters(success)
+ } catch {
+ case e: Exception => logError("Failed to release shuffle writers", e)
+ }
+ }
+ }
+ }
+
+ private def commitWritesAndBuildStatus(): MapStatus = {
+ // Commit the writes. Get the size of each bucket block (total block size).
+ var totalBytes = 0L
+ var totalTime = 0L
+ val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter =>
+ writer.commit()
+ writer.close()
+ val size = writer.fileSegment().length
+ totalBytes += size
+ totalTime += writer.timeWriting()
+ MapOutputTracker.compressSize(size)
+ }
+
+ // Update shuffle metrics.
+ val shuffleMetrics = new ShuffleWriteMetrics
+ shuffleMetrics.shuffleBytesWritten = totalBytes
+ shuffleMetrics.shuffleWriteTime = totalTime
+ metrics.shuffleWriteMetrics = Some(shuffleMetrics)
+
+ new MapStatus(blockManager.blockManagerId, compressedSizes)
+ }
+
+ private def revertWrites(): Unit = {
+ if (shuffle != null && shuffle.writers != null) {
+ for (writer <- shuffle.writers) {
+ writer.revertPartialWrites()
+ writer.close()
+ }
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index dc2db66df6..13b415cccb 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -201,7 +201,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
def newPairRDD = newRDD.map(_ -> 1)
def newShuffleRDD = newPairRDD.reduceByKey(_ + _)
def newBroadcast = sc.broadcast(1 to 100)
- def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = {
+ def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
getAllDependencies(dep.rdd)
@@ -211,8 +211,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
// Get all the shuffle dependencies
val shuffleDeps = getAllDependencies(rdd)
- .filter(_.isInstanceOf[ShuffleDependency[_, _]])
- .map(_.asInstanceOf[ShuffleDependency[_, _]])
+ .filter(_.isInstanceOf[ShuffleDependency[_, _, _]])
+ .map(_.asInstanceOf[ShuffleDependency[_, _, _]])
(rdd, shuffleDeps)
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 7b0607dd3e..47112ce66d 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -58,7 +58,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf))
- val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 10)
@@ -97,7 +97,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
.setSerializer(new KryoSerializer(conf))
- val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
@@ -122,7 +122,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
// NOTE: The default Java serializer should create zero-sized blocks
val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
- val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>