aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/Dependency.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala87
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala118
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala129
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala71
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala24
-rw-r--r--core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala11
12 files changed, 198 insertions, 302 deletions
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index 09a6057123..3935c87722 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -27,7 +27,9 @@ import org.apache.spark.shuffle.ShuffleHandle
* Base class for dependencies.
*/
@DeveloperApi
-abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
+abstract class Dependency[T] extends Serializable {
+ def rdd: RDD[T]
+}
/**
@@ -36,20 +38,24 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
* partition of the child RDD. Narrow dependencies allow for pipelined execution.
*/
@DeveloperApi
-abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
+abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
/**
* Get the parent partitions for a child partition.
* @param partitionId a partition of the child RDD
* @return the partitions of the parent RDD that the child partition depends upon
*/
def getParents(partitionId: Int): Seq[Int]
+
+ override def rdd: RDD[T] = _rdd
}
/**
* :: DeveloperApi ::
- * Represents a dependency on the output of a shuffle stage.
- * @param rdd the parent RDD
+ * Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle,
+ * the RDD is transient since we don't need it on the executor side.
+ *
+ * @param _rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
* the default serializer, as specified by `spark.serializer` config option, will
@@ -57,20 +63,22 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
*/
@DeveloperApi
class ShuffleDependency[K, V, C](
- @transient rdd: RDD[_ <: Product2[K, V]],
+ @transient _rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
val serializer: Option[Serializer] = None,
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false)
- extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
+ extends Dependency[Product2[K, V]] {
+
+ override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]]
- val shuffleId: Int = rdd.context.newShuffleId()
+ val shuffleId: Int = _rdd.context.newShuffleId()
- val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
- shuffleId, rdd.partitions.size, this)
+ val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
+ shuffleId, _rdd.partitions.size, this)
- rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
+ _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 3e6addeaf0..fb4c86716b 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -997,8 +997,6 @@ class SparkContext(config: SparkConf) extends Logging {
// TODO: Cache.stop()?
env.stop()
SparkEnv.set(null)
- ShuffleMapTask.clearCache()
- ResultTask.clearCache()
listenerBus.stop()
eventLogger.foreach(_.stop())
logInfo("Successfully stopped SparkContext")
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index a6abc49c53..726b3f2bbe 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -35,12 +35,13 @@ import org.apache.spark.Partitioner._
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{BoundedPriorityQueue, CallSite, Utils}
+import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}
@@ -1206,16 +1207,12 @@ abstract class RDD[T: ClassTag](
/**
* Return whether this RDD has been checkpointed or not
*/
- def isCheckpointed: Boolean = {
- checkpointData.map(_.isCheckpointed).getOrElse(false)
- }
+ def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed)
/**
* Gets the name of the file to which this RDD was checkpointed
*/
- def getCheckpointFile: Option[String] = {
- checkpointData.flatMap(_.getCheckpointFile)
- }
+ def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile)
// =======================================================================
// Other internal methods and fields
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index c3b2a33fb5..f67e5f1857 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -106,7 +106,6 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
cpRDD = Some(newRDD)
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
- RDDCheckpointData.clearTaskCaches()
}
logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id)
}
@@ -131,9 +130,5 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
}
}
-private[spark] object RDDCheckpointData {
- def clearTaskCaches() {
- ShuffleMapTask.clearCache()
- ResultTask.clearCache()
- }
-}
+// Used for synchronization
+private[spark] object RDDCheckpointData
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index dc6142ab79..50186d097a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -17,7 +17,7 @@
package org.apache.spark.scheduler
-import java.io.{NotSerializableException, PrintWriter, StringWriter}
+import java.io.NotSerializableException
import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger
@@ -35,6 +35,7 @@ import akka.pattern.ask
import akka.util.Timeout
import org.apache.spark._
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
@@ -114,6 +115,10 @@ class DAGScheduler(
private val dagSchedulerActorSupervisor =
env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this)))
+ // A closure serializer that we reuse.
+ // This is only safe because DAGScheduler runs in a single thread.
+ private val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
+
private[scheduler] var eventProcessActor: ActorRef = _
private def initializeEventProcessActor() {
@@ -361,9 +366,6 @@ class DAGScheduler(
// data structures based on StageId
stageIdToStage -= stageId
- ShuffleMapTask.removeStage(stageId)
- ResultTask.removeStage(stageId)
-
logDebug("After removal of stage %d, remaining stages = %d"
.format(stageId, stageIdToStage.size))
}
@@ -691,49 +693,83 @@ class DAGScheduler(
}
}
-
/** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
stage.pendingTasks.clear()
var tasks = ArrayBuffer[Task[_]]()
+
+ val properties = if (jobIdToActiveJob.contains(jobId)) {
+ jobIdToActiveJob(stage.jobId).properties
+ } else {
+ // this stage will be assigned to "default" pool
+ null
+ }
+
+ runningStages += stage
+ // SparkListenerStageSubmitted should be posted before testing whether tasks are
+ // serializable. If tasks are not serializable, a SparkListenerStageCompleted event
+ // will be posted, which should always come after a corresponding SparkListenerStageSubmitted
+ // event.
+ listenerBus.post(SparkListenerStageSubmitted(stage.info, properties))
+
+ // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
+ // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast
+ // the serialized copy of the RDD and for each task we will deserialize it, which means each
+ // task gets a different copy of the RDD. This provides stronger isolation between tasks that
+ // might modify state of objects referenced in their closures. This is necessary in Hadoop
+ // where the JobConf/Configuration object is not thread-safe.
+ var taskBinary: Broadcast[Array[Byte]] = null
+ try {
+ // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
+ // For ResultTask, serialize and broadcast (rdd, func).
+ val taskBinaryBytes: Array[Byte] =
+ if (stage.isShuffleMap) {
+ closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array()
+ } else {
+ closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func) : AnyRef).array()
+ }
+ taskBinary = sc.broadcast(taskBinaryBytes)
+ } catch {
+ // In the case of a failure during serialization, abort the stage.
+ case e: NotSerializableException =>
+ abortStage(stage, "Task not serializable: " + e.toString)
+ runningStages -= stage
+ return
+ case NonFatal(e) =>
+ abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}")
+ runningStages -= stage
+ return
+ }
+
if (stage.isShuffleMap) {
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
val locs = getPreferredLocs(stage.rdd, p)
- tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs)
+ val part = stage.rdd.partitions(p)
+ tasks += new ShuffleMapTask(stage.id, taskBinary, part, locs)
}
} else {
// This is a final stage; figure out its job's missing partitions
val job = stage.resultOfJob.get
for (id <- 0 until job.numPartitions if !job.finished(id)) {
- val partition = job.partitions(id)
- val locs = getPreferredLocs(stage.rdd, partition)
- tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id)
+ val p: Int = job.partitions(id)
+ val part = stage.rdd.partitions(p)
+ val locs = getPreferredLocs(stage.rdd, p)
+ tasks += new ResultTask(stage.id, taskBinary, part, locs, id)
}
}
- val properties = if (jobIdToActiveJob.contains(jobId)) {
- jobIdToActiveJob(stage.jobId).properties
- } else {
- // this stage will be assigned to "default" pool
- null
- }
-
if (tasks.size > 0) {
- runningStages += stage
- // SparkListenerStageSubmitted should be posted before testing whether tasks are
- // serializable. If tasks are not serializable, a SparkListenerStageCompleted event
- // will be posted, which should always come after a corresponding SparkListenerStageSubmitted
- // event.
- listenerBus.post(SparkListenerStageSubmitted(stage.info, properties))
-
// Preemptively serialize a task to make sure it can be serialized. We are catching this
// exception here because it would be fairly hard to catch the non-serializable exception
// down the road, where we have several different implementations for local scheduler and
// cluster schedulers.
+ //
+ // We've already serialized RDDs and closures in taskBinary, but here we check for all other
+ // objects such as Partition.
try {
- SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head)
+ closureSerializer.serialize(tasks.head)
} catch {
case e: NotSerializableException =>
abortStage(stage, "Task not serializable: " + e.toString)
@@ -752,6 +788,9 @@ class DAGScheduler(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
stage.info.submissionTime = Some(clock.getTime())
} else {
+ // Because we posted SparkListenerStageSubmitted earlier, we should post
+ // SparkListenerStageCompleted here in case there are no tasks to run.
+ listenerBus.post(SparkListenerStageCompleted(stage.info))
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
runningStages -= stage
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index bbf9f7388b..d09fd7aa57 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -17,134 +17,56 @@
package org.apache.spark.scheduler
-import scala.language.existentials
+import java.nio.ByteBuffer
import java.io._
-import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-
-import scala.collection.mutable.HashMap
import org.apache.spark._
-import org.apache.spark.rdd.{RDD, RDDCheckpointData}
-
-private[spark] object ResultTask {
-
- // A simple map between the stage id to the serialized byte array of a task.
- // Served as a cache for task serialization because serialization can be
- // expensive on the master node if it needs to launch thousands of tasks.
- private val serializedInfoCache = new HashMap[Int, Array[Byte]]
-
- def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] =
- {
- synchronized {
- val old = serializedInfoCache.get(stageId).orNull
- if (old != null) {
- old
- } else {
- val out = new ByteArrayOutputStream
- val ser = SparkEnv.get.closureSerializer.newInstance()
- val objOut = ser.serializeStream(new GZIPOutputStream(out))
- objOut.writeObject(rdd)
- objOut.writeObject(func)
- objOut.close()
- val bytes = out.toByteArray
- serializedInfoCache.put(stageId, bytes)
- bytes
- }
- }
- }
-
- def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) =
- {
- val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
- val ser = SparkEnv.get.closureSerializer.newInstance()
- val objIn = ser.deserializeStream(in)
- val rdd = objIn.readObject().asInstanceOf[RDD[_]]
- val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
- (rdd, func)
- }
-
- def removeStage(stageId: Int) {
- serializedInfoCache.remove(stageId)
- }
-
- def clearCache() {
- synchronized {
- serializedInfoCache.clear()
- }
- }
-}
-
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
/**
* A task that sends back the output to the driver application.
*
- * See [[org.apache.spark.scheduler.Task]] for more information.
+ * See [[Task]] for more information.
*
* @param stageId id of the stage this task belongs to
- * @param rdd input to func
- * @param func a function to apply on a partition of the RDD
- * @param _partitionId index of the number in the RDD
+ * @param taskBinary broadcasted version of the serialized RDD and the function to apply on each
+ * partition of the given RDD. Once deserialized, the type should be
+ * (RDD[T], (TaskContext, Iterator[T]) => U).
+ * @param partition partition of the RDD this task is associated with
* @param locs preferred task execution locations for locality scheduling
* @param outputId index of the task in this job (a job can launch tasks on only a subset of the
* input RDD's partitions).
*/
private[spark] class ResultTask[T, U](
stageId: Int,
- var rdd: RDD[T],
- var func: (TaskContext, Iterator[T]) => U,
- _partitionId: Int,
+ taskBinary: Broadcast[Array[Byte]],
+ partition: Partition,
@transient locs: Seq[TaskLocation],
- var outputId: Int)
- extends Task[U](stageId, _partitionId) with Externalizable {
+ val outputId: Int)
+ extends Task[U](stageId, partition.index) with Serializable {
- def this() = this(0, null, null, 0, null, 0)
-
- var split = if (rdd == null) null else rdd.partitions(partitionId)
-
- @transient private val preferredLocs: Seq[TaskLocation] = {
+ @transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
}
override def runTask(context: TaskContext): U = {
+ // Deserialize the RDD and the func using the broadcast variables.
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
+ ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
+
metrics = Some(context.taskMetrics)
try {
- func(context, rdd.iterator(split, context))
+ func(context, rdd.iterator(partition, context))
} finally {
context.executeOnCompleteCallbacks()
}
}
+ // This is only callable on the driver side.
override def preferredLocations: Seq[TaskLocation] = preferredLocs
override def toString = "ResultTask(" + stageId + ", " + partitionId + ")"
-
- override def writeExternal(out: ObjectOutput) {
- RDDCheckpointData.synchronized {
- split = rdd.partitions(partitionId)
- out.writeInt(stageId)
- val bytes = ResultTask.serializeInfo(
- stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
- out.writeInt(bytes.length)
- out.write(bytes)
- out.writeInt(partitionId)
- out.writeInt(outputId)
- out.writeLong(epoch)
- out.writeObject(split)
- }
- }
-
- override def readExternal(in: ObjectInput) {
- val stageId = in.readInt()
- val numBytes = in.readInt()
- val bytes = new Array[Byte](numBytes)
- in.readFully(bytes)
- val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
- rdd = rdd_.asInstanceOf[RDD[T]]
- func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
- partitionId = in.readInt()
- outputId = in.readInt()
- epoch = in.readLong()
- split = in.readObject().asInstanceOf[Partition]
- }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index fdaf1de83f..11255c0746 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -17,134 +17,55 @@
package org.apache.spark.scheduler
-import scala.language.existentials
-
-import java.io._
-import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+import java.nio.ByteBuffer
-import scala.collection.mutable.HashMap
+import scala.language.existentials
import org.apache.spark._
-import org.apache.spark.rdd.{RDD, RDDCheckpointData}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
import org.apache.spark.shuffle.ShuffleWriter
-private[spark] object ShuffleMapTask {
-
- // A simple map between the stage id to the serialized byte array of a task.
- // Served as a cache for task serialization because serialization can be
- // expensive on the master node if it needs to launch thousands of tasks.
- private val serializedInfoCache = new HashMap[Int, Array[Byte]]
-
- def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = {
- synchronized {
- val old = serializedInfoCache.get(stageId).orNull
- if (old != null) {
- return old
- } else {
- val out = new ByteArrayOutputStream
- val ser = SparkEnv.get.closureSerializer.newInstance()
- val objOut = ser.serializeStream(new GZIPOutputStream(out))
- objOut.writeObject(rdd)
- objOut.writeObject(dep)
- objOut.close()
- val bytes = out.toByteArray
- serializedInfoCache.put(stageId, bytes)
- bytes
- }
- }
- }
-
- def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = {
- val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
- val ser = SparkEnv.get.closureSerializer.newInstance()
- val objIn = ser.deserializeStream(in)
- val rdd = objIn.readObject().asInstanceOf[RDD[_]]
- val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]]
- (rdd, dep)
- }
-
- // Since both the JarSet and FileSet have the same format this is used for both.
- def deserializeFileSet(bytes: Array[Byte]): HashMap[String, Long] = {
- val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
- val objIn = new ObjectInputStream(in)
- val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
- HashMap(set.toSeq: _*)
- }
-
- def removeStage(stageId: Int) {
- serializedInfoCache.remove(stageId)
- }
-
- def clearCache() {
- synchronized {
- serializedInfoCache.clear()
- }
- }
-}
-
/**
- * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
- * specified in the ShuffleDependency).
- *
- * See [[org.apache.spark.scheduler.Task]] for more information.
- *
+* A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
+* specified in the ShuffleDependency).
+*
+* See [[org.apache.spark.scheduler.Task]] for more information.
+*
* @param stageId id of the stage this task belongs to
- * @param rdd the final RDD in this stage
- * @param dep the ShuffleDependency
- * @param _partitionId index of the number in the RDD
+ * @param taskBinary broadcast version of of the RDD and the ShuffleDependency. Once deserialized,
+ * the type should be (RDD[_], ShuffleDependency[_, _, _]).
+ * @param partition partition of the RDD this task is associated with
* @param locs preferred task execution locations for locality scheduling
*/
private[spark] class ShuffleMapTask(
stageId: Int,
- var rdd: RDD[_],
- var dep: ShuffleDependency[_, _, _],
- _partitionId: Int,
+ taskBinary: Broadcast[Array[Byte]],
+ partition: Partition,
@transient private var locs: Seq[TaskLocation])
- extends Task[MapStatus](stageId, _partitionId)
- with Externalizable
- with Logging {
+ extends Task[MapStatus](stageId, partition.index) with Logging {
- protected def this() = this(0, null, null, 0, null)
+ /** A constructor used only in test suites. This does not require passing in an RDD. */
+ def this(partitionId: Int) {
+ this(0, null, new Partition { override def index = 0 }, null)
+ }
@transient private val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
}
- var split = if (rdd == null) null else rdd.partitions(partitionId)
-
- override def writeExternal(out: ObjectOutput) {
- RDDCheckpointData.synchronized {
- split = rdd.partitions(partitionId)
- out.writeInt(stageId)
- val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
- out.writeInt(bytes.length)
- out.write(bytes)
- out.writeInt(partitionId)
- out.writeLong(epoch)
- out.writeObject(split)
- }
- }
-
- override def readExternal(in: ObjectInput) {
- val stageId = in.readInt()
- val numBytes = in.readInt()
- val bytes = new Array[Byte](numBytes)
- in.readFully(bytes)
- val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
- rdd = rdd_
- dep = dep_
- partitionId = in.readInt()
- epoch = in.readLong()
- split = in.readObject().asInstanceOf[Partition]
- }
-
override def runTask(context: TaskContext): MapStatus = {
+ // Deserialize the RDD using the broadcast variable.
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
+ ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
+
metrics = Some(context.taskMetrics)
var writer: ShuffleWriter[Any, Any] = null
try {
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
- writer.write(rdd.iterator(split, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
+ writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
return writer.stop(success = true).get
} catch {
case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 69f65b4bdc..f8fbb3ad6d 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -38,7 +38,7 @@ import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
import org.json4s._
import tachyon.client.{TachyonFile,TachyonFS}
-import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
+import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.ExecutorUncaughtExceptionHandler
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index 13b415cccb..ad20f9b937 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -19,6 +19,9 @@ package org.apache.spark
import java.lang.ref.WeakReference
+import org.apache.spark.broadcast.Broadcast
+
+import scala.collection.mutable
import scala.collection.mutable.{HashSet, SynchronizedSet}
import scala.language.existentials
import scala.language.postfixOps
@@ -52,9 +55,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
}
}
-
test("cleanup RDD") {
- val rdd = newRDD.persist()
+ val rdd = newRDD().persist()
val collected = rdd.collect().toList
val tester = new CleanerTester(sc, rddIds = Seq(rdd.id))
@@ -67,7 +69,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
}
test("cleanup shuffle") {
- val (rdd, shuffleDeps) = newRDDWithShuffleDependencies
+ val (rdd, shuffleDeps) = newRDDWithShuffleDependencies()
val collected = rdd.collect().toList
val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId))
@@ -80,7 +82,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
}
test("cleanup broadcast") {
- val broadcast = newBroadcast
+ val broadcast = newBroadcast()
val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
// Explicit cleanup
@@ -89,7 +91,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
}
test("automatically cleanup RDD") {
- var rdd = newRDD.persist()
+ var rdd = newRDD().persist()
rdd.count()
// Test that GC does not cause RDD cleanup due to a strong reference
@@ -107,7 +109,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
}
test("automatically cleanup shuffle") {
- var rdd = newShuffleRDD
+ var rdd = newShuffleRDD()
rdd.count()
// Test that GC does not cause shuffle cleanup due to a strong reference
@@ -125,7 +127,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
}
test("automatically cleanup broadcast") {
- var broadcast = newBroadcast
+ var broadcast = newBroadcast()
// Test that GC does not cause broadcast cleanup due to a strong reference
val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
@@ -144,11 +146,11 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
test("automatically cleanup RDD + shuffle + broadcast") {
val numRdds = 100
val numBroadcasts = 4 // Broadcasts are more costly
- val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
- val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
+ val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
+ val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer
val rddIds = sc.persistentRdds.keys.toSeq
val shuffleIds = 0 until sc.newShuffleId
- val broadcastIds = 0L until numBroadcasts
+ val broadcastIds = broadcastBuffer.map(_.id)
val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
runGC()
@@ -162,6 +164,13 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
rddBuffer.clear()
runGC()
postGCTester.assertCleanup()
+
+ // Make sure the broadcasted task closure no longer exists after GC.
+ val taskClosureBroadcastId = broadcastIds.max + 1
+ assert(sc.env.blockManager.master.getMatchingBlockIds({
+ case BroadcastBlockId(`taskClosureBroadcastId`, _) => true
+ case _ => false
+ }, askSlaves = true).isEmpty)
}
test("automatically cleanup RDD + shuffle + broadcast in distributed mode") {
@@ -175,11 +184,11 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
val numRdds = 10
val numBroadcasts = 4 // Broadcasts are more costly
- val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
- val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
+ val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
+ val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer
val rddIds = sc.persistentRdds.keys.toSeq
val shuffleIds = 0 until sc.newShuffleId
- val broadcastIds = 0L until numBroadcasts
+ val broadcastIds = broadcastBuffer.map(_.id)
val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
runGC()
@@ -193,21 +202,29 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
rddBuffer.clear()
runGC()
postGCTester.assertCleanup()
+
+ // Make sure the broadcasted task closure no longer exists after GC.
+ val taskClosureBroadcastId = broadcastIds.max + 1
+ assert(sc.env.blockManager.master.getMatchingBlockIds({
+ case BroadcastBlockId(`taskClosureBroadcastId`, _) => true
+ case _ => false
+ }, askSlaves = true).isEmpty)
}
//------ Helper functions ------
- def newRDD = sc.makeRDD(1 to 10)
- def newPairRDD = newRDD.map(_ -> 1)
- def newShuffleRDD = newPairRDD.reduceByKey(_ + _)
- def newBroadcast = sc.broadcast(1 to 100)
- def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
+ private def newRDD() = sc.makeRDD(1 to 10)
+ private def newPairRDD() = newRDD().map(_ -> 1)
+ private def newShuffleRDD() = newPairRDD().reduceByKey(_ + _)
+ private def newBroadcast() = sc.broadcast(1 to 100)
+
+ private def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
getAllDependencies(dep.rdd)
}
}
- val rdd = newShuffleRDD
+ val rdd = newShuffleRDD()
// Get all the shuffle dependencies
val shuffleDeps = getAllDependencies(rdd)
@@ -216,34 +233,34 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
(rdd, shuffleDeps)
}
- def randomRdd = {
+ private def randomRdd() = {
val rdd: RDD[_] = Random.nextInt(3) match {
- case 0 => newRDD
- case 1 => newShuffleRDD
- case 2 => newPairRDD.join(newPairRDD)
+ case 0 => newRDD()
+ case 1 => newShuffleRDD()
+ case 2 => newPairRDD.join(newPairRDD())
}
if (Random.nextBoolean()) rdd.persist()
rdd.count()
rdd
}
- def randomBroadcast = {
+ private def randomBroadcast() = {
sc.broadcast(Random.nextInt(Int.MaxValue))
}
/** Run GC and make sure it actually has run */
- def runGC() {
+ private def runGC() {
val weakRef = new WeakReference(new Object())
val startTime = System.currentTimeMillis
System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
// Wait until a weak reference object has been GCed
- while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
+ while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
System.gc()
Thread.sleep(200)
}
}
- def cleaner = sc.cleaner.get
+ private def cleaner = sc.cleaner.get
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index fdc83bc0a5..4953d565ae 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -155,19 +155,13 @@ class RDDSuite extends FunSuite with SharedSparkContext {
override def getPartitions: Array[Partition] = Array(onlySplit)
override val getDependencies = List[Dependency[_]]()
override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
- if (shouldFail) {
- throw new Exception("injected failure")
- } else {
- Array(1, 2, 3, 4).iterator
- }
+ throw new Exception("injected failure")
}
}.cache()
val thrown = intercept[Exception]{
rdd.collect()
}
assert(thrown.getMessage.contains("injected failure"))
- shouldFail = false
- assert(rdd.collect().toList === List(1, 2, 3, 4))
}
test("empty RDD") {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 8bb5317cd2..270f7e6610 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -20,31 +20,35 @@ package org.apache.spark.scheduler
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
-import org.apache.spark.LocalSparkContext
-import org.apache.spark.Partition
-import org.apache.spark.SparkContext
-import org.apache.spark.TaskContext
+import org.apache.spark._
import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
test("Calls executeOnCompleteCallbacks after failure") {
- var completed = false
+ TaskContextSuite.completed = false
sc = new SparkContext("local", "test")
val rdd = new RDD[String](sc, List()) {
override def getPartitions = Array[Partition](StubPartition(0))
override def compute(split: Partition, context: TaskContext) = {
- context.addOnCompleteCallback(() => completed = true)
+ context.addOnCompleteCallback(() => TaskContextSuite.completed = true)
sys.error("failed")
}
}
- val func = (c: TaskContext, i: Iterator[String]) => i.next
- val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0)
+ val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
+ val func = (c: TaskContext, i: Iterator[String]) => i.next()
+ val task = new ResultTask[String, String](
+ 0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
intercept[RuntimeException] {
task.run(0)
}
- assert(completed === true)
+ assert(TaskContextSuite.completed === true)
}
+}
- case class StubPartition(val index: Int) extends Partition
+private object TaskContextSuite {
+ @volatile var completed = false
}
+
+private case class StubPartition(index: Int) extends Partition
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index b52f81877d..86a271eb67 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.scheduler._
import org.apache.spark.util.Utils
class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers {
+
test("test LRU eviction of stages") {
val conf = new SparkConf()
conf.set("spark.ui.retainedStages", 5.toString)
@@ -66,7 +67,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics)
var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
taskInfo.finishTime = 1
- var task = new ShuffleMapTask(0, null, null, 0, null)
+ var task = new ShuffleMapTask(0)
val taskType = Utils.getFormattedClassName(task)
listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics))
assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail())
@@ -76,14 +77,14 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
taskInfo =
new TaskInfo(1234L, 0, 1, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL, true)
taskInfo.finishTime = 1
- task = new ShuffleMapTask(0, null, null, 0, null)
+ task = new ShuffleMapTask(0)
listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics))
assert(listener.stageIdToData.size === 1)
// finish this task, should get updated duration
taskInfo = new TaskInfo(1235L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
taskInfo.finishTime = 1
- task = new ShuffleMapTask(0, null, null, 0, null)
+ task = new ShuffleMapTask(0)
listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics))
assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail())
.shuffleRead === 2000)
@@ -91,7 +92,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
// finish this task, should get updated duration
taskInfo = new TaskInfo(1236L, 0, 2, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL, false)
taskInfo.finishTime = 1
- task = new ShuffleMapTask(0, null, null, 0, null)
+ task = new ShuffleMapTask(0)
listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics))
assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-2", fail())
.shuffleRead === 1000)
@@ -103,7 +104,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
val metrics = new TaskMetrics()
val taskInfo = new TaskInfo(1234L, 0, 3, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
taskInfo.finishTime = 1
- val task = new ShuffleMapTask(0, null, null, 0, null)
+ val task = new ShuffleMapTask(0)
val taskType = Utils.getFormattedClassName(task)
// Go through all the failure cases to make sure we are counting them as failures.