aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/ContextCleaner.scala192
-rw-r--r--core/src/main/scala/org/apache/spark/Dependency.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala148
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala107
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala66
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala128
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala45
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala162
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala46
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManager.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala38
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockId.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala67
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala84
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala107
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala109
-rw-r--r--core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala170
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/BroadcastSuite.scala311
-rw-r--r--core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala415
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala25
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala243
-rw-r--r--core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala10
-rw-r--r--core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala264
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala4
40 files changed, 2571 insertions, 469 deletions
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
new file mode 100644
index 0000000000..54e08d7866
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.lang.ref.{ReferenceQueue, WeakReference}
+
+import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+
+/**
+ * Classes that represent cleaning tasks.
+ */
+private sealed trait CleanupTask
+private case class CleanRDD(rddId: Int) extends CleanupTask
+private case class CleanShuffle(shuffleId: Int) extends CleanupTask
+private case class CleanBroadcast(broadcastId: Long) extends CleanupTask
+
+/**
+ * A WeakReference associated with a CleanupTask.
+ *
+ * When the referent object becomes only weakly reachable, the corresponding
+ * CleanupTaskWeakReference is automatically added to the given reference queue.
+ */
+private class CleanupTaskWeakReference(
+ val task: CleanupTask,
+ referent: AnyRef,
+ referenceQueue: ReferenceQueue[AnyRef])
+ extends WeakReference(referent, referenceQueue)
+
+/**
+ * An asynchronous cleaner for RDD, shuffle, and broadcast state.
+ *
+ * This maintains a weak reference for each RDD, ShuffleDependency, and Broadcast of interest,
+ * to be processed when the associated object goes out of scope of the application. Actual
+ * cleanup is performed in a separate daemon thread.
+ */
+private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
+
+ private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference]
+ with SynchronizedBuffer[CleanupTaskWeakReference]
+
+ private val referenceQueue = new ReferenceQueue[AnyRef]
+
+ private val listeners = new ArrayBuffer[CleanerListener]
+ with SynchronizedBuffer[CleanerListener]
+
+ private val cleaningThread = new Thread() { override def run() { keepCleaning() }}
+
+ /**
+ * Whether the cleaning thread will block on cleanup tasks.
+ * This is set to true only for tests.
+ */
+ private val blockOnCleanupTasks = sc.conf.getBoolean(
+ "spark.cleaner.referenceTracking.blocking", false)
+
+ @volatile private var stopped = false
+
+ /** Attach a listener object to get information of when objects are cleaned. */
+ def attachListener(listener: CleanerListener) {
+ listeners += listener
+ }
+
+ /** Start the cleaner. */
+ def start() {
+ cleaningThread.setDaemon(true)
+ cleaningThread.setName("Spark Context Cleaner")
+ cleaningThread.start()
+ }
+
+ /** Stop the cleaner. */
+ def stop() {
+ stopped = true
+ }
+
+ /** Register a RDD for cleanup when it is garbage collected. */
+ def registerRDDForCleanup(rdd: RDD[_]) {
+ registerForCleanup(rdd, CleanRDD(rdd.id))
+ }
+
+ /** Register a ShuffleDependency for cleanup when it is garbage collected. */
+ def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) {
+ registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
+ }
+
+ /** Register a Broadcast for cleanup when it is garbage collected. */
+ def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) {
+ registerForCleanup(broadcast, CleanBroadcast(broadcast.id))
+ }
+
+ /** Register an object for cleanup. */
+ private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) {
+ referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)
+ }
+
+ /** Keep cleaning RDD, shuffle, and broadcast state. */
+ private def keepCleaning() {
+ while (!stopped) {
+ try {
+ val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT))
+ .map(_.asInstanceOf[CleanupTaskWeakReference])
+ reference.map(_.task).foreach { task =>
+ logDebug("Got cleaning task " + task)
+ referenceBuffer -= reference.get
+ task match {
+ case CleanRDD(rddId) =>
+ doCleanupRDD(rddId, blocking = blockOnCleanupTasks)
+ case CleanShuffle(shuffleId) =>
+ doCleanupShuffle(shuffleId, blocking = blockOnCleanupTasks)
+ case CleanBroadcast(broadcastId) =>
+ doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks)
+ }
+ }
+ } catch {
+ case t: Throwable => logError("Error in cleaning thread", t)
+ }
+ }
+ }
+
+ /** Perform RDD cleanup. */
+ def doCleanupRDD(rddId: Int, blocking: Boolean) {
+ try {
+ logDebug("Cleaning RDD " + rddId)
+ sc.unpersistRDD(rddId, blocking)
+ listeners.foreach(_.rddCleaned(rddId))
+ logInfo("Cleaned RDD " + rddId)
+ } catch {
+ case t: Throwable => logError("Error cleaning RDD " + rddId, t)
+ }
+ }
+
+ /** Perform shuffle cleanup, asynchronously. */
+ def doCleanupShuffle(shuffleId: Int, blocking: Boolean) {
+ try {
+ logDebug("Cleaning shuffle " + shuffleId)
+ mapOutputTrackerMaster.unregisterShuffle(shuffleId)
+ blockManagerMaster.removeShuffle(shuffleId, blocking)
+ listeners.foreach(_.shuffleCleaned(shuffleId))
+ logInfo("Cleaned shuffle " + shuffleId)
+ } catch {
+ case t: Throwable => logError("Error cleaning shuffle " + shuffleId, t)
+ }
+ }
+
+ /** Perform broadcast cleanup. */
+ def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) {
+ try {
+ logDebug("Cleaning broadcast " + broadcastId)
+ broadcastManager.unbroadcast(broadcastId, true, blocking)
+ listeners.foreach(_.broadcastCleaned(broadcastId))
+ logInfo("Cleaned broadcast " + broadcastId)
+ } catch {
+ case t: Throwable => logError("Error cleaning broadcast " + broadcastId, t)
+ }
+ }
+
+ private def blockManagerMaster = sc.env.blockManager.master
+ private def broadcastManager = sc.env.broadcastManager
+ private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+
+ // Used for testing. These methods explicitly blocks until cleanup is completed
+ // to ensure that more reliable testing.
+}
+
+private object ContextCleaner {
+ private val REF_QUEUE_POLL_TIMEOUT = 100
+}
+
+/**
+ * Listener class used for testing when any item has been cleaned by the Cleaner class.
+ */
+private[spark] trait CleanerListener {
+ def rddCleaned(rddId: Int)
+ def shuffleCleaned(shuffleId: Int)
+ def broadcastCleaned(broadcastId: Long)
+}
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index 3132dcf745..1cd629c15b 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -55,6 +55,8 @@ class ShuffleDependency[K, V](
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
val shuffleId: Int = rdd.context.newShuffleId()
+
+ rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 80cbf951cb..ee82d9fa78 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -20,21 +20,21 @@ package org.apache.spark
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-import scala.collection.mutable.HashSet
+import scala.collection.mutable.{HashSet, HashMap, Map}
import scala.concurrent.Await
import akka.actor._
import akka.pattern.ask
-
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.util.{AkkaUtils, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
+import org.apache.spark.util._
private[spark] sealed trait MapOutputTrackerMessage
private[spark] case class GetMapOutputStatuses(shuffleId: Int)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
+/** Actor class for MapOutputTrackerMaster */
private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf)
extends Actor with Logging {
val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
@@ -65,26 +65,41 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
}
}
-private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
-
+/**
+ * Class that keeps track of the location of the map output of
+ * a stage. This is abstract because different versions of MapOutputTracker
+ * (driver and worker) use different HashMap to store its metadata.
+ */
+private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging {
private val timeout = AkkaUtils.askTimeout(conf)
- // Set to the MapOutputTrackerActor living on the driver
+ /** Set to the MapOutputTrackerActor living on the driver. */
var trackerActor: ActorRef = _
- protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ /**
+ * This HashMap has different behavior for the master and the workers.
+ *
+ * On the master, it serves as the source of map outputs recorded from ShuffleMapTasks.
+ * On the workers, it simply serves as a cache, in which a miss triggers a fetch from the
+ * master's corresponding HashMap.
+ */
+ protected val mapStatuses: Map[Int, Array[MapStatus]]
- // Incremented every time a fetch fails so that client nodes know to clear
- // their cache of map output locations if this happens.
+ /**
+ * Incremented every time a fetch fails so that client nodes know to clear
+ * their cache of map output locations if this happens.
+ */
protected var epoch: Long = 0
- protected val epochLock = new java.lang.Object
+ protected val epochLock = new AnyRef
- private val metadataCleaner =
- new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf)
+ /** Remembers which map output locations are currently being fetched on a worker. */
+ private val fetching = new HashSet[Int]
- // Send a message to the trackerActor and get its result within a default timeout, or
- // throw a SparkException if this fails.
- private def askTracker(message: Any): Any = {
+ /**
+ * Send a message to the trackerActor and get its result within a default timeout, or
+ * throw a SparkException if this fails.
+ */
+ protected def askTracker(message: Any): Any = {
try {
val future = trackerActor.ask(message)(timeout)
Await.result(future, timeout)
@@ -94,17 +109,17 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
}
}
- // Send a one-way message to the trackerActor, to which we expect it to reply with true.
- private def communicate(message: Any) {
+ /** Send a one-way message to the trackerActor, to which we expect it to reply with true. */
+ protected def sendTracker(message: Any) {
if (askTracker(message) != true) {
throw new SparkException("Error reply received from MapOutputTracker")
}
}
- // Remembers which map output locations are currently being fetched on a worker
- private val fetching = new HashSet[Int]
-
- // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
+ /**
+ * Called from executors to get the server URIs and output sizes of the map outputs of
+ * a given shuffle.
+ */
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
@@ -152,8 +167,7 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
fetchedStatuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
- }
- else {
+ } else {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing all output locations for shuffle " + shuffleId))
}
@@ -164,27 +178,18 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
}
}
- protected def cleanup(cleanupTime: Long) {
- mapStatuses.clearOldValues(cleanupTime)
- }
-
- def stop() {
- communicate(StopMapOutputTracker)
- mapStatuses.clear()
- metadataCleaner.cancel()
- trackerActor = null
- }
-
- // Called to get current epoch number
+ /** Called to get current epoch number. */
def getEpoch: Long = {
epochLock.synchronized {
return epoch
}
}
- // Called on workers to update the epoch number, potentially clearing old outputs
- // because of a fetch failure. (Each worker task calls this with the latest epoch
- // number on the master at the time it was created.)
+ /**
+ * Called from executors to update the epoch number, potentially clearing old outputs
+ * because of a fetch failure. Each worker task calls this with the latest epoch
+ * number on the master at the time it was created.
+ */
def updateEpoch(newEpoch: Long) {
epochLock.synchronized {
if (newEpoch > epoch) {
@@ -194,17 +199,40 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging {
}
}
}
+
+ /** Unregister shuffle data. */
+ def unregisterShuffle(shuffleId: Int) {
+ mapStatuses.remove(shuffleId)
+ }
+
+ /** Stop the tracker. */
+ def stop() { }
}
+/**
+ * MapOutputTracker for the driver. This uses TimeStampedHashMap to keep track of map
+ * output information, which allows old output information based on a TTL.
+ */
private[spark] class MapOutputTrackerMaster(conf: SparkConf)
extends MapOutputTracker(conf) {
- // Cache a serialized version of the output statuses for each shuffle to send them out faster
+ /** Cache a serialized version of the output statuses for each shuffle to send them out faster */
private var cacheEpoch = epoch
- private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
+
+ /**
+ * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the master,
+ * so that statuses are dropped only by explicit de-registering or by TTL-based cleaning (if set).
+ * Other than these two scenarios, nothing should be dropped from this HashMap.
+ */
+ protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]()
+ private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]()
+
+ // For cleaning up TimeStampedHashMaps
+ private val metadataCleaner =
+ new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf)
def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
+ if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
}
@@ -216,6 +244,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}
}
+ /** Register multiple map output information for the given shuffle */
def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
if (changeEpoch) {
@@ -223,6 +252,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}
}
+ /** Unregister map output information of the given shuffle, mapper and block manager */
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
val arrayOpt = mapStatuses.get(shuffleId)
if (arrayOpt.isDefined && arrayOpt.get != null) {
@@ -238,6 +268,17 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}
}
+ /** Unregister shuffle data */
+ override def unregisterShuffle(shuffleId: Int) {
+ mapStatuses.remove(shuffleId)
+ cachedSerializedStatuses.remove(shuffleId)
+ }
+
+ /** Check if the given shuffle is being tracked */
+ def containsShuffle(shuffleId: Int): Boolean = {
+ cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId)
+ }
+
def incrementEpoch() {
epochLock.synchronized {
epoch += 1
@@ -274,23 +315,26 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
bytes
}
- protected override def cleanup(cleanupTime: Long) {
- super.cleanup(cleanupTime)
- cachedSerializedStatuses.clearOldValues(cleanupTime)
- }
-
override def stop() {
- super.stop()
+ sendTracker(StopMapOutputTracker)
+ mapStatuses.clear()
+ trackerActor = null
+ metadataCleaner.cancel()
cachedSerializedStatuses.clear()
}
- override def updateEpoch(newEpoch: Long) {
- // This might be called on the MapOutputTrackerMaster if we're running in local mode.
+ private def cleanup(cleanupTime: Long) {
+ mapStatuses.clearOldValues(cleanupTime)
+ cachedSerializedStatuses.clearOldValues(cleanupTime)
}
+}
- def has(shuffleId: Int): Boolean = {
- cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId)
- }
+/**
+ * MapOutputTracker for the workers, which fetches map output information from the driver's
+ * MapOutputTrackerMaster.
+ */
+private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {
+ protected val mapStatuses = new HashMap[Int, Array[MapStatus]]
}
private[spark] object MapOutputTracker {
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index e5ebd350ee..d7124616d3 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -45,7 +45,7 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me
import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
+import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -157,7 +157,7 @@ class SparkContext(
private[spark] val addedJars = HashMap[String, Long]()
// Keeps track of all persisted RDDs
- private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
+ private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]]
private[spark] val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf)
@@ -233,6 +233,15 @@ class SparkContext(
@volatile private[spark] var dagScheduler = new DAGScheduler(this)
dagScheduler.start()
+ private[spark] val cleaner: Option[ContextCleaner] = {
+ if (conf.getBoolean("spark.cleaner.referenceTracking", true)) {
+ Some(new ContextCleaner(this))
+ } else {
+ None
+ }
+ }
+ cleaner.foreach(_.start())
+
postEnvironmentUpdate()
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
@@ -679,7 +688,11 @@ class SparkContext(
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*/
- def broadcast[T](value: T): Broadcast[T] = env.broadcastManager.newBroadcast[T](value, isLocal)
+ def broadcast[T](value: T): Broadcast[T] = {
+ val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
+ cleaner.foreach(_.registerBroadcastForCleanup(bc))
+ bc
+ }
/**
* Add a file to be downloaded with this Spark job on every node.
@@ -789,8 +802,7 @@ class SparkContext(
/**
* Unpersist an RDD from memory and/or disk storage
*/
- private[spark] def unpersistRDD(rdd: RDD[_], blocking: Boolean = true) {
- val rddId = rdd.id
+ private[spark] def unpersistRDD(rddId: Int, blocking: Boolean = true) {
env.blockManager.master.removeRdd(rddId, blocking)
persistentRdds.remove(rddId)
listenerBus.post(SparkListenerUnpersistRDD(rddId))
@@ -869,6 +881,7 @@ class SparkContext(
dagScheduler = null
if (dagSchedulerCopy != null) {
metadataCleaner.cancel()
+ cleaner.foreach(_.stop())
dagSchedulerCopy.stop()
listenerBus.stop()
taskScheduler = null
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 5ceac28fe7..9ea123f174 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -180,12 +180,24 @@ object SparkEnv extends Logging {
}
}
+ val mapOutputTracker = if (isDriver) {
+ new MapOutputTrackerMaster(conf)
+ } else {
+ new MapOutputTrackerWorker(conf)
+ }
+
+ // Have to assign trackerActor after initialization as MapOutputTrackerActor
+ // requires the MapOutputTracker itself
+ mapOutputTracker.trackerActor = registerOrLookup(
+ "MapOutputTracker",
+ new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
+
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf)
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
- serializer, conf, securityManager)
+ serializer, conf, securityManager, mapOutputTracker)
val connectionManager = blockManager.connectionManager
@@ -193,17 +205,6 @@ object SparkEnv extends Logging {
val cacheManager = new CacheManager(blockManager)
- // Have to assign trackerActor after initialization as MapOutputTrackerActor
- // requires the MapOutputTracker itself
- val mapOutputTracker = if (isDriver) {
- new MapOutputTrackerMaster(conf)
- } else {
- new MapOutputTracker(conf)
- }
- mapOutputTracker.trackerActor = registerOrLookup(
- "MapOutputTracker",
- new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
-
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
index e3c3a12d16..738a3b1bed 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -18,9 +18,8 @@
package org.apache.spark.broadcast
import java.io.Serializable
-import java.util.concurrent.atomic.AtomicLong
-import org.apache.spark._
+import org.apache.spark.SparkException
/**
* A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable
@@ -29,7 +28,8 @@ import org.apache.spark._
* attempts to distribute broadcast variables using efficient broadcast algorithms to reduce
* communication cost.
*
- * Broadcast variables are created from a variable `v` by calling [[SparkContext#broadcast]].
+ * Broadcast variables are created from a variable `v` by calling
+ * [[org.apache.spark.SparkContext#broadcast]].
* The broadcast variable is a wrapper around `v`, and its value can be accessed by calling the
* `value` method. The interpreter session below shows this:
*
@@ -51,49 +51,80 @@ import org.apache.spark._
* @tparam T Type of the data contained in the broadcast variable.
*/
abstract class Broadcast[T](val id: Long) extends Serializable {
- def value: T
- // We cannot have an abstract readObject here due to some weird issues with
- // readObject having to be 'private' in sub-classes.
+ /**
+ * Flag signifying whether the broadcast variable is valid
+ * (that is, not already destroyed) or not.
+ */
+ @volatile private var _isValid = true
- override def toString = "Broadcast(" + id + ")"
-}
-
-private[spark]
-class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager)
- extends Logging with Serializable {
-
- private var initialized = false
- private var broadcastFactory: BroadcastFactory = null
-
- initialize()
-
- // Called by SparkContext or Executor before using Broadcast
- private def initialize() {
- synchronized {
- if (!initialized) {
- val broadcastFactoryClass = conf.get(
- "spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
-
- broadcastFactory =
- Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
+ /** Get the broadcasted value. */
+ def value: T = {
+ assertValid()
+ getValue()
+ }
- // Initialize appropriate BroadcastFactory and BroadcastObject
- broadcastFactory.initialize(isDriver, conf, securityManager)
+ /**
+ * Asynchronously delete cached copies of this broadcast on the executors.
+ * If the broadcast is used after this is called, it will need to be re-sent to each executor.
+ */
+ def unpersist() {
+ unpersist(blocking = false)
+ }
- initialized = true
- }
- }
+ /**
+ * Delete cached copies of this broadcast on the executors. If the broadcast is used after
+ * this is called, it will need to be re-sent to each executor.
+ * @param blocking Whether to block until unpersisting has completed
+ */
+ def unpersist(blocking: Boolean) {
+ assertValid()
+ doUnpersist(blocking)
}
- def stop() {
- broadcastFactory.stop()
+ /**
+ * Destroy all data and metadata related to this broadcast variable. Use this with caution;
+ * once a broadcast variable has been destroyed, it cannot be used again.
+ */
+ private[spark] def destroy(blocking: Boolean) {
+ assertValid()
+ _isValid = false
+ doDestroy(blocking)
}
- private val nextBroadcastId = new AtomicLong(0)
+ /**
+ * Whether this Broadcast is actually usable. This should be false once persisted state is
+ * removed from the driver.
+ */
+ private[spark] def isValid: Boolean = {
+ _isValid
+ }
- def newBroadcast[T](value_ : T, isLocal: Boolean) =
- broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
+ /**
+ * Actually get the broadcasted value. Concrete implementations of Broadcast class must
+ * define their own way to get the value.
+ */
+ private[spark] def getValue(): T
+
+ /**
+ * Actually unpersist the broadcasted value on the executors. Concrete implementations of
+ * Broadcast class must define their own logic to unpersist their own data.
+ */
+ private[spark] def doUnpersist(blocking: Boolean)
+
+ /**
+ * Actually destroy all data and metadata related to this broadcast variable.
+ * Implementation of Broadcast class must define their own logic to destroy their own
+ * state.
+ */
+ private[spark] def doDestroy(blocking: Boolean)
+
+ /** Check if this broadcast is valid. If not valid, exception is thrown. */
+ private[spark] def assertValid() {
+ if (!_isValid) {
+ throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString))
+ }
+ }
- def isDriver = _isDriver
+ override def toString = "Broadcast(" + id + ")"
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
index 6beecaeced..c7f7c59cfb 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
@@ -27,7 +27,8 @@ import org.apache.spark.SparkConf
* entire Spark job.
*/
trait BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
+ def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit
def stop(): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
new file mode 100644
index 0000000000..cf62aca4d4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark._
+
+private[spark] class BroadcastManager(
+ val isDriver: Boolean,
+ conf: SparkConf,
+ securityManager: SecurityManager)
+ extends Logging {
+
+ private var initialized = false
+ private var broadcastFactory: BroadcastFactory = null
+
+ initialize()
+
+ // Called by SparkContext or Executor before using Broadcast
+ private def initialize() {
+ synchronized {
+ if (!initialized) {
+ val broadcastFactoryClass =
+ conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+
+ broadcastFactory =
+ Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
+
+ // Initialize appropriate BroadcastFactory and BroadcastObject
+ broadcastFactory.initialize(isDriver, conf, securityManager)
+
+ initialized = true
+ }
+ }
+ }
+
+ def stop() {
+ broadcastFactory.stop()
+ }
+
+ private val nextBroadcastId = new AtomicLong(0)
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean) = {
+ broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
+ }
+
+ def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
+ broadcastFactory.unbroadcast(id, removeFromDriver, blocking)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index e8eb04bb10..f6a8a8af91 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -17,34 +17,65 @@
package org.apache.spark.broadcast
-import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream}
-import java.net.{URL, URLConnection, URI}
+import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream}
+import java.net.{URI, URL, URLConnection}
import java.util.concurrent.TimeUnit
-import it.unimi.dsi.fastutil.io.FastBufferedInputStream
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+import it.unimi.dsi.fastutil.io.{FastBufferedInputStream, FastBufferedOutputStream}
-import org.apache.spark.{SparkConf, HttpServer, Logging, SecurityManager, SparkEnv}
+import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
+/**
+ * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses HTTP server
+ * as a broadcast mechanism. The first time a HTTP broadcast variable (sent as part of a
+ * task) is deserialized in the executor, the broadcasted data is fetched from the driver
+ * (through a HTTP server running at the driver) and stored in the BlockManager of the
+ * executor to speed up future accesses.
+ */
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
- def value = value_
+ def getValue = value_
- def blockId = BroadcastBlockId(id)
+ val blockId = BroadcastBlockId(id)
+ /*
+ * Broadcasted data is also stored in the BlockManager of the driver. The BlockManagerMaster
+ * does not need to be told about this block as not only need to know about this data block.
+ */
HttpBroadcast.synchronized {
- SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ SparkEnv.get.blockManager.putSingle(
+ blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
}
if (!isLocal) {
HttpBroadcast.write(id, value_)
}
- // Called by JVM when deserializing an object
+ /**
+ * Remove all persisted state associated with this HTTP broadcast on the executors.
+ */
+ def doUnpersist(blocking: Boolean) {
+ HttpBroadcast.unpersist(id, removeFromDriver = false, blocking)
+ }
+
+ /**
+ * Remove all persisted state associated with this HTTP broadcast on the executors and driver.
+ */
+ def doDestroy(blocking: Boolean) {
+ HttpBroadcast.unpersist(id, removeFromDriver = true, blocking)
+ }
+
+ /** Used by the JVM when serializing this object. */
+ private def writeObject(out: ObjectOutputStream) {
+ assertValid()
+ out.defaultWriteObject()
+ }
+
+ /** Used by the JVM when deserializing this object. */
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
HttpBroadcast.synchronized {
@@ -54,7 +85,13 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
logInfo("Started reading broadcast variable " + id)
val start = System.nanoTime
value_ = HttpBroadcast.read[T](id)
- SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ /*
+ * We cache broadcast data in the BlockManager so that subsequent tasks using it
+ * do not need to re-fetch. This data is only used locally and no other node
+ * needs to fetch this block, so we don't notify the master.
+ */
+ SparkEnv.get.blockManager.putSingle(
+ blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
@@ -63,23 +100,8 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
}
}
-/**
- * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium.
- */
-class HttpBroadcastFactory extends BroadcastFactory {
- def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
- HttpBroadcast.initialize(isDriver, conf, securityMgr)
- }
-
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
- new HttpBroadcast[T](value_, isLocal, id)
-
- def stop() { HttpBroadcast.stop() }
-}
-
-private object HttpBroadcast extends Logging {
+private[spark] object HttpBroadcast extends Logging {
private var initialized = false
-
private var broadcastDir: File = null
private var compress: Boolean = false
private var bufferSize: Int = 65536
@@ -89,11 +111,9 @@ private object HttpBroadcast extends Logging {
// TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist
private val files = new TimeStampedHashSet[String]
- private var cleaner: MetadataCleaner = null
-
private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt
-
private var compressionCodec: CompressionCodec = null
+ private var cleaner: MetadataCleaner = null
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
synchronized {
@@ -136,8 +156,10 @@ private object HttpBroadcast extends Logging {
logInfo("Broadcast server started at " + serverUri)
}
+ def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name)
+
def write(id: Long, value: Any) {
- val file = new File(broadcastDir, BroadcastBlockId(id).name)
+ val file = getFile(id)
val out: OutputStream = {
if (compress) {
compressionCodec.compressedOutputStream(new FileOutputStream(file))
@@ -160,7 +182,7 @@ private object HttpBroadcast extends Logging {
if (securityManager.isAuthenticationEnabled()) {
logDebug("broadcast security enabled")
val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager)
- uc = newuri.toURL().openConnection()
+ uc = newuri.toURL.openConnection()
uc.setAllowUserInteraction(false)
} else {
logDebug("broadcast not using security")
@@ -169,7 +191,7 @@ private object HttpBroadcast extends Logging {
val in = {
uc.setReadTimeout(httpReadTimeout)
- val inputStream = uc.getInputStream();
+ val inputStream = uc.getInputStream
if (compress) {
compressionCodec.compressedInputStream(inputStream)
} else {
@@ -183,20 +205,48 @@ private object HttpBroadcast extends Logging {
obj
}
- def cleanup(cleanupTime: Long) {
+ /**
+ * Remove all persisted blocks associated with this HTTP broadcast on the executors.
+ * If removeFromDriver is true, also remove these persisted blocks on the driver
+ * and delete the associated broadcast file.
+ */
+ def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized {
+ SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
+ if (removeFromDriver) {
+ val file = getFile(id)
+ files.remove(file.toString)
+ deleteBroadcastFile(file)
+ }
+ }
+
+ /**
+ * Periodically clean up old broadcasts by removing the associated map entries and
+ * deleting the associated files.
+ */
+ private def cleanup(cleanupTime: Long) {
val iterator = files.internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
val (file, time) = (entry.getKey, entry.getValue)
if (time < cleanupTime) {
- try {
- iterator.remove()
- new File(file.toString).delete()
- logInfo("Deleted broadcast file '" + file + "'")
- } catch {
- case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e)
+ iterator.remove()
+ deleteBroadcastFile(new File(file.toString))
+ }
+ }
+ }
+
+ private def deleteBroadcastFile(file: File) {
+ try {
+ if (file.exists) {
+ if (file.delete()) {
+ logInfo("Deleted broadcast file: %s".format(file))
+ } else {
+ logWarning("Could not delete broadcast file: %s".format(file))
}
}
+ } catch {
+ case e: Exception =>
+ logError("Exception while deleting broadcast file: %s".format(file), e)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala
new file mode 100644
index 0000000000..e3f6cdc615
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import org.apache.spark.{SecurityManager, SparkConf}
+
+/**
+ * A [[org.apache.spark.broadcast.BroadcastFactory]] implementation that uses a
+ * HTTP server as the broadcast mechanism. Refer to
+ * [[org.apache.spark.broadcast.HttpBroadcast]] for more details about this mechanism.
+ */
+class HttpBroadcastFactory extends BroadcastFactory {
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
+ HttpBroadcast.initialize(isDriver, conf, securityMgr)
+ }
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ new HttpBroadcast[T](value_, isLocal, id)
+
+ def stop() { HttpBroadcast.stop() }
+
+ /**
+ * Remove all persisted state associated with the HTTP broadcast with the given ID.
+ * @param removeFromDriver Whether to remove state from the driver
+ * @param blocking Whether to block until unbroadcasted
+ */
+ def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
+ HttpBroadcast.unpersist(id, removeFromDriver, blocking)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 2595c15104..2b32546c68 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -17,24 +17,43 @@
package org.apache.spark.broadcast
-import java.io._
+import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream}
import scala.math
import scala.util.Random
-import org.apache.spark._
-import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
+import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
+import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.Utils
+/**
+ * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
+ * protocol to do a distributed transfer of the broadcasted data to the executors.
+ * The mechanism is as follows. The driver divides the serializes the broadcasted data,
+ * divides it into smaller chunks, and stores them in the BlockManager of the driver.
+ * These chunks are reported to the BlockManagerMaster so that all the executors can
+ * learn the location of those chunks. The first time the broadcast variable (sent as
+ * part of task) is deserialized at a executor, all the chunks are fetched using
+ * the BlockManager. When all the chunks are fetched (initially from the driver's
+ * BlockManager), they are combined and deserialized to recreate the broadcasted data.
+ * However, the chunks are also stored in the BlockManager and reported to the
+ * BlockManagerMaster. As more executors fetch the chunks, BlockManagerMaster learns
+ * multiple locations for each chunk. Hence, subsequent fetches of each chunk will be
+ * made to other executors who already have those chunks, resulting in a distributed
+ * fetching. This prevents the driver from being the bottleneck in sending out multiple
+ * copies of the broadcast data (one per executor) as done by the
+ * [[org.apache.spark.broadcast.HttpBroadcast]].
+ */
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
-extends Broadcast[T](id) with Logging with Serializable {
+ extends Broadcast[T](id) with Logging with Serializable {
- def value = value_
+ def getValue = value_
- def broadcastId = BroadcastBlockId(id)
+ val broadcastId = BroadcastBlockId(id)
TorrentBroadcast.synchronized {
- SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ SparkEnv.get.blockManager.putSingle(
+ broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
}
@transient var arrayOfBlocks: Array[TorrentBlock] = null
@@ -46,32 +65,52 @@ extends Broadcast[T](id) with Logging with Serializable {
sendBroadcast()
}
- def sendBroadcast() {
- var tInfo = TorrentBroadcast.blockifyObject(value_)
+ /**
+ * Remove all persisted state associated with this Torrent broadcast on the executors.
+ */
+ def doUnpersist(blocking: Boolean) {
+ TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking)
+ }
+
+ /**
+ * Remove all persisted state associated with this Torrent broadcast on the executors
+ * and driver.
+ */
+ def doDestroy(blocking: Boolean) {
+ TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
+ }
+ def sendBroadcast() {
+ val tInfo = TorrentBroadcast.blockifyObject(value_)
totalBlocks = tInfo.totalBlocks
totalBytes = tInfo.totalBytes
hasBlocks = tInfo.totalBlocks
// Store meta-info
- val metaId = BroadcastHelperBlockId(broadcastId, "meta")
+ val metaId = BroadcastBlockId(id, "meta")
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
- metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true)
+ metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
}
// Store individual pieces
for (i <- 0 until totalBlocks) {
- val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
+ val pieceId = BroadcastBlockId(id, "piece" + i)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
- pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
+ pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
}
}
}
- // Called by JVM when deserializing an object
+ /** Used by the JVM when serializing this object. */
+ private def writeObject(out: ObjectOutputStream) {
+ assertValid()
+ out.defaultWriteObject()
+ }
+
+ /** Used by the JVM when deserializing this object. */
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
TorrentBroadcast.synchronized {
@@ -86,18 +125,22 @@ extends Broadcast[T](id) with Logging with Serializable {
// Initialize @transient variables that will receive garbage values from the master.
resetWorkerVariables()
- if (receiveBroadcast(id)) {
+ if (receiveBroadcast()) {
value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- // Store the merged copy in cache so that the next worker doesn't need to rebuild it.
- // This creates a tradeoff between memory usage and latency.
- // Storing copy doubles the memory footprint; not storing doubles deserialization cost.
+ /* Store the merged copy in cache so that the next worker doesn't need to rebuild it.
+ * This creates a trade-off between memory usage and latency. Storing copy doubles
+ * the memory footprint; not storing doubles deserialization cost. Also,
+ * this does not need to be reported to BlockManagerMaster since other executors
+ * does not need to access this block (they only need to fetch the chunks,
+ * which are reported).
+ */
SparkEnv.get.blockManager.putSingle(
- broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
// Remove arrayOfBlocks from memory once value_ is on local cache
resetWorkerVariables()
- } else {
+ } else {
logError("Reading broadcast variable " + id + " failed")
}
@@ -114,9 +157,10 @@ extends Broadcast[T](id) with Logging with Serializable {
hasBlocks = 0
}
- def receiveBroadcast(variableID: Long): Boolean = {
- // Receive meta-info
- val metaId = BroadcastHelperBlockId(broadcastId, "meta")
+ def receiveBroadcast(): Boolean = {
+ // Receive meta-info about the size of broadcast data,
+ // the number of chunks it is divided into, etc.
+ val metaId = BroadcastBlockId(id, "meta")
var attemptId = 10
while (attemptId > 0 && totalBlocks == -1) {
TorrentBroadcast.synchronized {
@@ -138,17 +182,21 @@ extends Broadcast[T](id) with Logging with Serializable {
return false
}
- // Receive actual blocks
+ /*
+ * Fetch actual chunks of data. Note that all these chunks are stored in
+ * the BlockManager and reported to the master, so that other executors
+ * can find out and pull the chunks from this executor.
+ */
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
for (pid <- recvOrder) {
- val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
+ val pieceId = BroadcastBlockId(id, "piece" + pid)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(pieceId) match {
case Some(x) =>
arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
hasBlocks += 1
SparkEnv.get.blockManager.putSingle(
- pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)
+ pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
case None =>
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
@@ -156,16 +204,16 @@ extends Broadcast[T](id) with Logging with Serializable {
}
}
- (hasBlocks == totalBlocks)
+ hasBlocks == totalBlocks
}
}
-private object TorrentBroadcast
-extends Logging {
-
+private[spark] object TorrentBroadcast extends Logging {
+ private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
private var initialized = false
private var conf: SparkConf = null
+
def initialize(_isDriver: Boolean, conf: SparkConf) {
TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests
synchronized {
@@ -179,39 +227,37 @@ extends Logging {
initialized = false
}
- lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
-
def blockifyObject[T](obj: T): TorrentInfo = {
val byteArray = Utils.serialize[T](obj)
val bais = new ByteArrayInputStream(byteArray)
- var blockNum = (byteArray.length / BLOCK_SIZE)
+ var blockNum = byteArray.length / BLOCK_SIZE
if (byteArray.length % BLOCK_SIZE != 0) {
blockNum += 1
}
- var retVal = new Array[TorrentBlock](blockNum)
- var blockID = 0
+ val blocks = new Array[TorrentBlock](blockNum)
+ var blockId = 0
for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
- var tempByteArray = new Array[Byte](thisBlockSize)
- val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
+ val tempByteArray = new Array[Byte](thisBlockSize)
+ bais.read(tempByteArray, 0, thisBlockSize)
- retVal(blockID) = new TorrentBlock(blockID, tempByteArray)
- blockID += 1
+ blocks(blockId) = new TorrentBlock(blockId, tempByteArray)
+ blockId += 1
}
bais.close()
- val tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
- tInfo.hasBlocks = blockNum
-
- tInfo
+ val info = TorrentInfo(blocks, blockNum, byteArray.length)
+ info.hasBlocks = blockNum
+ info
}
- def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock],
- totalBytes: Int,
- totalBlocks: Int): T = {
+ def unBlockifyObject[T](
+ arrayOfBlocks: Array[TorrentBlock],
+ totalBytes: Int,
+ totalBlocks: Int): T = {
val retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
@@ -220,6 +266,13 @@ extends Logging {
Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)
}
+ /**
+ * Remove all persisted blocks associated with this torrent broadcast on the executors.
+ * If removeFromDriver is true, also remove these persisted blocks on the driver.
+ */
+ def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized {
+ SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
+ }
}
private[spark] case class TorrentBlock(
@@ -228,25 +281,10 @@ private[spark] case class TorrentBlock(
extends Serializable
private[spark] case class TorrentInfo(
- @transient arrayOfBlocks : Array[TorrentBlock],
+ @transient arrayOfBlocks: Array[TorrentBlock],
totalBlocks: Int,
totalBytes: Int)
extends Serializable {
@transient var hasBlocks = 0
}
-
-/**
- * A [[BroadcastFactory]] that creates a torrent-based implementation of broadcast.
- */
-class TorrentBroadcastFactory extends BroadcastFactory {
-
- def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
- TorrentBroadcast.initialize(isDriver, conf)
- }
-
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
- new TorrentBroadcast[T](value_, isLocal, id)
-
- def stop() { TorrentBroadcast.stop() }
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
new file mode 100644
index 0000000000..d216b58718
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import org.apache.spark.{SecurityManager, SparkConf}
+
+/**
+ * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
+ * protocol to do a distributed transfer of the broadcasted data to the executors. Refer to
+ * [[org.apache.spark.broadcast.TorrentBroadcast]] for more details.
+ */
+class TorrentBroadcastFactory extends BroadcastFactory {
+
+ def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
+ TorrentBroadcast.initialize(isDriver, conf)
+ }
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ new TorrentBroadcast[T](value_, isLocal, id)
+
+ def stop() { TorrentBroadcast.stop() }
+
+ /**
+ * Remove all persisted state associated with the torrent broadcast with the given ID.
+ * @param removeFromDriver Whether to remove state from the driver.
+ * @param blocking Whether to block until unbroadcasted
+ */
+ def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
+ TorrentBroadcast.unpersist(id, removeFromDriver, blocking)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index 6b0a972f0b..bdf586351a 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -17,7 +17,6 @@
package org.apache.spark.network
-import java.net._
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index c43823bd76..bf3c57ad41 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -138,6 +138,8 @@ abstract class RDD[T: ClassTag](
"Cannot change storage level of an RDD after it was already assigned a level")
}
sc.persistRDD(this)
+ // Register the RDD with the ContextCleaner for automatic GC-based cleanup
+ sc.cleaner.foreach(_.registerRDDForCleanup(this))
storageLevel = newLevel
this
}
@@ -156,7 +158,7 @@ abstract class RDD[T: ClassTag](
*/
def unpersist(blocking: Boolean = true): RDD[T] = {
logInfo("Removing RDD " + id + " from persistence list")
- sc.unpersistRDD(this, blocking)
+ sc.unpersistRDD(id, blocking)
storageLevel = StorageLevel.NONE
this
}
@@ -1141,5 +1143,4 @@ abstract class RDD[T: ClassTag](
def toJavaRDD() : JavaRDD[T] = {
new JavaRDD(this)(elementClassTag)
}
-
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 442a95bb2c..6368665f24 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -32,7 +32,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId}
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
+import org.apache.spark.util.Utils
/**
* The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of
@@ -80,13 +80,13 @@ class DAGScheduler(
private[scheduler] def numTotalJobs: Int = nextJobId.get()
private val nextStageId = new AtomicInteger(0)
- private[scheduler] val jobIdToStageIds = new TimeStampedHashMap[Int, HashSet[Int]]
- private[scheduler] val stageIdToJobIds = new TimeStampedHashMap[Int, HashSet[Int]]
- private[scheduler] val stageIdToStage = new TimeStampedHashMap[Int, Stage]
- private[scheduler] val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
+ private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]]
+ private[scheduler] val stageIdToJobIds = new HashMap[Int, HashSet[Int]]
+ private[scheduler] val stageIdToStage = new HashMap[Int, Stage]
+ private[scheduler] val shuffleToMapStage = new HashMap[Int, Stage]
private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob]
private[scheduler] val resultStageToJob = new HashMap[Stage, ActiveJob]
- private[scheduler] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
+ private[scheduler] val stageToInfos = new HashMap[Stage, StageInfo]
// Stages we need to run whose parents aren't done
private[scheduler] val waitingStages = new HashSet[Stage]
@@ -98,7 +98,7 @@ class DAGScheduler(
private[scheduler] val failedStages = new HashSet[Stage]
// Missing tasks from each stage
- private[scheduler] val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]]
+ private[scheduler] val pendingTasks = new HashMap[Stage, HashSet[Task[_]]]
private[scheduler] val activeJobs = new HashSet[ActiveJob]
@@ -113,9 +113,6 @@ class DAGScheduler(
// stray messages to detect.
private val failedEpoch = new HashMap[String, Long]
- private val metadataCleaner =
- new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup, env.conf)
-
taskScheduler.setDAGScheduler(this)
/**
@@ -258,7 +255,7 @@ class DAGScheduler(
: Stage =
{
val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite)
- if (mapOutputTracker.has(shuffleDep.shuffleId)) {
+ if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
for (i <- 0 until locs.size) {
@@ -390,6 +387,9 @@ class DAGScheduler(
stageIdToStage -= stageId
stageIdToJobIds -= stageId
+ ShuffleMapTask.removeStage(stageId)
+ ResultTask.removeStage(stageId)
+
logDebug("After removal of stage %d, remaining stages = %d"
.format(stageId, stageIdToStage.size))
}
@@ -1084,26 +1084,10 @@ class DAGScheduler(
Nil
}
- private def cleanup(cleanupTime: Long) {
- Map(
- "stageIdToStage" -> stageIdToStage,
- "shuffleToMapStage" -> shuffleToMapStage,
- "pendingTasks" -> pendingTasks,
- "stageToInfos" -> stageToInfos,
- "jobIdToStageIds" -> jobIdToStageIds,
- "stageIdToJobIds" -> stageIdToJobIds).
- foreach { case (s, t) =>
- val sizeBefore = t.size
- t.clearOldValues(cleanupTime)
- logInfo("%s %d --> %d".format(s, sizeBefore, t.size))
- }
- }
-
def stop() {
if (eventProcessActor != null) {
eventProcessActor ! StopDAGScheduler
}
- metadataCleaner.cancel()
taskScheduler.stop()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 3fc6cc9850..083fb895d8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -20,21 +20,17 @@ package org.apache.spark.scheduler
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
+import scala.collection.mutable.HashMap
+
import org.apache.spark._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.rdd.RDDCheckpointData
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
+import org.apache.spark.rdd.{RDD, RDDCheckpointData}
private[spark] object ResultTask {
// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
- val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
-
- // TODO: This object shouldn't have global variables
- val metadataCleaner = new MetadataCleaner(
- MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues, new SparkConf)
+ private val serializedInfoCache = new HashMap[Int, Array[Byte]]
def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] =
{
@@ -67,6 +63,10 @@ private[spark] object ResultTask {
(rdd, func)
}
+ def removeStage(stageId: Int) {
+ serializedInfoCache.remove(stageId)
+ }
+
def clearCache() {
synchronized {
serializedInfoCache.clear()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 2a9edf4a76..23f3b3e824 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -24,22 +24,16 @@ import scala.collection.mutable.HashMap
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.rdd.RDD
-import org.apache.spark.rdd.RDDCheckpointData
+import org.apache.spark.rdd.{RDD, RDDCheckpointData}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
private[spark] object ShuffleMapTask {
// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
- val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
-
- // TODO: This object shouldn't have global variables
- val metadataCleaner = new MetadataCleaner(
- MetadataCleanerType.SHUFFLE_MAP_TASK, serializedInfoCache.clearOldValues, new SparkConf)
+ private val serializedInfoCache = new HashMap[Int, Array[Byte]]
def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
synchronized {
@@ -80,6 +74,10 @@ private[spark] object ShuffleMapTask {
HashMap(set.toSeq: _*)
}
+ def removeStage(stageId: Int) {
+ serializedInfoCache.remove(stageId)
+ }
+
def clearCache() {
synchronized {
serializedInfoCache.clear()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index a92922166f..acd152dda8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -42,7 +42,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
*
* THREADING: SchedulerBackends and task-submitting clients can call this class from multiple
* threads, so it needs locks in public API methods to maintain its state. In addition, some
- * SchedulerBackends sycnchronize on themselves when they want to send events here, and then
+ * SchedulerBackends synchronize on themselves when they want to send events here, and then
* acquire a lock on us, so we need to make sure that we don't try to lock the backend while
* we are holding a lock on ourselves.
*/
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 301d784b35..cffea28fbf 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -34,7 +34,7 @@ private[spark] sealed abstract class BlockId {
def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
def isRDD = isInstanceOf[RDDBlockId]
def isShuffle = isInstanceOf[ShuffleBlockId]
- def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId]
+ def isBroadcast = isInstanceOf[BroadcastBlockId]
override def toString = name
override def hashCode = name.hashCode
@@ -48,18 +48,13 @@ private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockI
def name = "rdd_" + rddId + "_" + splitIndex
}
-private[spark]
-case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
+private[spark] case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int)
+ extends BlockId {
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
}
-private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
- def name = "broadcast_" + broadcastId
-}
-
-private[spark]
-case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
- def name = broadcastId.name + "_" + hType
+private[spark] case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId {
+ def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field)
}
private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
@@ -83,8 +78,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId {
private[spark] object BlockId {
val RDD = "rdd_([0-9]+)_([0-9]+)".r
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
- val BROADCAST = "broadcast_([0-9]+)".r
- val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
+ val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
val TASKRESULT = "taskresult_([0-9]+)".r
val STREAM = "input-([0-9]+)-([0-9]+)".r
val TEST = "test_(.*)".r
@@ -95,10 +89,8 @@ private[spark] object BlockId {
RDDBlockId(rddId.toInt, splitIndex.toInt)
case SHUFFLE(shuffleId, mapId, reduceId) =>
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
- case BROADCAST(broadcastId) =>
- BroadcastBlockId(broadcastId.toLong)
- case BROADCAST_HELPER(broadcastId, hType) =>
- BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType)
+ case BROADCAST(broadcastId, field) =>
+ BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_"))
case TASKRESULT(taskId) =>
TaskResultBlockId(taskId.toLong)
case STREAM(streamId, uniqueId) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 19138d9dde..b021564477 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -19,20 +19,22 @@ package org.apache.spark.storage
import java.io.{File, InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
+
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
import scala.util.Random
+
import akka.actor.{ActorSystem, Cancellable, Props}
import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
import sun.nio.ch.DirectBuffer
-import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException}
+
+import org.apache.spark.{Logging, MapOutputTracker, SecurityManager, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.serializer.Serializer
import org.apache.spark.util._
-
sealed trait Values
case class ByteBufferValues(buffer: ByteBuffer) extends Values
@@ -46,7 +48,8 @@ private[spark] class BlockManager(
val defaultSerializer: Serializer,
maxMemory: Long,
val conf: SparkConf,
- securityManager: SecurityManager)
+ securityManager: SecurityManager,
+ mapOutputTracker: MapOutputTracker)
extends Logging {
val shuffleBlockManager = new ShuffleBlockManager(this)
@@ -55,7 +58,7 @@ private[spark] class BlockManager(
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
- private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
+ private[storage] val memoryStore = new MemoryStore(this, maxMemory)
private[storage] val diskStore = new DiskStore(this, diskBlockManager)
var tachyonInitialized = false
private[storage] lazy val tachyonStore: TachyonStore = {
@@ -98,7 +101,7 @@ private[spark] class BlockManager(
val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf)
- val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
+ val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this, mapOutputTracker)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
// Pending re-registration action being executed asynchronously or null if none
@@ -137,9 +140,10 @@ private[spark] class BlockManager(
master: BlockManagerMaster,
serializer: Serializer,
conf: SparkConf,
- securityManager: SecurityManager) = {
+ securityManager: SecurityManager,
+ mapOutputTracker: MapOutputTracker) = {
this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf),
- conf, securityManager)
+ conf, securityManager, mapOutputTracker)
}
/**
@@ -217,9 +221,26 @@ private[spark] class BlockManager(
}
/**
- * Get storage level of local block. If no info exists for the block, then returns null.
+ * Get the BlockStatus for the block identified by the given ID, if it exists.
+ * NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon.
+ */
+ def getStatus(blockId: BlockId): Option[BlockStatus] = {
+ blockInfo.get(blockId).map { info =>
+ val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L
+ val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L
+ // Assume that block is not in Tachyon
+ BlockStatus(info.level, memSize, diskSize, 0L)
+ }
+ }
+
+ /**
+ * Get the ids of existing blocks that match the given filter. Note that this will
+ * query the blocks stored in the disk block manager (that the block manager
+ * may not know of).
*/
- def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
+ def getMatchingBlockIds(filter: BlockId => Boolean): Seq[BlockId] = {
+ (blockInfo.keys ++ diskBlockManager.getAllBlocks()).filter(filter).toSeq
+ }
/**
* Tell the master about the current storage status of a block. This will send a block update
@@ -525,9 +546,8 @@ private[spark] class BlockManager(
/**
* A short circuited method to get a block writer that can write data directly to disk.
- * The Block will be appended to the File specified by filename.
- * This is currently used for writing shuffle files out. Callers should handle error
- * cases.
+ * The Block will be appended to the File specified by filename. This is currently used for
+ * writing shuffle files out. Callers should handle error cases.
*/
def getDiskWriter(
blockId: BlockId,
@@ -863,11 +883,22 @@ private[spark] class BlockManager(
* @return The number of blocks removed.
*/
def removeRdd(rddId: Int): Int = {
- // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps
- // from RDD.id to blocks.
+ // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks.
logInfo("Removing RDD " + rddId)
val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
- blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false))
+ blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) }
+ blocksToRemove.size
+ }
+
+ /**
+ * Remove all blocks belonging to the given broadcast.
+ */
+ def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = {
+ logInfo("Removing broadcast " + broadcastId)
+ val blocksToRemove = blockInfo.keys.collect {
+ case bid @ BroadcastBlockId(`broadcastId`, _) => bid
+ }
+ blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) }
blocksToRemove.size
}
@@ -908,10 +939,10 @@ private[spark] class BlockManager(
}
private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) {
- val iterator = blockInfo.internalMap.entrySet().iterator()
+ val iterator = blockInfo.getEntrySet.iterator
while (iterator.hasNext) {
val entry = iterator.next()
- val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
+ val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp)
if (time < cleanupTime && shouldDrop(id)) {
info.synchronized {
val level = info.level
@@ -935,7 +966,7 @@ private[spark] class BlockManager(
def shouldCompress(blockId: BlockId): Boolean = blockId match {
case ShuffleBlockId(_, _, _) => compressShuffle
- case BroadcastBlockId(_) => compressBroadcast
+ case BroadcastBlockId(_, _) => compressBroadcast
case RDDBlockId(_, _) => compressRdds
case TempBlockId(_) => compressShuffleSpill
case _ => false
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 4bc1b407ad..7897fade2d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -81,6 +81,14 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
}
+ /**
+ * Check if block manager master has a block. Note that this can be used to check for only
+ * those blocks that are reported to block manager master.
+ */
+ def contains(blockId: BlockId) = {
+ !getLocations(blockId).isEmpty
+ }
+
/** Get ids of other nodes in the cluster from the driver */
def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = {
val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers))
@@ -99,12 +107,10 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
askDriverWithReply(RemoveBlock(blockId))
}
- /**
- * Remove all blocks belonging to the given RDD.
- */
+ /** Remove all blocks belonging to the given RDD. */
def removeRdd(rddId: Int, blocking: Boolean) {
val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId))
- future onFailure {
+ future.onFailure {
case e: Throwable => logError("Failed to remove RDD " + rddId, e)
}
if (blocking) {
@@ -112,6 +118,31 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
}
}
+ /** Remove all blocks belonging to the given shuffle. */
+ def removeShuffle(shuffleId: Int, blocking: Boolean) {
+ val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
+ future.onFailure {
+ case e: Throwable => logError("Failed to remove shuffle " + shuffleId, e)
+ }
+ if (blocking) {
+ Await.result(future, timeout)
+ }
+ }
+
+ /** Remove all blocks belonging to the given broadcast. */
+ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) {
+ val future = askDriverWithReply[Future[Seq[Int]]](
+ RemoveBroadcast(broadcastId, removeFromMaster))
+ future.onFailure {
+ case e: Throwable =>
+ logError("Failed to remove broadcast " + broadcastId +
+ " with removeFromMaster = " + removeFromMaster, e)
+ }
+ if (blocking) {
+ Await.result(future, timeout)
+ }
+ }
+
/**
* Return the memory status for each block manager, in the form of a map from
* the block manager's id to two long values. The first value is the maximum
@@ -126,6 +157,51 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
askDriverWithReply[Array[StorageStatus]](GetStorageStatus)
}
+ /**
+ * Return the block's status on all block managers, if any. NOTE: This is a
+ * potentially expensive operation and should only be used for testing.
+ *
+ * If askSlaves is true, this invokes the master to query each block manager for the most
+ * updated block statuses. This is useful when the master is not informed of the given block
+ * by all block managers.
+ */
+ def getBlockStatus(
+ blockId: BlockId,
+ askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = {
+ val msg = GetBlockStatus(blockId, askSlaves)
+ /*
+ * To avoid potential deadlocks, the use of Futures is necessary, because the master actor
+ * should not block on waiting for a block manager, which can in turn be waiting for the
+ * master actor for a response to a prior message.
+ */
+ val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
+ val (blockManagerIds, futures) = response.unzip
+ val result = Await.result(Future.sequence(futures), timeout)
+ if (result == null) {
+ throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId)
+ }
+ val blockStatus = result.asInstanceOf[Iterable[Option[BlockStatus]]]
+ blockManagerIds.zip(blockStatus).flatMap { case (blockManagerId, status) =>
+ status.map { s => (blockManagerId, s) }
+ }.toMap
+ }
+
+ /**
+ * Return a list of ids of existing blocks such that the ids match the given filter. NOTE: This
+ * is a potentially expensive operation and should only be used for testing.
+ *
+ * If askSlaves is true, this invokes the master to query each block manager for the most
+ * updated block statuses. This is useful when the master is not informed of the given block
+ * by all block managers.
+ */
+ def getMatchingBlockIds(
+ filter: BlockId => Boolean,
+ askSlaves: Boolean): Seq[BlockId] = {
+ val msg = GetMatchingBlockIds(filter, askSlaves)
+ val future = askDriverWithReply[Future[Seq[BlockId]]](msg)
+ Await.result(future, timeout)
+ }
+
/** Stop the driver actor, called only on the Spark driver node */
def stop() {
if (driverActor != null) {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 378f4cadc1..c57b6e8391 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -94,9 +94,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
case GetStorageStatus =>
sender ! storageStatus
+ case GetBlockStatus(blockId, askSlaves) =>
+ sender ! blockStatus(blockId, askSlaves)
+
+ case GetMatchingBlockIds(filter, askSlaves) =>
+ sender ! getMatchingBlockIds(filter, askSlaves)
+
case RemoveRdd(rddId) =>
sender ! removeRdd(rddId)
+ case RemoveShuffle(shuffleId) =>
+ sender ! removeShuffle(shuffleId)
+
+ case RemoveBroadcast(broadcastId, removeFromDriver) =>
+ sender ! removeBroadcast(broadcastId, removeFromDriver)
+
case RemoveBlock(blockId) =>
removeBlockFromWorkers(blockId)
sender ! true
@@ -140,9 +152,41 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
// The dispatcher is used as an implicit argument into the Future sequence construction.
import context.dispatcher
val removeMsg = RemoveRdd(rddId)
- Future.sequence(blockManagerInfo.values.map { bm =>
- bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
- }.toSeq)
+ Future.sequence(
+ blockManagerInfo.values.map { bm =>
+ bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
+ }.toSeq
+ )
+ }
+
+ private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = {
+ // Nothing to do in the BlockManagerMasterActor data structures
+ import context.dispatcher
+ val removeMsg = RemoveShuffle(shuffleId)
+ Future.sequence(
+ blockManagerInfo.values.map { bm =>
+ bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean]
+ }.toSeq
+ )
+ }
+
+ /**
+ * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified
+ * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed
+ * from the executors, but not from the driver.
+ */
+ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = {
+ // TODO: Consolidate usages of <driver>
+ import context.dispatcher
+ val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver)
+ val requiredBlockManagers = blockManagerInfo.values.filter { info =>
+ removeFromDriver || info.blockManagerId.executorId != "<driver>"
+ }
+ Future.sequence(
+ requiredBlockManagers.map { bm =>
+ bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int]
+ }.toSeq
+ )
}
private def removeBlockManager(blockManagerId: BlockManagerId) {
@@ -225,6 +269,61 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
}.toArray
}
+ /**
+ * Return the block's status for all block managers, if any. NOTE: This is a
+ * potentially expensive operation and should only be used for testing.
+ *
+ * If askSlaves is true, the master queries each block manager for the most updated block
+ * statuses. This is useful when the master is not informed of the given block by all block
+ * managers.
+ */
+ private def blockStatus(
+ blockId: BlockId,
+ askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = {
+ import context.dispatcher
+ val getBlockStatus = GetBlockStatus(blockId)
+ /*
+ * Rather than blocking on the block status query, master actor should simply return
+ * Futures to avoid potential deadlocks. This can arise if there exists a block manager
+ * that is also waiting for this master actor's response to a previous message.
+ */
+ blockManagerInfo.values.map { info =>
+ val blockStatusFuture =
+ if (askSlaves) {
+ info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]]
+ } else {
+ Future { info.getStatus(blockId) }
+ }
+ (info.blockManagerId, blockStatusFuture)
+ }.toMap
+ }
+
+ /**
+ * Return the ids of blocks present in all the block managers that match the given filter.
+ * NOTE: This is a potentially expensive operation and should only be used for testing.
+ *
+ * If askSlaves is true, the master queries each block manager for the most updated block
+ * statuses. This is useful when the master is not informed of the given block by all block
+ * managers.
+ */
+ private def getMatchingBlockIds(
+ filter: BlockId => Boolean,
+ askSlaves: Boolean): Future[Seq[BlockId]] = {
+ import context.dispatcher
+ val getMatchingBlockIds = GetMatchingBlockIds(filter)
+ Future.sequence(
+ blockManagerInfo.values.map { info =>
+ val future =
+ if (askSlaves) {
+ info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]]
+ } else {
+ Future { info.blocks.keys.filter(filter).toSeq }
+ }
+ future
+ }
+ ).map(_.flatten.toSeq)
+ }
+
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
@@ -334,6 +433,8 @@ private[spark] class BlockManagerInfo(
logInfo("Registering block manager %s with %s RAM".format(
blockManagerId.hostPort, Utils.bytesToString(maxMem)))
+ def getStatus(blockId: BlockId) = Option(_blocks.get(blockId))
+
def updateLastSeenMs() {
_lastSeenMs = System.currentTimeMillis()
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 8a36b5cc42..2b53bf33b5 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -34,6 +34,13 @@ private[storage] object BlockManagerMessages {
// Remove all blocks belonging to a specific RDD.
case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave
+ // Remove all blocks belonging to a specific shuffle.
+ case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave
+
+ // Remove all blocks belonging to a specific broadcast.
+ case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true)
+ extends ToBlockManagerSlave
+
//////////////////////////////////////////////////////////////////////////////////
// Messages from slaves to the master.
@@ -80,7 +87,8 @@ private[storage] object BlockManagerMessages {
}
object UpdateBlockInfo {
- def apply(blockManagerId: BlockManagerId,
+ def apply(
+ blockManagerId: BlockManagerId,
blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
@@ -108,7 +116,13 @@ private[storage] object BlockManagerMessages {
case object GetMemoryStatus extends ToBlockManagerMaster
- case object ExpireDeadHosts extends ToBlockManagerMaster
-
case object GetStorageStatus extends ToBlockManagerMaster
+
+ case class GetBlockStatus(blockId: BlockId, askSlaves: Boolean = true)
+ extends ToBlockManagerMaster
+
+ case class GetMatchingBlockIds(filter: BlockId => Boolean, askSlaves: Boolean = true)
+ extends ToBlockManagerMaster
+
+ case object ExpireDeadHosts extends ToBlockManagerMaster
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
index bcfb82d3c7..6d4db064df 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
@@ -17,8 +17,11 @@
package org.apache.spark.storage
-import akka.actor.Actor
+import scala.concurrent.Future
+import akka.actor.{ActorRef, Actor}
+
+import org.apache.spark.{Logging, MapOutputTracker}
import org.apache.spark.storage.BlockManagerMessages._
/**
@@ -26,14 +29,59 @@ import org.apache.spark.storage.BlockManagerMessages._
* this is used to remove blocks from the slave's BlockManager.
*/
private[storage]
-class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor {
- override def receive = {
+class BlockManagerSlaveActor(
+ blockManager: BlockManager,
+ mapOutputTracker: MapOutputTracker)
+ extends Actor with Logging {
+
+ import context.dispatcher
+ // Operations that involve removing blocks may be slow and should be done asynchronously
+ override def receive = {
case RemoveBlock(blockId) =>
- blockManager.removeBlock(blockId)
+ doAsync[Boolean]("removing block " + blockId, sender) {
+ blockManager.removeBlock(blockId)
+ true
+ }
case RemoveRdd(rddId) =>
- val numBlocksRemoved = blockManager.removeRdd(rddId)
- sender ! numBlocksRemoved
+ doAsync[Int]("removing RDD " + rddId, sender) {
+ blockManager.removeRdd(rddId)
+ }
+
+ case RemoveShuffle(shuffleId) =>
+ doAsync[Boolean]("removing shuffle " + shuffleId, sender) {
+ if (mapOutputTracker != null) {
+ mapOutputTracker.unregisterShuffle(shuffleId)
+ }
+ blockManager.shuffleBlockManager.removeShuffle(shuffleId)
+ }
+
+ case RemoveBroadcast(broadcastId, tellMaster) =>
+ doAsync[Int]("removing broadcast " + broadcastId, sender) {
+ blockManager.removeBroadcast(broadcastId, tellMaster)
+ }
+
+ case GetBlockStatus(blockId, _) =>
+ sender ! blockManager.getStatus(blockId)
+
+ case GetMatchingBlockIds(filter, _) =>
+ sender ! blockManager.getMatchingBlockIds(filter)
+ }
+
+ private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) {
+ val future = Future {
+ logDebug(actionMessage)
+ body
+ }
+ future.onSuccess { case response =>
+ logDebug("Done " + actionMessage + ", response is " + response)
+ responseActor ! response
+ logDebug("Sent response: " + response + " to " + responseActor)
+ }
+ future.onFailure { case t: Throwable =>
+ logError("Error in " + actionMessage, t)
+ responseActor ! null.asInstanceOf[T]
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index f3e1c38744..7a24c8f57f 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -90,6 +90,20 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
def getFile(blockId: BlockId): File = getFile(blockId.name)
+ /** Check if disk block manager has a block. */
+ def containsBlock(blockId: BlockId): Boolean = {
+ getBlockLocation(blockId).file.exists()
+ }
+
+ /** List all the blocks currently stored on disk by the disk manager. */
+ def getAllBlocks(): Seq[BlockId] = {
+ // Get all the files inside the array of array of directories
+ subDirs.flatten.filter(_ != null).flatMap { dir =>
+ val files = dir.list()
+ if (files != null) files else Seq.empty
+ }.map(BlockId.apply)
+ }
+
/** Produces a unique block id and File suitable for intermediate results. */
def createTempBlock(): (TempBlockId, File) = {
var blockId = new TempBlockId(UUID.randomUUID())
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index bb07c8cb13..4cd4cdbd99 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -169,23 +169,43 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
throw new IllegalStateException("Failed to find shuffle block: " + id)
}
+ /** Remove all the blocks / files and metadata related to a particular shuffle. */
+ def removeShuffle(shuffleId: ShuffleId): Boolean = {
+ // Do not change the ordering of this, if shuffleStates should be removed only
+ // after the corresponding shuffle blocks have been removed
+ val cleaned = removeShuffleBlocks(shuffleId)
+ shuffleStates.remove(shuffleId)
+ cleaned
+ }
+
+ /** Remove all the blocks / files related to a particular shuffle. */
+ private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = {
+ shuffleStates.get(shuffleId) match {
+ case Some(state) =>
+ if (consolidateShuffleFiles) {
+ for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
+ file.delete()
+ }
+ } else {
+ for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) {
+ val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId)
+ blockManager.diskBlockManager.getFile(blockId).delete()
+ }
+ }
+ logInfo("Deleted all files for shuffle " + shuffleId)
+ true
+ case None =>
+ logInfo("Could not find files for shuffle " + shuffleId + " for deleting")
+ false
+ }
+ }
+
private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
"merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
}
private def cleanup(cleanupTime: Long) {
- shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => {
- if (consolidateShuffleFiles) {
- for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
- file.delete()
- }
- } else {
- for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) {
- val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId)
- blockManager.diskBlockManager.getFile(blockId).delete()
- }
- }
- })
+ shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId))
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
index 226ed2a132..a107c5182b 100644
--- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
@@ -22,7 +22,7 @@ import java.util.concurrent.ArrayBlockingQueue
import akka.actor._
import util.Random
-import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.KryoSerializer
@@ -48,7 +48,7 @@ private[spark] object ThreadingTest {
val block = (1 to blockSize).map(_ => Random.nextInt())
val level = randomLevel()
val startTime = System.currentTimeMillis()
- manager.put(blockId, block.iterator, level, true)
+ manager.put(blockId, block.iterator, level, tellMaster = true)
println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms")
queue.add((blockId, block))
}
@@ -101,7 +101,7 @@ private[spark] object ThreadingTest {
conf)
val blockManager = new BlockManager(
"<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf,
- new SecurityManager(conf))
+ new SecurityManager(conf), new MapOutputTrackerMaster(conf))
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
producers.foreach(_.start)
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index 0448919e09..7ebed5105b 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -62,8 +62,8 @@ private[spark] class MetadataCleaner(
private[spark] object MetadataCleanerType extends Enumeration {
- val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
- SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
+ val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, BLOCK_MANAGER,
+ SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
type MetadataCleanerType = Value
@@ -78,15 +78,16 @@ private[spark] object MetadataCleaner {
conf.getInt("spark.cleaner.ttl", -1)
}
- def getDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType): Int =
- {
- conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString)
- .toInt
+ def getDelaySeconds(
+ conf: SparkConf,
+ cleanerType: MetadataCleanerType.MetadataCleanerType): Int = {
+ conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString).toInt
}
- def setDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType,
- delay: Int)
- {
+ def setDelaySeconds(
+ conf: SparkConf,
+ cleanerType: MetadataCleanerType.MetadataCleanerType,
+ delay: Int) {
conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString)
}
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
index ddbd084ed7..8de75ba9a9 100644
--- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
@@ -17,48 +17,54 @@
package org.apache.spark.util
+import java.util.Set
+import java.util.Map.Entry
import java.util.concurrent.ConcurrentHashMap
-import scala.collection.JavaConversions
-import scala.collection.immutable
-import scala.collection.mutable.Map
+import scala.collection.{JavaConversions, mutable}
import org.apache.spark.Logging
+private[spark] case class TimeStampedValue[V](value: V, timestamp: Long)
+
/**
* This is a custom implementation of scala.collection.mutable.Map which stores the insertion
* timestamp along with each key-value pair. If specified, the timestamp of each pair can be
* updated every time it is accessed. Key-value pairs whose timestamp are older than a particular
* threshold time can then be removed using the clearOldValues method. This is intended to
* be a drop-in replacement of scala.collection.mutable.HashMap.
- * @param updateTimeStampOnGet When enabled, the timestamp of a pair will be
- * updated when it is accessed
+ *
+ * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed
*/
-class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
- extends Map[A, B]() with Logging {
- val internalMap = new ConcurrentHashMap[A, (B, Long)]()
+private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
+ extends mutable.Map[A, B]() with Logging {
+
+ private val internalMap = new ConcurrentHashMap[A, TimeStampedValue[B]]()
def get(key: A): Option[B] = {
val value = internalMap.get(key)
if (value != null && updateTimeStampOnGet) {
- internalMap.replace(key, value, (value._1, currentTime))
+ internalMap.replace(key, value, TimeStampedValue(value.value, currentTime))
}
- Option(value).map(_._1)
+ Option(value).map(_.value)
}
def iterator: Iterator[(A, B)] = {
- val jIterator = internalMap.entrySet().iterator()
- JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1))
+ val jIterator = getEntrySet.iterator
+ JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value))
}
- override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = {
+ def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet
+
+ override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = {
val newMap = new TimeStampedHashMap[A, B1]
- newMap.internalMap.putAll(this.internalMap)
- newMap.internalMap.put(kv._1, (kv._2, currentTime))
+ val oldInternalMap = this.internalMap.asInstanceOf[ConcurrentHashMap[A, TimeStampedValue[B1]]]
+ newMap.internalMap.putAll(oldInternalMap)
+ kv match { case (a, b) => newMap.internalMap.put(a, TimeStampedValue(b, currentTime)) }
newMap
}
- override def - (key: A): Map[A, B] = {
+ override def - (key: A): mutable.Map[A, B] = {
val newMap = new TimeStampedHashMap[A, B]
newMap.internalMap.putAll(this.internalMap)
newMap.internalMap.remove(key)
@@ -66,17 +72,10 @@ class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
}
override def += (kv: (A, B)): this.type = {
- internalMap.put(kv._1, (kv._2, currentTime))
+ kv match { case (a, b) => internalMap.put(a, TimeStampedValue(b, currentTime)) }
this
}
- // Should we return previous value directly or as Option ?
- def putIfAbsent(key: A, value: B): Option[B] = {
- val prev = internalMap.putIfAbsent(key, (value, currentTime))
- if (prev != null) Some(prev._1) else None
- }
-
-
override def -= (key: A): this.type = {
internalMap.remove(key)
this
@@ -87,53 +86,65 @@ class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false)
}
override def apply(key: A): B = {
- val value = internalMap.get(key)
- if (value == null) throw new NoSuchElementException()
- value._1
+ get(key).getOrElse { throw new NoSuchElementException() }
}
- override def filter(p: ((A, B)) => Boolean): Map[A, B] = {
- JavaConversions.mapAsScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p)
+ override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = {
+ JavaConversions.mapAsScalaConcurrentMap(internalMap)
+ .map { case (k, TimeStampedValue(v, t)) => (k, v) }
+ .filter(p)
}
- override def empty: Map[A, B] = new TimeStampedHashMap[A, B]()
+ override def empty: mutable.Map[A, B] = new TimeStampedHashMap[A, B]()
override def size: Int = internalMap.size
override def foreach[U](f: ((A, B)) => U) {
- val iterator = internalMap.entrySet().iterator()
- while(iterator.hasNext) {
- val entry = iterator.next()
- val kv = (entry.getKey, entry.getValue._1)
+ val it = getEntrySet.iterator
+ while(it.hasNext) {
+ val entry = it.next()
+ val kv = (entry.getKey, entry.getValue.value)
f(kv)
}
}
- def toMap: immutable.Map[A, B] = iterator.toMap
+ def putIfAbsent(key: A, value: B): Option[B] = {
+ val prev = internalMap.putIfAbsent(key, TimeStampedValue(value, currentTime))
+ Option(prev).map(_.value)
+ }
+
+ def putAll(map: Map[A, B]) {
+ map.foreach { case (k, v) => update(k, v) }
+ }
+
+ def toMap: Map[A, B] = iterator.toMap
- /**
- * Removes old key-value pairs that have timestamp earlier than `threshTime`,
- * calling the supplied function on each such entry before removing.
- */
def clearOldValues(threshTime: Long, f: (A, B) => Unit) {
- val iterator = internalMap.entrySet().iterator()
- while (iterator.hasNext) {
- val entry = iterator.next()
- if (entry.getValue._2 < threshTime) {
- f(entry.getKey, entry.getValue._1)
+ val it = getEntrySet.iterator
+ while (it.hasNext) {
+ val entry = it.next()
+ if (entry.getValue.timestamp < threshTime) {
+ f(entry.getKey, entry.getValue.value)
logDebug("Removing key " + entry.getKey)
- iterator.remove()
+ it.remove()
}
}
}
- /**
- * Removes old key-value pairs that have timestamp earlier than `threshTime`
- */
+ /** Removes old key-value pairs that have timestamp earlier than `threshTime`. */
def clearOldValues(threshTime: Long) {
clearOldValues(threshTime, (_, _) => ())
}
- private def currentTime: Long = System.currentTimeMillis()
+ private def currentTime: Long = System.currentTimeMillis
+ // For testing
+
+ def getTimeStampedValue(key: A): Option[TimeStampedValue[B]] = {
+ Option(internalMap.get(key))
+ }
+
+ def getTimestamp(key: A): Option[Long] = {
+ getTimeStampedValue(key).map(_.timestamp)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala
new file mode 100644
index 0000000000..b65017d680
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.lang.ref.WeakReference
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.collection.mutable
+
+import org.apache.spark.Logging
+
+/**
+ * A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped.
+ *
+ * If the value is garbage collected and the weak reference is null, get() will return a
+ * non-existent value. These entries are removed from the map periodically (every N inserts), as
+ * their values are no longer strongly reachable. Further, key-value pairs whose timestamps are
+ * older than a particular threshold can be removed using the clearOldValues method.
+ *
+ * TimeStampedWeakValueHashMap exposes a scala.collection.mutable.Map interface, which allows it
+ * to be a drop-in replacement for Scala HashMaps. Internally, it uses a Java ConcurrentHashMap,
+ * so all operations on this HashMap are thread-safe.
+ *
+ * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed.
+ */
+private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false)
+ extends mutable.Map[A, B]() with Logging {
+
+ import TimeStampedWeakValueHashMap._
+
+ private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet)
+ private val insertCount = new AtomicInteger(0)
+
+ /** Return a map consisting only of entries whose values are still strongly reachable. */
+ private def nonNullReferenceMap = internalMap.filter { case (_, ref) => ref.get != null }
+
+ def get(key: A): Option[B] = internalMap.get(key)
+
+ def iterator: Iterator[(A, B)] = nonNullReferenceMap.iterator
+
+ override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = {
+ val newMap = new TimeStampedWeakValueHashMap[A, B1]
+ val oldMap = nonNullReferenceMap.asInstanceOf[mutable.Map[A, WeakReference[B1]]]
+ newMap.internalMap.putAll(oldMap.toMap)
+ newMap.internalMap += kv
+ newMap
+ }
+
+ override def - (key: A): mutable.Map[A, B] = {
+ val newMap = new TimeStampedWeakValueHashMap[A, B]
+ newMap.internalMap.putAll(nonNullReferenceMap.toMap)
+ newMap.internalMap -= key
+ newMap
+ }
+
+ override def += (kv: (A, B)): this.type = {
+ internalMap += kv
+ if (insertCount.incrementAndGet() % CLEAR_NULL_VALUES_INTERVAL == 0) {
+ clearNullValues()
+ }
+ this
+ }
+
+ override def -= (key: A): this.type = {
+ internalMap -= key
+ this
+ }
+
+ override def update(key: A, value: B) = this += ((key, value))
+
+ override def apply(key: A): B = internalMap.apply(key)
+
+ override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = nonNullReferenceMap.filter(p)
+
+ override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]()
+
+ override def size: Int = internalMap.size
+
+ override def foreach[U](f: ((A, B)) => U) = nonNullReferenceMap.foreach(f)
+
+ def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value)
+
+ def toMap: Map[A, B] = iterator.toMap
+
+ /** Remove old key-value pairs with timestamps earlier than `threshTime`. */
+ def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime)
+
+ /** Remove entries with values that are no longer strongly reachable. */
+ def clearNullValues() {
+ val it = internalMap.getEntrySet.iterator
+ while (it.hasNext) {
+ val entry = it.next()
+ if (entry.getValue.value.get == null) {
+ logDebug("Removing key " + entry.getKey + " because it is no longer strongly reachable.")
+ it.remove()
+ }
+ }
+ }
+
+ // For testing
+
+ def getTimestamp(key: A): Option[Long] = {
+ internalMap.getTimeStampedValue(key).map(_.timestamp)
+ }
+
+ def getReference(key: A): Option[WeakReference[B]] = {
+ internalMap.getTimeStampedValue(key).map(_.value)
+ }
+}
+
+/**
+ * Helper methods for converting to and from WeakReferences.
+ */
+private object TimeStampedWeakValueHashMap {
+
+ // Number of inserts after which entries with null references are removed
+ val CLEAR_NULL_VALUES_INTERVAL = 100
+
+ /* Implicit conversion methods to WeakReferences. */
+
+ implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v)
+
+ implicit def toWeakReferenceTuple[K, V](kv: (K, V)): (K, WeakReference[V]) = {
+ kv match { case (k, v) => (k, toWeakReference(v)) }
+ }
+
+ implicit def toWeakReferenceFunction[K, V, R](p: ((K, V)) => R): ((K, WeakReference[V])) => R = {
+ (kv: (K, WeakReference[V])) => p(kv)
+ }
+
+ /* Implicit conversion methods from WeakReferences. */
+
+ implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get
+
+ implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = {
+ v match {
+ case Some(ref) => Option(fromWeakReference(ref))
+ case None => None
+ }
+ }
+
+ implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = {
+ kv match { case (k, v) => (k, fromWeakReference(v)) }
+ }
+
+ implicit def fromWeakReferenceIterator[K, V](
+ it: Iterator[(K, WeakReference[V])]): Iterator[(K, V)] = {
+ it.map(fromWeakReferenceTuple)
+ }
+
+ implicit def fromWeakReferenceMap[K, V](
+ map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = {
+ mutable.Map(map.mapValues(fromWeakReference).toSeq: _*)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 4435b21a75..59da51f3e0 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -499,10 +499,10 @@ private[spark] object Utils extends Logging {
private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]()
def parseHostPort(hostPort: String): (String, Int) = {
- {
- // Check cache first.
- val cached = hostPortParseResults.get(hostPort)
- if (cached != null) return cached
+ // Check cache first.
+ val cached = hostPortParseResults.get(hostPort)
+ if (cached != null) {
+ return cached
}
val indx: Int = hostPort.lastIndexOf(':')
diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
index d2e303d81c..c5f24c66ce 100644
--- a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
@@ -56,7 +56,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
conf = conf, securityManager = securityManagerBad)
- val slaveTracker = new MapOutputTracker(conf)
+ val slaveTracker = new MapOutputTrackerWorker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
val timeout = AkkaUtils.lookupTimeout(conf)
@@ -93,7 +93,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
conf = badconf, securityManager = securityManagerBad)
- val slaveTracker = new MapOutputTracker(conf)
+ val slaveTracker = new MapOutputTrackerWorker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
val timeout = AkkaUtils.lookupTimeout(conf)
@@ -147,7 +147,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
conf = goodconf, securityManager = securityManagerGood)
- val slaveTracker = new MapOutputTracker(conf)
+ val slaveTracker = new MapOutputTrackerWorker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
val timeout = AkkaUtils.lookupTimeout(conf)
@@ -200,7 +200,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0,
conf = badconf, securityManager = securityManagerBad)
- val slaveTracker = new MapOutputTracker(conf)
+ val slaveTracker = new MapOutputTrackerWorker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
val timeout = AkkaUtils.lookupTimeout(conf)
diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
index 96ba3929c1..c9936256a5 100644
--- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
@@ -19,68 +19,297 @@ package org.apache.spark
import org.scalatest.FunSuite
-class BroadcastSuite extends FunSuite with LocalSparkContext {
+import org.apache.spark.storage._
+import org.apache.spark.broadcast.{Broadcast, HttpBroadcast}
+import org.apache.spark.storage.BroadcastBlockId
+class BroadcastSuite extends FunSuite with LocalSparkContext {
- override def afterEach() {
- super.afterEach()
- System.clearProperty("spark.broadcast.factory")
- }
+ private val httpConf = broadcastConf("HttpBroadcastFactory")
+ private val torrentConf = broadcastConf("TorrentBroadcastFactory")
test("Using HttpBroadcast locally") {
- System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
- sc = new SparkContext("local", "test")
- val list = List(1, 2, 3, 4)
- val listBroadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
- assert(results.collect.toSet === Set((1, 10), (2, 10)))
+ sc = new SparkContext("local", "test", httpConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === Set((1, 10), (2, 10)))
}
test("Accessing HttpBroadcast variables from multiple threads") {
- System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
- sc = new SparkContext("local[10]", "test")
- val list = List(1, 2, 3, 4)
- val listBroadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
- assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
+ sc = new SparkContext("local[10]", "test", httpConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet)
}
test("Accessing HttpBroadcast variables in a local cluster") {
- System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
val numSlaves = 4
- sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
- val list = List(1, 2, 3, 4)
- val listBroadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
- assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
}
test("Using TorrentBroadcast locally") {
- System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
- sc = new SparkContext("local", "test")
- val list = List(1, 2, 3, 4)
- val listBroadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
- assert(results.collect.toSet === Set((1, 10), (2, 10)))
+ sc = new SparkContext("local", "test", torrentConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === Set((1, 10), (2, 10)))
}
test("Accessing TorrentBroadcast variables from multiple threads") {
- System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
- sc = new SparkContext("local[10]", "test")
- val list = List(1, 2, 3, 4)
- val listBroadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
- assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
+ sc = new SparkContext("local[10]", "test", torrentConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet)
}
test("Accessing TorrentBroadcast variables in a local cluster") {
- System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
val numSlaves = 4
- sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
- val list = List(1, 2, 3, 4)
- val listBroadcast = sc.broadcast(list)
- val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
- assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf)
+ val list = List[Int](1, 2, 3, 4)
+ val broadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ }
+
+ test("Unpersisting HttpBroadcast on executors only in local mode") {
+ testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false)
+ }
+
+ test("Unpersisting HttpBroadcast on executors and driver in local mode") {
+ testUnpersistHttpBroadcast(distributed = false, removeFromDriver = true)
+ }
+
+ test("Unpersisting HttpBroadcast on executors only in distributed mode") {
+ testUnpersistHttpBroadcast(distributed = true, removeFromDriver = false)
+ }
+
+ test("Unpersisting HttpBroadcast on executors and driver in distributed mode") {
+ testUnpersistHttpBroadcast(distributed = true, removeFromDriver = true)
+ }
+
+ test("Unpersisting TorrentBroadcast on executors only in local mode") {
+ testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = false)
+ }
+
+ test("Unpersisting TorrentBroadcast on executors and driver in local mode") {
+ testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = true)
+ }
+
+ test("Unpersisting TorrentBroadcast on executors only in distributed mode") {
+ testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = false)
+ }
+
+ test("Unpersisting TorrentBroadcast on executors and driver in distributed mode") {
+ testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = true)
+ }
+ /**
+ * Verify the persistence of state associated with an HttpBroadcast in either local mode or
+ * local-cluster mode (when distributed = true).
+ *
+ * This test creates a broadcast variable, uses it on all executors, and then unpersists it.
+ * In between each step, this test verifies that the broadcast blocks and the broadcast file
+ * are present only on the expected nodes.
+ */
+ private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
+ val numSlaves = if (distributed) 2 else 0
+
+ def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id))
+
+ // Verify that the broadcast file is created, and blocks are persisted only on the driver
+ def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ assert(blockIds.size === 1)
+ val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ assert(statuses.size === 1)
+ statuses.head match { case (bm, status) =>
+ assert(bm.executorId === "<driver>", "Block should only be on the driver")
+ assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
+ assert(status.memSize > 0, "Block should be in memory store on the driver")
+ assert(status.diskSize === 0, "Block should not be in disk store on the driver")
+ }
+ if (distributed) {
+ // this file is only generated in distributed mode
+ assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!")
+ }
+ }
+
+ // Verify that blocks are persisted in both the executors and the driver
+ def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ assert(blockIds.size === 1)
+ val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ assert(statuses.size === numSlaves + 1)
+ statuses.foreach { case (_, status) =>
+ assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
+ assert(status.memSize > 0, "Block should be in memory store")
+ assert(status.diskSize === 0, "Block should not be in disk store")
+ }
+ }
+
+ // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
+ // is true. In the latter case, also verify that the broadcast file is deleted on the driver.
+ def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ assert(blockIds.size === 1)
+ val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ val expectedNumBlocks = if (removeFromDriver) 0 else 1
+ val possiblyNot = if (removeFromDriver) "" else " not"
+ assert(statuses.size === expectedNumBlocks,
+ "Block should%s be unpersisted on the driver".format(possiblyNot))
+ if (distributed && removeFromDriver) {
+ // this file is only generated in distributed mode
+ assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists,
+ "Broadcast file should%s be deleted".format(possiblyNot))
+ }
+ }
+
+ testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation,
+ afterUsingBroadcast, afterUnpersist, removeFromDriver)
+ }
+
+ /**
+ * Verify the persistence of state associated with an TorrentBroadcast in a local-cluster.
+ *
+ * This test creates a broadcast variable, uses it on all executors, and then unpersists it.
+ * In between each step, this test verifies that the broadcast blocks are present only on the
+ * expected nodes.
+ */
+ private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
+ val numSlaves = if (distributed) 2 else 0
+
+ def getBlockIds(id: Long) = {
+ val broadcastBlockId = BroadcastBlockId(id)
+ val metaBlockId = BroadcastBlockId(id, "meta")
+ // Assume broadcast value is small enough to fit into 1 piece
+ val pieceBlockId = BroadcastBlockId(id, "piece0")
+ if (distributed) {
+ // the metadata and piece blocks are generated only in distributed mode
+ Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId)
+ } else {
+ Seq[BroadcastBlockId](broadcastBlockId)
+ }
+ }
+
+ // Verify that blocks are persisted only on the driver
+ def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ blockIds.foreach { blockId =>
+ val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+ assert(statuses.size === 1)
+ statuses.head match { case (bm, status) =>
+ assert(bm.executorId === "<driver>", "Block should only be on the driver")
+ assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
+ assert(status.memSize > 0, "Block should be in memory store on the driver")
+ assert(status.diskSize === 0, "Block should not be in disk store on the driver")
+ }
+ }
+ }
+
+ // Verify that blocks are persisted in both the executors and the driver
+ def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ blockIds.foreach { blockId =>
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ if (blockId.field == "meta") {
+ // Meta data is only on the driver
+ assert(statuses.size === 1)
+ statuses.head match { case (bm, _) => assert(bm.executorId === "<driver>") }
+ } else {
+ // Other blocks are on both the executors and the driver
+ assert(statuses.size === numSlaves + 1,
+ blockId + " has " + statuses.size + " statuses: " + statuses.mkString(","))
+ statuses.foreach { case (_, status) =>
+ assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
+ assert(status.memSize > 0, "Block should be in memory store")
+ assert(status.diskSize === 0, "Block should not be in disk store")
+ }
+ }
+ }
+ }
+
+ // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
+ // is true.
+ def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
+ val expectedNumBlocks = if (removeFromDriver) 0 else 1
+ val possiblyNot = if (removeFromDriver) "" else " not"
+ blockIds.foreach { blockId =>
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === expectedNumBlocks,
+ "Block should%s be unpersisted on the driver".format(possiblyNot))
+ }
+ }
+
+ testUnpersistBroadcast(distributed, numSlaves, torrentConf, getBlockIds, afterCreation,
+ afterUsingBroadcast, afterUnpersist, removeFromDriver)
+ }
+
+ /**
+ * This test runs in 4 steps:
+ *
+ * 1) Create broadcast variable, and verify that all state is persisted on the driver.
+ * 2) Use the broadcast variable on all executors, and verify that all state is persisted
+ * on both the driver and the executors.
+ * 3) Unpersist the broadcast, and verify that all state is removed where they should be.
+ * 4) [Optional] If removeFromDriver is false, we verify that the broadcast is re-usable.
+ */
+ private def testUnpersistBroadcast(
+ distributed: Boolean,
+ numSlaves: Int, // used only when distributed = true
+ broadcastConf: SparkConf,
+ getBlockIds: Long => Seq[BroadcastBlockId],
+ afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
+ afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
+ afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
+ removeFromDriver: Boolean) {
+
+ sc = if (distributed) {
+ new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf)
+ } else {
+ new SparkContext("local", "test", broadcastConf)
+ }
+ val blockManagerMaster = sc.env.blockManager.master
+ val list = List[Int](1, 2, 3, 4)
+
+ // Create broadcast variable
+ val broadcast = sc.broadcast(list)
+ val blocks = getBlockIds(broadcast.id)
+ afterCreation(blocks, blockManagerMaster)
+
+ // Use broadcast variable on all executors
+ val partitions = 10
+ assert(partitions > numSlaves)
+ val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
+ afterUsingBroadcast(blocks, blockManagerMaster)
+
+ // Unpersist broadcast
+ if (removeFromDriver) {
+ broadcast.destroy(blocking = true)
+ } else {
+ broadcast.unpersist(blocking = true)
+ }
+ afterUnpersist(blocks, blockManagerMaster)
+
+ // If the broadcast is removed from driver, all subsequent uses of the broadcast variable
+ // should throw SparkExceptions. Otherwise, the result should be the same as before.
+ if (removeFromDriver) {
+ // Using this variable on the executors crashes them, which hangs the test.
+ // Instead, crash the driver by directly accessing the broadcast value.
+ intercept[SparkException] { broadcast.value }
+ intercept[SparkException] { broadcast.unpersist() }
+ intercept[SparkException] { broadcast.destroy(blocking = true) }
+ } else {
+ val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
+ assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
+ }
}
+ /** Helper method to create a SparkConf that uses the given broadcast factory. */
+ private def broadcastConf(factoryName: String): SparkConf = {
+ val conf = new SparkConf
+ conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName))
+ conf
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
new file mode 100644
index 0000000000..e50981cf6f
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -0,0 +1,415 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.lang.ref.WeakReference
+
+import scala.collection.mutable.{HashSet, SynchronizedSet}
+import scala.util.Random
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.concurrent.Eventually
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.{BlockId, BroadcastBlockId, RDDBlockId, ShuffleBlockId}
+
+class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
+
+ implicit val defaultTimeout = timeout(10000 millis)
+ val conf = new SparkConf()
+ .setMaster("local[2]")
+ .setAppName("ContextCleanerSuite")
+ .set("spark.cleaner.referenceTracking.blocking", "true")
+
+ before {
+ sc = new SparkContext(conf)
+ }
+
+ after {
+ if (sc != null) {
+ sc.stop()
+ sc = null
+ }
+ }
+
+
+ test("cleanup RDD") {
+ val rdd = newRDD.persist()
+ val collected = rdd.collect().toList
+ val tester = new CleanerTester(sc, rddIds = Seq(rdd.id))
+
+ // Explicit cleanup
+ cleaner.doCleanupRDD(rdd.id, blocking = true)
+ tester.assertCleanup()
+
+ // Verify that RDDs can be re-executed after cleaning up
+ assert(rdd.collect().toList === collected)
+ }
+
+ test("cleanup shuffle") {
+ val (rdd, shuffleDeps) = newRDDWithShuffleDependencies
+ val collected = rdd.collect().toList
+ val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId))
+
+ // Explicit cleanup
+ shuffleDeps.foreach(s => cleaner.doCleanupShuffle(s.shuffleId, blocking = true))
+ tester.assertCleanup()
+
+ // Verify that shuffles can be re-executed after cleaning up
+ assert(rdd.collect().toList === collected)
+ }
+
+ test("cleanup broadcast") {
+ val broadcast = newBroadcast
+ val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
+
+ // Explicit cleanup
+ cleaner.doCleanupBroadcast(broadcast.id, blocking = true)
+ tester.assertCleanup()
+ }
+
+ test("automatically cleanup RDD") {
+ var rdd = newRDD.persist()
+ rdd.count()
+
+ // Test that GC does not cause RDD cleanup due to a strong reference
+ val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id))
+ runGC()
+ intercept[Exception] {
+ preGCTester.assertCleanup()(timeout(1000 millis))
+ }
+
+ // Test that GC causes RDD cleanup after dereferencing the RDD
+ val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id))
+ rdd = null // Make RDD out of scope
+ runGC()
+ postGCTester.assertCleanup()
+ }
+
+ test("automatically cleanup shuffle") {
+ var rdd = newShuffleRDD
+ rdd.count()
+
+ // Test that GC does not cause shuffle cleanup due to a strong reference
+ val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
+ runGC()
+ intercept[Exception] {
+ preGCTester.assertCleanup()(timeout(1000 millis))
+ }
+
+ // Test that GC causes shuffle cleanup after dereferencing the RDD
+ val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
+ rdd = null // Make RDD out of scope, so that corresponding shuffle goes out of scope
+ runGC()
+ postGCTester.assertCleanup()
+ }
+
+ test("automatically cleanup broadcast") {
+ var broadcast = newBroadcast
+
+ // Test that GC does not cause broadcast cleanup due to a strong reference
+ val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
+ runGC()
+ intercept[Exception] {
+ preGCTester.assertCleanup()(timeout(1000 millis))
+ }
+
+ // Test that GC causes broadcast cleanup after dereferencing the broadcast variable
+ val postGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
+ broadcast = null // Make broadcast variable out of scope
+ runGC()
+ postGCTester.assertCleanup()
+ }
+
+ test("automatically cleanup RDD + shuffle + broadcast") {
+ val numRdds = 100
+ val numBroadcasts = 4 // Broadcasts are more costly
+ val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
+ val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
+ val rddIds = sc.persistentRdds.keys.toSeq
+ val shuffleIds = 0 until sc.newShuffleId
+ val broadcastIds = 0L until numBroadcasts
+
+ val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+ runGC()
+ intercept[Exception] {
+ preGCTester.assertCleanup()(timeout(1000 millis))
+ }
+
+ // Test that GC triggers the cleanup of all variables after the dereferencing them
+ val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+ broadcastBuffer.clear()
+ rddBuffer.clear()
+ runGC()
+ postGCTester.assertCleanup()
+ }
+
+ test("automatically cleanup RDD + shuffle + broadcast in distributed mode") {
+ sc.stop()
+
+ val conf2 = new SparkConf()
+ .setMaster("local-cluster[2, 1, 512]")
+ .setAppName("ContextCleanerSuite")
+ .set("spark.cleaner.referenceTracking.blocking", "true")
+ sc = new SparkContext(conf2)
+
+ val numRdds = 10
+ val numBroadcasts = 4 // Broadcasts are more costly
+ val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
+ val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
+ val rddIds = sc.persistentRdds.keys.toSeq
+ val shuffleIds = 0 until sc.newShuffleId
+ val broadcastIds = 0L until numBroadcasts
+
+ val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+ runGC()
+ intercept[Exception] {
+ preGCTester.assertCleanup()(timeout(1000 millis))
+ }
+
+ // Test that GC triggers the cleanup of all variables after the dereferencing them
+ val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+ broadcastBuffer.clear()
+ rddBuffer.clear()
+ runGC()
+ postGCTester.assertCleanup()
+ }
+
+ //------ Helper functions ------
+
+ def newRDD = sc.makeRDD(1 to 10)
+ def newPairRDD = newRDD.map(_ -> 1)
+ def newShuffleRDD = newPairRDD.reduceByKey(_ + _)
+ def newBroadcast = sc.broadcast(1 to 100)
+ def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = {
+ def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
+ rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
+ getAllDependencies(dep.rdd)
+ }
+ }
+ val rdd = newShuffleRDD
+
+ // Get all the shuffle dependencies
+ val shuffleDeps = getAllDependencies(rdd)
+ .filter(_.isInstanceOf[ShuffleDependency[_, _]])
+ .map(_.asInstanceOf[ShuffleDependency[_, _]])
+ (rdd, shuffleDeps)
+ }
+
+ def randomRdd = {
+ val rdd: RDD[_] = Random.nextInt(3) match {
+ case 0 => newRDD
+ case 1 => newShuffleRDD
+ case 2 => newPairRDD.join(newPairRDD)
+ }
+ if (Random.nextBoolean()) rdd.persist()
+ rdd.count()
+ rdd
+ }
+
+ def randomBroadcast = {
+ sc.broadcast(Random.nextInt(Int.MaxValue))
+ }
+
+ /** Run GC and make sure it actually has run */
+ def runGC() {
+ val weakRef = new WeakReference(new Object())
+ val startTime = System.currentTimeMillis
+ System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
+ // Wait until a weak reference object has been GCed
+ while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
+ System.gc()
+ Thread.sleep(200)
+ }
+ }
+
+ def cleaner = sc.cleaner.get
+}
+
+
+/** Class to test whether RDDs, shuffles, etc. have been successfully cleaned. */
+class CleanerTester(
+ sc: SparkContext,
+ rddIds: Seq[Int] = Seq.empty,
+ shuffleIds: Seq[Int] = Seq.empty,
+ broadcastIds: Seq[Long] = Seq.empty)
+ extends Logging {
+
+ val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds
+ val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds
+ val toBeCleanedBroadcstIds = new HashSet[Long] with SynchronizedSet[Long] ++= broadcastIds
+ val isDistributed = !sc.isLocal
+
+ val cleanerListener = new CleanerListener {
+ def rddCleaned(rddId: Int): Unit = {
+ toBeCleanedRDDIds -= rddId
+ logInfo("RDD "+ rddId + " cleaned")
+ }
+
+ def shuffleCleaned(shuffleId: Int): Unit = {
+ toBeCleanedShuffleIds -= shuffleId
+ logInfo("Shuffle " + shuffleId + " cleaned")
+ }
+
+ def broadcastCleaned(broadcastId: Long): Unit = {
+ toBeCleanedBroadcstIds -= broadcastId
+ logInfo("Broadcast" + broadcastId + " cleaned")
+ }
+ }
+
+ val MAX_VALIDATION_ATTEMPTS = 10
+ val VALIDATION_ATTEMPT_INTERVAL = 100
+
+ logInfo("Attempting to validate before cleanup:\n" + uncleanedResourcesToString)
+ preCleanupValidate()
+ sc.cleaner.get.attachListener(cleanerListener)
+
+ /** Assert that all the stuff has been cleaned up */
+ def assertCleanup()(implicit waitTimeout: Eventually.Timeout) {
+ try {
+ eventually(waitTimeout, interval(100 millis)) {
+ assert(isAllCleanedUp)
+ }
+ postCleanupValidate()
+ } finally {
+ logInfo("Resources left from cleaning up:\n" + uncleanedResourcesToString)
+ }
+ }
+
+ /** Verify that RDDs, shuffles, etc. occupy resources */
+ private def preCleanupValidate() {
+ assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty, "Nothing to cleanup")
+
+ // Verify the RDDs have been persisted and blocks are present
+ rddIds.foreach { rddId =>
+ assert(
+ sc.persistentRdds.contains(rddId),
+ "RDD " + rddId + " have not been persisted, cannot start cleaner test"
+ )
+
+ assert(
+ !getRDDBlocks(rddId).isEmpty,
+ "Blocks of RDD " + rddId + " cannot be found in block manager, " +
+ "cannot start cleaner test"
+ )
+ }
+
+ // Verify the shuffle ids are registered and blocks are present
+ shuffleIds.foreach { shuffleId =>
+ assert(
+ mapOutputTrackerMaster.containsShuffle(shuffleId),
+ "Shuffle " + shuffleId + " have not been registered, cannot start cleaner test"
+ )
+
+ assert(
+ !getShuffleBlocks(shuffleId).isEmpty,
+ "Blocks of shuffle " + shuffleId + " cannot be found in block manager, " +
+ "cannot start cleaner test"
+ )
+ }
+
+ // Verify that the broadcast blocks are present
+ broadcastIds.foreach { broadcastId =>
+ assert(
+ !getBroadcastBlocks(broadcastId).isEmpty,
+ "Blocks of broadcast " + broadcastId + "cannot be found in block manager, " +
+ "cannot start cleaner test"
+ )
+ }
+ }
+
+ /**
+ * Verify that RDDs, shuffles, etc. do not occupy resources. Tests multiple times as there is
+ * as there is not guarantee on how long it will take clean up the resources.
+ */
+ private def postCleanupValidate() {
+ // Verify the RDDs have been persisted and blocks are present
+ rddIds.foreach { rddId =>
+ assert(
+ !sc.persistentRdds.contains(rddId),
+ "RDD " + rddId + " was not cleared from sc.persistentRdds"
+ )
+
+ assert(
+ getRDDBlocks(rddId).isEmpty,
+ "Blocks of RDD " + rddId + " were not cleared from block manager"
+ )
+ }
+
+ // Verify the shuffle ids are registered and blocks are present
+ shuffleIds.foreach { shuffleId =>
+ assert(
+ !mapOutputTrackerMaster.containsShuffle(shuffleId),
+ "Shuffle " + shuffleId + " was not deregistered from map output tracker"
+ )
+
+ assert(
+ getShuffleBlocks(shuffleId).isEmpty,
+ "Blocks of shuffle " + shuffleId + " were not cleared from block manager"
+ )
+ }
+
+ // Verify that the broadcast blocks are present
+ broadcastIds.foreach { broadcastId =>
+ assert(
+ getBroadcastBlocks(broadcastId).isEmpty,
+ "Blocks of broadcast " + broadcastId + " were not cleared from block manager"
+ )
+ }
+ }
+
+ private def uncleanedResourcesToString = {
+ s"""
+ |\tRDDs = ${toBeCleanedRDDIds.toSeq.sorted.mkString("[", ", ", "]")}
+ |\tShuffles = ${toBeCleanedShuffleIds.toSeq.sorted.mkString("[", ", ", "]")}
+ |\tBroadcasts = ${toBeCleanedBroadcstIds.toSeq.sorted.mkString("[", ", ", "]")}
+ """.stripMargin
+ }
+
+ private def isAllCleanedUp =
+ toBeCleanedRDDIds.isEmpty &&
+ toBeCleanedShuffleIds.isEmpty &&
+ toBeCleanedBroadcstIds.isEmpty
+
+ private def getRDDBlocks(rddId: Int): Seq[BlockId] = {
+ blockManager.master.getMatchingBlockIds( _ match {
+ case RDDBlockId(`rddId`, _) => true
+ case _ => false
+ }, askSlaves = true)
+ }
+
+ private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = {
+ blockManager.master.getMatchingBlockIds( _ match {
+ case ShuffleBlockId(`shuffleId`, _, _) => true
+ case _ => false
+ }, askSlaves = true)
+ }
+
+ private def getBroadcastBlocks(broadcastId: Long): Seq[BlockId] = {
+ blockManager.master.getMatchingBlockIds( _ match {
+ case BroadcastBlockId(`broadcastId`, _) => true
+ case _ => false
+ }, askSlaves = true)
+ }
+
+ private def blockManager = sc.env.blockManager
+ private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+}
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index a5bd72eb0a..6b2571cd92 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -57,12 +57,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
tracker.stop()
}
- test("master register and fetch") {
+ test("master register shuffle and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster(conf)
tracker.trackerActor =
actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
tracker.registerShuffle(10, 2)
+ assert(tracker.containsShuffle(10))
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
@@ -77,7 +78,25 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
tracker.stop()
}
- test("master register and unregister and fetch") {
+ test("master register and unregister shuffle") {
+ val actorSystem = ActorSystem("test")
+ val tracker = new MapOutputTrackerMaster(conf)
+ tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
+ tracker.registerShuffle(10, 2)
+ val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+ val compressedSize10000 = MapOutputTracker.compressSize(10000L)
+ tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
+ Array(compressedSize1000, compressedSize10000)))
+ tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
+ Array(compressedSize10000, compressedSize1000)))
+ assert(tracker.containsShuffle(10))
+ assert(tracker.getServerStatuses(10, 0).nonEmpty)
+ tracker.unregisterShuffle(10)
+ assert(!tracker.containsShuffle(10))
+ assert(tracker.getServerStatuses(10, 0).isEmpty)
+ }
+
+ test("master register shuffle and unregister map output and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster(conf)
tracker.trackerActor =
@@ -114,7 +133,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf,
securityManager = new SecurityManager(conf))
- val slaveTracker = new MapOutputTracker(conf)
+ val slaveTracker = new MapOutputTrackerWorker(conf)
val selection = slaveSystem.actorSelection(
s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker")
val timeout = AkkaUtils.lookupTimeout(conf)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index b6dd052610..e10ec7d262 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -28,7 +28,7 @@ import org.scalatest.concurrent.Timeouts._
import org.scalatest.matchers.ShouldMatchers._
import org.scalatest.time.SpanSugar._
-import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils}
@@ -42,6 +42,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
var oldArch: String = null
conf.set("spark.authenticate", "false")
val securityMgr = new SecurityManager(conf)
+ val mapOutputTracker = new MapOutputTrackerMaster(conf)
// Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
conf.set("spark.kryoserializer.buffer.mb", "1")
@@ -130,7 +131,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("master + 1 manager interaction") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf,
+ securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -160,9 +162,10 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("master + 2 managers interaction") {
- store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr)
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf,
+ securityMgr, mapOutputTracker)
store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf,
- securityMgr)
+ securityMgr, mapOutputTracker)
val peers = master.getPeers(store.blockManagerId, 1)
assert(peers.size === 1, "master did not return the other manager as a peer")
@@ -177,7 +180,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("removing block") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf,
+ securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -225,7 +229,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("removing rdd") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf,
+ securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -257,9 +262,82 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
master.getLocations(rdd(0, 1)) should have size 0
}
+ test("removing broadcast") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf,
+ securityMgr, mapOutputTracker)
+ val driverStore = store
+ val executorStore = new BlockManager("executor", actorSystem, master, serializer, 2000, conf,
+ securityMgr, mapOutputTracker)
+ val a1 = new Array[Byte](400)
+ val a2 = new Array[Byte](400)
+ val a3 = new Array[Byte](400)
+ val a4 = new Array[Byte](400)
+
+ val broadcast0BlockId = BroadcastBlockId(0)
+ val broadcast1BlockId = BroadcastBlockId(1)
+ val broadcast2BlockId = BroadcastBlockId(2)
+ val broadcast2BlockId2 = BroadcastBlockId(2, "_")
+
+ // insert broadcast blocks in both the stores
+ Seq(driverStore, executorStore).foreach { case s =>
+ s.putSingle(broadcast0BlockId, a1, StorageLevel.DISK_ONLY)
+ s.putSingle(broadcast1BlockId, a2, StorageLevel.DISK_ONLY)
+ s.putSingle(broadcast2BlockId, a3, StorageLevel.DISK_ONLY)
+ s.putSingle(broadcast2BlockId2, a4, StorageLevel.DISK_ONLY)
+ }
+
+ // verify whether the blocks exist in both the stores
+ Seq(driverStore, executorStore).foreach { case s =>
+ s.getLocal(broadcast0BlockId) should not be (None)
+ s.getLocal(broadcast1BlockId) should not be (None)
+ s.getLocal(broadcast2BlockId) should not be (None)
+ s.getLocal(broadcast2BlockId2) should not be (None)
+ }
+
+ // remove broadcast 0 block only from executors
+ master.removeBroadcast(0, removeFromMaster = false, blocking = true)
+
+ // only broadcast 0 block should be removed from the executor store
+ executorStore.getLocal(broadcast0BlockId) should be (None)
+ executorStore.getLocal(broadcast1BlockId) should not be (None)
+ executorStore.getLocal(broadcast2BlockId) should not be (None)
+
+ // nothing should be removed from the driver store
+ driverStore.getLocal(broadcast0BlockId) should not be (None)
+ driverStore.getLocal(broadcast1BlockId) should not be (None)
+ driverStore.getLocal(broadcast2BlockId) should not be (None)
+
+ // remove broadcast 0 block from the driver as well
+ master.removeBroadcast(0, removeFromMaster = true, blocking = true)
+ driverStore.getLocal(broadcast0BlockId) should be (None)
+ driverStore.getLocal(broadcast1BlockId) should not be (None)
+
+ // remove broadcast 1 block from both the stores asynchronously
+ // and verify all broadcast 1 blocks have been removed
+ master.removeBroadcast(1, removeFromMaster = true, blocking = false)
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ driverStore.getLocal(broadcast1BlockId) should be (None)
+ executorStore.getLocal(broadcast1BlockId) should be (None)
+ }
+
+ // remove broadcast 2 from both the stores asynchronously
+ // and verify all broadcast 2 blocks have been removed
+ master.removeBroadcast(2, removeFromMaster = true, blocking = false)
+ eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+ driverStore.getLocal(broadcast2BlockId) should be (None)
+ driverStore.getLocal(broadcast2BlockId2) should be (None)
+ executorStore.getLocal(broadcast2BlockId) should be (None)
+ executorStore.getLocal(broadcast2BlockId2) should be (None)
+ }
+ executorStore.stop()
+ driverStore.stop()
+ store = null
+ }
+
test("reregistration on heart beat") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf,
+ securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
@@ -275,7 +353,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("reregistration on block update") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf,
+ securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
@@ -294,7 +373,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("reregistration doesn't dead lock") {
val heartBeat = PrivateMethod[Unit]('heartBeat)
- store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 2000, conf,
+ securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = List(new Array[Byte](400))
@@ -331,7 +411,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU storage") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -350,7 +431,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU storage with serialization") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -369,7 +451,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of same RDD") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -388,7 +471,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU for partitions of multiple RDDs") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
@@ -414,7 +498,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
// TODO Make the spark.test.tachyon.enable true after using tachyon 0.5.0 testing jar.
val tachyonUnitTestEnabled = conf.getBoolean("spark.test.tachyon.enable", false)
if (tachyonUnitTestEnabled) {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -430,7 +515,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("on-disk storage") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -443,7 +529,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -458,7 +545,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with getLocalBytes") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -473,7 +561,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -488,7 +577,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("disk and memory storage with serialization and getLocalBytes") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -503,7 +593,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
@@ -525,7 +616,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("in-memory LRU with streams") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@@ -549,7 +641,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("LRU with mixed storage levels and streams") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
val list1 = List(new Array[Byte](200), new Array[Byte](200))
val list2 = List(new Array[Byte](200), new Array[Byte](200))
val list3 = List(new Array[Byte](200), new Array[Byte](200))
@@ -595,7 +688,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("overly large block") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 500, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 500, conf,
+ securityMgr, mapOutputTracker)
store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.getSingle("a1") === None, "a1 was in store")
store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK)
@@ -606,7 +700,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("block compression") {
try {
conf.set("spark.shuffle.compress", "true")
- store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr)
+ store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf,
+ securityMgr, mapOutputTracker)
store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100,
"shuffle_0_0_0 was not compressed")
@@ -614,7 +709,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
store = null
conf.set("spark.shuffle.compress", "false")
- store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf, securityMgr)
+ store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf,
+ securityMgr, mapOutputTracker)
store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000,
"shuffle_0_0_0 was compressed")
@@ -622,7 +718,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
store = null
conf.set("spark.broadcast.compress", "true")
- store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf, securityMgr)
+ store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf,
+ securityMgr, mapOutputTracker)
store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100,
"broadcast_0 was not compressed")
@@ -630,28 +727,32 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
store = null
conf.set("spark.broadcast.compress", "false")
- store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf, securityMgr)
+ store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf,
+ securityMgr, mapOutputTracker)
store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed")
store.stop()
store = null
conf.set("spark.rdd.compress", "true")
- store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf, securityMgr)
+ store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf,
+ securityMgr, mapOutputTracker)
store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed")
store.stop()
store = null
conf.set("spark.rdd.compress", "false")
- store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf, securityMgr)
+ store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf,
+ securityMgr, mapOutputTracker)
store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed")
store.stop()
store = null
// Check that any other block types are also kept uncompressed
- store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf, securityMgr)
+ store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf,
+ securityMgr, mapOutputTracker)
store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY)
assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed")
store.stop()
@@ -666,7 +767,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
test("block store put failure") {
// Use Java serializer so we can create an unserializable error.
store = new BlockManager("<driver>", actorSystem, master, new JavaSerializer(conf), 1200, conf,
- securityMgr)
+ securityMgr, mapOutputTracker)
// The put should fail since a1 is not serializable.
class UnserializableClass
@@ -682,7 +783,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
}
test("updated block statuses") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
val list = List.fill(2)(new Array[Byte](200))
val bigList = List.fill(8)(new Array[Byte](200))
@@ -735,8 +837,83 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
assert(!store.get("list5").isDefined, "list5 was in store")
}
+ test("query block statuses") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
+ val list = List.fill(2)(new Array[Byte](200))
+
+ // Tell master. By LRU, only list2 and list3 remains.
+ store.put("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
+ store.put("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
+ store.put("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
+
+ // getLocations and getBlockStatus should yield the same locations
+ assert(store.master.getLocations("list1").size === 0)
+ assert(store.master.getLocations("list2").size === 1)
+ assert(store.master.getLocations("list3").size === 1)
+ assert(store.master.getBlockStatus("list1", askSlaves = false).size === 0)
+ assert(store.master.getBlockStatus("list2", askSlaves = false).size === 1)
+ assert(store.master.getBlockStatus("list3", askSlaves = false).size === 1)
+ assert(store.master.getBlockStatus("list1", askSlaves = true).size === 0)
+ assert(store.master.getBlockStatus("list2", askSlaves = true).size === 1)
+ assert(store.master.getBlockStatus("list3", askSlaves = true).size === 1)
+
+ // This time don't tell master and see what happens. By LRU, only list5 and list6 remains.
+ store.put("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false)
+ store.put("list5", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+ store.put("list6", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false)
+
+ // getLocations should return nothing because the master is not informed
+ // getBlockStatus without asking slaves should have the same result
+ // getBlockStatus with asking slaves, however, should return the actual block statuses
+ assert(store.master.getLocations("list4").size === 0)
+ assert(store.master.getLocations("list5").size === 0)
+ assert(store.master.getLocations("list6").size === 0)
+ assert(store.master.getBlockStatus("list4", askSlaves = false).size === 0)
+ assert(store.master.getBlockStatus("list5", askSlaves = false).size === 0)
+ assert(store.master.getBlockStatus("list6", askSlaves = false).size === 0)
+ assert(store.master.getBlockStatus("list4", askSlaves = true).size === 0)
+ assert(store.master.getBlockStatus("list5", askSlaves = true).size === 1)
+ assert(store.master.getBlockStatus("list6", askSlaves = true).size === 1)
+ }
+
+ test("get matching blocks") {
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
+ val list = List.fill(2)(new Array[Byte](10))
+
+ // insert some blocks
+ store.put("list1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
+ store.put("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
+ store.put("list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
+
+ // getLocations and getBlockStatus should yield the same locations
+ assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size === 3)
+ assert(store.master.getMatchingBlockIds(_.toString.contains("list1"), askSlaves = false).size === 1)
+
+ // insert some more blocks
+ store.put("newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
+ store.put("newlist2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+ store.put("newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+
+ // getLocations and getBlockStatus should yield the same locations
+ assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size === 1)
+ assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = true).size === 3)
+
+ val blockIds = Seq(RDDBlockId(1, 0), RDDBlockId(1, 1), RDDBlockId(2, 0))
+ blockIds.foreach { blockId =>
+ store.put(blockId, list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true)
+ }
+ val matchedBlockIds = store.master.getMatchingBlockIds(_ match {
+ case RDDBlockId(1, _) => true
+ case _ => false
+ }, askSlaves = true)
+ assert(matchedBlockIds.toSet === Set(RDDBlockId(1, 0), RDDBlockId(1, 1)))
+ }
+
test("SPARK-1194 regression: fix the same-RDD rule for cache replacement") {
- store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf, securityMgr)
+ store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
+ securityMgr, mapOutputTracker)
store.putSingle(rdd(0, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
store.putSingle(rdd(1, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
// Access rdd_1_0 to ensure it's not least recently used.
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
index 62f9b3cc7b..808ddfdcf4 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -59,8 +59,16 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach {
val newFile = diskBlockManager.getFile(blockId)
writeToFile(newFile, 10)
assertSegmentEquals(blockId, blockId.name, 0, 10)
-
+ assert(diskBlockManager.containsBlock(blockId))
newFile.delete()
+ assert(!diskBlockManager.containsBlock(blockId))
+ }
+
+ test("enumerating blocks") {
+ val ids = (1 to 100).map(i => TestBlockId("test_" + i))
+ val files = ids.map(id => diskBlockManager.getFile(id))
+ files.foreach(file => writeToFile(file, 10))
+ assert(diskBlockManager.getAllBlocks.toSet === ids.toSet)
}
test("block appending") {
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 054eb01a64..7bab7da8fe 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -108,8 +108,7 @@ class JsonProtocolSuite extends FunSuite {
// BlockId
testBlockId(RDDBlockId(1, 2))
testBlockId(ShuffleBlockId(1, 2, 3))
- testBlockId(BroadcastBlockId(1L))
- testBlockId(BroadcastHelperBlockId(BroadcastBlockId(2L), "Spark"))
+ testBlockId(BroadcastBlockId(1L, "insert_words_of_wisdom_here"))
testBlockId(TaskResultBlockId(1L))
testBlockId(StreamBlockId(1, 2L))
}
@@ -555,4 +554,4 @@ class JsonProtocolSuite extends FunSuite {
{"Event":"SparkListenerUnpersistRDD","RDD ID":12345}
"""
- }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala
new file mode 100644
index 0000000000..6a5653ed2f
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala
@@ -0,0 +1,264 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.lang.ref.WeakReference
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.Random
+
+import org.scalatest.FunSuite
+
+class TimeStampedHashMapSuite extends FunSuite {
+
+ // Test the testMap function - a Scala HashMap should obviously pass
+ testMap(new mutable.HashMap[String, String]())
+
+ // Test TimeStampedHashMap basic functionality
+ testMap(new TimeStampedHashMap[String, String]())
+ testMapThreadSafety(new TimeStampedHashMap[String, String]())
+
+ // Test TimeStampedWeakValueHashMap basic functionality
+ testMap(new TimeStampedWeakValueHashMap[String, String]())
+ testMapThreadSafety(new TimeStampedWeakValueHashMap[String, String]())
+
+ test("TimeStampedHashMap - clearing by timestamp") {
+ // clearing by insertion time
+ val map = new TimeStampedHashMap[String, String](updateTimeStampOnGet = false)
+ map("k1") = "v1"
+ assert(map("k1") === "v1")
+ Thread.sleep(10)
+ val threshTime = System.currentTimeMillis
+ assert(map.getTimestamp("k1").isDefined)
+ assert(map.getTimestamp("k1").get < threshTime)
+ map.clearOldValues(threshTime)
+ assert(map.get("k1") === None)
+
+ // clearing by modification time
+ val map1 = new TimeStampedHashMap[String, String](updateTimeStampOnGet = true)
+ map1("k1") = "v1"
+ map1("k2") = "v2"
+ assert(map1("k1") === "v1")
+ Thread.sleep(10)
+ val threshTime1 = System.currentTimeMillis
+ Thread.sleep(10)
+ assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime
+ assert(map1.getTimestamp("k1").isDefined)
+ assert(map1.getTimestamp("k1").get < threshTime1)
+ assert(map1.getTimestamp("k2").isDefined)
+ assert(map1.getTimestamp("k2").get >= threshTime1)
+ map1.clearOldValues(threshTime1) //should only clear k1
+ assert(map1.get("k1") === None)
+ assert(map1.get("k2").isDefined)
+ }
+
+ test("TimeStampedWeakValueHashMap - clearing by timestamp") {
+ // clearing by insertion time
+ val map = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = false)
+ map("k1") = "v1"
+ assert(map("k1") === "v1")
+ Thread.sleep(10)
+ val threshTime = System.currentTimeMillis
+ assert(map.getTimestamp("k1").isDefined)
+ assert(map.getTimestamp("k1").get < threshTime)
+ map.clearOldValues(threshTime)
+ assert(map.get("k1") === None)
+
+ // clearing by modification time
+ val map1 = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = true)
+ map1("k1") = "v1"
+ map1("k2") = "v2"
+ assert(map1("k1") === "v1")
+ Thread.sleep(10)
+ val threshTime1 = System.currentTimeMillis
+ Thread.sleep(10)
+ assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime
+ assert(map1.getTimestamp("k1").isDefined)
+ assert(map1.getTimestamp("k1").get < threshTime1)
+ assert(map1.getTimestamp("k2").isDefined)
+ assert(map1.getTimestamp("k2").get >= threshTime1)
+ map1.clearOldValues(threshTime1) //should only clear k1
+ assert(map1.get("k1") === None)
+ assert(map1.get("k2").isDefined)
+ }
+
+ test("TimeStampedWeakValueHashMap - clearing weak references") {
+ var strongRef = new Object
+ val weakRef = new WeakReference(strongRef)
+ val map = new TimeStampedWeakValueHashMap[String, Object]
+ map("k1") = strongRef
+ map("k2") = "v2"
+ map("k3") = "v3"
+ assert(map("k1") === strongRef)
+
+ // clear strong reference to "k1"
+ strongRef = null
+ val startTime = System.currentTimeMillis
+ System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
+ System.runFinalization() // Make a best effort to call finalizer on all cleaned objects.
+ while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
+ System.gc()
+ System.runFinalization()
+ Thread.sleep(100)
+ }
+ assert(map.getReference("k1").isDefined)
+ val ref = map.getReference("k1").get
+ assert(ref.get === null)
+ assert(map.get("k1") === None)
+
+ // operations should only display non-null entries
+ assert(map.iterator.forall { case (k, v) => k != "k1" })
+ assert(map.filter { case (k, v) => k != "k2" }.size === 1)
+ assert(map.filter { case (k, v) => k != "k2" }.head._1 === "k3")
+ assert(map.toMap.size === 2)
+ assert(map.toMap.forall { case (k, v) => k != "k1" })
+ val buffer = new ArrayBuffer[String]
+ map.foreach { case (k, v) => buffer += v.toString }
+ assert(buffer.size === 2)
+ assert(buffer.forall(_ != "k1"))
+ val plusMap = map + (("k4", "v4"))
+ assert(plusMap.size === 3)
+ assert(plusMap.forall { case (k, v) => k != "k1" })
+ val minusMap = map - "k2"
+ assert(minusMap.size === 1)
+ assert(minusMap.head._1 == "k3")
+
+ // clear null values - should only clear k1
+ map.clearNullValues()
+ assert(map.getReference("k1") === None)
+ assert(map.get("k1") === None)
+ assert(map.get("k2").isDefined)
+ assert(map.get("k2").get === "v2")
+ }
+
+ /** Test basic operations of a Scala mutable Map. */
+ def testMap(hashMapConstructor: => mutable.Map[String, String]) {
+ def newMap() = hashMapConstructor
+ val testMap1 = newMap()
+ val testMap2 = newMap()
+ val name = testMap1.getClass.getSimpleName
+
+ test(name + " - basic test") {
+ // put, get, and apply
+ testMap1 += (("k1", "v1"))
+ assert(testMap1.get("k1").isDefined)
+ assert(testMap1.get("k1").get === "v1")
+ testMap1("k2") = "v2"
+ assert(testMap1.get("k2").isDefined)
+ assert(testMap1.get("k2").get === "v2")
+ assert(testMap1("k2") === "v2")
+ testMap1.update("k3", "v3")
+ assert(testMap1.get("k3").isDefined)
+ assert(testMap1.get("k3").get === "v3")
+
+ // remove
+ testMap1.remove("k1")
+ assert(testMap1.get("k1").isEmpty)
+ testMap1.remove("k2")
+ intercept[NoSuchElementException] {
+ testMap1("k2") // Map.apply(<non-existent-key>) causes exception
+ }
+ testMap1 -= "k3"
+ assert(testMap1.get("k3").isEmpty)
+
+ // multi put
+ val keys = (1 to 100).map(_.toString)
+ val pairs = keys.map(x => (x, x * 2))
+ assert((testMap2 ++ pairs).iterator.toSet === pairs.toSet)
+ testMap2 ++= pairs
+
+ // iterator
+ assert(testMap2.iterator.toSet === pairs.toSet)
+
+ // filter
+ val filtered = testMap2.filter { case (_, v) => v.toInt % 2 == 0 }
+ val evenPairs = pairs.filter { case (_, v) => v.toInt % 2 == 0 }
+ assert(filtered.iterator.toSet === evenPairs.toSet)
+
+ // foreach
+ val buffer = new ArrayBuffer[(String, String)]
+ testMap2.foreach(x => buffer += x)
+ assert(testMap2.toSet === buffer.toSet)
+
+ // multi remove
+ testMap2("k1") = "v1"
+ testMap2 --= keys
+ assert(testMap2.size === 1)
+ assert(testMap2.iterator.toSeq.head === ("k1", "v1"))
+
+ // +
+ val testMap3 = testMap2 + (("k0", "v0"))
+ assert(testMap3.size === 2)
+ assert(testMap3.get("k1").isDefined)
+ assert(testMap3.get("k1").get === "v1")
+ assert(testMap3.get("k0").isDefined)
+ assert(testMap3.get("k0").get === "v0")
+
+ // -
+ val testMap4 = testMap3 - "k0"
+ assert(testMap4.size === 1)
+ assert(testMap4.get("k1").isDefined)
+ assert(testMap4.get("k1").get === "v1")
+ }
+ }
+
+ /** Test thread safety of a Scala mutable map. */
+ def testMapThreadSafety(hashMapConstructor: => mutable.Map[String, String]) {
+ def newMap() = hashMapConstructor
+ val name = newMap().getClass.getSimpleName
+ val testMap = newMap()
+ @volatile var error = false
+
+ def getRandomKey(m: mutable.Map[String, String]): Option[String] = {
+ val keys = testMap.keysIterator.toSeq
+ if (keys.nonEmpty) {
+ Some(keys(Random.nextInt(keys.size)))
+ } else {
+ None
+ }
+ }
+
+ val threads = (1 to 25).map(i => new Thread() {
+ override def run() {
+ try {
+ for (j <- 1 to 1000) {
+ Random.nextInt(3) match {
+ case 0 =>
+ testMap(Random.nextString(10)) = Random.nextDouble().toString // put
+ case 1 =>
+ getRandomKey(testMap).map(testMap.get) // get
+ case 2 =>
+ getRandomKey(testMap).map(testMap.remove) // remove
+ }
+ }
+ } catch {
+ case t: Throwable =>
+ error = true
+ throw t
+ }
+ }
+ })
+
+ test(name + " - threading safety test") {
+ threads.map(_.start)
+ threads.map(_.join)
+ assert(!error)
+ }
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index d48b51aa69..d043200f71 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -341,9 +341,11 @@ abstract class DStream[T: ClassTag] (
*/
private[streaming] def clearMetadata(time: Time) {
val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration))
+ logDebug("Clearing references to old RDDs: [" +
+ oldRDDs.map(x => s"${x._1} -> ${x._2.id}").mkString(", ") + "]")
generatedRDDs --= oldRDDs.keys
if (ssc.conf.getBoolean("spark.streaming.unpersist", false)) {
- logDebug("Unpersisting old RDDs: " + oldRDDs.keys.mkString(", "))
+ logDebug("Unpersisting old RDDs: " + oldRDDs.values.map(_.id).mkString(", "))
oldRDDs.values.foreach(_.unpersist(false))
}
logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " +