aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-07-19 16:56:22 -0700
committerReynold Xin <rxin@apache.org>2014-07-19 16:56:22 -0700
commit1efb3698b6cf39a80683b37124d2736ebf3c9d9a (patch)
tree3e4648ec06a8f395a15cd0ce363bf7f8bfc1d667
parent2a732110d46712c535b75dd4f5a73761b6463aa8 (diff)
downloadspark-1efb3698b6cf39a80683b37124d2736ebf3c9d9a.tar.gz
spark-1efb3698b6cf39a80683b37124d2736ebf3c9d9a.tar.bz2
spark-1efb3698b6cf39a80683b37124d2736ebf3c9d9a.zip
Revert "[SPARK-2521] Broadcast RDD object (instead of sending it along with every task)."
This reverts commit 7b8cd175254d42c8e82f0aa8eb4b7f3508d8fde2.
-rw-r--r--core/src/main/scala/org/apache/spark/Dependency.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala30
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala128
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala125
-rw-r--r--core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala62
8 files changed, 251 insertions, 137 deletions
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index 3935c87722..09a6057123 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -27,9 +27,7 @@ import org.apache.spark.shuffle.ShuffleHandle
* Base class for dependencies.
*/
@DeveloperApi
-abstract class Dependency[T] extends Serializable {
- def rdd: RDD[T]
-}
+abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
/**
@@ -38,24 +36,20 @@ abstract class Dependency[T] extends Serializable {
* partition of the child RDD. Narrow dependencies allow for pipelined execution.
*/
@DeveloperApi
-abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
+abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
/**
* Get the parent partitions for a child partition.
* @param partitionId a partition of the child RDD
* @return the partitions of the parent RDD that the child partition depends upon
*/
def getParents(partitionId: Int): Seq[Int]
-
- override def rdd: RDD[T] = _rdd
}
/**
* :: DeveloperApi ::
- * Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle,
- * the RDD is transient since we don't need it on the executor side.
- *
- * @param _rdd the parent RDD
+ * Represents a dependency on the output of a shuffle stage.
+ * @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
* the default serializer, as specified by `spark.serializer` config option, will
@@ -63,22 +57,20 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
*/
@DeveloperApi
class ShuffleDependency[K, V, C](
- @transient _rdd: RDD[_ <: Product2[K, V]],
+ @transient rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
val serializer: Option[Serializer] = None,
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false)
- extends Dependency[Product2[K, V]] {
-
- override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]]
+ extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
- val shuffleId: Int = _rdd.context.newShuffleId()
+ val shuffleId: Int = rdd.context.newShuffleId()
- val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
- shuffleId, _rdd.partitions.size, this)
+ val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
+ shuffleId, rdd.partitions.size, this)
- _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
+ rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 48a09657fd..8052499ab7 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -997,6 +997,8 @@ class SparkContext(config: SparkConf) extends Logging {
// TODO: Cache.stop()?
env.stop()
SparkEnv.set(null)
+ ShuffleMapTask.clearCache()
+ ResultTask.clearCache()
listenerBus.stop()
eventLogger.foreach(_.stop())
logInfo("Successfully stopped SparkContext")
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 2ee9a8f1a8..88a918aebf 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -35,13 +35,12 @@ import org.apache.spark.Partitioner._
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.broadcast.Broadcast
import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{BoundedPriorityQueue, Utils}
+import org.apache.spark.util.{BoundedPriorityQueue, CallSite, Utils}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}
@@ -1196,36 +1195,21 @@ abstract class RDD[T: ClassTag](
/**
* Return whether this RDD has been checkpointed or not
*/
- def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed)
+ def isCheckpointed: Boolean = {
+ checkpointData.map(_.isCheckpointed).getOrElse(false)
+ }
/**
* Gets the name of the file to which this RDD was checkpointed
*/
- def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile)
+ def getCheckpointFile: Option[String] = {
+ checkpointData.flatMap(_.getCheckpointFile)
+ }
// =======================================================================
// Other internal methods and fields
// =======================================================================
- /**
- * Broadcasted copy of this RDD, used to dispatch tasks to executors. Note that we broadcast
- * the serialized copy of the RDD and for each task we will deserialize it, which means each
- * task gets a different copy of the RDD. This provides stronger isolation between tasks that
- * might modify state of objects referenced in their closures. This is necessary in Hadoop
- * where the JobConf/Configuration object is not thread-safe.
- */
- @transient private[spark] lazy val broadcasted: Broadcast[Array[Byte]] = {
- val ser = SparkEnv.get.closureSerializer.newInstance()
- val bytes = ser.serialize(this).array()
- val size = Utils.bytesToString(bytes.length)
- if (bytes.length > (1L << 20)) {
- logWarning(s"Broadcasting RDD $id ($size), which contains large objects")
- } else {
- logDebug(s"Broadcasting RDD $id ($size)")
- }
- sc.broadcast(bytes)
- }
-
private var storageLevel: StorageLevel = StorageLevel.NONE
/** User code that created this RDD (e.g. `textFile`, `parallelize`). */
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index f67e5f1857..c3b2a33fb5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -106,6 +106,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
cpRDD = Some(newRDD)
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
+ RDDCheckpointData.clearTaskCaches()
}
logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id)
}
@@ -130,5 +131,9 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
}
}
-// Used for synchronization
-private[spark] object RDDCheckpointData
+private[spark] object RDDCheckpointData {
+ def clearTaskCaches() {
+ ShuffleMapTask.clearCache()
+ ResultTask.clearCache()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 88cb5feaaf..ede3c7d9f0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -376,6 +376,9 @@ class DAGScheduler(
stageIdToStage -= stageId
stageIdToJobIds -= stageId
+ ShuffleMapTask.removeStage(stageId)
+ ResultTask.removeStage(stageId)
+
logDebug("After removal of stage %d, remaining stages = %d"
.format(stageId, stageIdToStage.size))
}
@@ -720,6 +723,7 @@ class DAGScheduler(
}
}
+
/** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 62beb0d02a..bbf9f7388b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -17,68 +17,134 @@
package org.apache.spark.scheduler
-import java.nio.ByteBuffer
+import scala.language.existentials
import java.io._
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
+import scala.collection.mutable.HashMap
import org.apache.spark._
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{RDD, RDDCheckpointData}
+
+private[spark] object ResultTask {
+
+ // A simple map between the stage id to the serialized byte array of a task.
+ // Served as a cache for task serialization because serialization can be
+ // expensive on the master node if it needs to launch thousands of tasks.
+ private val serializedInfoCache = new HashMap[Int, Array[Byte]]
+
+ def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] =
+ {
+ synchronized {
+ val old = serializedInfoCache.get(stageId).orNull
+ if (old != null) {
+ old
+ } else {
+ val out = new ByteArrayOutputStream
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ val objOut = ser.serializeStream(new GZIPOutputStream(out))
+ objOut.writeObject(rdd)
+ objOut.writeObject(func)
+ objOut.close()
+ val bytes = out.toByteArray
+ serializedInfoCache.put(stageId, bytes)
+ bytes
+ }
+ }
+ }
+
+ def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) =
+ {
+ val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ val objIn = ser.deserializeStream(in)
+ val rdd = objIn.readObject().asInstanceOf[RDD[_]]
+ val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
+ (rdd, func)
+ }
+
+ def removeStage(stageId: Int) {
+ serializedInfoCache.remove(stageId)
+ }
+
+ def clearCache() {
+ synchronized {
+ serializedInfoCache.clear()
+ }
+ }
+}
+
/**
* A task that sends back the output to the driver application.
*
- * See [[Task]] for more information.
+ * See [[org.apache.spark.scheduler.Task]] for more information.
*
* @param stageId id of the stage this task belongs to
- * @param rddBinary broadcast version of of the serialized RDD
+ * @param rdd input to func
* @param func a function to apply on a partition of the RDD
- * @param partition partition of the RDD this task is associated with
+ * @param _partitionId index of the number in the RDD
* @param locs preferred task execution locations for locality scheduling
* @param outputId index of the task in this job (a job can launch tasks on only a subset of the
* input RDD's partitions).
*/
private[spark] class ResultTask[T, U](
stageId: Int,
- val rddBinary: Broadcast[Array[Byte]],
- val func: (TaskContext, Iterator[T]) => U,
- val partition: Partition,
+ var rdd: RDD[T],
+ var func: (TaskContext, Iterator[T]) => U,
+ _partitionId: Int,
@transient locs: Seq[TaskLocation],
- val outputId: Int)
- extends Task[U](stageId, partition.index) with Serializable {
-
- // TODO: Should we also broadcast func? For that we would need a place to
- // keep a reference to it (perhaps in DAGScheduler's job object).
-
- def this(
- stageId: Int,
- rdd: RDD[T],
- func: (TaskContext, Iterator[T]) => U,
- partitionId: Int,
- locs: Seq[TaskLocation],
- outputId: Int) = {
- this(stageId, rdd.broadcasted, func, rdd.partitions(partitionId), locs, outputId)
- }
+ var outputId: Int)
+ extends Task[U](stageId, _partitionId) with Externalizable {
+
+ def this() = this(0, null, null, 0, null, 0)
+
+ var split = if (rdd == null) null else rdd.partitions(partitionId)
- @transient private[this] val preferredLocs: Seq[TaskLocation] = {
+ @transient private val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
}
override def runTask(context: TaskContext): U = {
- // Deserialize the RDD using the broadcast variable.
- val ser = SparkEnv.get.closureSerializer.newInstance()
- val rdd = ser.deserialize[RDD[T]](ByteBuffer.wrap(rddBinary.value),
- Thread.currentThread.getContextClassLoader)
metrics = Some(context.taskMetrics)
try {
- func(context, rdd.iterator(partition, context))
+ func(context, rdd.iterator(split, context))
} finally {
context.executeOnCompleteCallbacks()
}
}
- // This is only callable on the driver side.
override def preferredLocations: Seq[TaskLocation] = preferredLocs
override def toString = "ResultTask(" + stageId + ", " + partitionId + ")"
+
+ override def writeExternal(out: ObjectOutput) {
+ RDDCheckpointData.synchronized {
+ split = rdd.partitions(partitionId)
+ out.writeInt(stageId)
+ val bytes = ResultTask.serializeInfo(
+ stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ out.writeInt(partitionId)
+ out.writeInt(outputId)
+ out.writeLong(epoch)
+ out.writeObject(split)
+ }
+ }
+
+ override def readExternal(in: ObjectInput) {
+ val stageId = in.readInt()
+ val numBytes = in.readInt()
+ val bytes = new Array[Byte](numBytes)
+ in.readFully(bytes)
+ val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
+ rdd = rdd_.asInstanceOf[RDD[T]]
+ func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
+ partitionId = in.readInt()
+ outputId = in.readInt()
+ epoch = in.readLong()
+ split = in.readObject().asInstanceOf[Partition]
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 033c6e5286..fdaf1de83f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -17,13 +17,71 @@
package org.apache.spark.scheduler
-import java.nio.ByteBuffer
+import scala.language.existentials
+
+import java.io._
+import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+
+import scala.collection.mutable.HashMap
import org.apache.spark._
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{RDD, RDDCheckpointData}
import org.apache.spark.shuffle.ShuffleWriter
+private[spark] object ShuffleMapTask {
+
+ // A simple map between the stage id to the serialized byte array of a task.
+ // Served as a cache for task serialization because serialization can be
+ // expensive on the master node if it needs to launch thousands of tasks.
+ private val serializedInfoCache = new HashMap[Int, Array[Byte]]
+
+ def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = {
+ synchronized {
+ val old = serializedInfoCache.get(stageId).orNull
+ if (old != null) {
+ return old
+ } else {
+ val out = new ByteArrayOutputStream
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ val objOut = ser.serializeStream(new GZIPOutputStream(out))
+ objOut.writeObject(rdd)
+ objOut.writeObject(dep)
+ objOut.close()
+ val bytes = out.toByteArray
+ serializedInfoCache.put(stageId, bytes)
+ bytes
+ }
+ }
+ }
+
+ def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = {
+ val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ val objIn = ser.deserializeStream(in)
+ val rdd = objIn.readObject().asInstanceOf[RDD[_]]
+ val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]]
+ (rdd, dep)
+ }
+
+ // Since both the JarSet and FileSet have the same format this is used for both.
+ def deserializeFileSet(bytes: Array[Byte]): HashMap[String, Long] = {
+ val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
+ val objIn = new ObjectInputStream(in)
+ val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
+ HashMap(set.toSeq: _*)
+ }
+
+ def removeStage(stageId: Int) {
+ serializedInfoCache.remove(stageId)
+ }
+
+ def clearCache() {
+ synchronized {
+ serializedInfoCache.clear()
+ }
+ }
+}
+
/**
* A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
* specified in the ShuffleDependency).
@@ -31,47 +89,62 @@ import org.apache.spark.shuffle.ShuffleWriter
* See [[org.apache.spark.scheduler.Task]] for more information.
*
* @param stageId id of the stage this task belongs to
- * @param rddBinary broadcast version of of the serialized RDD
+ * @param rdd the final RDD in this stage
* @param dep the ShuffleDependency
- * @param partition partition of the RDD this task is associated with
+ * @param _partitionId index of the number in the RDD
* @param locs preferred task execution locations for locality scheduling
*/
private[spark] class ShuffleMapTask(
stageId: Int,
- var rddBinary: Broadcast[Array[Byte]],
+ var rdd: RDD[_],
var dep: ShuffleDependency[_, _, _],
- partition: Partition,
+ _partitionId: Int,
@transient private var locs: Seq[TaskLocation])
- extends Task[MapStatus](stageId, partition.index) with Logging {
-
- // TODO: Should we also broadcast the ShuffleDependency? For that we would need a place to
- // keep a reference to it (perhaps in Stage).
-
- def this(
- stageId: Int,
- rdd: RDD[_],
- dep: ShuffleDependency[_, _, _],
- partitionId: Int,
- locs: Seq[TaskLocation]) = {
- this(stageId, rdd.broadcasted, dep, rdd.partitions(partitionId), locs)
- }
+ extends Task[MapStatus](stageId, _partitionId)
+ with Externalizable
+ with Logging {
+
+ protected def this() = this(0, null, null, 0, null)
@transient private val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
}
- override def runTask(context: TaskContext): MapStatus = {
- // Deserialize the RDD using the broadcast variable.
- val ser = SparkEnv.get.closureSerializer.newInstance()
- val rdd = ser.deserialize[RDD[_]](ByteBuffer.wrap(rddBinary.value),
- Thread.currentThread.getContextClassLoader)
+ var split = if (rdd == null) null else rdd.partitions(partitionId)
+
+ override def writeExternal(out: ObjectOutput) {
+ RDDCheckpointData.synchronized {
+ split = rdd.partitions(partitionId)
+ out.writeInt(stageId)
+ val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ out.writeInt(partitionId)
+ out.writeLong(epoch)
+ out.writeObject(split)
+ }
+ }
+ override def readExternal(in: ObjectInput) {
+ val stageId = in.readInt()
+ val numBytes = in.readInt()
+ val bytes = new Array[Byte](numBytes)
+ in.readFully(bytes)
+ val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
+ rdd = rdd_
+ dep = dep_
+ partitionId = in.readInt()
+ epoch = in.readLong()
+ split = in.readObject().asInstanceOf[Partition]
+ }
+
+ override def runTask(context: TaskContext): MapStatus = {
metrics = Some(context.taskMetrics)
var writer: ShuffleWriter[Any, Any] = null
try {
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
- writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
+ writer.write(rdd.iterator(split, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
return writer.stop(success = true).get
} catch {
case e: Exception =>
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index 871f831531..13b415cccb 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -52,8 +52,9 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
}
}
+
test("cleanup RDD") {
- val rdd = newRDD().persist()
+ val rdd = newRDD.persist()
val collected = rdd.collect().toList
val tester = new CleanerTester(sc, rddIds = Seq(rdd.id))
@@ -66,7 +67,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
}
test("cleanup shuffle") {
- val (rdd, shuffleDeps) = newRDDWithShuffleDependencies()
+ val (rdd, shuffleDeps) = newRDDWithShuffleDependencies
val collected = rdd.collect().toList
val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId))
@@ -79,7 +80,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
}
test("cleanup broadcast") {
- val broadcast = newBroadcast()
+ val broadcast = newBroadcast
val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
// Explicit cleanup
@@ -88,7 +89,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
}
test("automatically cleanup RDD") {
- var rdd = newRDD().persist()
+ var rdd = newRDD.persist()
rdd.count()
// Test that GC does not cause RDD cleanup due to a strong reference
@@ -106,7 +107,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
}
test("automatically cleanup shuffle") {
- var rdd = newShuffleRDD()
+ var rdd = newShuffleRDD
rdd.count()
// Test that GC does not cause shuffle cleanup due to a strong reference
@@ -124,7 +125,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
}
test("automatically cleanup broadcast") {
- var broadcast = newBroadcast()
+ var broadcast = newBroadcast
// Test that GC does not cause broadcast cleanup due to a strong reference
val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
@@ -140,23 +141,11 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
postGCTester.assertCleanup()
}
- test("automatically cleanup broadcast data for task dispatching") {
- var rdd = newRDDWithShuffleDependencies()._1
- rdd.count() // This triggers an action that broadcasts the RDDs.
-
- // Test that GC causes broadcast task data cleanup after dereferencing the RDD.
- val postGCTester = new CleanerTester(sc,
- broadcastIds = Seq(rdd.broadcasted.id, rdd.firstParent.broadcasted.id))
- rdd = null
- runGC()
- postGCTester.assertCleanup()
- }
-
test("automatically cleanup RDD + shuffle + broadcast") {
val numRdds = 100
val numBroadcasts = 4 // Broadcasts are more costly
- val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
- val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer
+ val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
+ val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
val rddIds = sc.persistentRdds.keys.toSeq
val shuffleIds = 0 until sc.newShuffleId
val broadcastIds = 0L until numBroadcasts
@@ -186,8 +175,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
val numRdds = 10
val numBroadcasts = 4 // Broadcasts are more costly
- val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
- val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer
+ val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
+ val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
val rddIds = sc.persistentRdds.keys.toSeq
val shuffleIds = 0 until sc.newShuffleId
val broadcastIds = 0L until numBroadcasts
@@ -208,18 +197,17 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
//------ Helper functions ------
- private def newRDD() = sc.makeRDD(1 to 10)
- private def newPairRDD() = newRDD().map(_ -> 1)
- private def newShuffleRDD() = newPairRDD().reduceByKey(_ + _)
- private def newBroadcast() = sc.broadcast(1 to 100)
-
- private def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
+ def newRDD = sc.makeRDD(1 to 10)
+ def newPairRDD = newRDD.map(_ -> 1)
+ def newShuffleRDD = newPairRDD.reduceByKey(_ + _)
+ def newBroadcast = sc.broadcast(1 to 100)
+ def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
getAllDependencies(dep.rdd)
}
}
- val rdd = newShuffleRDD()
+ val rdd = newShuffleRDD
// Get all the shuffle dependencies
val shuffleDeps = getAllDependencies(rdd)
@@ -228,34 +216,34 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
(rdd, shuffleDeps)
}
- private def randomRdd() = {
+ def randomRdd = {
val rdd: RDD[_] = Random.nextInt(3) match {
- case 0 => newRDD()
- case 1 => newShuffleRDD()
- case 2 => newPairRDD.join(newPairRDD())
+ case 0 => newRDD
+ case 1 => newShuffleRDD
+ case 2 => newPairRDD.join(newPairRDD)
}
if (Random.nextBoolean()) rdd.persist()
rdd.count()
rdd
}
- private def randomBroadcast() = {
+ def randomBroadcast = {
sc.broadcast(Random.nextInt(Int.MaxValue))
}
/** Run GC and make sure it actually has run */
- private def runGC() {
+ def runGC() {
val weakRef = new WeakReference(new Object())
val startTime = System.currentTimeMillis
System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
// Wait until a weak reference object has been GCed
- while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
+ while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
System.gc()
Thread.sleep(200)
}
}
- private def cleaner = sc.cleaner.get
+ def cleaner = sc.cleaner.get
}