aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2013-11-02 22:45:15 -0700
committerReynold Xin <rxin@apache.org>2013-11-02 22:45:15 -0700
commitda6bb0aedd5121d9e3b92031dcc0884a9682da05 (patch)
tree818928e20787661304af22153683ab987c0df491 /core/src/main/scala/org
parent3e7df8f6c6edec9e71c6e416de86aa1a2cefc176 (diff)
parent41ead7a74533ffdd208a4ba2f7cd38945b4343ec (diff)
downloadspark-da6bb0aedd5121d9e3b92031dcc0884a9682da05.tar.gz
spark-da6bb0aedd5121d9e3b92031dcc0884a9682da05.tar.bz2
spark-da6bb0aedd5121d9e3b92031dcc0884a9682da05.zip
Merge branch 'master' into hash1
Diffstat (limited to 'core/src/main/scala/org')
-rw-r--r--core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala17
-rw-r--r--core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala168
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala132
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala35
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java10
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java5
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function.java3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function2.java3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function3.java36
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java6
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala1058
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala410
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala54
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala247
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala601
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala41
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala (renamed from core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala)17
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala46
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManager.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/package.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala77
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala128
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala)26
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala)43
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala66
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockId.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockInfo.scala97
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala556
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala133
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala184
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskStore.scala263
-rw-r--r--core/src/main/scala/org/apache/spark/storage/FileSegment.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala59
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala86
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala105
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala32
77 files changed, 2001 insertions, 3303 deletions
diff --git a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
index f87460039b..0c47afae54 100644
--- a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
+++ b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
@@ -17,20 +17,29 @@
package org.apache.hadoop.mapred
+private[apache]
trait SparkHadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = {
- val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", "org.apache.hadoop.mapred.JobContext");
- val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[org.apache.hadoop.mapreduce.JobID])
+ val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl",
+ "org.apache.hadoop.mapred.JobContext")
+ val ctor = klass.getDeclaredConstructor(classOf[JobConf],
+ classOf[org.apache.hadoop.mapreduce.JobID])
ctor.newInstance(conf, jobId).asInstanceOf[JobContext]
}
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = {
- val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", "org.apache.hadoop.mapred.TaskAttemptContext")
+ val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl",
+ "org.apache.hadoop.mapred.TaskAttemptContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID])
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
- def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
+ def newTaskAttemptID(
+ jtIdentifier: String,
+ jobId: Int,
+ isMap: Boolean,
+ taskId: Int,
+ attemptId: Int) = {
new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId)
}
diff --git a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
index 93180307fa..32429f01ac 100644
--- a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
+++ b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
@@ -17,9 +17,10 @@
package org.apache.hadoop.mapreduce
-import org.apache.hadoop.conf.Configuration
import java.lang.{Integer => JInteger, Boolean => JBoolean}
+import org.apache.hadoop.conf.Configuration
+private[apache]
trait SparkHadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = {
val klass = firstAvailableClass(
@@ -37,23 +38,31 @@ trait SparkHadoopMapReduceUtil {
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
- def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
- val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID");
+ def newTaskAttemptID(
+ jtIdentifier: String,
+ jobId: Int,
+ isMap: Boolean,
+ taskId: Int,
+ attemptId: Int) = {
+ val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID")
try {
- // first, attempt to use the old-style constructor that takes a boolean isMap (not available in YARN)
+ // First, attempt to use the old-style constructor that takes a boolean isMap
+ // (not available in YARN)
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean],
- classOf[Int], classOf[Int])
- ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), new
- JInteger(attemptId)).asInstanceOf[TaskAttemptID]
+ classOf[Int], classOf[Int])
+ ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId),
+ new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
} catch {
case exc: NoSuchMethodException => {
- // failed, look for the new ctor that takes a TaskType (not available in 1.x)
- val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType").asInstanceOf[Class[Enum[_]]]
- val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(taskTypeClass, if(isMap) "MAP" else "REDUCE")
+ // If that failed, look for the new constructor that takes a TaskType (not available in 1.x)
+ val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType")
+ .asInstanceOf[Class[Enum[_]]]
+ val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(
+ taskTypeClass, if(isMap) "MAP" else "REDUCE")
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass,
classOf[Int], classOf[Int])
- ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), new
- JInteger(attemptId)).asInstanceOf[TaskAttemptID]
+ ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId),
+ new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 1e3f1ebfaf..5e465fa22c 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -20,13 +20,11 @@ package org.apache.spark
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import akka.actor._
import akka.dispatch._
import akka.pattern.ask
-import akka.remote._
import akka.util.Duration
@@ -40,11 +38,12 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
-private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
+private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
+ extends Actor with Logging {
def receive = {
case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
- sender ! tracker.getSerializedLocations(shuffleId)
+ sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
@@ -60,22 +59,19 @@ private[spark] class MapOutputTracker extends Logging {
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _
- private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
+ protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
- private var epoch: Long = 0
- private val epochLock = new java.lang.Object
+ protected var epoch: Long = 0
+ protected val epochLock = new java.lang.Object
- // Cache a serialized version of the output statuses for each shuffle to send them out faster
- var cacheEpoch = epoch
- private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
-
- val metadataCleaner = new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
+ private val metadataCleaner =
+ new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
- def askTracker(message: Any): Any = {
+ private def askTracker(message: Any): Any = {
try {
val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout)
@@ -86,50 +82,12 @@ private[spark] class MapOutputTracker extends Logging {
}
// Send a one-way message to the trackerActor, to which we expect it to reply with true.
- def communicate(message: Any) {
+ private def communicate(message: Any) {
if (askTracker(message) != true) {
throw new SparkException("Error reply received from MapOutputTracker")
}
}
- def registerShuffle(shuffleId: Int, numMaps: Int) {
- if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
- throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
- }
- }
-
- def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
- var array = mapStatuses(shuffleId)
- array.synchronized {
- array(mapId) = status
- }
- }
-
- def registerMapOutputs(
- shuffleId: Int,
- statuses: Array[MapStatus],
- changeEpoch: Boolean = false) {
- mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
- if (changeEpoch) {
- incrementEpoch()
- }
- }
-
- def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
- var arrayOpt = mapStatuses.get(shuffleId)
- if (arrayOpt.isDefined && arrayOpt.get != null) {
- var array = arrayOpt.get
- array.synchronized {
- if (array(mapId) != null && array(mapId).location == bmAddress) {
- array(mapId) = null
- }
- }
- incrementEpoch()
- } else {
- throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
- }
- }
-
// Remembers which map output locations are currently being fetched on a worker
private val fetching = new HashSet[Int]
@@ -168,7 +126,7 @@ private[spark] class MapOutputTracker extends Logging {
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
- fetchedStatuses = deserializeStatuses(fetchedBytes)
+ fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
@@ -194,9 +152,8 @@ private[spark] class MapOutputTracker extends Logging {
}
}
- private def cleanup(cleanupTime: Long) {
+ protected def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
- cachedSerializedStatuses.clearOldValues(cleanupTime)
}
def stop() {
@@ -206,15 +163,7 @@ private[spark] class MapOutputTracker extends Logging {
trackerActor = null
}
- // Called on master to increment the epoch number
- def incrementEpoch() {
- epochLock.synchronized {
- epoch += 1
- logDebug("Increasing epoch to " + epoch)
- }
- }
-
- // Called on master or workers to get current epoch number
+ // Called to get current epoch number
def getEpoch: Long = {
epochLock.synchronized {
return epoch
@@ -228,14 +177,62 @@ private[spark] class MapOutputTracker extends Logging {
epochLock.synchronized {
if (newEpoch > epoch) {
logInfo("Updating epoch to " + newEpoch + " and clearing cache")
- // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
- mapStatuses.clear()
epoch = newEpoch
+ mapStatuses.clear()
+ }
+ }
+ }
+}
+
+private[spark] class MapOutputTrackerMaster extends MapOutputTracker {
+
+ // Cache a serialized version of the output statuses for each shuffle to send them out faster
+ private var cacheEpoch = epoch
+ private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
+
+ def registerShuffle(shuffleId: Int, numMaps: Int) {
+ if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
+ throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
+ }
+ }
+
+ def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
+ val array = mapStatuses(shuffleId)
+ array.synchronized {
+ array(mapId) = status
+ }
+ }
+
+ def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
+ mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
+ if (changeEpoch) {
+ incrementEpoch()
+ }
+ }
+
+ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
+ val arrayOpt = mapStatuses.get(shuffleId)
+ if (arrayOpt.isDefined && arrayOpt.get != null) {
+ val array = arrayOpt.get
+ array.synchronized {
+ if (array(mapId) != null && array(mapId).location == bmAddress) {
+ array(mapId) = null
+ }
}
+ incrementEpoch()
+ } else {
+ throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
}
}
- def getSerializedLocations(shuffleId: Int): Array[Byte] = {
+ def incrementEpoch() {
+ epochLock.synchronized {
+ epoch += 1
+ logDebug("Increasing epoch to " + epoch)
+ }
+ }
+
+ def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
var epochGotten: Long = -1
epochLock.synchronized {
@@ -253,7 +250,7 @@ private[spark] class MapOutputTracker extends Logging {
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
- val bytes = serializeStatuses(statuses)
+ val bytes = MapOutputTracker.serializeMapStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the epoch hasn't changed while we were working
epochLock.synchronized {
@@ -261,13 +258,31 @@ private[spark] class MapOutputTracker extends Logging {
cachedSerializedStatuses(shuffleId) = bytes
}
}
- return bytes
+ bytes
+ }
+
+ protected override def cleanup(cleanupTime: Long) {
+ super.cleanup(cleanupTime)
+ cachedSerializedStatuses.clearOldValues(cleanupTime)
}
+ override def stop() {
+ super.stop()
+ cachedSerializedStatuses.clear()
+ }
+
+ override def updateEpoch(newEpoch: Long) {
+ // This might be called on the MapOutputTrackerMaster if we're running in local mode.
+ }
+}
+
+private[spark] object MapOutputTracker {
+ private val LOG_BASE = 1.1
+
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
- private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
+ def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
// Since statuses can be modified in parallel, sync on it
@@ -278,18 +293,11 @@ private[spark] class MapOutputTracker extends Logging {
out.toByteArray
}
- // Opposite of serializeStatuses.
- def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
+ // Opposite of serializeMapStatuses.
+ def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
- objIn.readObject().
- // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
- // comment this out - nulls could be due to missing location ?
- asInstanceOf[Array[MapStatus]] // .filter( _ != null )
+ objIn.readObject().asInstanceOf[Array[MapStatus]]
}
-}
-
-private[spark] object MapOutputTracker {
- private val LOG_BASE = 1.1
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index f674cb397f..880b49e8ef 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -24,7 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.Map
import scala.collection.generic.Growable
-import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@@ -51,25 +51,19 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor
import org.apache.mesos.MesosNativeLibrary
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.deploy.LocalSparkCluster
+import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend,
- ClusterScheduler}
-import org.apache.spark.scheduler.local.LocalScheduler
+import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
+ SparkDeploySchedulerBackend, ClusterScheduler, SimrSchedulerBackend}
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import org.apache.spark.storage.{StorageUtils, BlockManagerSource}
-import org.apache.spark.ui.SparkUI
-import org.apache.spark.util._
-import org.apache.spark.scheduler.StageInfo
-import org.apache.spark.storage.RDDInfo
-import org.apache.spark.storage.StorageStatus
-import scala.Some
+import org.apache.spark.scheduler.local.LocalScheduler
import org.apache.spark.scheduler.StageInfo
-import org.apache.spark.storage.RDDInfo
-import org.apache.spark.storage.StorageStatus
+import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
+import org.apache.spark.ui.SparkUI
+import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType,
+ TimeStampedHashMap, Utils}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -125,7 +119,7 @@ class SparkContext(
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup)
- // Initalize the Spark UI
+ // Initialize the Spark UI
private[spark] val ui = new SparkUI(this)
ui.bind()
@@ -161,8 +155,10 @@ class SparkContext(
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """spark://(.*)""".r
- //Regular expression for connection to Mesos cluster
- val MESOS_REGEX = """(mesos://.*)""".r
+ // Regular expression for connection to Mesos cluster
+ val MESOS_REGEX = """mesos://(.*)""".r
+ // Regular expression for connection to Simr cluster
+ val SIMR_REGEX = """simr://(.*)""".r
master match {
case "local" =>
@@ -181,6 +177,12 @@ class SparkContext(
scheduler.initialize(backend)
scheduler
+ case SIMR_REGEX(simrUrl) =>
+ val scheduler = new ClusterScheduler(this)
+ val backend = new SimrSchedulerBackend(scheduler, this, simrUrl)
+ scheduler.initialize(backend)
+ scheduler
+
case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
// Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
val memoryPerSlaveInt = memoryPerSlave.toInt
@@ -213,25 +215,24 @@ class SparkContext(
throw new SparkException("YARN mode not available ?", th)
}
}
- val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem)
+ val backend = new CoarseGrainedSchedulerBackend(scheduler, this.env.actorSystem)
scheduler.initialize(backend)
scheduler
- case _ =>
- if (MESOS_REGEX.findFirstIn(master).isEmpty) {
- logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
- }
+ case MESOS_REGEX(mesosUrl) =>
MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this)
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
- val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos://
val backend = if (coarseGrained) {
- new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
+ new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName)
} else {
- new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
+ new MesosSchedulerBackend(scheduler, this, mesosUrl, appName)
}
scheduler.initialize(backend)
scheduler
+
+ case _ =>
+ throw new SparkException("Could not parse Master URL: '" + master + "'")
}
}
taskScheduler.start()
@@ -244,7 +245,7 @@ class SparkContext(
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
val env = SparkEnv.get
- val conf = env.hadoop.newConfiguration()
+ val conf = SparkHadoopUtil.get.newConfiguration()
// Explicitly check for S3 environment variables
if (System.getenv("AWS_ACCESS_KEY_ID") != null &&
System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
@@ -254,8 +255,10 @@ class SparkContext(
conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
}
// Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
- for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) {
- conf.set(key.substring("spark.hadoop.".length), System.getProperty(key))
+ Utils.getSystemProperties.foreach { case (key, value) =>
+ if (key.startsWith("spark.hadoop.")) {
+ conf.set(key.substring("spark.hadoop.".length), value)
+ }
}
val bufferSize = System.getProperty("spark.buffer.size", "65536")
conf.set("io.file.buffer.size", bufferSize)
@@ -288,15 +291,46 @@ class SparkContext(
Option(localProperties.get).map(_.getProperty(key)).getOrElse(null)
/** Set a human readable description of the current job. */
+ @deprecated("use setJobGroup", "0.8.1")
def setJobDescription(value: String) {
- setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
+ setJobGroup("", value)
+ }
+
+ /**
+ * Assigns a group id to all the jobs started by this thread until the group id is set to a
+ * different value or cleared.
+ *
+ * Often, a unit of execution in an application consists of multiple Spark actions or jobs.
+ * Application programmers can use this method to group all those jobs together and give a
+ * group description. Once set, the Spark web UI will associate such jobs with this group.
+ *
+ * The application can also use [[org.apache.spark.SparkContext.cancelJobGroup]] to cancel all
+ * running jobs in this group. For example,
+ * {{{
+ * // In the main thread:
+ * sc.setJobGroup("some_job_to_cancel", "some job description")
+ * sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
+ *
+ * // In a separate thread:
+ * sc.cancelJobGroup("some_job_to_cancel")
+ * }}}
+ */
+ def setJobGroup(groupId: String, description: String) {
+ setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description)
+ setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId)
+ }
+
+ /** Clear the job group id and its description. */
+ def clearJobGroup() {
+ setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null)
+ setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null)
}
// Post init
taskScheduler.postStartHook()
- val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
- val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
+ private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
+ private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
def initDriverMetrics() {
SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource)
@@ -347,7 +381,7 @@ class SparkContext(
minSplits: Int = defaultMinSplits
): RDD[(K, V)] = {
// Add necessary security credentials to the JobConf before broadcasting it.
- SparkEnv.get.hadoop.addCredentials(conf)
+ SparkHadoopUtil.get.addCredentials(conf)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
}
@@ -557,7 +591,8 @@ class SparkContext(
val uri = new URI(path)
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
- case _ => path
+ case "local" => "file:" + uri.getPath
+ case _ => path
}
addedFiles(key) = System.currentTimeMillis
@@ -651,12 +686,11 @@ class SparkContext(
/**
* Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
- * filesystems), or an HTTP, HTTPS or FTP URI.
+ * filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node.
*/
def addJar(path: String) {
if (path == null) {
- logWarning("null specified as parameter to addJar",
- new SparkException("null specified as parameter to addJar"))
+ logWarning("null specified as parameter to addJar")
} else {
var key = ""
if (path.contains("\\")) {
@@ -665,8 +699,9 @@ class SparkContext(
} else {
val uri = new URI(path)
key = uri.getScheme match {
+ // A JAR file which exists only on the driver node
case null | "file" =>
- if (env.hadoop.isYarnMode()) {
+ if (SparkHadoopUtil.get.isYarnMode()) {
// In order for this to work on yarn the user must specify the --addjars option to
// the client to upload the file into the distributed cache to make it show up in the
// current working directory.
@@ -682,6 +717,9 @@ class SparkContext(
} else {
env.httpFileServer.addJar(new File(uri.getPath))
}
+ // A JAR file which exists locally on every worker node
+ case "local" =>
+ "file:" + uri.getPath
case _ =>
path
}
@@ -867,13 +905,19 @@ class SparkContext(
callSite,
allowLocal = false,
resultHandler,
- null)
+ localProperties.get)
new SimpleFutureAction(waiter, resultFunc)
}
/**
- * Cancel all jobs that have been scheduled or are running.
+ * Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]]
+ * for more information.
*/
+ def cancelJobGroup(groupId: String) {
+ dagScheduler.cancelJobGroup(groupId)
+ }
+
+ /** Cancel all jobs that have been scheduled or are running. */
def cancelAllJobs() {
dagScheduler.cancelAllJobs()
}
@@ -895,9 +939,8 @@ class SparkContext(
* prevent accidental overriding of checkpoint files in the existing directory.
*/
def setCheckpointDir(dir: String, useExisting: Boolean = false) {
- val env = SparkEnv.get
val path = new Path(dir)
- val fs = path.getFileSystem(env.hadoop.newConfiguration())
+ val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
if (!useExisting) {
if (fs.exists(path)) {
throw new Exception("Checkpoint directory '" + path + "' already exists.")
@@ -934,7 +977,10 @@ class SparkContext(
* various Spark features.
*/
object SparkContext {
- val SPARK_JOB_DESCRIPTION = "spark.job.description"
+
+ private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description"
+
+ private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 29968c273c..ff2df8fb6a 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -25,13 +25,13 @@ import akka.remote.RemoteActorRefProvider
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.storage.{BlockManagerMasterActor, BlockManager, BlockManagerMaster}
import org.apache.spark.network.ConnectionManager
import org.apache.spark.serializer.{Serializer, SerializerManager}
import org.apache.spark.util.{Utils, AkkaUtils}
import org.apache.spark.api.python.PythonWorkerFactory
+import com.google.common.collect.MapMaker
/**
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
@@ -58,18 +58,9 @@ class SparkEnv (
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
- val hadoop = {
- val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
- if(yarnMode) {
- try {
- Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil]
- } catch {
- case th: Throwable => throw new SparkException("Unable to load YARN support", th)
- }
- } else {
- new SparkHadoopUtil
- }
- }
+ // A general, soft-reference map for metadata needed during HadoopRDD split computation
+ // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
+ private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
def stop() {
pythonWorkers.foreach { case(key, worker) => worker.stop() }
@@ -187,10 +178,14 @@ object SparkEnv extends Logging {
// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
- val mapOutputTracker = new MapOutputTracker()
+ val mapOutputTracker = if (isDriver) {
+ new MapOutputTrackerMaster()
+ } else {
+ new MapOutputTracker()
+ }
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
- new MapOutputTrackerActor(mapOutputTracker))
+ new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 2bab9d6e3d..103a1c2051 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -17,14 +17,14 @@
package org.apache.hadoop.mapred
-import org.apache.hadoop.fs.FileSystem
-import org.apache.hadoop.fs.Path
-
+import java.io.IOException
import java.text.SimpleDateFormat
import java.text.NumberFormat
-import java.io.IOException
import java.util.Date
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.Logging
import org.apache.spark.SerializableWritable
@@ -36,7 +36,11 @@ import org.apache.spark.SerializableWritable
* Saves the RDD using a JobConf, which should contain an output key class, an output value class,
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
-class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable {
+private[apache]
+class SparkHadoopWriter(@transient jobConf: JobConf)
+ extends Logging
+ with SparkHadoopMapRedUtil
+ with Serializable {
private val now = new Date()
private val conf = new SerializableWritable(jobConf)
@@ -83,13 +87,11 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH
}
getOutputCommitter().setupTask(getTaskContext())
- writer = getOutputFormat().getRecordWriter(
- fs, conf.value, outputName, Reporter.NULL)
+ writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL)
}
def write(key: AnyRef, value: AnyRef) {
- if (writer!=null) {
- //println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")")
+ if (writer != null) {
writer.write(key, value)
} else {
throw new IOException("Writer is null, open() has not been called")
@@ -179,6 +181,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH
}
}
+private[apache]
object SparkHadoopWriter {
def createJobID(time: Date, id: Int): JobID = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 51584d686d..cae983ed4c 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.executor.TaskMetrics
class TaskContext(
- private[spark] val stageId: Int,
+ val stageId: Int,
val partitionId: Int,
val attemptId: Long,
val runningLocally: Boolean = false,
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index 5fd1fab580..043cb183ba 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -48,6 +48,19 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
*/
def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel))
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ * This method blocks until all blocks are deleted.
+ */
+ def unpersist(): JavaDoubleRDD = fromRDD(srdd.unpersist())
+
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ *
+ * @param blocking Whether to block until all blocks are deleted.
+ */
+ def unpersist(blocking: Boolean): JavaDoubleRDD = fromRDD(srdd.unpersist(blocking))
+
// first() has to be overriden here in order for its return type to be Double instead of Object.
override def first(): Double = srdd.first()
@@ -81,6 +94,17 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
fromRDD(srdd.coalesce(numPartitions, shuffle))
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.repartition(numPartitions))
+
+ /**
* Return an RDD with the elements from `this` that are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index a6518abf45..2142fd7327 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -65,6 +65,19 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
def persist(newLevel: StorageLevel): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.persist(newLevel))
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ * This method blocks until all blocks are deleted.
+ */
+ def unpersist(): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist())
+
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ *
+ * @param blocking Whether to block until all blocks are deleted.
+ */
+ def unpersist(blocking: Boolean): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist(blocking))
+
// Transformations (return a new RDD)
/**
@@ -95,6 +108,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
fromRDD(rdd.coalesce(numPartitions, shuffle))
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.repartition(numPartitions))
+
+ /**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
@@ -598,4 +622,15 @@ object JavaPairRDD {
new JavaPairRDD[K, V](rdd)
implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd
+
+
+ /** Convert a JavaRDD of key-value pairs to JavaPairRDD. */
+ def fromJavaRDD[K, V](rdd: JavaRDD[(K, V)]): JavaPairRDD[K, V] = {
+ implicit val cmk: ClassManifest[K] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
+ implicit val cmv: ClassManifest[V] =
+ implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
+ new JavaPairRDD[K, V](rdd.rdd)
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index eec58abdd6..3b359a8fd6 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -41,9 +41,17 @@ JavaRDDLike[T, JavaRDD[T]] {
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ * This method blocks until all blocks are deleted.
*/
def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist())
+ /**
+ * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
+ *
+ * @param blocking Whether to block until all blocks are deleted.
+ */
+ def unpersist(blocking: Boolean): JavaRDD[T] = wrapRDD(rdd.unpersist(blocking))
+
// Transformations (return a new RDD)
/**
@@ -74,6 +82,17 @@ JavaRDDLike[T, JavaRDD[T]] {
rdd.coalesce(numPartitions, shuffle)
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions)
+
+ /**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
index 4830067f7a..3e85052cd0 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java
@@ -18,8 +18,6 @@
package org.apache.spark.api.java.function;
-import scala.runtime.AbstractFunction1;
-
import java.io.Serializable;
/**
@@ -27,11 +25,7 @@ import java.io.Serializable;
*/
// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is
// overloaded for both FlatMapFunction and DoubleFlatMapFunction.
-public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>>
+public abstract class DoubleFlatMapFunction<T> extends WrappedFunction1<T, Iterable<Double>>
implements Serializable {
-
- public abstract Iterable<Double> call(T t);
-
- @Override
- public final Iterable<Double> apply(T t) { return call(t); }
+ // Intentionally left blank
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java
index db34cd190a..5e9b8c48b8 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java
@@ -18,8 +18,6 @@
package org.apache.spark.api.java.function;
-import scala.runtime.AbstractFunction1;
-
import java.io.Serializable;
/**
@@ -29,6 +27,5 @@ import java.io.Serializable;
// are overloaded for both Function and DoubleFunction.
public abstract class DoubleFunction<T> extends WrappedFunction1<T, Double>
implements Serializable {
-
- public abstract Double call(T t) throws Exception;
+ // Intentionally left blank
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
index 158539a846..2dfda8b09a 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
@@ -21,8 +21,5 @@ package org.apache.spark.api.java.function
* A function that returns zero or more output records from each input record.
*/
abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
- @throws(classOf[Exception])
- def call(x: T) : java.lang.Iterable[R]
-
def elementType() : ClassManifest[R] = ClassManifest.Any.asInstanceOf[ClassManifest[R]]
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
index 5ef6a814f5..528e1c0a7c 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
@@ -21,8 +21,5 @@ package org.apache.spark.api.java.function
* A function that takes two inputs and returns zero or more output records.
*/
abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
- @throws(classOf[Exception])
- def call(a: A, b:B) : java.lang.Iterable[C]
-
def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]]
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function.java b/core/src/main/scala/org/apache/spark/api/java/function/Function.java
index b9070cfd83..ce368ee01b 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function.java
@@ -19,7 +19,6 @@ package org.apache.spark.api.java.function;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
-import scala.runtime.AbstractFunction1;
import java.io.Serializable;
@@ -30,8 +29,6 @@ import java.io.Serializable;
* when mapping RDDs of other types.
*/
public abstract class Function<T, R> extends WrappedFunction1<T, R> implements Serializable {
- public abstract R call(T t) throws Exception;
-
public ClassManifest<R> returnType() {
return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java
index d4c9154869..44ad559d48 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java
@@ -19,7 +19,6 @@ package org.apache.spark.api.java.function;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
-import scala.runtime.AbstractFunction2;
import java.io.Serializable;
@@ -29,8 +28,6 @@ import java.io.Serializable;
public abstract class Function2<T1, T2, R> extends WrappedFunction2<T1, T2, R>
implements Serializable {
- public abstract R call(T1 t1, T2 t2) throws Exception;
-
public ClassManifest<R> returnType() {
return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function3.java b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java
new file mode 100644
index 0000000000..ac6178924a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import scala.reflect.ClassManifest;
+import scala.reflect.ClassManifest$;
+import scala.runtime.AbstractFunction2;
+
+import java.io.Serializable;
+
+/**
+ * A three-argument function that takes arguments of type T1, T2 and T3 and returns an R.
+ */
+public abstract class Function3<T1, T2, T3, R> extends WrappedFunction3<T1, T2, T3, R>
+ implements Serializable {
+
+ public ClassManifest<R> returnType() {
+ return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
index c0e5544b7d..6d76a8f970 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
@@ -20,7 +20,6 @@ package org.apache.spark.api.java.function;
import scala.Tuple2;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
-import scala.runtime.AbstractFunction1;
import java.io.Serializable;
@@ -34,8 +33,6 @@ public abstract class PairFlatMapFunction<T, K, V>
extends WrappedFunction1<T, Iterable<Tuple2<K, V>>>
implements Serializable {
- public abstract Iterable<Tuple2<K, V>> call(T t) throws Exception;
-
public ClassManifest<K> keyType() {
return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
index 40480fe8e8..ede7ceefb5 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
@@ -20,7 +20,6 @@ package org.apache.spark.api.java.function;
import scala.Tuple2;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
-import scala.runtime.AbstractFunction1;
import java.io.Serializable;
@@ -29,12 +28,9 @@ import java.io.Serializable;
*/
// PairFunction does not extend Function because some UDF functions, like map,
// are overloaded for both Function and PairFunction.
-public abstract class PairFunction<T, K, V>
- extends WrappedFunction1<T, Tuple2<K, V>>
+public abstract class PairFunction<T, K, V> extends WrappedFunction1<T, Tuple2<K, V>>
implements Serializable {
- public abstract Tuple2<K, V> call(T t) throws Exception;
-
public ClassManifest<K> keyType() {
return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala
new file mode 100644
index 0000000000..d314dbdf1d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function
+
+import scala.runtime.AbstractFunction3
+
+/**
+ * Subclass of Function3 for ease of calling from Java. The main thing it does is re-expose the
+ * apply() method as call() and declare that it can throw Exception (since AbstractFunction3.apply
+ * isn't marked to allow that).
+ */
+private[spark] abstract class WrappedFunction3[T1, T2, T3, R]
+ extends AbstractFunction3[T1, T2, T3, R] {
+ @throws(classOf[Exception])
+ def call(t1: T1, t2: T2, t3: T3): R
+
+ final def apply(t1: T1, t2: T2, t3: T3): R = call(t1, t2, t3)
+}
+
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 1f8ad688a6..12b4d94a56 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -308,7 +308,7 @@ private class BytesToString extends org.apache.spark.api.java.function.Function[
* Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
* collects a list of pickled strings that we pass to Python through a socket.
*/
-class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
+private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
Utils.checkHost(serverHost, "Expected hostname")
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala
deleted file mode 100644
index b6c484bfe1..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala
+++ /dev/null
@@ -1,1058 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.io._
-import java.net._
-import java.util.{BitSet, Comparator, Timer, TimerTask, UUID}
-import java.util.concurrent.atomic.AtomicInteger
-
-import scala.collection.mutable.{ListBuffer, Map, Set}
-import scala.math
-
-import org.apache.spark._
-import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
-import org.apache.spark.util.Utils
-
-private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
- extends Broadcast[T](id)
- with Logging
- with Serializable {
-
- def value = value_
-
- def blockId = BroadcastBlockId(id)
-
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- }
-
- @transient var arrayOfBlocks: Array[BroadcastBlock] = null
- @transient var hasBlocksBitVector: BitSet = null
- @transient var numCopiesSent: Array[Int] = null
- @transient var totalBytes = -1
- @transient var totalBlocks = -1
- @transient var hasBlocks = new AtomicInteger(0)
-
- // Used ONLY by driver to track how many unique blocks have been sent out
- @transient var sentBlocks = new AtomicInteger(0)
-
- @transient var listenPortLock = new Object
- @transient var guidePortLock = new Object
- @transient var totalBlocksLock = new Object
-
- @transient var listOfSources = ListBuffer[SourceInfo]()
-
- @transient var serveMR: ServeMultipleRequests = null
-
- // Used only in driver
- @transient var guideMR: GuideMultipleRequests = null
-
- // Used only in Workers
- @transient var ttGuide: TalkToGuide = null
-
- @transient var hostAddress = Utils.localIpAddress
- @transient var listenPort = -1
- @transient var guidePort = -1
-
- @transient var stopBroadcast = false
-
- // Must call this after all the variables have been created/initialized
- if (!isLocal) {
- sendBroadcast()
- }
-
- def sendBroadcast() {
- logInfo("Local host address: " + hostAddress)
-
- // Create a variableInfo object and store it in valueInfos
- var variableInfo = MultiTracker.blockifyObject(value_)
-
- // Prepare the value being broadcasted
- arrayOfBlocks = variableInfo.arrayOfBlocks
- totalBytes = variableInfo.totalBytes
- totalBlocks = variableInfo.totalBlocks
- hasBlocks.set(variableInfo.totalBlocks)
-
- // Guide has all the blocks
- hasBlocksBitVector = new BitSet(totalBlocks)
- hasBlocksBitVector.set(0, totalBlocks)
-
- // Guide still hasn't sent any block
- numCopiesSent = new Array[Int](totalBlocks)
-
- guideMR = new GuideMultipleRequests
- guideMR.setDaemon(true)
- guideMR.start()
- logInfo("GuideMultipleRequests started...")
-
- // Must always come AFTER guideMR is created
- while (guidePort == -1) {
- guidePortLock.synchronized { guidePortLock.wait() }
- }
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- // Must always come AFTER serveMR is created
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Must always come AFTER listenPort is created
- val driverSource =
- SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
- hasBlocksBitVector.synchronized {
- driverSource.hasBlocksBitVector = hasBlocksBitVector
- }
-
- // In the beginning, this is the only known source to Guide
- listOfSources += driverSource
-
- // Register with the Tracker
- MultiTracker.registerBroadcast(id,
- SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
- }
-
- private def readObject(in: ObjectInputStream) {
- in.defaultReadObject()
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.getSingle(blockId) match {
- case Some(x) =>
- value_ = x.asInstanceOf[T]
-
- case None =>
- logInfo("Started reading broadcast variable " + id)
- // Initializing everything because driver will only send null/0 values
- // Only the 1st worker in a node can be here. Others will get from cache
- initializeWorkerVariables()
-
- logInfo("Local host address: " + hostAddress)
-
- // Start local ServeMultipleRequests thread first
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- val start = System.nanoTime
-
- val receptionSucceeded = receiveBroadcast(id)
- if (receptionSucceeded) {
- value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- SparkEnv.get.blockManager.putSingle(
- blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- } else {
- logError("Reading broadcast variable " + id + " failed")
- }
-
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading broadcast variable " + id + " took " + time + " s")
- }
- }
- }
-
- // Initialize variables in the worker node. Driver sends everything as 0/null
- private def initializeWorkerVariables() {
- arrayOfBlocks = null
- hasBlocksBitVector = null
- numCopiesSent = null
- totalBytes = -1
- totalBlocks = -1
- hasBlocks = new AtomicInteger(0)
-
- listenPortLock = new Object
- totalBlocksLock = new Object
-
- serveMR = null
- ttGuide = null
-
- hostAddress = Utils.localIpAddress
- listenPort = -1
-
- listOfSources = ListBuffer[SourceInfo]()
-
- stopBroadcast = false
- }
-
- private def getLocalSourceInfo: SourceInfo = {
- // Wait till hostName and listenPort are OK
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Wait till totalBlocks and totalBytes are OK
- while (totalBlocks == -1) {
- totalBlocksLock.synchronized { totalBlocksLock.wait() }
- }
-
- var localSourceInfo = SourceInfo(
- hostAddress, listenPort, totalBlocks, totalBytes)
-
- localSourceInfo.hasBlocks = hasBlocks.get
-
- hasBlocksBitVector.synchronized {
- localSourceInfo.hasBlocksBitVector = hasBlocksBitVector
- }
-
- return localSourceInfo
- }
-
- // Add new SourceInfo to the listOfSources. Update if it exists already.
- // Optimizing just by OR-ing the BitVectors was BAD for performance
- private def addToListOfSources(newSourceInfo: SourceInfo) {
- listOfSources.synchronized {
- if (listOfSources.contains(newSourceInfo)) {
- listOfSources = listOfSources - newSourceInfo
- }
- listOfSources += newSourceInfo
- }
- }
-
- private def addToListOfSources(newSourceInfos: ListBuffer[SourceInfo]) {
- newSourceInfos.foreach { newSourceInfo =>
- addToListOfSources(newSourceInfo)
- }
- }
-
- class TalkToGuide(gInfo: SourceInfo)
- extends Thread with Logging {
- override def run() {
-
- // Keep exchaning information until all blocks have been received
- while (hasBlocks.get < totalBlocks) {
- talkOnce
- Thread.sleep(MultiTracker.ranGen.nextInt(
- MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
- MultiTracker.MinKnockInterval)
- }
-
- // Talk one more time to let the Guide know of reception completion
- talkOnce
- }
-
- // Connect to Guide and send this worker's information
- private def talkOnce {
- var clientSocketToGuide: Socket = null
- var oosGuide: ObjectOutputStream = null
- var oisGuide: ObjectInputStream = null
-
- clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort)
- oosGuide = new ObjectOutputStream(clientSocketToGuide.getOutputStream)
- oosGuide.flush()
- oisGuide = new ObjectInputStream(clientSocketToGuide.getInputStream)
-
- // Send local information
- oosGuide.writeObject(getLocalSourceInfo)
- oosGuide.flush()
-
- // Receive source information from Guide
- var suitableSources =
- oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]]
- logDebug("Received suitableSources from Driver " + suitableSources)
-
- addToListOfSources(suitableSources)
-
- oisGuide.close()
- oosGuide.close()
- clientSocketToGuide.close()
- }
- }
-
- def receiveBroadcast(variableID: Long): Boolean = {
- val gInfo = MultiTracker.getGuideInfo(variableID)
-
- if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
- return false
- }
-
- // Wait until hostAddress and listenPort are created by the
- // ServeMultipleRequests thread
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Setup initial states of variables
- totalBlocks = gInfo.totalBlocks
- arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
- hasBlocksBitVector = new BitSet(totalBlocks)
- numCopiesSent = new Array[Int](totalBlocks)
- totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
- totalBytes = gInfo.totalBytes
-
- // Start ttGuide to periodically talk to the Guide
- var ttGuide = new TalkToGuide(gInfo)
- ttGuide.setDaemon(true)
- ttGuide.start()
- logInfo("TalkToGuide started...")
-
- // Start pController to run TalkToPeer threads
- var pcController = new PeerChatterController
- pcController.setDaemon(true)
- pcController.start()
- logInfo("PeerChatterController started...")
-
- // FIXME: Must fix this. This might never break if broadcast fails.
- // We should be able to break and send false. Also need to kill threads
- while (hasBlocks.get < totalBlocks) {
- Thread.sleep(MultiTracker.MaxKnockInterval)
- }
-
- return true
- }
-
- class PeerChatterController
- extends Thread with Logging {
- private var peersNowTalking = ListBuffer[SourceInfo]()
- // TODO: There is a possible bug with blocksInRequestBitVector when a
- // certain bit is NOT unset upon failure resulting in an infinite loop.
- private var blocksInRequestBitVector = new BitSet(totalBlocks)
-
- override def run() {
- var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
-
- while (hasBlocks.get < totalBlocks) {
- var numThreadsToCreate = 0
- listOfSources.synchronized {
- numThreadsToCreate = math.min(listOfSources.size, MultiTracker.MaxChatSlots) -
- threadPool.getActiveCount
- }
-
- while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) {
- var peerToTalkTo = pickPeerToTalkToRandom
-
- if (peerToTalkTo != null)
- logDebug("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector)
- else
- logDebug("No peer chosen...")
-
- if (peerToTalkTo != null) {
- threadPool.execute(new TalkToPeer(peerToTalkTo))
-
- // Add to peersNowTalking. Remove in the thread. We have to do this
- // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once
- peersNowTalking.synchronized { peersNowTalking += peerToTalkTo }
- }
-
- numThreadsToCreate = numThreadsToCreate - 1
- }
-
- // Sleep for a while before starting some more threads
- Thread.sleep(MultiTracker.MinKnockInterval)
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- // Right now picking the one that has the most blocks this peer wants
- // Also picking peer randomly if no one has anything interesting
- private def pickPeerToTalkToRandom: SourceInfo = {
- var curPeer: SourceInfo = null
- var curMax = 0
-
- logDebug("Picking peers to talk to...")
-
- // Find peers that are not connected right now
- var peersNotInUse = ListBuffer[SourceInfo]()
- listOfSources.synchronized {
- peersNowTalking.synchronized {
- peersNotInUse = listOfSources -- peersNowTalking
- }
- }
-
- // Select the peer that has the most blocks that this receiver does not
- peersNotInUse.foreach { eachSource =>
- var tempHasBlocksBitVector: BitSet = null
- hasBlocksBitVector.synchronized {
- tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
- }
- tempHasBlocksBitVector.flip(0, tempHasBlocksBitVector.size)
- tempHasBlocksBitVector.and(eachSource.hasBlocksBitVector)
-
- if (tempHasBlocksBitVector.cardinality > curMax) {
- curPeer = eachSource
- curMax = tempHasBlocksBitVector.cardinality
- }
- }
-
- // Always picking randomly
- if (curPeer == null && peersNotInUse.size > 0) {
- // Pick uniformly the i'th required peer
- var i = MultiTracker.ranGen.nextInt(peersNotInUse.size)
-
- var peerIter = peersNotInUse.iterator
- curPeer = peerIter.next
-
- while (i > 0) {
- curPeer = peerIter.next
- i = i - 1
- }
- }
-
- return curPeer
- }
-
- // Picking peer with the weight of rare blocks it has
- private def pickPeerToTalkToRarestFirst: SourceInfo = {
- // Find peers that are not connected right now
- var peersNotInUse = ListBuffer[SourceInfo]()
- listOfSources.synchronized {
- peersNowTalking.synchronized {
- peersNotInUse = listOfSources -- peersNowTalking
- }
- }
-
- // Count the number of copies of each block in the neighborhood
- var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0)
-
- listOfSources.synchronized {
- listOfSources.foreach { eachSource =>
- for (i <- 0 until totalBlocks) {
- numCopiesPerBlock(i) +=
- ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 )
- }
- }
- }
-
- // A block is considered rare if there are at most 2 copies of that block
- // This CONSTANT could be a function of the neighborhood size
- var rareBlocksIndices = ListBuffer[Int]()
- for (i <- 0 until totalBlocks) {
- if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) {
- rareBlocksIndices += i
- }
- }
-
- // Find peers with rare blocks
- var peersWithRareBlocks = ListBuffer[(SourceInfo, Int)]()
- var totalRareBlocks = 0
-
- peersNotInUse.foreach { eachPeer =>
- var hasRareBlocks = 0
- rareBlocksIndices.foreach { rareBlock =>
- if (eachPeer.hasBlocksBitVector.get(rareBlock)) {
- hasRareBlocks += 1
- }
- }
-
- if (hasRareBlocks > 0) {
- peersWithRareBlocks += ((eachPeer, hasRareBlocks))
- }
- totalRareBlocks += hasRareBlocks
- }
-
- // Select a peer from peersWithRareBlocks based on weight calculated from
- // unique rare blocks
- var selectedPeerToTalkTo: SourceInfo = null
-
- if (peersWithRareBlocks.size > 0) {
- // Sort the peers based on how many rare blocks they have
- peersWithRareBlocks.sortBy(_._2)
-
- var randomNumber = MultiTracker.ranGen.nextDouble
- var tempSum = 0.0
-
- var i = 0
- do {
- tempSum += (1.0 * peersWithRareBlocks(i)._2 / totalRareBlocks)
- if (tempSum >= randomNumber) {
- selectedPeerToTalkTo = peersWithRareBlocks(i)._1
- }
- i += 1
- } while (i < peersWithRareBlocks.size && selectedPeerToTalkTo == null)
- }
-
- if (selectedPeerToTalkTo == null) {
- selectedPeerToTalkTo = pickPeerToTalkToRandom
- }
-
- return selectedPeerToTalkTo
- }
-
- class TalkToPeer(peerToTalkTo: SourceInfo)
- extends Thread with Logging {
- private var peerSocketToSource: Socket = null
- private var oosSource: ObjectOutputStream = null
- private var oisSource: ObjectInputStream = null
-
- override def run() {
- // TODO: There is a possible bug here regarding blocksInRequestBitVector
- var blockToAskFor = -1
-
- // Setup the timeout mechanism
- var timeOutTask = new TimerTask {
- override def run() {
- cleanUpConnections()
- }
- }
-
- var timeOutTimer = new Timer
- timeOutTimer.schedule(timeOutTask, MultiTracker.MaxKnockInterval)
-
- logInfo("TalkToPeer started... => " + peerToTalkTo)
-
- try {
- // Connect to the source
- peerSocketToSource =
- new Socket(peerToTalkTo.hostAddress, peerToTalkTo.listenPort)
- oosSource =
- new ObjectOutputStream(peerSocketToSource.getOutputStream)
- oosSource.flush()
- oisSource =
- new ObjectInputStream(peerSocketToSource.getInputStream)
-
- // Receive latest SourceInfo from peerToTalkTo
- var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo]
- // Update listOfSources
- addToListOfSources(newPeerToTalkTo)
-
- // Turn the timer OFF, if the sender responds before timeout
- timeOutTimer.cancel()
-
- // Send the latest SourceInfo
- oosSource.writeObject(getLocalSourceInfo)
- oosSource.flush()
-
- var keepReceiving = true
-
- while (hasBlocks.get < totalBlocks && keepReceiving) {
- blockToAskFor =
- pickBlockRandom(newPeerToTalkTo.hasBlocksBitVector)
-
- // No block to request
- if (blockToAskFor < 0) {
- // Nothing to receive from newPeerToTalkTo
- keepReceiving = false
- } else {
- // Let other threads know that blockToAskFor is being requested
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set(blockToAskFor)
- }
-
- // Start with sending the blockID
- oosSource.writeObject(blockToAskFor)
- oosSource.flush()
-
- // CHANGED: Driver might send some other block than the one
- // requested to ensure fast spreading of all blocks.
- val recvStartTime = System.currentTimeMillis
- val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
- val receptionTime = (System.currentTimeMillis - recvStartTime)
-
- logDebug("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.")
-
- if (!hasBlocksBitVector.get(bcBlock.blockID)) {
- arrayOfBlocks(bcBlock.blockID) = bcBlock
-
- // Update the hasBlocksBitVector first
- hasBlocksBitVector.synchronized {
- hasBlocksBitVector.set(bcBlock.blockID)
- hasBlocks.getAndIncrement
- }
-
- // Some block(may NOT be blockToAskFor) has arrived.
- // In any case, blockToAskFor is not in request any more
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set(blockToAskFor, false)
- }
-
- // Reset blockToAskFor to -1. Else it will be considered missing
- blockToAskFor = -1
- }
-
- // Send the latest SourceInfo
- oosSource.writeObject(getLocalSourceInfo)
- oosSource.flush()
- }
- }
- } catch {
- // EOFException is expected to happen because sender can break
- // connection due to timeout
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logError("TalktoPeer had a " + e)
- // FIXME: Remove 'newPeerToTalkTo' from listOfSources
- // We probably should have the following in some form, but not
- // really here. This exception can happen if the sender just breaks connection
- // listOfSources.synchronized {
- // logInfo("Exception in TalkToPeer. Removing source: " + peerToTalkTo)
- // listOfSources = listOfSources - peerToTalkTo
- // }
- }
- } finally {
- // blockToAskFor != -1 => there was an exception
- if (blockToAskFor != -1) {
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set(blockToAskFor, false)
- }
- }
-
- cleanUpConnections()
- }
- }
-
- // Right now it picks a block uniformly that this peer does not have
- private def pickBlockRandom(txHasBlocksBitVector: BitSet): Int = {
- var needBlocksBitVector: BitSet = null
-
- // Blocks already present
- hasBlocksBitVector.synchronized {
- needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
- }
-
- // Include blocks already in transmission ONLY IF
- // MultiTracker.EndGameFraction has NOT been achieved
- if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) {
- blocksInRequestBitVector.synchronized {
- needBlocksBitVector.or(blocksInRequestBitVector)
- }
- }
-
- // Find blocks that are neither here nor in transit
- needBlocksBitVector.flip(0, needBlocksBitVector.size)
-
- // Blocks that should/can be requested
- needBlocksBitVector.and(txHasBlocksBitVector)
-
- if (needBlocksBitVector.cardinality == 0) {
- return -1
- } else {
- // Pick uniformly the i'th required block
- var i = MultiTracker.ranGen.nextInt(needBlocksBitVector.cardinality)
- var pickedBlockIndex = needBlocksBitVector.nextSetBit(0)
-
- while (i > 0) {
- pickedBlockIndex =
- needBlocksBitVector.nextSetBit(pickedBlockIndex + 1)
- i -= 1
- }
-
- return pickedBlockIndex
- }
- }
-
- // Pick the block that seems to be the rarest across sources
- private def pickBlockRarestFirst(txHasBlocksBitVector: BitSet): Int = {
- var needBlocksBitVector: BitSet = null
-
- // Blocks already present
- hasBlocksBitVector.synchronized {
- needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
- }
-
- // Include blocks already in transmission ONLY IF
- // MultiTracker.EndGameFraction has NOT been achieved
- if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) {
- blocksInRequestBitVector.synchronized {
- needBlocksBitVector.or(blocksInRequestBitVector)
- }
- }
-
- // Find blocks that are neither here nor in transit
- needBlocksBitVector.flip(0, needBlocksBitVector.size)
-
- // Blocks that should/can be requested
- needBlocksBitVector.and(txHasBlocksBitVector)
-
- if (needBlocksBitVector.cardinality == 0) {
- return -1
- } else {
- // Count the number of copies for each block across all sources
- var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0)
-
- listOfSources.synchronized {
- listOfSources.foreach { eachSource =>
- for (i <- 0 until totalBlocks) {
- numCopiesPerBlock(i) +=
- ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 )
- }
- }
- }
-
- // Find the minimum
- var minVal = Integer.MAX_VALUE
- for (i <- 0 until totalBlocks) {
- if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) < minVal) {
- minVal = numCopiesPerBlock(i)
- }
- }
-
- // Find the blocks with the least copies that this peer does not have
- var minBlocksIndices = ListBuffer[Int]()
- for (i <- 0 until totalBlocks) {
- if (needBlocksBitVector.get(i) && numCopiesPerBlock(i) == minVal) {
- minBlocksIndices += i
- }
- }
-
- // Now select a random index from minBlocksIndices
- if (minBlocksIndices.size == 0) {
- return -1
- } else {
- // Pick uniformly the i'th index
- var i = MultiTracker.ranGen.nextInt(minBlocksIndices.size)
- return minBlocksIndices(i)
- }
- }
- }
-
- private def cleanUpConnections() {
- if (oisSource != null) {
- oisSource.close()
- }
- if (oosSource != null) {
- oosSource.close()
- }
- if (peerSocketToSource != null) {
- peerSocketToSource.close()
- }
-
- // Delete from peersNowTalking
- peersNowTalking.synchronized { peersNowTalking -= peerToTalkTo }
- }
- }
- }
-
- class GuideMultipleRequests
- extends Thread with Logging {
- // Keep track of sources that have completed reception
- private var setOfCompletedSources = Set[SourceInfo]()
-
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(0)
- guidePort = serverSocket.getLocalPort
- logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
-
- guidePortLock.synchronized { guidePortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => {
- // Stop broadcast if at least one worker has connected and
- // everyone connected so far are done. Comparing with
- // listOfSources.size - 1, because it includes the Guide itself
- listOfSources.synchronized {
- setOfCompletedSources.synchronized {
- if (listOfSources.size > 1 &&
- setOfCompletedSources.size == listOfSources.size - 1) {
- stopBroadcast = true
- logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
- }
- }
- }
- }
- }
- if (clientSocket != null) {
- logDebug("Guide: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute(new GuideSingleRequest(clientSocket))
- } catch {
- // In failure, close the socket here; else, thread will close it
- case ioe: IOException => {
- clientSocket.close()
- }
- }
- }
- }
-
- // Shutdown the thread pool
- threadPool.shutdown()
-
- logInfo("Sending stopBroadcast notifications...")
- sendStopBroadcastNotifications
-
- MultiTracker.unregisterBroadcast(id)
- } finally {
- if (serverSocket != null) {
- logInfo("GuideMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- }
-
- private def sendStopBroadcastNotifications() {
- listOfSources.synchronized {
- listOfSources.foreach { sourceInfo =>
-
- var guideSocketToSource: Socket = null
- var gosSource: ObjectOutputStream = null
- var gisSource: ObjectInputStream = null
-
- try {
- // Connect to the source
- guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
- gosSource.flush()
- gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
-
- // Throw away whatever comes in
- gisSource.readObject.asInstanceOf[SourceInfo]
-
- // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast
- gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast))
- gosSource.flush()
- } catch {
- case e: Exception => {
- logError("sendStopBroadcastNotifications had a " + e)
- }
- } finally {
- if (gisSource != null) {
- gisSource.close()
- }
- if (gosSource != null) {
- gosSource.close()
- }
- if (guideSocketToSource != null) {
- guideSocketToSource.close()
- }
- }
- }
- }
- }
-
- class GuideSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- private var sourceInfo: SourceInfo = null
- private var selectedSources: ListBuffer[SourceInfo] = null
-
- override def run() {
- try {
- logInfo("new GuideSingleRequest is running")
- // Connecting worker is sending in its information
- sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- // Select a suitable source and send it back to the worker
- selectedSources = selectSuitableSources(sourceInfo)
- logDebug("Sending selectedSources:" + selectedSources)
- oos.writeObject(selectedSources)
- oos.flush()
-
- // Add this source to the listOfSources
- addToListOfSources(sourceInfo)
- } catch {
- case e: Exception => {
- // Assuming exception caused by receiver failure: remove
- if (listOfSources != null) {
- listOfSources.synchronized { listOfSources -= sourceInfo }
- }
- }
- } finally {
- logInfo("GuideSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- // Randomly select some sources to send back
- private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = {
- var selectedSources = ListBuffer[SourceInfo]()
-
- // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true'
- // then add skipSourceInfo to setOfCompletedSources. Return blank.
- if (skipSourceInfo.hasBlocks == totalBlocks) {
- setOfCompletedSources.synchronized { setOfCompletedSources += skipSourceInfo }
- return selectedSources
- }
-
- listOfSources.synchronized {
- if (listOfSources.size <= MultiTracker.MaxPeersInGuideResponse) {
- selectedSources = listOfSources.clone
- } else {
- var picksLeft = MultiTracker.MaxPeersInGuideResponse
- var alreadyPicked = new BitSet(listOfSources.size)
-
- while (picksLeft > 0) {
- var i = -1
-
- do {
- i = MultiTracker.ranGen.nextInt(listOfSources.size)
- } while (alreadyPicked.get(i))
-
- var peerIter = listOfSources.iterator
- var curPeer = peerIter.next
-
- // Set the BitSet before i is decremented
- alreadyPicked.set(i)
-
- while (i > 0) {
- curPeer = peerIter.next
- i = i - 1
- }
-
- selectedSources += curPeer
-
- picksLeft = picksLeft - 1
- }
- }
- }
-
- // Remove the receiving source (if present)
- selectedSources = selectedSources - skipSourceInfo
-
- return selectedSources
- }
- }
- }
-
- class ServeMultipleRequests
- extends Thread with Logging {
- // Server at most MultiTracker.MaxChatSlots peers
- var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
-
- override def run() {
- var serverSocket = new ServerSocket(0)
- listenPort = serverSocket.getLocalPort
-
- logInfo("ServeMultipleRequests started with " + serverSocket)
-
- listenPortLock.synchronized { listenPortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => { }
- }
- if (clientSocket != null) {
- logDebug("Serve: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute(new ServeSingleRequest(clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo("ServeMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- class ServeSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- logInfo("new ServeSingleRequest is running")
-
- override def run() {
- try {
- // Send latest local SourceInfo to the receiver
- // In the case of receiver timeout and connection close, this will
- // throw a java.net.SocketException: Broken pipe
- oos.writeObject(getLocalSourceInfo)
- oos.flush()
-
- // Receive latest SourceInfo from the receiver
- var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) {
- stopBroadcast = true
- } else {
- addToListOfSources(rxSourceInfo)
- }
-
- val startTime = System.currentTimeMillis
- var curTime = startTime
- var keepSending = true
- var numBlocksToSend = MultiTracker.MaxChatBlocks
-
- while (!stopBroadcast && keepSending && numBlocksToSend > 0) {
- // Receive which block to send
- var blockToSend = ois.readObject.asInstanceOf[Int]
-
- // If it is driver AND at least one copy of each block has not been
- // sent out already, MODIFY blockToSend
- if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) {
- blockToSend = sentBlocks.getAndIncrement
- }
-
- // Send the block
- sendBlock(blockToSend)
- rxSourceInfo.hasBlocksBitVector.set(blockToSend)
-
- numBlocksToSend -= 1
-
- // Receive latest SourceInfo from the receiver
- rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
- logDebug("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector)
- addToListOfSources(rxSourceInfo)
-
- curTime = System.currentTimeMillis
- // Revoke sending only if there is anyone waiting in the queue
- if (curTime - startTime >= MultiTracker.MaxChatTime &&
- threadPool.getQueue.size > 0) {
- keepSending = false
- }
- }
- } catch {
- case e: Exception => logError("ServeSingleRequest had a " + e)
- } finally {
- logInfo("ServeSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- private def sendBlock(blockToSend: Int) {
- try {
- oos.writeObject(arrayOfBlocks(blockToSend))
- oos.flush()
- } catch {
- case e: Exception => logError("sendBlock had a " + e)
- }
- logDebug("Sent block: " + blockToSend + " to " + clientSocket)
- }
- }
- }
-}
-
-private[spark] class BitTorrentBroadcastFactory
-extends BroadcastFactory {
- def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
-
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
- new BitTorrentBroadcast[T](value_, isLocal, id)
-
- def stop() { MultiTracker.stop() }
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala b/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala
deleted file mode 100644
index 21ec94659e..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala
+++ /dev/null
@@ -1,410 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.io._
-import java.net._
-import java.util.Random
-
-import scala.collection.mutable.Map
-
-import org.apache.spark._
-import org.apache.spark.util.Utils
-
-private object MultiTracker
-extends Logging {
-
- // Tracker Messages
- val REGISTER_BROADCAST_TRACKER = 0
- val UNREGISTER_BROADCAST_TRACKER = 1
- val FIND_BROADCAST_TRACKER = 2
-
- // Map to keep track of guides of ongoing broadcasts
- var valueToGuideMap = Map[Long, SourceInfo]()
-
- // Random number generator
- var ranGen = new Random
-
- private var initialized = false
- private var _isDriver = false
-
- private var stopBroadcast = false
-
- private var trackMV: TrackMultipleValues = null
-
- def initialize(__isDriver: Boolean) {
- synchronized {
- if (!initialized) {
- _isDriver = __isDriver
-
- if (isDriver) {
- trackMV = new TrackMultipleValues
- trackMV.setDaemon(true)
- trackMV.start()
-
- // Set DriverHostAddress to the driver's IP address for the slaves to read
- System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress)
- }
-
- initialized = true
- }
- }
- }
-
- def stop() {
- stopBroadcast = true
- }
-
- // Load common parameters
- private var DriverHostAddress_ = System.getProperty(
- "spark.MultiTracker.DriverHostAddress", "")
- private var DriverTrackerPort_ = System.getProperty(
- "spark.broadcast.driverTrackerPort", "11111").toInt
- private var BlockSize_ = System.getProperty(
- "spark.broadcast.blockSize", "4096").toInt * 1024
- private var MaxRetryCount_ = System.getProperty(
- "spark.broadcast.maxRetryCount", "2").toInt
-
- private var TrackerSocketTimeout_ = System.getProperty(
- "spark.broadcast.trackerSocketTimeout", "50000").toInt
- private var ServerSocketTimeout_ = System.getProperty(
- "spark.broadcast.serverSocketTimeout", "10000").toInt
-
- private var MinKnockInterval_ = System.getProperty(
- "spark.broadcast.minKnockInterval", "500").toInt
- private var MaxKnockInterval_ = System.getProperty(
- "spark.broadcast.maxKnockInterval", "999").toInt
-
- // Load TreeBroadcast config params
- private var MaxDegree_ = System.getProperty(
- "spark.broadcast.maxDegree", "2").toInt
-
- // Load BitTorrentBroadcast config params
- private var MaxPeersInGuideResponse_ = System.getProperty(
- "spark.broadcast.maxPeersInGuideResponse", "4").toInt
-
- private var MaxChatSlots_ = System.getProperty(
- "spark.broadcast.maxChatSlots", "4").toInt
- private var MaxChatTime_ = System.getProperty(
- "spark.broadcast.maxChatTime", "500").toInt
- private var MaxChatBlocks_ = System.getProperty(
- "spark.broadcast.maxChatBlocks", "1024").toInt
-
- private var EndGameFraction_ = System.getProperty(
- "spark.broadcast.endGameFraction", "0.95").toDouble
-
- def isDriver = _isDriver
-
- // Common config params
- def DriverHostAddress = DriverHostAddress_
- def DriverTrackerPort = DriverTrackerPort_
- def BlockSize = BlockSize_
- def MaxRetryCount = MaxRetryCount_
-
- def TrackerSocketTimeout = TrackerSocketTimeout_
- def ServerSocketTimeout = ServerSocketTimeout_
-
- def MinKnockInterval = MinKnockInterval_
- def MaxKnockInterval = MaxKnockInterval_
-
- // TreeBroadcast configs
- def MaxDegree = MaxDegree_
-
- // BitTorrentBroadcast configs
- def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
-
- def MaxChatSlots = MaxChatSlots_
- def MaxChatTime = MaxChatTime_
- def MaxChatBlocks = MaxChatBlocks_
-
- def EndGameFraction = EndGameFraction_
-
- class TrackMultipleValues
- extends Thread with Logging {
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(DriverTrackerPort)
- logInfo("TrackMultipleValues started at " + serverSocket)
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(TrackerSocketTimeout)
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => {
- if (stopBroadcast) {
- logInfo("Stopping TrackMultipleValues...")
- }
- }
- }
-
- if (clientSocket != null) {
- try {
- threadPool.execute(new Thread {
- override def run() {
- val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- try {
- // First, read message type
- val messageType = ois.readObject.asInstanceOf[Int]
-
- if (messageType == REGISTER_BROADCAST_TRACKER) {
- // Receive Long
- val id = ois.readObject.asInstanceOf[Long]
- // Receive hostAddress and listenPort
- val gInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- // Add to the map
- valueToGuideMap.synchronized {
- valueToGuideMap += (id -> gInfo)
- }
-
- logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
-
- // Send dummy ACK
- oos.writeObject(-1)
- oos.flush()
- } else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
- // Receive Long
- val id = ois.readObject.asInstanceOf[Long]
-
- // Remove from the map
- valueToGuideMap.synchronized {
- valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault)
- }
-
- logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
-
- // Send dummy ACK
- oos.writeObject(-1)
- oos.flush()
- } else if (messageType == FIND_BROADCAST_TRACKER) {
- // Receive Long
- val id = ois.readObject.asInstanceOf[Long]
-
- var gInfo =
- if (valueToGuideMap.contains(id)) valueToGuideMap(id)
- else SourceInfo("", SourceInfo.TxNotStartedRetry)
-
- logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort)
-
- // Send reply back
- oos.writeObject(gInfo)
- oos.flush()
- } else {
- throw new SparkException("Undefined messageType at TrackMultipleValues")
- }
- } catch {
- case e: Exception => {
- logError("TrackMultipleValues had a " + e)
- }
- } finally {
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
- })
- } catch {
- // In failure, close socket here; else, client thread will close
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- serverSocket.close()
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
- }
-
- def getGuideInfo(variableLong: Long): SourceInfo = {
- var clientSocketToTracker: Socket = null
- var oosTracker: ObjectOutputStream = null
- var oisTracker: ObjectInputStream = null
-
- var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry)
-
- var retriesLeft = MultiTracker.MaxRetryCount
- do {
- try {
- // Connect to the tracker to find out GuideInfo
- clientSocketToTracker =
- new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort)
- oosTracker =
- new ObjectOutputStream(clientSocketToTracker.getOutputStream)
- oosTracker.flush()
- oisTracker =
- new ObjectInputStream(clientSocketToTracker.getInputStream)
-
- // Send messageType/intention
- oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER)
- oosTracker.flush()
-
- // Send Long and receive GuideInfo
- oosTracker.writeObject(variableLong)
- oosTracker.flush()
- gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
- } catch {
- case e: Exception => logError("getGuideInfo had a " + e)
- } finally {
- if (oisTracker != null) {
- oisTracker.close()
- }
- if (oosTracker != null) {
- oosTracker.close()
- }
- if (clientSocketToTracker != null) {
- clientSocketToTracker.close()
- }
- }
-
- Thread.sleep(MultiTracker.ranGen.nextInt(
- MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
- MultiTracker.MinKnockInterval)
-
- retriesLeft -= 1
- } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry)
-
- logDebug("Got this guidePort from Tracker: " + gInfo.listenPort)
- return gInfo
- }
-
- def registerBroadcast(id: Long, gInfo: SourceInfo) {
- val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
- val oosST = new ObjectOutputStream(socket.getOutputStream)
- oosST.flush()
- val oisST = new ObjectInputStream(socket.getInputStream)
-
- // Send messageType/intention
- oosST.writeObject(REGISTER_BROADCAST_TRACKER)
- oosST.flush()
-
- // Send Long of this broadcast
- oosST.writeObject(id)
- oosST.flush()
-
- // Send this tracker's information
- oosST.writeObject(gInfo)
- oosST.flush()
-
- // Receive ACK and throw it away
- oisST.readObject.asInstanceOf[Int]
-
- // Shut stuff down
- oisST.close()
- oosST.close()
- socket.close()
- }
-
- def unregisterBroadcast(id: Long) {
- val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
- val oosST = new ObjectOutputStream(socket.getOutputStream)
- oosST.flush()
- val oisST = new ObjectInputStream(socket.getInputStream)
-
- // Send messageType/intention
- oosST.writeObject(UNREGISTER_BROADCAST_TRACKER)
- oosST.flush()
-
- // Send Long of this broadcast
- oosST.writeObject(id)
- oosST.flush()
-
- // Receive ACK and throw it away
- oisST.readObject.asInstanceOf[Int]
-
- // Shut stuff down
- oisST.close()
- oosST.close()
- socket.close()
- }
-
- // Helper method to convert an object to Array[BroadcastBlock]
- def blockifyObject[IN](obj: IN): VariableInfo = {
- val baos = new ByteArrayOutputStream
- val oos = new ObjectOutputStream(baos)
- oos.writeObject(obj)
- oos.close()
- baos.close()
- val byteArray = baos.toByteArray
- val bais = new ByteArrayInputStream(byteArray)
-
- var blockNum = (byteArray.length / BlockSize)
- if (byteArray.length % BlockSize != 0)
- blockNum += 1
-
- var retVal = new Array[BroadcastBlock](blockNum)
- var blockID = 0
-
- for (i <- 0 until (byteArray.length, BlockSize)) {
- val thisBlockSize = math.min(BlockSize, byteArray.length - i)
- var tempByteArray = new Array[Byte](thisBlockSize)
- val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
-
- retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
- blockID += 1
- }
- bais.close()
-
- var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
- variableInfo.hasBlocks = blockNum
-
- return variableInfo
- }
-
- // Helper method to convert Array[BroadcastBlock] to object
- def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
- totalBytes: Int,
- totalBlocks: Int): OUT = {
-
- var retByteArray = new Array[Byte](totalBytes)
- for (i <- 0 until totalBlocks) {
- System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
- i * BlockSize, arrayOfBlocks(i).byteArray.length)
- }
- byteArrayToObject(retByteArray)
- }
-
- private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
- val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
- override def resolveClass(desc: ObjectStreamClass) =
- Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
- }
- val retVal = in.readObject.asInstanceOf[OUT]
- in.close()
- return retVal
- }
-}
-
-private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
-extends Serializable
-
-private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
- totalBlocks: Int,
- totalBytes: Int)
-extends Serializable {
- @transient var hasBlocks = 0
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala b/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala
deleted file mode 100644
index baa1fd6da4..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.util.BitSet
-
-import org.apache.spark._
-
-/**
- * Used to keep and pass around information of peers involved in a broadcast
- */
-private[spark] case class SourceInfo (hostAddress: String,
- listenPort: Int,
- totalBlocks: Int = SourceInfo.UnusedParam,
- totalBytes: Int = SourceInfo.UnusedParam)
-extends Comparable[SourceInfo] with Logging {
-
- var currentLeechers = 0
- var receptionFailed = false
-
- var hasBlocks = 0
- var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
-
- // Ascending sort based on leecher count
- def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
-}
-
-/**
- * Helper Object of SourceInfo for its constants
- */
-private[spark] object SourceInfo {
- // Broadcast has not started yet! Should never happen.
- val TxNotStartedRetry = -1
- // Broadcast has already finished. Try default mechanism.
- val TxOverGoToDefault = -3
- // Other constants
- val StopBroadcast = -2
- val UnusedParam = 0
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
new file mode 100644
index 0000000000..073a0a5029
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import java.io._
+
+import scala.math
+import scala.util.Random
+
+import org.apache.spark._
+import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
+import org.apache.spark.util.Utils
+
+
+private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
+extends Broadcast[T](id) with Logging with Serializable {
+
+ def value = value_
+
+ def broadcastId = BroadcastBlockId(id)
+
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ }
+
+ @transient var arrayOfBlocks: Array[TorrentBlock] = null
+ @transient var totalBlocks = -1
+ @transient var totalBytes = -1
+ @transient var hasBlocks = 0
+
+ if (!isLocal) {
+ sendBroadcast()
+ }
+
+ def sendBroadcast() {
+ var tInfo = TorrentBroadcast.blockifyObject(value_)
+
+ totalBlocks = tInfo.totalBlocks
+ totalBytes = tInfo.totalBytes
+ hasBlocks = tInfo.totalBlocks
+
+ // Store meta-info
+ val metaId = BroadcastHelperBlockId(broadcastId, "meta")
+ val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(
+ metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true)
+ }
+
+ // Store individual pieces
+ for (i <- 0 until totalBlocks) {
+ val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(
+ pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
+ }
+ }
+ }
+
+ // Called by JVM when deserializing an object
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject()
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(broadcastId) match {
+ case Some(x) =>
+ value_ = x.asInstanceOf[T]
+
+ case None =>
+ val start = System.nanoTime
+ logInfo("Started reading broadcast variable " + id)
+
+ // Initialize @transient variables that will receive garbage values from the master.
+ resetWorkerVariables()
+
+ if (receiveBroadcast(id)) {
+ value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
+
+ // Store the merged copy in cache so that the next worker doesn't need to rebuild it.
+ // This creates a tradeoff between memory usage and latency.
+ // Storing copy doubles the memory footprint; not storing doubles deserialization cost.
+ SparkEnv.get.blockManager.putSingle(
+ broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
+
+ // Remove arrayOfBlocks from memory once value_ is on local cache
+ resetWorkerVariables()
+ } else {
+ logError("Reading broadcast variable " + id + " failed")
+ }
+
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading broadcast variable " + id + " took " + time + " s")
+ }
+ }
+ }
+
+ private def resetWorkerVariables() {
+ arrayOfBlocks = null
+ totalBytes = -1
+ totalBlocks = -1
+ hasBlocks = 0
+ }
+
+ def receiveBroadcast(variableID: Long): Boolean = {
+ // Receive meta-info
+ val metaId = BroadcastHelperBlockId(broadcastId, "meta")
+ var attemptId = 10
+ while (attemptId > 0 && totalBlocks == -1) {
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(metaId) match {
+ case Some(x) =>
+ val tInfo = x.asInstanceOf[TorrentInfo]
+ totalBlocks = tInfo.totalBlocks
+ totalBytes = tInfo.totalBytes
+ arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
+ hasBlocks = 0
+
+ case None =>
+ Thread.sleep(500)
+ }
+ }
+ attemptId -= 1
+ }
+ if (totalBlocks == -1) {
+ return false
+ }
+
+ // Receive actual blocks
+ val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
+ for (pid <- recvOrder) {
+ val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(pieceId) match {
+ case Some(x) =>
+ arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
+ hasBlocks += 1
+ SparkEnv.get.blockManager.putSingle(
+ pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)
+
+ case None =>
+ throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
+ }
+ }
+ }
+
+ (hasBlocks == totalBlocks)
+ }
+
+}
+
+private object TorrentBroadcast
+extends Logging {
+
+ private var initialized = false
+
+ def initialize(_isDriver: Boolean) {
+ synchronized {
+ if (!initialized) {
+ initialized = true
+ }
+ }
+ }
+
+ def stop() {
+ initialized = false
+ }
+
+ val BLOCK_SIZE = System.getProperty("spark.broadcast.blockSize", "4096").toInt * 1024
+
+ def blockifyObject[T](obj: T): TorrentInfo = {
+ val byteArray = Utils.serialize[T](obj)
+ val bais = new ByteArrayInputStream(byteArray)
+
+ var blockNum = (byteArray.length / BLOCK_SIZE)
+ if (byteArray.length % BLOCK_SIZE != 0)
+ blockNum += 1
+
+ var retVal = new Array[TorrentBlock](blockNum)
+ var blockID = 0
+
+ for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
+ val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
+ var tempByteArray = new Array[Byte](thisBlockSize)
+ val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
+
+ retVal(blockID) = new TorrentBlock(blockID, tempByteArray)
+ blockID += 1
+ }
+ bais.close()
+
+ var tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
+ tInfo.hasBlocks = blockNum
+
+ return tInfo
+ }
+
+ def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock],
+ totalBytes: Int,
+ totalBlocks: Int): T = {
+ var retByteArray = new Array[Byte](totalBytes)
+ for (i <- 0 until totalBlocks) {
+ System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
+ i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
+ }
+ Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)
+ }
+
+}
+
+private[spark] case class TorrentBlock(
+ blockID: Int,
+ byteArray: Array[Byte])
+ extends Serializable
+
+private[spark] case class TorrentInfo(
+ @transient arrayOfBlocks : Array[TorrentBlock],
+ totalBlocks: Int,
+ totalBytes: Int)
+ extends Serializable {
+
+ @transient var hasBlocks = 0
+}
+
+private[spark] class TorrentBroadcastFactory
+ extends BroadcastFactory {
+
+ def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) }
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ new TorrentBroadcast[T](value_, isLocal, id)
+
+ def stop() { TorrentBroadcast.stop() }
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala
deleted file mode 100644
index e6674d49a7..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala
+++ /dev/null
@@ -1,601 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.io._
-import java.net._
-
-import scala.collection.mutable.{ListBuffer, Set}
-
-import org.apache.spark._
-import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
-import org.apache.spark.util.Utils
-
-private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
-extends Broadcast[T](id) with Logging with Serializable {
-
- def value = value_
-
- def blockId = BroadcastBlockId(id)
-
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- }
-
- @transient var arrayOfBlocks: Array[BroadcastBlock] = null
- @transient var totalBytes = -1
- @transient var totalBlocks = -1
- @transient var hasBlocks = 0
-
- @transient var listenPortLock = new Object
- @transient var guidePortLock = new Object
- @transient var totalBlocksLock = new Object
- @transient var hasBlocksLock = new Object
-
- @transient var listOfSources = ListBuffer[SourceInfo]()
-
- @transient var serveMR: ServeMultipleRequests = null
- @transient var guideMR: GuideMultipleRequests = null
-
- @transient var hostAddress = Utils.localIpAddress
- @transient var listenPort = -1
- @transient var guidePort = -1
-
- @transient var stopBroadcast = false
-
- // Must call this after all the variables have been created/initialized
- if (!isLocal) {
- sendBroadcast()
- }
-
- def sendBroadcast() {
- logInfo("Local host address: " + hostAddress)
-
- // Create a variableInfo object and store it in valueInfos
- var variableInfo = MultiTracker.blockifyObject(value_)
-
- // Prepare the value being broadcasted
- arrayOfBlocks = variableInfo.arrayOfBlocks
- totalBytes = variableInfo.totalBytes
- totalBlocks = variableInfo.totalBlocks
- hasBlocks = variableInfo.totalBlocks
-
- guideMR = new GuideMultipleRequests
- guideMR.setDaemon(true)
- guideMR.start()
- logInfo("GuideMultipleRequests started...")
-
- // Must always come AFTER guideMR is created
- while (guidePort == -1) {
- guidePortLock.synchronized { guidePortLock.wait() }
- }
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- // Must always come AFTER serveMR is created
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Must always come AFTER listenPort is created
- val masterSource =
- SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
- listOfSources += masterSource
-
- // Register with the Tracker
- MultiTracker.registerBroadcast(id,
- SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
- }
-
- private def readObject(in: ObjectInputStream) {
- in.defaultReadObject()
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.getSingle(blockId) match {
- case Some(x) =>
- value_ = x.asInstanceOf[T]
-
- case None =>
- logInfo("Started reading broadcast variable " + id)
- // Initializing everything because Driver will only send null/0 values
- // Only the 1st worker in a node can be here. Others will get from cache
- initializeWorkerVariables()
-
- logInfo("Local host address: " + hostAddress)
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- val start = System.nanoTime
-
- val receptionSucceeded = receiveBroadcast(id)
- if (receptionSucceeded) {
- value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- SparkEnv.get.blockManager.putSingle(
- blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- } else {
- logError("Reading broadcast variable " + id + " failed")
- }
-
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading broadcast variable " + id + " took " + time + " s")
- }
- }
- }
-
- private def initializeWorkerVariables() {
- arrayOfBlocks = null
- totalBytes = -1
- totalBlocks = -1
- hasBlocks = 0
-
- listenPortLock = new Object
- totalBlocksLock = new Object
- hasBlocksLock = new Object
-
- serveMR = null
-
- hostAddress = Utils.localIpAddress
- listenPort = -1
-
- stopBroadcast = false
- }
-
- def receiveBroadcast(variableID: Long): Boolean = {
- val gInfo = MultiTracker.getGuideInfo(variableID)
-
- if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
- return false
- }
-
- // Wait until hostAddress and listenPort are created by the
- // ServeMultipleRequests thread
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- var clientSocketToDriver: Socket = null
- var oosDriver: ObjectOutputStream = null
- var oisDriver: ObjectInputStream = null
-
- // Connect and receive broadcast from the specified source, retrying the
- // specified number of times in case of failures
- var retriesLeft = MultiTracker.MaxRetryCount
- do {
- // Connect to Driver and send this worker's Information
- clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort)
- oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream)
- oosDriver.flush()
- oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream)
-
- logDebug("Connected to Driver's guiding object")
-
- // Send local source information
- oosDriver.writeObject(SourceInfo(hostAddress, listenPort))
- oosDriver.flush()
-
- // Receive source information from Driver
- var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo]
- totalBlocks = sourceInfo.totalBlocks
- arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
- totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
- totalBytes = sourceInfo.totalBytes
-
- logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort)
-
- val start = System.nanoTime
- val receptionSucceeded = receiveSingleTransmission(sourceInfo)
- val time = (System.nanoTime - start) / 1e9
-
- // Updating some statistics in sourceInfo. Driver will be using them later
- if (!receptionSucceeded) {
- sourceInfo.receptionFailed = true
- }
-
- // Send back statistics to the Driver
- oosDriver.writeObject(sourceInfo)
-
- if (oisDriver != null) {
- oisDriver.close()
- }
- if (oosDriver != null) {
- oosDriver.close()
- }
- if (clientSocketToDriver != null) {
- clientSocketToDriver.close()
- }
-
- retriesLeft -= 1
- } while (retriesLeft > 0 && hasBlocks < totalBlocks)
-
- return (hasBlocks == totalBlocks)
- }
-
- /**
- * Tries to receive broadcast from the source and returns Boolean status.
- * This might be called multiple times to retry a defined number of times.
- */
- private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
- var clientSocketToSource: Socket = null
- var oosSource: ObjectOutputStream = null
- var oisSource: ObjectInputStream = null
-
- var receptionSucceeded = false
- try {
- // Connect to the source to get the object itself
- clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream)
- oosSource.flush()
- oisSource = new ObjectInputStream(clientSocketToSource.getInputStream)
-
- logDebug("Inside receiveSingleTransmission")
- logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
-
- // Send the range
- oosSource.writeObject((hasBlocks, totalBlocks))
- oosSource.flush()
-
- for (i <- hasBlocks until totalBlocks) {
- val recvStartTime = System.currentTimeMillis
- val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
- val receptionTime = (System.currentTimeMillis - recvStartTime)
-
- logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
-
- arrayOfBlocks(hasBlocks) = bcBlock
- hasBlocks += 1
-
- // Set to true if at least one block is received
- receptionSucceeded = true
- hasBlocksLock.synchronized { hasBlocksLock.notifyAll() }
- }
- } catch {
- case e: Exception => logError("receiveSingleTransmission had a " + e)
- } finally {
- if (oisSource != null) {
- oisSource.close()
- }
- if (oosSource != null) {
- oosSource.close()
- }
- if (clientSocketToSource != null) {
- clientSocketToSource.close()
- }
- }
-
- return receptionSucceeded
- }
-
- class GuideMultipleRequests
- extends Thread with Logging {
- // Keep track of sources that have completed reception
- private var setOfCompletedSources = Set[SourceInfo]()
-
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(0)
- guidePort = serverSocket.getLocalPort
- logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
-
- guidePortLock.synchronized { guidePortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- // Stop broadcast if at least one worker has connected and
- // everyone connected so far are done. Comparing with
- // listOfSources.size - 1, because it includes the Guide itself
- listOfSources.synchronized {
- setOfCompletedSources.synchronized {
- if (listOfSources.size > 1 &&
- setOfCompletedSources.size == listOfSources.size - 1) {
- stopBroadcast = true
- logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
- }
- }
- }
- }
- }
- if (clientSocket != null) {
- logDebug("Guide: Accepted new client connection: " + clientSocket)
- try {
- threadPool.execute(new GuideSingleRequest(clientSocket))
- } catch {
- // In failure, close() the socket here; else, the thread will close() it
- case ioe: IOException => clientSocket.close()
- }
- }
- }
-
- logInfo("Sending stopBroadcast notifications...")
- sendStopBroadcastNotifications
-
- MultiTracker.unregisterBroadcast(id)
- } finally {
- if (serverSocket != null) {
- logInfo("GuideMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- private def sendStopBroadcastNotifications() {
- listOfSources.synchronized {
- var listIter = listOfSources.iterator
- while (listIter.hasNext) {
- var sourceInfo = listIter.next
-
- var guideSocketToSource: Socket = null
- var gosSource: ObjectOutputStream = null
- var gisSource: ObjectInputStream = null
-
- try {
- // Connect to the source
- guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
- gosSource.flush()
- gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
-
- // Send stopBroadcast signal
- gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast))
- gosSource.flush()
- } catch {
- case e: Exception => {
- logError("sendStopBroadcastNotifications had a " + e)
- }
- } finally {
- if (gisSource != null) {
- gisSource.close()
- }
- if (gosSource != null) {
- gosSource.close()
- }
- if (guideSocketToSource != null) {
- guideSocketToSource.close()
- }
- }
- }
- }
- }
-
- class GuideSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- private var selectedSourceInfo: SourceInfo = null
- private var thisWorkerInfo:SourceInfo = null
-
- override def run() {
- try {
- logInfo("new GuideSingleRequest is running")
- // Connecting worker is sending in its hostAddress and listenPort it will
- // be listening to. Other fields are invalid (SourceInfo.UnusedParam)
- var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- listOfSources.synchronized {
- // Select a suitable source and send it back to the worker
- selectedSourceInfo = selectSuitableSource(sourceInfo)
- logDebug("Sending selectedSourceInfo: " + selectedSourceInfo)
- oos.writeObject(selectedSourceInfo)
- oos.flush()
-
- // Add this new (if it can finish) source to the list of sources
- thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
- sourceInfo.listenPort, totalBlocks, totalBytes)
- logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo)
- listOfSources += thisWorkerInfo
- }
-
- // Wait till the whole transfer is done. Then receive and update source
- // statistics in listOfSources
- sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- listOfSources.synchronized {
- // This should work since SourceInfo is a case class
- assert(listOfSources.contains(selectedSourceInfo))
-
- // Remove first
- // (Currently removing a source based on just one failure notification!)
- listOfSources = listOfSources - selectedSourceInfo
-
- // Update sourceInfo and put it back in, IF reception succeeded
- if (!sourceInfo.receptionFailed) {
- // Add thisWorkerInfo to sources that have completed reception
- setOfCompletedSources.synchronized {
- setOfCompletedSources += thisWorkerInfo
- }
-
- // Update leecher count and put it back in
- selectedSourceInfo.currentLeechers -= 1
- listOfSources += selectedSourceInfo
- }
- }
- } catch {
- case e: Exception => {
- // Remove failed worker from listOfSources and update leecherCount of
- // corresponding source worker
- listOfSources.synchronized {
- if (selectedSourceInfo != null) {
- // Remove first
- listOfSources = listOfSources - selectedSourceInfo
- // Update leecher count and put it back in
- selectedSourceInfo.currentLeechers -= 1
- listOfSources += selectedSourceInfo
- }
-
- // Remove thisWorkerInfo
- if (listOfSources != null) {
- listOfSources = listOfSources - thisWorkerInfo
- }
- }
- }
- } finally {
- logInfo("GuideSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- // Assuming the caller to have a synchronized block on listOfSources
- // Select one with the most leechers. This will level-wise fill the tree
- private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
- var maxLeechers = -1
- var selectedSource: SourceInfo = null
-
- listOfSources.foreach { source =>
- if ((source.hostAddress != skipSourceInfo.hostAddress ||
- source.listenPort != skipSourceInfo.listenPort) &&
- source.currentLeechers < MultiTracker.MaxDegree &&
- source.currentLeechers > maxLeechers) {
- selectedSource = source
- maxLeechers = source.currentLeechers
- }
- }
-
- // Update leecher count
- selectedSource.currentLeechers += 1
- return selectedSource
- }
- }
- }
-
- class ServeMultipleRequests
- extends Thread with Logging {
-
- var threadPool = Utils.newDaemonCachedThreadPool()
-
- override def run() {
- var serverSocket = new ServerSocket(0)
- listenPort = serverSocket.getLocalPort
-
- logInfo("ServeMultipleRequests started with " + serverSocket)
-
- listenPortLock.synchronized { listenPortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => { }
- }
-
- if (clientSocket != null) {
- logDebug("Serve: Accepted new client connection: " + clientSocket)
- try {
- threadPool.execute(new ServeSingleRequest(clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo("ServeMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- class ServeSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- private var sendFrom = 0
- private var sendUntil = totalBlocks
-
- override def run() {
- try {
- logInfo("new ServeSingleRequest is running")
-
- // Receive range to send
- var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
- sendFrom = rangeToSend._1
- sendUntil = rangeToSend._2
-
- // If not a valid range, stop broadcast
- if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) {
- stopBroadcast = true
- } else {
- sendObject
- }
- } catch {
- case e: Exception => logError("ServeSingleRequest had a " + e)
- } finally {
- logInfo("ServeSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- private def sendObject() {
- // Wait till receiving the SourceInfo from Driver
- while (totalBlocks == -1) {
- totalBlocksLock.synchronized { totalBlocksLock.wait() }
- }
-
- for (i <- sendFrom until sendUntil) {
- while (i == hasBlocks) {
- hasBlocksLock.synchronized { hasBlocksLock.wait() }
- }
- try {
- oos.writeObject(arrayOfBlocks(i))
- oos.flush()
- } catch {
- case e: Exception => logError("sendObject had a " + e)
- }
- logDebug("Sent block: " + i + " to " + clientSocket)
- }
- }
- }
- }
-}
-
-private[spark] class TreeBroadcastFactory
-extends BroadcastFactory {
- def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
-
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
- new TreeBroadcast[T](value_, isLocal, id)
-
- def stop() { MultiTracker.stop() }
-}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 993ba6bd3d..6bc846aa92 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -17,28 +17,47 @@
package org.apache.spark.deploy
-import com.google.common.collect.MapMaker
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
+import org.apache.spark.SparkException
/**
- * Contains util methods to interact with Hadoop from spark.
+ * Contains util methods to interact with Hadoop from Spark.
*/
+private[spark]
class SparkHadoopUtil {
- // A general, soft-reference map for metadata needed during HadoopRDD split computation
- // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
- private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
- // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop
- // subsystems
+ /**
+ * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
+ * subsystems.
+ */
def newConfiguration(): Configuration = new Configuration()
- // Add any user credentials to the job conf which are necessary for running on a secure Hadoop
- // cluster
+ /**
+ * Add any user credentials to the job conf which are necessary for running on a secure Hadoop
+ * cluster.
+ */
def addCredentials(conf: JobConf) {}
def isYarnMode(): Boolean = { false }
-
+}
+
+object SparkHadoopUtil {
+ private val hadoop = {
+ val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
+ if (yarnMode) {
+ try {
+ Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil]
+ } catch {
+ case th: Throwable => throw new SparkException("Unable to load YARN support", th)
+ }
+ } else {
+ new SparkHadoopUtil
+ }
+ }
+
+ def get: SparkHadoopUtil = {
+ hadoop
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index db0bea0472..80ff4c59cb 100644
--- a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -24,11 +24,11 @@ import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClie
import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{Utils, AkkaUtils}
-private[spark] class StandaloneExecutorBackend(
+private[spark] class CoarseGrainedExecutorBackend(
driverUrl: String,
executorId: String,
hostPort: String,
@@ -80,6 +80,11 @@ private[spark] class StandaloneExecutorBackend(
case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
logError("Driver terminated or disconnected! Shutting down.")
System.exit(1)
+
+ case StopExecutor =>
+ logInfo("Driver commanded a shutdown")
+ context.stop(self)
+ context.system.shutdown()
}
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
@@ -87,7 +92,7 @@ private[spark] class StandaloneExecutorBackend(
}
}
-private[spark] object StandaloneExecutorBackend {
+private[spark] object CoarseGrainedExecutorBackend {
def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
// Debug code
Utils.checkHost(hostname)
@@ -99,7 +104,7 @@ private[spark] object StandaloneExecutorBackend {
val sparkHostPort = hostname + ":" + boundPort
System.setProperty("spark.hostPort", sparkHostPort)
val actor = actorSystem.actorOf(
- Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)),
+ Props(new CoarseGrainedExecutorBackend(driverUrl, executorId, sparkHostPort, cores)),
name = "Executor")
actorSystem.awaitTermination()
}
@@ -107,7 +112,9 @@ private[spark] object StandaloneExecutorBackend {
def main(args: Array[String]) {
if (args.length < 4) {
//the reason we allow the last frameworkId argument is to make it easy to kill rogue executors
- System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores> [<appid>]")
+ System.err.println(
+ "Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> <cores> " +
+ "[<appid>]")
System.exit(1)
}
run(args(0), args(1), args(2), args(3).toInt)
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 20323ea038..b773346df3 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -74,30 +74,33 @@ private[spark] class Executor(
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
Thread.currentThread.setContextClassLoader(replClassLoader)
- // Make any thread terminations due to uncaught exceptions kill the entire
- // executor process to avoid surprising stalls.
- Thread.setDefaultUncaughtExceptionHandler(
- new Thread.UncaughtExceptionHandler {
- override def uncaughtException(thread: Thread, exception: Throwable) {
- try {
- logError("Uncaught exception in thread " + thread, exception)
-
- // We may have been called from a shutdown hook. If so, we must not call System.exit().
- // (If we do, we will deadlock.)
- if (!Utils.inShutdown()) {
- if (exception.isInstanceOf[OutOfMemoryError]) {
- System.exit(ExecutorExitCode.OOM)
- } else {
- System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
+ if (!isLocal) {
+ // Setup an uncaught exception handler for non-local mode.
+ // Make any thread terminations due to uncaught exceptions kill the entire
+ // executor process to avoid surprising stalls.
+ Thread.setDefaultUncaughtExceptionHandler(
+ new Thread.UncaughtExceptionHandler {
+ override def uncaughtException(thread: Thread, exception: Throwable) {
+ try {
+ logError("Uncaught exception in thread " + thread, exception)
+
+ // We may have been called from a shutdown hook. If so, we must not call System.exit().
+ // (If we do, we will deadlock.)
+ if (!Utils.inShutdown()) {
+ if (exception.isInstanceOf[OutOfMemoryError]) {
+ System.exit(ExecutorExitCode.OOM)
+ } else {
+ System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
+ }
}
+ } catch {
+ case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
+ case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
}
- } catch {
- case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
- case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
}
}
- }
- )
+ )
+ }
val executorSource = new ExecutorSource(this, executorId)
@@ -121,8 +124,7 @@ private[spark] class Executor(
}
// Start worker thread pool
- val threadPool = new ThreadPoolExecutor(
- 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable], Utils.daemonThreadFactory)
+ val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
// Maintains the list of running tasks.
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index f311141148..0b4892f98f 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -102,4 +102,9 @@ class ShuffleWriteMetrics extends Serializable {
* Number of bytes written for a shuffle
*/
var shuffleBytesWritten: Long = _
+
+ /**
+ * Time spent blocking on writes to disk or buffer cache, in nanoseconds.
+ */
+ var shuffleWriteTime: Long = _
}
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index e15a839c4e..9c2fee4023 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -79,7 +79,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
private val registerRequests = new SynchronizedQueue[SendingConnection]
- implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
+ implicit val futureExecContext = ExecutionContext.fromExecutor(
+ Utils.newDaemonCachedThreadPool("Connection manager future execution context"))
private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
index 1586dff254..546d921067 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
@@ -21,7 +21,7 @@ import java.io.File
import org.apache.spark.Logging
import org.apache.spark.util.Utils
-import org.apache.spark.storage.BlockId
+import org.apache.spark.storage.{BlockId, FileSegment}
private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
@@ -54,8 +54,7 @@ private[spark] object ShuffleSender {
val localDirs = args.drop(2).map(new File(_))
val pResovler = new PathResolver {
- override def getAbsolutePath(blockIdString: String): String = {
- val blockId = BlockId(blockIdString)
+ override def getBlockLocation(blockId: BlockId): FileSegment = {
if (!blockId.isShuffle) {
throw new Exception("Block " + blockId + " is not a shuffle block")
}
@@ -65,7 +64,7 @@ private[spark] object ShuffleSender {
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
val file = new File(subDir, blockId.name)
- return file.getAbsolutePath
+ return new FileSegment(file, 0, file.length())
}
}
val sender = new ShuffleSender(port, pResovler)
diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala
index f132e2b735..70a5a8caff 100644
--- a/core/src/main/scala/org/apache/spark/package.scala
+++ b/core/src/main/scala/org/apache/spark/package.scala
@@ -15,6 +15,8 @@
* limitations under the License.
*/
+package org.apache
+
/**
* Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to
* Spark, while [[org.apache.spark.rdd.RDD]] is the data type representing a distributed collection,
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index ccaaecb85b..d3033ea4a6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -18,6 +18,7 @@
package org.apache.spark.rdd
import org.apache.spark._
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.{NullWritable, BytesWritable}
@@ -83,7 +84,7 @@ private[spark] object CheckpointRDD extends Logging {
def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
val env = SparkEnv.get
val outputDir = new Path(path)
- val fs = outputDir.getFileSystem(env.hadoop.newConfiguration())
+ val fs = outputDir.getFileSystem(SparkHadoopUtil.get.newConfiguration())
val finalOutputName = splitIdToFile(ctx.partitionId)
val finalOutputPath = new Path(outputDir, finalOutputName)
@@ -122,7 +123,7 @@ private[spark] object CheckpointRDD extends Logging {
def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
val env = SparkEnv.get
- val fs = path.getFileSystem(env.hadoop.newConfiguration())
+ val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val fileInputStream = fs.open(path, bufferSize)
val serializer = env.serializer.newInstance()
@@ -145,7 +146,7 @@ private[spark] object CheckpointRDD extends Logging {
val sc = new SparkContext(cluster, "CheckpointRDD Test")
val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
val path = new Path(hdfsPath, "temp")
- val fs = path.getFileSystem(env.hadoop.newConfiguration())
+ val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index fad042c7ae..32901a508f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -29,6 +29,7 @@ import org.apache.hadoop.util.ReflectionUtils
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.NextIterator
import org.apache.hadoop.conf.{Configuration, Configurable}
@@ -198,10 +199,10 @@ private[spark] object HadoopRDD {
* The three methods below are helpers for accessing the local map, a property of the SparkEnv of
* the local process.
*/
- def getCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.get(key)
+ def getCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.get(key)
- def containsCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.containsKey(key)
+ def containsCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.containsKey(key)
def putCachedMetadata(key: String, value: Any) =
- SparkEnv.get.hadoop.hadoopJobMetadata.put(key, value)
+ SparkEnv.get.hadoopJobMetadata.put(key, value)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 0355618e43..6e88be6f6a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -266,6 +266,19 @@ abstract class RDD[T: ClassManifest](
def distinct(): RDD[T] = distinct(partitions.size)
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Can increase or decrease the level of parallelism in this RDD. Internally, this uses
+ * a shuffle to redistribute data.
+ *
+ * If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
+ * which can avoid performing a shuffle.
+ */
+ def repartition(numPartitions: Int): RDD[T] = {
+ coalesce(numPartitions, true)
+ }
+
+ /**
* Return a new RDD that is reduced into `numPartitions` partitions.
*
* This results in a narrow dependency, e.g. if you go from 1000 partitions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 7fb614402b..4cef0825dd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -52,23 +52,29 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
private[spark]
class DAGScheduler(
taskSched: TaskScheduler,
- mapOutputTracker: MapOutputTracker,
+ mapOutputTracker: MapOutputTrackerMaster,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv)
- extends TaskSchedulerListener with Logging {
+ extends Logging {
def this(taskSched: TaskScheduler) {
- this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
+ this(taskSched, SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+ SparkEnv.get.blockManager.master, SparkEnv.get)
}
- taskSched.setListener(this)
+ taskSched.setDAGScheduler(this)
// Called by TaskScheduler to report task's starting.
- override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
+ def taskStarted(task: Task[_], taskInfo: TaskInfo) {
eventQueue.put(BeginEvent(task, taskInfo))
}
+ // Called to report that a task has completed and results are being fetched remotely.
+ def taskGettingResult(task: Task[_], taskInfo: TaskInfo) {
+ eventQueue.put(GettingResultEvent(task, taskInfo))
+ }
+
// Called by TaskScheduler to report task completions or failures.
- override def taskEnded(
+ def taskEnded(
task: Task[_],
reason: TaskEndReason,
result: Any,
@@ -79,18 +85,18 @@ class DAGScheduler(
}
// Called by TaskScheduler when an executor fails.
- override def executorLost(execId: String) {
+ def executorLost(execId: String) {
eventQueue.put(ExecutorLost(execId))
}
// Called by TaskScheduler when a host is added
- override def executorGained(execId: String, host: String) {
+ def executorGained(execId: String, host: String) {
eventQueue.put(ExecutorGained(execId, host))
}
// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
// cancellation of the job itself.
- override def taskSetFailed(taskSet: TaskSet, reason: String) {
+ def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
}
@@ -182,7 +188,7 @@ class DAGScheduler(
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
- val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId)
+ val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
}
@@ -195,6 +201,7 @@ class DAGScheduler(
*/
private def newStage(
rdd: RDD[_],
+ numTasks: Int,
shuffleDep: Option[ShuffleDependency[_,_]],
jobId: Int,
callSite: Option[String] = None)
@@ -207,9 +214,10 @@ class DAGScheduler(
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
}
val id = nextStageId.getAndIncrement()
- val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
+ val stage =
+ new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
stageIdToStage(id) = stage
- stageToInfos(stage) = StageInfo(stage)
+ stageToInfos(stage) = new StageInfo(stage)
stage
}
@@ -277,11 +285,6 @@ class DAGScheduler(
resultHandler: (Int, U) => Unit,
properties: Properties = null): JobWaiter[U] =
{
- val jobId = nextJobId.getAndIncrement()
- if (partitions.size == 0) {
- return new JobWaiter[U](this, jobId, 0, resultHandler)
- }
-
// Check to make sure we are not launching a task on a partition that does not exist.
val maxPartitions = rdd.partitions.length
partitions.find(p => p >= maxPartitions).foreach { p =>
@@ -290,6 +293,11 @@ class DAGScheduler(
"Total number of partitions: " + maxPartitions)
}
+ val jobId = nextJobId.getAndIncrement()
+ if (partitions.size == 0) {
+ return new JobWaiter[U](this, jobId, 0, resultHandler)
+ }
+
assert(partitions.size > 0)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
@@ -342,6 +350,11 @@ class DAGScheduler(
eventQueue.put(JobCancelled(jobId))
}
+ def cancelJobGroup(groupId: String) {
+ logInfo("Asked to cancel job group " + groupId)
+ eventQueue.put(JobGroupCancelled(groupId))
+ }
+
/**
* Cancel all jobs that are running or waiting in the queue.
*/
@@ -356,7 +369,7 @@ class DAGScheduler(
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
- val finalStage = newStage(rdd, None, jobId, Some(callSite))
+ val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
@@ -381,6 +394,17 @@ class DAGScheduler(
taskSched.cancelTasks(stage.id)
}
+ case JobGroupCancelled(groupId) =>
+ // Cancel all jobs belonging to this job group.
+ // First finds all active jobs with this group id, and then kill stages for them.
+ val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
+ .map(_.jobId)
+ if (!jobIds.isEmpty) {
+ running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage =>
+ taskSched.cancelTasks(stage.id)
+ }
+ }
+
case AllJobsCancelled =>
// Cancel all running jobs.
running.foreach { stage =>
@@ -396,6 +420,9 @@ class DAGScheduler(
case begin: BeginEvent =>
listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo))
+ case gettingResult: GettingResultEvent =>
+ listenerBus.post(SparkListenerTaskGettingResult(gettingResult.task, gettingResult.taskInfo))
+
case completion: CompletionEvent =>
listenerBus.post(SparkListenerTaskEnd(
completion.task, completion.reason, completion.taskInfo, completion.taskMetrics))
@@ -568,7 +595,7 @@ class DAGScheduler(
// must be run listener before possible NotSerializableException
// should be "StageSubmitted" first and then "JobEnded"
- listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties))
+ listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties))
if (tasks.size > 0) {
// Preemptively serialize a task to make sure it can be serialized. We are catching this
@@ -589,9 +616,7 @@ class DAGScheduler(
logDebug("New pending tasks: " + myPending)
taskSched.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
- if (!stage.submissionTime.isDefined) {
- stage.submissionTime = Some(System.currentTimeMillis())
- }
+ stageToInfos(stage).submissionTime = Some(System.currentTimeMillis())
} else {
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@@ -613,12 +638,12 @@ class DAGScheduler(
val stage = stageIdToStage(task.stageId)
def markStageAsFinished(stage: Stage) = {
- val serviceTime = stage.submissionTime match {
+ val serviceTime = stageToInfos(stage).submissionTime match {
case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0)
- case _ => "Unkown"
+ case _ => "Unknown"
}
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
- stage.completionTime = Some(System.currentTimeMillis)
+ stageToInfos(stage).completionTime = Some(System.currentTimeMillis())
listenerBus.post(StageCompleted(stageToInfos(stage)))
running -= stage
}
@@ -788,7 +813,7 @@ class DAGScheduler(
*/
private def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
- failedStage.completionTime = Some(System.currentTimeMillis())
+ stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis())
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
val error = new SparkException("Job aborted: " + reason)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index ee89bfb38d..708d221d60 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -46,11 +46,16 @@ private[scheduler] case class JobSubmitted(
private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent
+private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent
+
private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
private[scheduler]
case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+private[scheduler]
+case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+
private[scheduler] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
index 370ccd183c..1791ee660d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala
@@ -18,6 +18,7 @@
package org.apache.spark.scheduler
import org.apache.spark.{Logging, SparkEnv}
+import org.apache.spark.deploy.SparkHadoopUtil
import scala.collection.immutable.Set
import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.hadoop.security.UserGroupInformation
@@ -87,9 +88,8 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
// This method does not expect failures, since validate has already passed ...
private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = {
- val env = SparkEnv.get
val conf = new JobConf(configuration)
- env.hadoop.addCredentials(conf)
+ SparkHadoopUtil.get.addCredentials(conf)
FileInputFormat.setInputPaths(conf, path)
val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] =
@@ -108,9 +108,8 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
// This method does not expect failures, since validate has already passed ...
private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = {
- val env = SparkEnv.get
val jobConf = new JobConf(configuration)
- env.hadoop.addCredentials(jobConf)
+ SparkHadoopUtil.get.addCredentials(jobConf)
FileInputFormat.setInputPaths(jobConf, path)
val instance: org.apache.hadoop.mapred.InputFormat[_, _] =
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 3628b1b078..12b0d74fb5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -24,56 +24,54 @@ import java.text.SimpleDateFormat
import java.util.{Date, Properties}
import java.util.concurrent.LinkedBlockingQueue
-import scala.collection.mutable.{Map, HashMap, ListBuffer}
-import scala.io.Source
+import scala.collection.mutable.{HashMap, ListBuffer}
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.executor.TaskMetrics
-// Used to record runtime information for each job, including RDD graph
-// tasks' start/stop shuffle information and information from outside
-
+/**
+ * A logger class to record runtime information for jobs in Spark. This class outputs one log file
+ * per Spark job with information such as RDD graph, tasks start/stop, shuffle information.
+ *
+ * @param logDirName The base directory for the log files.
+ */
class JobLogger(val logDirName: String) extends SparkListener with Logging {
- private val logDir =
- if (System.getenv("SPARK_LOG_DIR") != null)
- System.getenv("SPARK_LOG_DIR")
- else
- "/tmp/spark"
+
+ private val logDir = Option(System.getenv("SPARK_LOG_DIR")).getOrElse("/tmp/spark")
+
private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
private val stageIDToJobID = new HashMap[Int, Int]
private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
-
+
createLogDir()
def this() = this(String.valueOf(System.currentTimeMillis()))
-
- def getLogDir = logDir
- def getJobIDtoPrintWriter = jobIDToPrintWriter
- def getStageIDToJobID = stageIDToJobID
- def getJobIDToStages = jobIDToStages
- def getEventQueue = eventQueue
-
+
+ // The following 5 functions are used only in testing.
+ private[scheduler] def getLogDir = logDir
+ private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
+ private[scheduler] def getStageIDToJobID = stageIDToJobID
+ private[scheduler] def getJobIDToStages = jobIDToStages
+ private[scheduler] def getEventQueue = eventQueue
+
// Create a folder for log files, the folder's name is the creation time of the jobLogger
protected def createLogDir() {
val dir = new File(logDir + "/" + logDirName + "/")
- if (dir.exists()) {
- return
- }
- if (dir.mkdirs() == false) {
- logError("create log directory error:" + logDir + "/" + logDirName + "/")
+ if (!dir.exists() && !dir.mkdirs()) {
+ logError("Error creating log directory: " + logDir + "/" + logDirName + "/")
}
}
// Create a log file for one job, the file name is the jobID
protected def createLogWriter(jobID: Int) {
- try{
+ try {
val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
jobIDToPrintWriter += (jobID -> fileWriter)
- } catch {
- case e: FileNotFoundException => e.printStackTrace()
- }
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
}
// Close log file, and clean the stage relationship in stageIDToJobID
@@ -118,10 +116,9 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
var rddList = new ListBuffer[RDD[_]]
rddList += rdd
- rdd.dependencies.foreach{ dep => dep match {
- case shufDep: ShuffleDependency[_,_] =>
- case _ => rddList ++= getRddsInStage(dep.rdd)
- }
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
}
rddList
}
@@ -161,29 +158,27 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
jobLogInfo(jobID, indentString(indent) + rddInfo, false)
- rdd.dependencies.foreach{ dep => dep match {
- case shufDep: ShuffleDependency[_,_] =>
- val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
- jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
- case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
- }
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
}
}
protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
- var stageInfo: String = ""
- if (stage.isShuffleMap) {
- stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
- stage.shuffleDep.get.shuffleId
- }else{
- stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ val stageInfo = if (stage.isShuffleMap) {
+ "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
+ } else {
+ "STAGE_ID=" + stage.id + " RESULT_STAGE"
}
if (stage.jobId == jobID) {
jobLogInfo(jobID, indentString(indent) + stageInfo, false)
recordRddInStageGraph(jobID, stage.rdd, indent)
stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
- } else
+ } else {
jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
+ }
}
// Record task metrics into job log files
@@ -193,39 +188,32 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
" START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
" EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
- val readMetrics =
- taskMetrics.shuffleReadMetrics match {
- case Some(metrics) =>
- " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
- " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
- " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
- " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
- " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
- " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
- " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
- case None => ""
- }
- val writeMetrics =
- taskMetrics.shuffleWriteMetrics match {
- case Some(metrics) =>
- " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
- case None => ""
- }
+ val readMetrics = taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics = taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
}
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
- stageLogInfo(
- stageSubmitted.stage.id,
- "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
- stageSubmitted.stage.id, stageSubmitted.taskSize))
+ stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
+ stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
}
override def onStageCompleted(stageCompleted: StageCompleted) {
- stageLogInfo(
- stageCompleted.stageInfo.stage.id,
- "STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id))
-
+ stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
+ stageCompleted.stage.stageId))
}
override def onTaskStart(taskStart: SparkListenerTaskStart) { }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 802791797a..24d97da6eb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -164,17 +164,19 @@ private[spark] class ShuffleMapTask(
// Commit the writes. Get the size of each bucket block (total block size).
var totalBytes = 0L
+ var totalTime = 0L
val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
writer.commit()
- writer.close()
- val size = writer.size()
+ val size = writer.fileSegment().length
totalBytes += size
+ totalTime += writer.timeWriting()
MapOutputTracker.compressSize(size)
}
// Update shuffle metrics.
val shuffleMetrics = new ShuffleWriteMetrics
shuffleMetrics.shuffleBytesWritten = totalBytes
+ shuffleMetrics.shuffleWriteTime = totalTime
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
new MapStatus(blockManager.blockManagerId, compressedSizes)
@@ -188,6 +190,7 @@ private[spark] class ShuffleMapTask(
} finally {
// Release the writers back to the shuffle block manager.
if (shuffle != null && buckets != null) {
+ buckets.writers.foreach(_.close())
shuffle.releaseWriters(buckets)
}
// Execute the callbacks on task completion.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 466baf9913..a35081f7b1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -24,13 +24,16 @@ import org.apache.spark.executor.TaskMetrics
sealed trait SparkListenerEvents
-case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int, properties: Properties)
+case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties)
extends SparkListenerEvents
-case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
+case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents
case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
+case class SparkListenerTaskGettingResult(
+ task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
+
case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
taskMetrics: TaskMetrics) extends SparkListenerEvents
@@ -57,6 +60,12 @@ trait SparkListener {
def onTaskStart(taskStart: SparkListenerTaskStart) { }
/**
+ * Called when a task begins remotely fetching its result (will not be called for tasks that do
+ * not need to fetch the result remotely).
+ */
+ def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
+
+ /**
* Called when a task ends
*/
def onTaskEnd(taskEnd: SparkListenerTaskEnd) { }
@@ -80,7 +89,7 @@ class StatsReportListener extends SparkListener with Logging {
override def onStageCompleted(stageCompleted: StageCompleted) {
import org.apache.spark.scheduler.StatsReportListener._
implicit val sc = stageCompleted
- this.logInfo("Finished stage: " + stageCompleted.stageInfo)
+ this.logInfo("Finished stage: " + stageCompleted.stage)
showMillisDistribution("task runtime:", (info, _) => Some(info.duration))
//shuffle write
@@ -93,7 +102,7 @@ class StatsReportListener extends SparkListener with Logging {
//runtime breakdown
- val runtimePcts = stageCompleted.stageInfo.taskInfos.map{
+ val runtimePcts = stageCompleted.stage.taskInfos.map{
case (info, metrics) => RuntimePercentage(info.duration, metrics)
}
showDistribution("executor (non-fetch) time pct: ", Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%")
@@ -111,7 +120,7 @@ object StatsReportListener extends Logging {
val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"
def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = {
- Distribution(stage.stageInfo.taskInfos.flatMap{
+ Distribution(stage.stage.taskInfos.flatMap {
case ((info,metric)) => getMetric(info, metric)})
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index 4d3e4a17ba..d5824e7954 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -49,6 +49,8 @@ private[spark] class SparkListenerBus() extends Logging {
sparkListeners.foreach(_.onJobEnd(jobEnd))
case taskStart: SparkListenerTaskStart =>
sparkListeners.foreach(_.onTaskStart(taskStart))
+ case taskGettingResult: SparkListenerTaskGettingResult =>
+ sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
case taskEnd: SparkListenerTaskEnd =>
sparkListeners.foreach(_.onTaskEnd(taskEnd))
case _ =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index aa293dc6b3..7cb3fe46e5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManagerId
private[spark] class Stage(
val id: Int,
val rdd: RDD[_],
+ val numTasks: Int,
val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
val parents: List[Stage],
val jobId: Int,
@@ -49,11 +50,6 @@ private[spark] class Stage(
val numPartitions = rdd.partitions.size
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
var numAvailableOutputs = 0
-
- /** When first task was submitted to scheduler. */
- var submissionTime: Option[Long] = None
- var completionTime: Option[Long] = None
-
private var nextAttemptId = 0
def isAvailable: Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index b6f11969e5..93599dfdc8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -21,9 +21,16 @@ import scala.collection._
import org.apache.spark.executor.TaskMetrics
-case class StageInfo(
- val stage: Stage,
+class StageInfo(
+ stage: Stage,
val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]()
) {
- override def toString = stage.rdd.toString
+ val stageId = stage.id
+ /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */
+ var submissionTime: Option[Long] = None
+ var completionTime: Option[Long] = None
+ val rddName = stage.rdd.name
+ val name = stage.name
+ val numPartitions = stage.numPartitions
+ val numTasks = stage.numTasks
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 1fe0d0e4e2..69b42e86ea 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -45,7 +45,7 @@ import org.apache.spark.util.ByteBufferInputStream
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
- def run(attemptId: Long): T = {
+ final def run(attemptId: Long): T = {
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
if (_killed) {
kill()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 7c2a422aff..4bae26f3a6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -31,9 +31,25 @@ class TaskInfo(
val host: String,
val taskLocality: TaskLocality.TaskLocality) {
+ /**
+ * The time when the task started remotely getting the result. Will not be set if the
+ * task result was sent immediately when the task finished (as opposed to sending an
+ * IndirectTaskResult and later fetching the result from the block manager).
+ */
+ var gettingResultTime: Long = 0
+
+ /**
+ * The time when the task has completed successfully (including the time to remotely fetch
+ * results, if necessary).
+ */
var finishTime: Long = 0
+
var failed = false
+ def markGettingResult(time: Long = System.currentTimeMillis) {
+ gettingResultTime = time
+ }
+
def markSuccessful(time: Long = System.currentTimeMillis) {
finishTime = time
}
@@ -43,6 +59,8 @@ class TaskInfo(
failed = true
}
+ def gettingResult: Boolean = gettingResultTime != 0
+
def finished: Boolean = finishTime != 0
def successful: Boolean = finished && !failed
@@ -52,6 +70,8 @@ class TaskInfo(
def status: String = {
if (running)
"RUNNING"
+ else if (gettingResult)
+ "GET RESULT"
else if (failed)
"FAILED"
else if (successful)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 6a51efe8d6..10e0478108 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -24,8 +24,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
* Each TaskScheduler schedulers task for a single SparkContext.
* These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
* and are responsible for sending the tasks to the cluster, running them, retrying if there
- * are failures, and mitigating stragglers. They return events to the DAGScheduler through
- * the TaskSchedulerListener interface.
+ * are failures, and mitigating stragglers. They return events to the DAGScheduler.
*/
private[spark] trait TaskScheduler {
@@ -48,8 +47,8 @@ private[spark] trait TaskScheduler {
// Cancel a stage.
def cancelTasks(stageId: Int)
- // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
- def setListener(listener: TaskSchedulerListener): Unit
+ // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
+ def setDAGScheduler(dagScheduler: DAGScheduler): Unit
// Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
def defaultParallelism(): Int
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala
deleted file mode 100644
index 593fa9fb93..0000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler
-
-import scala.collection.mutable.Map
-
-import org.apache.spark.TaskEndReason
-import org.apache.spark.executor.TaskMetrics
-
-/**
- * Interface for getting events back from the TaskScheduler.
- */
-private[spark] trait TaskSchedulerListener {
- // A task has started.
- def taskStarted(task: Task[_], taskInfo: TaskInfo)
-
- // A task has finished or failed.
- def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any],
- taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit
-
- // A node was added to the cluster.
- def executorGained(execId: String, host: String): Unit
-
- // A node was lost from the cluster.
- def executorLost(execId: String): Unit
-
- // The TaskScheduler wants to abort an entire task set.
- def taskSetFailed(taskSet: TaskSet, reason: String): Unit
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
index 7a72ff0474..85033958ef 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
@@ -79,7 +79,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
private val executorIdToHost = new HashMap[String, String]
// Listener object to pass upcalls into
- var listener: TaskSchedulerListener = null
+ var dagScheduler: DAGScheduler = null
var backend: SchedulerBackend = null
@@ -94,8 +94,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// This is a var so that we can reset it for testing purposes.
private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
- override def setListener(listener: TaskSchedulerListener) {
- this.listener = listener
+ override def setDAGScheduler(dagScheduler: DAGScheduler) {
+ this.dagScheduler = dagScheduler
}
def initialize(context: SchedulerBackend) {
@@ -297,7 +297,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
// Update the DAGScheduler without holding a lock on this, since that can deadlock
if (failedExecutor != None) {
- listener.executorLost(failedExecutor.get)
+ dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
if (taskFailed) {
@@ -306,6 +306,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
+ def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) {
+ taskSetManager.handleTaskGettingResult(tid)
+ }
+
def handleSuccessfulTask(
taskSetManager: ClusterTaskSetManager,
tid: Long,
@@ -397,9 +401,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
logError("Lost an executor " + executorId + " (already removed): " + reason)
}
}
- // Call listener.executorLost without holding the lock on this to prevent deadlock
+ // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock
if (failedExecutor != None) {
- listener.executorLost(failedExecutor.get)
+ dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
}
@@ -418,7 +422,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
def executorGained(execId: String, host: String) {
- listener.executorGained(execId, host)
+ dagScheduler.executorGained(execId, host)
}
def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
index 7bd3499300..ee47aaffca 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -415,11 +415,17 @@ private[spark] class ClusterTaskSetManager(
}
private def taskStarted(task: Task[_], info: TaskInfo) {
- sched.listener.taskStarted(task, info)
+ sched.dagScheduler.taskStarted(task, info)
+ }
+
+ def handleTaskGettingResult(tid: Long) = {
+ val info = taskInfos(tid)
+ info.markGettingResult()
+ sched.dagScheduler.taskGettingResult(tasks(info.index), info)
}
/**
- * Marks the task as successful and notifies the listener that a task has ended.
+ * Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/
def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
val info = taskInfos(tid)
@@ -429,7 +435,7 @@ private[spark] class ClusterTaskSetManager(
if (!successful(index)) {
logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
tid, info.duration, info.host, tasksSuccessful, numTasks))
- sched.listener.taskEnded(
+ sched.dagScheduler.taskEnded(
tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
// Mark successful and stop if all the tasks have succeeded.
@@ -445,7 +451,8 @@ private[spark] class ClusterTaskSetManager(
}
/**
- * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the listener.
+ * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
+ * DAG Scheduler.
*/
def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
val info = taskInfos(tid)
@@ -463,7 +470,7 @@ private[spark] class ClusterTaskSetManager(
reason.foreach {
case fetchFailed: FetchFailed =>
logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
- sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+ sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
successful(index) = true
tasksSuccessful += 1
sched.taskSetFinished(this)
@@ -472,11 +479,11 @@ private[spark] class ClusterTaskSetManager(
case TaskKilled =>
logWarning("Task %d was killed.".format(tid))
- sched.listener.taskEnded(tasks(index), reason.get, null, null, info, null)
+ sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
return
case ef: ExceptionFailure =>
- sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
+ sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
val key = ef.description
val now = clock.getTime()
val (printFull, dupCount) = {
@@ -504,7 +511,7 @@ private[spark] class ClusterTaskSetManager(
case TaskResultLost =>
logWarning("Lost result for TID %s on host %s".format(tid, info.host))
- sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
+ sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
case _ => {}
}
@@ -533,7 +540,7 @@ private[spark] class ClusterTaskSetManager(
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
- sched.listener.taskSetFailed(taskSet, message)
+ sched.dagScheduler.taskSetFailed(taskSet, message)
removeAllRunningTasks()
sched.taskSetFinished(this)
}
@@ -606,7 +613,7 @@ private[spark] class ClusterTaskSetManager(
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
- sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
+ sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 12b2fd01c0..53316dae2a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -24,28 +24,28 @@ import org.apache.spark.scheduler.TaskDescription
import org.apache.spark.util.{Utils, SerializableBuffer}
-private[spark] sealed trait StandaloneClusterMessage extends Serializable
+private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable
-private[spark] object StandaloneClusterMessages {
+private[spark] object CoarseGrainedClusterMessages {
// Driver to executors
- case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage
+ case class LaunchTask(task: TaskDescription) extends CoarseGrainedClusterMessage
- case class KillTask(taskId: Long, executor: String) extends StandaloneClusterMessage
+ case class KillTask(taskId: Long, executor: String) extends CoarseGrainedClusterMessage
case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
- extends StandaloneClusterMessage
+ extends CoarseGrainedClusterMessage
- case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage
+ case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage
// Executors to driver
case class RegisterExecutor(executorId: String, hostPort: String, cores: Int)
- extends StandaloneClusterMessage {
+ extends CoarseGrainedClusterMessage {
Utils.checkHostPort(hostPort, "Expected host port")
}
case class StatusUpdate(executorId: String, taskId: Long, state: TaskState,
- data: SerializableBuffer) extends StandaloneClusterMessage
+ data: SerializableBuffer) extends CoarseGrainedClusterMessage
object StatusUpdate {
/** Alternate factory method that takes a ByteBuffer directly for the data field */
@@ -56,10 +56,14 @@ private[spark] object StandaloneClusterMessages {
}
// Internal messages in driver
- case object ReviveOffers extends StandaloneClusterMessage
+ case object ReviveOffers extends CoarseGrainedClusterMessage
- case object StopDriver extends StandaloneClusterMessage
+ case object StopDriver extends CoarseGrainedClusterMessage
- case class RemoveExecutor(executorId: String, reason: String) extends StandaloneClusterMessage
+ case object StopExecutor extends CoarseGrainedClusterMessage
+
+ case object StopExecutors extends CoarseGrainedClusterMessage
+
+ case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 08ee2182a2..70f3f88401 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -30,16 +30,19 @@ import akka.util.duration._
import org.apache.spark.{SparkException, Logging, TaskState}
import org.apache.spark.scheduler.TaskDescription
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.Utils
/**
- * A standalone scheduler backend, which waits for standalone executors to connect to it through
- * Akka. These may be executed in a variety of ways, such as Mesos tasks for the coarse-grained
- * Mesos mode or standalone processes for Spark's standalone deploy mode (spark.deploy.*).
+ * A scheduler backend that waits for coarse grained executors to connect to it through Akka.
+ * This backend holds onto each executor for the duration of the Spark job rather than relinquishing
+ * executors whenever a task is done and asking the scheduler to launch a new executor for
+ * each new task. Executors may be launched in a variety of ways, such as Mesos tasks for the
+ * coarse-grained Mesos mode or standalone processes for Spark's standalone deploy mode
+ * (spark.deploy.*).
*/
private[spark]
-class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem)
+class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem)
extends SchedulerBackend with Logging
{
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
@@ -98,6 +101,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
sender ! true
context.stop(self)
+ case StopExecutors =>
+ logInfo("Asking each executor to shut down")
+ for (executor <- executorActor.values) {
+ executor ! StopExecutor
+ }
+ sender ! true
+
case RemoveExecutor(executorId, reason) =>
removeExecutor(executorId, reason)
sender ! true
@@ -162,16 +172,29 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
}
driverActor = actorSystem.actorOf(
- Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
+ Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME)
}
private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ def stopExecutors() {
+ try {
+ if (driverActor != null) {
+ logInfo("Shutting down all executors")
+ val future = driverActor.ask(StopExecutors)(timeout)
+ Await.ready(future, timeout)
+ }
+ } catch {
+ case e: Exception =>
+ throw new SparkException("Error asking standalone scheduler to shut down executors", e)
+ }
+ }
+
override def stop() {
try {
if (driverActor != null) {
val future = driverActor.ask(StopDriver)(timeout)
- Await.result(future, timeout)
+ Await.ready(future, timeout)
}
} catch {
case e: Exception =>
@@ -194,7 +217,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
def removeExecutor(executorId: String, reason: String) {
try {
val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout)
- Await.result(future, timeout)
+ Await.ready(future, timeout)
} catch {
case e: Exception =>
throw new SparkException("Error notifying standalone scheduler's driver actor", e)
@@ -202,6 +225,6 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
}
-private[spark] object StandaloneSchedulerBackend {
- val ACTOR_NAME = "StandaloneScheduler"
+private[spark] object CoarseGrainedSchedulerBackend {
+ val ACTOR_NAME = "CoarseGrainedScheduler"
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
new file mode 100644
index 0000000000..d78bdbaa7a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{Path, FileSystem}
+import org.apache.spark.{Logging, SparkContext}
+
+private[spark] class SimrSchedulerBackend(
+ scheduler: ClusterScheduler,
+ sc: SparkContext,
+ driverFilePath: String)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+ with Logging {
+
+ val tmpPath = new Path(driverFilePath + "_tmp")
+ val filePath = new Path(driverFilePath)
+
+ val maxCores = System.getProperty("spark.simr.executor.cores", "1").toInt
+
+ override def start() {
+ super.start()
+
+ val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
+
+ val conf = new Configuration()
+ val fs = FileSystem.get(conf)
+
+ logInfo("Writing to HDFS file: " + driverFilePath)
+ logInfo("Writing Akka address: " + driverUrl)
+
+ // Create temporary file to prevent race condition where executors get empty driverUrl file
+ val temp = fs.create(tmpPath, true)
+ temp.writeUTF(driverUrl)
+ temp.writeInt(maxCores)
+ temp.close()
+
+ // "Atomic" rename
+ fs.rename(tmpPath, filePath)
+ }
+
+ override def stop() {
+ val conf = new Configuration()
+ val fs = FileSystem.get(conf)
+ fs.delete(new Path(driverFilePath), false)
+ super.stopExecutors()
+ super.stop()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index cb88159b8d..cefa970bb9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -28,7 +28,7 @@ private[spark] class SparkDeploySchedulerBackend(
sc: SparkContext,
masters: Array[String],
appName: String)
- extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
with ClientListener
with Logging {
@@ -44,10 +44,10 @@ private[spark] class SparkDeploySchedulerBackend(
// The endpoint for executors to talk to us
val driverUrl = "akka://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
- StandaloneSchedulerBackend.ACTOR_NAME)
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}")
val command = Command(
- "org.apache.spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
+ "org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs)
val sparkHome = sc.getSparkHome().getOrElse(null)
val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome,
"http://" + sc.ui.appUIAddress)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
index feec8ecfe4..2064d97b49 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
@@ -24,33 +24,16 @@ import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
import org.apache.spark.serializer.SerializerInstance
+import org.apache.spark.util.Utils
/**
* Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
*/
private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
extends Logging {
- private val MIN_THREADS = System.getProperty("spark.resultGetter.minThreads", "4").toInt
- private val MAX_THREADS = System.getProperty("spark.resultGetter.maxThreads", "4").toInt
- private val getTaskResultExecutor = new ThreadPoolExecutor(
- MIN_THREADS,
- MAX_THREADS,
- 0L,
- TimeUnit.SECONDS,
- new LinkedBlockingDeque[Runnable],
- new ResultResolverThreadFactory)
-
- class ResultResolverThreadFactory extends ThreadFactory {
- private var counter = 0
- private var PREFIX = "Result resolver thread"
-
- override def newThread(r: Runnable): Thread = {
- val thread = new Thread(r, "%s-%s".format(PREFIX, counter))
- counter += 1
- thread.setDaemon(true)
- return thread
- }
- }
+ private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt
+ private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool(
+ THREADS, "Result resolver thread")
protected val serializer = new ThreadLocal[SerializerInstance] {
override def initialValue(): SerializerInstance = {
@@ -67,6 +50,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
case directResult: DirectTaskResult[_] => directResult
case IndirectTaskResult(blockId) =>
logDebug("Fetching indirect task result for TID %s".format(tid))
+ scheduler.handleTaskGettingResult(taskSetManager, tid)
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
if (!serializedTaskResult.isDefined) {
/* We won't be able to get the task result if the machine that ran the task failed
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 8f2eef9a53..300fe693f1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -30,13 +30,14 @@ import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
import org.apache.spark.{SparkException, Logging, SparkContext, TaskState}
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend}
+import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend}
/**
* A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
* onto each Mesos node for the duration of the Spark job instead of relinquishing cores whenever
* a task is done. It launches Spark tasks within the coarse-grained Mesos tasks using the
- * StandaloneBackend mechanism. This class is useful for lower and more predictable latency.
+ * CoarseGrainedSchedulerBackend mechanism. This class is useful for lower and more predictable
+ * latency.
*
* Unfortunately this has a bit of duplication from MesosSchedulerBackend, but it seems hard to
* remove this.
@@ -46,7 +47,7 @@ private[spark] class CoarseMesosSchedulerBackend(
sc: SparkContext,
master: String,
appName: String)
- extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
with MScheduler
with Logging {
@@ -122,20 +123,20 @@ private[spark] class CoarseMesosSchedulerBackend(
val driverUrl = "akka://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"),
System.getProperty("spark.driver.port"),
- StandaloneSchedulerBackend.ACTOR_NAME)
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
val uri = System.getProperty("spark.executor.uri")
if (uri == null) {
val runScript = new File(sparkHome, "spark-class").getCanonicalPath
command.setValue(
- "\"%s\" org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format(
+ "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format(
runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
command.setValue(
- "cd %s*; ./spark-class org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format(
- basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
+ "cd %s*; ./spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d"
+ .format(basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
return command.build()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
index b445260d1b..2699f0b33e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
@@ -81,7 +81,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
val env = SparkEnv.get
val attemptId = new AtomicInteger
- var listener: TaskSchedulerListener = null
+ var dagScheduler: DAGScheduler = null
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
@@ -114,8 +114,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
}
- override def setListener(listener: TaskSchedulerListener) {
- this.listener = listener
+ override def setDAGScheduler(dagScheduler: DAGScheduler) {
+ this.dagScheduler = dagScheduler
}
override def submitTasks(taskSet: TaskSet) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
index f72e77d40f..53bf78267e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
@@ -133,7 +133,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
}
def taskStarted(task: Task[_], info: TaskInfo) {
- sched.listener.taskStarted(task, info)
+ sched.dagScheduler.taskStarted(task, info)
}
def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
@@ -148,7 +148,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
}
}
result.metrics.resultSize = serializedData.limit()
- sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
+ sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info,
+ result.metrics)
numFinished += 1
decreaseRunningTasks(1)
finished(index) = true
@@ -165,7 +166,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
decreaseRunningTasks(1)
val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](
serializedData, getClass.getClassLoader)
- sched.listener.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
+ sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
if (!finished(index)) {
copiesRunning(index) -= 1
numFailures(index) += 1
@@ -174,9 +175,9 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
reason.className, reason.description, locs.mkString("\n")))
if (numFailures(index) > MAX_TASK_FAILURES) {
val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(
- taskSet.id, index, 4, reason.description)
+ taskSet.id, index, MAX_TASK_FAILURES, reason.description)
decreaseRunningTasks(runningTasks)
- sched.listener.taskSetFailed(taskSet, errorMessage)
+ sched.dagScheduler.taskSetFailed(taskSet, errorMessage)
// need to delete failed Taskset from schedule queue
sched.taskSetFinished(this)
}
@@ -184,7 +185,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
}
override def error(message: String) {
- sched.listener.taskSetFailed(taskSet, message)
+ sched.dagScheduler.taskSetFailed(taskSet, message)
sched.taskSetFinished(this)
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index c7efc67a4a..7156d855d8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -32,7 +32,7 @@ private[spark] sealed abstract class BlockId {
def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
def isRDD = isInstanceOf[RDDBlockId]
def isShuffle = isInstanceOf[ShuffleBlockId]
- def isBroadcast = isInstanceOf[BroadcastBlockId]
+ def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId]
override def toString = name
override def hashCode = name.hashCode
@@ -55,6 +55,10 @@ private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
def name = "broadcast_" + broadcastId
}
+private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
+ def name = broadcastId.name + "_" + hType
+}
+
private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
def name = "taskresult_" + taskId
}
@@ -72,6 +76,7 @@ private[spark] object BlockId {
val RDD = "rdd_([0-9]+)_([0-9]+)".r
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
val BROADCAST = "broadcast_([0-9]+)".r
+ val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
val TASKRESULT = "taskresult_([0-9]+)".r
val STREAM = "input-([0-9]+)-([0-9]+)".r
val TEST = "test_(.*)".r
@@ -84,6 +89,8 @@ private[spark] object BlockId {
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
case BROADCAST(broadcastId) =>
BroadcastBlockId(broadcastId.toLong)
+ case BROADCAST_HELPER(broadcastId, hType) =>
+ BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType)
case TASKRESULT(taskId) =>
TaskResultBlockId(taskId.toLong)
case STREAM(streamId, uniqueId) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
new file mode 100644
index 0000000000..dbe0bda615
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.ConcurrentHashMap
+
+private[storage] trait BlockInfo {
+ def level: StorageLevel
+ def tellMaster: Boolean
+ // To save space, 'pending' and 'failed' are encoded as special sizes:
+ @volatile var size: Long = BlockInfo.BLOCK_PENDING
+ private def pending: Boolean = size == BlockInfo.BLOCK_PENDING
+ private def failed: Boolean = size == BlockInfo.BLOCK_FAILED
+ private def initThread: Thread = BlockInfo.blockInfoInitThreads.get(this)
+
+ setInitThread()
+
+ private def setInitThread() {
+ // Set current thread as init thread - waitForReady will not block this thread
+ // (in case there is non trivial initialization which ends up calling waitForReady as part of
+ // initialization itself)
+ BlockInfo.blockInfoInitThreads.put(this, Thread.currentThread())
+ }
+
+ /**
+ * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
+ * Return true if the block is available, false otherwise.
+ */
+ def waitForReady(): Boolean = {
+ if (pending && initThread != Thread.currentThread()) {
+ synchronized {
+ while (pending) this.wait()
+ }
+ }
+ !failed
+ }
+
+ /** Mark this BlockInfo as ready (i.e. block is finished writing) */
+ def markReady(sizeInBytes: Long) {
+ require (sizeInBytes >= 0, "sizeInBytes was negative: " + sizeInBytes)
+ assert (pending)
+ size = sizeInBytes
+ BlockInfo.blockInfoInitThreads.remove(this)
+ synchronized {
+ this.notifyAll()
+ }
+ }
+
+ /** Mark this BlockInfo as ready but failed */
+ def markFailure() {
+ assert (pending)
+ size = BlockInfo.BLOCK_FAILED
+ BlockInfo.blockInfoInitThreads.remove(this)
+ synchronized {
+ this.notifyAll()
+ }
+ }
+}
+
+private object BlockInfo {
+ // initThread is logically a BlockInfo field, but we store it here because
+ // it's only needed while this block is in the 'pending' state and we want
+ // to minimize BlockInfo's memory footprint.
+ private val blockInfoInitThreads = new ConcurrentHashMap[BlockInfo, Thread]
+
+ private val BLOCK_PENDING: Long = -1L
+ private val BLOCK_FAILED: Long = -2L
+}
+
+// All shuffle blocks have the same `level` and `tellMaster` properties,
+// so we can save space by not storing them in each instance:
+private[storage] class ShuffleBlockInfo extends BlockInfo {
+ // These need to be defined using 'def' instead of 'val' in order for
+ // the compiler to eliminate the fields:
+ def level: StorageLevel = StorageLevel.DISK_ONLY
+ def tellMaster: Boolean = false
+}
+
+private[storage] class BlockInfoImpl(val level: StorageLevel, val tellMaster: Boolean)
+ extends BlockInfo {
+ // Intentionally left blank
+} \ No newline at end of file
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 801f88a3db..76d537f8e8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -20,14 +20,15 @@ package org.apache.spark.storage
import java.io.{InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
-import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet}
+import scala.collection.mutable.{HashMap, ArrayBuffer}
+import scala.util.Random
import akka.actor.{ActorSystem, Cancellable, Props}
import akka.dispatch.{Await, Future}
import akka.util.Duration
import akka.util.duration._
-import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
import org.apache.spark.{Logging, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
@@ -45,74 +46,20 @@ private[spark] class BlockManager(
maxMemory: Long)
extends Logging {
- private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
- @volatile var pending: Boolean = true
- @volatile var size: Long = -1L
- @volatile var initThread: Thread = null
- @volatile var failed = false
-
- setInitThread()
-
- private def setInitThread() {
- // Set current thread as init thread - waitForReady will not block this thread
- // (in case there is non trivial initialization which ends up calling waitForReady as part of
- // initialization itself)
- this.initThread = Thread.currentThread()
- }
-
- /**
- * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
- * Return true if the block is available, false otherwise.
- */
- def waitForReady(): Boolean = {
- if (initThread != Thread.currentThread() && pending) {
- synchronized {
- while (pending) this.wait()
- }
- }
- !failed
- }
-
- /** Mark this BlockInfo as ready (i.e. block is finished writing) */
- def markReady(sizeInBytes: Long) {
- assert (pending)
- size = sizeInBytes
- initThread = null
- failed = false
- initThread = null
- pending = false
- synchronized {
- this.notifyAll()
- }
- }
-
- /** Mark this BlockInfo as ready but failed */
- def markFailure() {
- assert (pending)
- size = 0
- initThread = null
- failed = true
- initThread = null
- pending = false
- synchronized {
- this.notifyAll()
- }
- }
- }
-
val shuffleBlockManager = new ShuffleBlockManager(this)
+ val diskBlockManager = new DiskBlockManager(
+ System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
- private[storage] val diskStore: DiskStore =
- new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
+ private[storage] val diskStore = new DiskStore(this, diskBlockManager)
// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
private val nettyPort: Int = {
val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt
- if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0
+ if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
}
val connectionManager = new ConnectionManager(0)
@@ -269,7 +216,7 @@ private[spark] class BlockManager(
}
/**
- * Actually send a UpdateBlockInfo message. Returns the mater's response,
+ * Actually send a UpdateBlockInfo message. Returns the master's response,
* which will be true if the block was successfully recorded and false if
* the slave needs to re-register.
*/
@@ -320,89 +267,14 @@ private[spark] class BlockManager(
*/
def getLocal(blockId: BlockId): Option[Iterator[Any]] = {
logDebug("Getting local block " + blockId)
- val info = blockInfo.get(blockId).orNull
- if (info != null) {
- info.synchronized {
-
- // In the another thread is writing the block, wait for it to become ready.
- if (!info.waitForReady()) {
- // If we get here, the block write failed.
- logWarning("Block " + blockId + " was marked as failure.")
- return None
- }
-
- val level = info.level
- logDebug("Level for block " + blockId + " is " + level)
-
- // Look for the block in memory
- if (level.useMemory) {
- logDebug("Getting block " + blockId + " from memory")
- memoryStore.getValues(blockId) match {
- case Some(iterator) =>
- return Some(iterator)
- case None =>
- logDebug("Block " + blockId + " not found in memory")
- }
- }
-
- // Look for block on disk, potentially loading it back into memory if required
- if (level.useDisk) {
- logDebug("Getting block " + blockId + " from disk")
- if (level.useMemory && level.deserialized) {
- diskStore.getValues(blockId) match {
- case Some(iterator) =>
- // Put the block back in memory before returning it
- // TODO: Consider creating a putValues that also takes in a iterator ?
- val elements = new ArrayBuffer[Any]
- elements ++= iterator
- memoryStore.putValues(blockId, elements, level, true).data match {
- case Left(iterator2) =>
- return Some(iterator2)
- case _ =>
- throw new Exception("Memory store did not return back an iterator")
- }
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- } else if (level.useMemory && !level.deserialized) {
- // Read it as a byte buffer into memory first, then return it
- diskStore.getBytes(blockId) match {
- case Some(bytes) =>
- // Put a copy of the block back in memory before returning it. Note that we can't
- // put the ByteBuffer returned by the disk store as that's a memory-mapped file.
- // The use of rewind assumes this.
- assert (0 == bytes.position())
- val copyForMemory = ByteBuffer.allocate(bytes.limit)
- copyForMemory.put(bytes)
- memoryStore.putBytes(blockId, copyForMemory, level)
- bytes.rewind()
- return Some(dataDeserialize(blockId, bytes))
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- } else {
- diskStore.getValues(blockId) match {
- case Some(iterator) =>
- return Some(iterator)
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- }
- }
- }
- } else {
- logDebug("Block " + blockId + " not registered locally")
- }
- return None
+ doGetLocal(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]]
}
/**
* Get block from the local block manager as serialized bytes.
*/
def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = {
- // TODO: This whole thing is very similar to getLocal; we need to refactor it somehow
logDebug("Getting local block " + blockId + " as bytes")
-
// As an optimization for map output fetches, if the block is for a shuffle, return it
// without acquiring a lock; the disk store never deletes (recent) items so this should work
if (blockId.isShuffle) {
@@ -413,12 +285,15 @@ private[spark] class BlockManager(
throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
}
+ doGetLocal(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]]
+ }
+ private def doGetLocal(blockId: BlockId, asValues: Boolean): Option[Any] = {
val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
- // In the another thread is writing the block, wait for it to become ready.
+ // If another thread is writing the block, wait for it to become ready.
if (!info.waitForReady()) {
// If we get here, the block write failed.
logWarning("Block " + blockId + " was marked as failure.")
@@ -431,62 +306,104 @@ private[spark] class BlockManager(
// Look for the block in memory
if (level.useMemory) {
logDebug("Getting block " + blockId + " from memory")
- memoryStore.getBytes(blockId) match {
- case Some(bytes) =>
- return Some(bytes)
+ val result = if (asValues) {
+ memoryStore.getValues(blockId)
+ } else {
+ memoryStore.getBytes(blockId)
+ }
+ result match {
+ case Some(values) =>
+ return Some(values)
case None =>
logDebug("Block " + blockId + " not found in memory")
}
}
- // Look for block on disk
+ // Look for block on disk, potentially storing it back into memory if required:
if (level.useDisk) {
- // Read it as a byte buffer into memory first, then return it
- diskStore.getBytes(blockId) match {
- case Some(bytes) =>
- assert (0 == bytes.position())
- if (level.useMemory) {
- if (level.deserialized) {
- memoryStore.putBytes(blockId, bytes, level)
- } else {
- // The memory store will hang onto the ByteBuffer, so give it a copy instead of
- // the memory-mapped file buffer we got from the disk store
- val copyForMemory = ByteBuffer.allocate(bytes.limit)
- copyForMemory.put(bytes)
- memoryStore.putBytes(blockId, copyForMemory, level)
- }
- }
- bytes.rewind()
- return Some(bytes)
+ logDebug("Getting block " + blockId + " from disk")
+ val bytes: ByteBuffer = diskStore.getBytes(blockId) match {
+ case Some(bytes) => bytes
case None =>
throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
+ assert (0 == bytes.position())
+
+ if (!level.useMemory) {
+ // If the block shouldn't be stored in memory, we can just return it:
+ if (asValues) {
+ return Some(dataDeserialize(blockId, bytes))
+ } else {
+ return Some(bytes)
+ }
+ } else {
+ // Otherwise, we also have to store something in the memory store:
+ if (!level.deserialized || !asValues) {
+ // We'll store the bytes in memory if the block's storage level includes
+ // "memory serialized", or if it should be cached as objects in memory
+ // but we only requested its serialized bytes:
+ val copyForMemory = ByteBuffer.allocate(bytes.limit)
+ copyForMemory.put(bytes)
+ memoryStore.putBytes(blockId, copyForMemory, level)
+ bytes.rewind()
+ }
+ if (!asValues) {
+ return Some(bytes)
+ } else {
+ val values = dataDeserialize(blockId, bytes)
+ if (level.deserialized) {
+ // Cache the values before returning them:
+ // TODO: Consider creating a putValues that also takes in a iterator?
+ val valuesBuffer = new ArrayBuffer[Any]
+ valuesBuffer ++= values
+ memoryStore.putValues(blockId, valuesBuffer, level, true).data match {
+ case Left(values2) =>
+ return Some(values2)
+ case _ =>
+ throw new Exception("Memory store did not return back an iterator")
+ }
+ } else {
+ return Some(values)
+ }
+ }
+ }
}
}
} else {
logDebug("Block " + blockId + " not registered locally")
}
- return None
+ None
}
/**
* Get block from remote block managers.
*/
def getRemote(blockId: BlockId): Option[Iterator[Any]] = {
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
logDebug("Getting remote block " + blockId)
- // Get locations of block
- val locations = master.getLocations(blockId)
+ doGetRemote(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]]
+ }
+
+ /**
+ * Get block from remote block managers as serialized bytes.
+ */
+ def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = {
+ logDebug("Getting remote block " + blockId + " as bytes")
+ doGetRemote(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]]
+ }
- // Get block from remote locations
+ private def doGetRemote(blockId: BlockId, asValues: Boolean): Option[Any] = {
+ require(blockId != null, "BlockId is null")
+ val locations = Random.shuffle(master.getLocations(blockId))
for (loc <- locations) {
logDebug("Getting remote block " + blockId + " from " + loc)
val data = BlockManagerWorker.syncGetBlock(
GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
if (data != null) {
- return Some(dataDeserialize(blockId, data))
+ if (asValues) {
+ return Some(dataDeserialize(blockId, data))
+ } else {
+ return Some(data)
+ }
}
logDebug("The value of block " + blockId + " is null")
}
@@ -495,31 +412,6 @@ private[spark] class BlockManager(
}
/**
- * Get block from remote block managers as serialized bytes.
- */
- def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = {
- // TODO: As with getLocalBytes, this is very similar to getRemote and perhaps should be
- // refactored.
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- logDebug("Getting remote block " + blockId + " as bytes")
-
- val locations = master.getLocations(blockId)
- for (loc <- locations) {
- logDebug("Getting remote block " + blockId + " from " + loc)
- val data = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
- if (data != null) {
- return Some(data)
- }
- logDebug("The value of block " + blockId + " is null")
- }
- logDebug("Block " + blockId + " not found")
- return None
- }
-
- /**
* Get a block from the block manager (either local or remote).
*/
def get(blockId: BlockId): Option[Iterator[Any]] = {
@@ -566,16 +458,22 @@ private[spark] class BlockManager(
/**
* A short circuited method to get a block writer that can write data directly to disk.
+ * The Block will be appended to the File specified by filename.
* This is currently used for writing shuffle files out. Callers should handle error
* cases.
*/
- def getDiskBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
+ def getDiskWriter(blockId: BlockId, filename: String, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
- val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
+ val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
+ val file = diskBlockManager.createBlockFile(blockId, filename, allowAppending = true)
+ val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
writer.registerCloseEventHandler(() => {
- val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
+ if (shuffleBlockManager.consolidateShuffleFiles) {
+ diskBlockManager.mapBlockToFileSegment(blockId, writer.fileSegment())
+ }
+ val myInfo = new ShuffleBlockInfo()
blockInfo.put(blockId, myInfo)
- myInfo.markReady(writer.size())
+ myInfo.markReady(writer.fileSegment().length)
})
writer
}
@@ -584,23 +482,30 @@ private[spark] class BlockManager(
* Put a new block of values to the block manager. Returns its (estimated) size in bytes.
*/
def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
- tellMaster: Boolean = true) : Long = {
+ tellMaster: Boolean = true) : Long = {
+ require(values != null, "Values is null")
+ doPut(blockId, Left(values), level, tellMaster)
+ }
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- if (values == null) {
- throw new IllegalArgumentException("Values is null")
- }
- if (level == null || !level.isValid) {
- throw new IllegalArgumentException("Storage level is null or invalid")
- }
+ /**
+ * Put a new block of serialized bytes to the block manager.
+ */
+ def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel,
+ tellMaster: Boolean = true) {
+ require(bytes != null, "Bytes is null")
+ doPut(blockId, Right(bytes), level, tellMaster)
+ }
+
+ private def doPut(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer],
+ level: StorageLevel, tellMaster: Boolean = true): Long = {
+ require(blockId != null, "BlockId is null")
+ require(level != null && level.isValid, "StorageLevel is null or invalid")
// Remember the block's storage level so that we can correctly drop it to disk if it needs
// to be dropped right after it got put into memory. Note, however, that other threads will
// not be able to get() this block until we call markReady on its BlockInfo.
val myInfo = {
- val tinfo = new BlockInfo(level, tellMaster)
+ val tinfo = new BlockInfoImpl(level, tellMaster)
// Do atomically !
val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
@@ -610,7 +515,8 @@ private[spark] class BlockManager(
return oldBlockOpt.get.size
}
- // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
+ // TODO: So the block info exists - but previous attempt to load it (?) failed.
+ // What do we do now ? Retry on it ?
oldBlockOpt.get
} else {
tinfo
@@ -619,10 +525,10 @@ private[spark] class BlockManager(
val startTimeMs = System.currentTimeMillis
- // If we need to replicate the data, we'll want access to the values, but because our
- // put will read the whole iterator, there will be no values left. For the case where
- // the put serializes data, we'll remember the bytes, above; but for the case where it
- // doesn't, such as deserialized storage, let's rely on the put returning an Iterator.
+ // If we're storing values and we need to replicate the data, we'll want access to the values,
+ // but because our put will read the whole iterator, there will be no values left. For the
+ // case where the put serializes data, we'll remember the bytes, above; but for the case where
+ // it doesn't, such as deserialized storage, let's rely on the put returning an Iterator.
var valuesAfterPut: Iterator[Any] = null
// Ditto for the bytes after the put
@@ -631,30 +537,51 @@ private[spark] class BlockManager(
// Size of the block in bytes (to return to caller)
var size = 0L
+ // If we're storing bytes, then initiate the replication before storing them locally.
+ // This is faster as data is already serialized and ready to send.
+ val replicationFuture = if (data.isRight && level.replication > 1) {
+ val bufferView = data.right.get.duplicate() // Doesn't copy the bytes, just creates a wrapper
+ Future {
+ replicate(blockId, bufferView, level)
+ }
+ } else {
+ null
+ }
+
myInfo.synchronized {
logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
var marked = false
try {
- if (level.useMemory) {
- // Save it just to memory first, even if it also has useDisk set to true; we will later
- // drop it to disk if the memory store can't hold it.
- val res = memoryStore.putValues(blockId, values, level, true)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case Left(newIterator) => valuesAfterPut = newIterator
+ data match {
+ case Left(values) => {
+ if (level.useMemory) {
+ // Save it just to memory first, even if it also has useDisk set to true; we will
+ // drop it to disk later if the memory store can't hold it.
+ val res = memoryStore.putValues(blockId, values, level, true)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case Left(newIterator) => valuesAfterPut = newIterator
+ }
+ } else {
+ // Save directly to disk.
+ // Don't get back the bytes unless we replicate them.
+ val askForBytes = level.replication > 1
+ val res = diskStore.putValues(blockId, values, level, askForBytes)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case _ =>
+ }
+ }
}
- } else {
- // Save directly to disk.
- // Don't get back the bytes unless we replicate them.
- val askForBytes = level.replication > 1
- val res = diskStore.putValues(blockId, values, level, askForBytes)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case _ =>
+ case Right(bytes) => {
+ bytes.rewind()
+ // Store it only in memory at first, even if useDisk is also set to true
+ (if (level.useMemory) memoryStore else diskStore).putBytes(blockId, bytes, level)
+ size = bytes.limit
}
}
@@ -679,125 +606,39 @@ private[spark] class BlockManager(
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
- // Replicate block if required
+ // Either we're storing bytes and we asynchronously started replication, or we're storing
+ // values and need to serialize and replicate them now:
if (level.replication > 1) {
- val remoteStartTime = System.currentTimeMillis
- // Serialize the block if not already done
- if (bytesAfterPut == null) {
- if (valuesAfterPut == null) {
- throw new SparkException(
- "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
- }
- bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
- }
- replicate(blockId, bytesAfterPut, level)
- logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime))
- }
- BlockManager.dispose(bytesAfterPut)
-
- return size
- }
-
-
- /**
- * Put a new block of serialized bytes to the block manager.
- */
- def putBytes(
- blockId: BlockId, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) {
-
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- if (bytes == null) {
- throw new IllegalArgumentException("Bytes is null")
- }
- if (level == null || !level.isValid) {
- throw new IllegalArgumentException("Storage level is null or invalid")
- }
-
- // Remember the block's storage level so that we can correctly drop it to disk if it needs
- // to be dropped right after it got put into memory. Note, however, that other threads will
- // not be able to get() this block until we call markReady on its BlockInfo.
- val myInfo = {
- val tinfo = new BlockInfo(level, tellMaster)
- // Do atomically !
- val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
-
- if (oldBlockOpt.isDefined) {
- if (oldBlockOpt.get.waitForReady()) {
- logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
- return
- }
-
- // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
- oldBlockOpt.get
- } else {
- tinfo
- }
- }
-
- val startTimeMs = System.currentTimeMillis
-
- // Initiate the replication before storing it locally. This is faster as
- // data is already serialized and ready for sending
- val replicationFuture = if (level.replication > 1) {
- val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper
- Future {
- replicate(blockId, bufferView, level)
- }
- } else {
- null
- }
-
- myInfo.synchronized {
- logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
- + " to get into synchronized block")
-
- var marked = false
- try {
- if (level.useMemory) {
- // Store it only in memory at first, even if useDisk is also set to true
- bytes.rewind()
- memoryStore.putBytes(blockId, bytes, level)
- } else {
- bytes.rewind()
- diskStore.putBytes(blockId, bytes, level)
- }
-
- // assert (0 == bytes.position(), "" + bytes)
-
- // Now that the block is in either the memory or disk store, let other threads read it,
- // and tell the master about it.
- marked = true
- myInfo.markReady(bytes.limit)
- if (tellMaster) {
- reportBlockStatus(blockId, myInfo)
- }
- } finally {
- // If we failed at putting the block to memory/disk, notify other possible readers
- // that it has failed, and then remove it from the block info map.
- if (! marked) {
- // Note that the remove must happen before markFailure otherwise another thread
- // could've inserted a new BlockInfo before we remove it.
- blockInfo.remove(blockId)
- myInfo.markFailure()
- logWarning("Putting block " + blockId + " failed")
+ data match {
+ case Right(bytes) => Await.ready(replicationFuture, Duration.Inf)
+ case Left(values) => {
+ val remoteStartTime = System.currentTimeMillis
+ // Serialize the block if not already done
+ if (bytesAfterPut == null) {
+ if (valuesAfterPut == null) {
+ throw new SparkException(
+ "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
+ }
+ bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
+ }
+ replicate(blockId, bytesAfterPut, level)
+ logDebug("Put block " + blockId + " remotely took " +
+ Utils.getUsedTimeMs(remoteStartTime))
}
}
}
- // If replication had started, then wait for it to finish
- if (level.replication > 1) {
- Await.ready(replicationFuture, Duration.Inf)
- }
+ BlockManager.dispose(bytesAfterPut)
if (level.replication > 1) {
- logDebug("PutBytes for block " + blockId + " with replication took " +
+ logDebug("Put for block " + blockId + " with replication took " +
Utils.getUsedTimeMs(startTimeMs))
} else {
- logDebug("PutBytes for block " + blockId + " without replication took " +
+ logDebug("Put for block " + blockId + " without replication took " +
Utils.getUsedTimeMs(startTimeMs))
}
+
+ size
}
/**
@@ -922,34 +763,20 @@ private[spark] class BlockManager(
private def dropOldNonBroadcastBlocks(cleanupTime: Long) {
logInfo("Dropping non broadcast blocks older than " + cleanupTime)
- val iterator = blockInfo.internalMap.entrySet().iterator()
- while (iterator.hasNext) {
- val entry = iterator.next()
- val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
- if (time < cleanupTime && !id.isBroadcast) {
- info.synchronized {
- val level = info.level
- if (level.useMemory) {
- memoryStore.remove(id)
- }
- if (level.useDisk) {
- diskStore.remove(id)
- }
- iterator.remove()
- logInfo("Dropped block " + id)
- }
- reportBlockStatus(id, info)
- }
- }
+ dropOldBlocks(cleanupTime, !_.isBroadcast)
}
private def dropOldBroadcastBlocks(cleanupTime: Long) {
logInfo("Dropping broadcast blocks older than " + cleanupTime)
+ dropOldBlocks(cleanupTime, _.isBroadcast)
+ }
+
+ private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) {
val iterator = blockInfo.internalMap.entrySet().iterator()
while (iterator.hasNext) {
val entry = iterator.next()
val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
- if (time < cleanupTime && id.isBroadcast) {
+ if (time < cleanupTime && shouldDrop(id)) {
info.synchronized {
val level = info.level
if (level.useMemory) {
@@ -987,13 +814,24 @@ private[spark] class BlockManager(
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}
+ /** Serializes into a stream. */
+ def dataSerializeStream(
+ blockId: BlockId,
+ outputStream: OutputStream,
+ values: Iterator[Any],
+ serializer: Serializer = defaultSerializer) {
+ val byteStream = new FastBufferedOutputStream(outputStream)
+ val ser = serializer.newInstance()
+ ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ }
+
+ /** Serializes into a byte buffer. */
def dataSerialize(
blockId: BlockId,
values: Iterator[Any],
serializer: Serializer = defaultSerializer): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
- val ser = serializer.newInstance()
- ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ dataSerializeStream(blockId, byteStream, values, serializer)
byteStream.trim()
ByteBuffer.wrap(byteStream.array)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 633230c0a8..f8cf14b503 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -227,9 +227,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
- if (id.executorId == "<driver>" && !isLocal) {
- // Got a register message from the master node; don't register it
- } else if (!blockManagerInfo.contains(id)) {
+ if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
case Some(manager) =>
// A block manager of the same executor already exists.
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
index 951503019f..3a65e55733 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
@@ -26,6 +26,7 @@ import org.apache.spark.storage.BlockManagerMessages._
* An actor to take commands from the master to execute options. For example,
* this is used to remove blocks from the slave's BlockManager.
*/
+private[storage]
class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor {
override def receive = {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 2a67800c45..32d2dd0694 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -17,6 +17,13 @@
package org.apache.spark.storage
+import java.io.{FileOutputStream, File, OutputStream}
+import java.nio.channels.FileChannel
+
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+
+import org.apache.spark.Logging
+import org.apache.spark.serializer.{SerializationStream, Serializer}
/**
* An interface for writing JVM objects to some underlying storage. This interface allows
@@ -59,7 +66,129 @@ abstract class BlockObjectWriter(val blockId: BlockId) {
def write(value: Any)
/**
- * Size of the valid writes, in bytes.
+ * Returns the file segment of committed data that this Writer has written.
+ */
+ def fileSegment(): FileSegment
+
+ /**
+ * Cumulative time spent performing blocking writes, in ns.
*/
- def size(): Long
+ def timeWriting(): Long
+}
+
+/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
+class DiskBlockObjectWriter(
+ blockId: BlockId,
+ file: File,
+ serializer: Serializer,
+ bufferSize: Int,
+ compressStream: OutputStream => OutputStream)
+ extends BlockObjectWriter(blockId)
+ with Logging
+{
+
+ /** Intercepts write calls and tracks total time spent writing. Not thread safe. */
+ private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
+ def timeWriting = _timeWriting
+ private var _timeWriting = 0L
+
+ private def callWithTiming(f: => Unit) = {
+ val start = System.nanoTime()
+ f
+ _timeWriting += (System.nanoTime() - start)
+ }
+
+ def write(i: Int): Unit = callWithTiming(out.write(i))
+ override def write(b: Array[Byte]) = callWithTiming(out.write(b))
+ override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
+ }
+
+ private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean
+
+ /** The file channel, used for repositioning / truncating the file. */
+ private var channel: FileChannel = null
+ private var bs: OutputStream = null
+ private var fos: FileOutputStream = null
+ private var ts: TimeTrackingOutputStream = null
+ private var objOut: SerializationStream = null
+ private var initialPosition = 0L
+ private var lastValidPosition = 0L
+ private var initialized = false
+ private var _timeWriting = 0L
+
+ override def open(): BlockObjectWriter = {
+ fos = new FileOutputStream(file, true)
+ ts = new TimeTrackingOutputStream(fos)
+ channel = fos.getChannel()
+ initialPosition = channel.position
+ lastValidPosition = initialPosition
+ bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
+ objOut = serializer.newInstance().serializeStream(bs)
+ initialized = true
+ this
+ }
+
+ override def close() {
+ if (initialized) {
+ if (syncWrites) {
+ // Force outstanding writes to disk and track how long it takes
+ objOut.flush()
+ val start = System.nanoTime()
+ fos.getFD.sync()
+ _timeWriting += System.nanoTime() - start
+ }
+ objOut.close()
+
+ _timeWriting += ts.timeWriting
+
+ channel = null
+ bs = null
+ fos = null
+ ts = null
+ objOut = null
+ }
+ // Invoke the close callback handler.
+ super.close()
+ }
+
+ override def isOpen: Boolean = objOut != null
+
+ override def commit(): Long = {
+ if (initialized) {
+ // NOTE: Flush the serializer first and then the compressed/buffered output stream
+ objOut.flush()
+ bs.flush()
+ val prevPos = lastValidPosition
+ lastValidPosition = channel.position()
+ lastValidPosition - prevPos
+ } else {
+ // lastValidPosition is zero if stream is uninitialized
+ lastValidPosition
+ }
+ }
+
+ override def revertPartialWrites() {
+ if (initialized) {
+ // Discard current writes. We do this by flushing the outstanding writes and
+ // truncate the file to the last valid position.
+ objOut.flush()
+ bs.flush()
+ channel.truncate(lastValidPosition)
+ }
+ }
+
+ override def write(value: Any) {
+ if (!initialized) {
+ open()
+ }
+ objOut.writeObject(value)
+ }
+
+ override def fileSegment(): FileSegment = {
+ val bytesWritten = lastValidPosition - initialPosition
+ new FileSegment(file, initialPosition, bytesWritten)
+ }
+
+ // Only valid if called after close()
+ override def timeWriting() = _timeWriting
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
new file mode 100644
index 0000000000..bcb58ad946
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.File
+import java.text.SimpleDateFormat
+import java.util.{Date, Random}
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.Logging
+import org.apache.spark.executor.ExecutorExitCode
+import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
+
+/**
+ * Creates and maintains the logical mapping between logical blocks and physical on-disk
+ * locations. By default, one block is mapped to one file with a name given by its BlockId.
+ * However, it is also possible to have a block map to only a segment of a file, by calling
+ * mapBlockToFileSegment().
+ *
+ * @param rootDirs The directories to use for storing block files. Data will be hashed among these.
+ */
+private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver with Logging {
+
+ private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
+ private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
+
+ // Create one local directory for each path mentioned in spark.local.dir; then, inside this
+ // directory, create multiple subdirectories that we will hash files into, in order to avoid
+ // having really large inodes at the top level.
+ private val localDirs: Array[File] = createLocalDirs()
+ private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
+ private var shuffleSender : ShuffleSender = null
+
+ // Stores only Blocks which have been specifically mapped to segments of files
+ // (rather than the default, which maps a Block to a whole file).
+ // This keeps our bookkeeping down, since the file system itself tracks the standalone Blocks.
+ private val blockToFileSegmentMap = new TimeStampedHashMap[BlockId, FileSegment]
+
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DISK_BLOCK_MANAGER, this.cleanup)
+
+ addShutdownHook()
+
+ /**
+ * Creates a logical mapping from the given BlockId to a segment of a file.
+ * This will cause any accesses of the logical BlockId to be directed to the specified
+ * physical location.
+ */
+ def mapBlockToFileSegment(blockId: BlockId, fileSegment: FileSegment) {
+ blockToFileSegmentMap.put(blockId, fileSegment)
+ }
+
+ /**
+ * Returns the phyiscal file segment in which the given BlockId is located.
+ * If the BlockId has been mapped to a specific FileSegment, that will be returned.
+ * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
+ */
+ def getBlockLocation(blockId: BlockId): FileSegment = {
+ if (blockToFileSegmentMap.internalMap.containsKey(blockId)) {
+ blockToFileSegmentMap.get(blockId).get
+ } else {
+ val file = getFile(blockId.name)
+ new FileSegment(file, 0, file.length())
+ }
+ }
+
+ /**
+ * Simply returns a File to place the given Block into. This does not physically create the file.
+ * If filename is given, that file will be used. Otherwise, we will use the BlockId to get
+ * a unique filename.
+ */
+ def createBlockFile(blockId: BlockId, filename: String = "", allowAppending: Boolean): File = {
+ val actualFilename = if (filename == "") blockId.name else filename
+ val file = getFile(actualFilename)
+ if (!allowAppending && file.exists()) {
+ throw new IllegalStateException(
+ "Attempted to create file that already exists: " + actualFilename)
+ }
+ file
+ }
+
+ private def getFile(filename: String): File = {
+ // Figure out which local directory it hashes to, and which subdirectory in that
+ val hash = Utils.nonNegativeHash(filename)
+ val dirId = hash % localDirs.length
+ val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
+
+ // Create the subdirectory if it doesn't already exist
+ var subDir = subDirs(dirId)(subDirId)
+ if (subDir == null) {
+ subDir = subDirs(dirId).synchronized {
+ val old = subDirs(dirId)(subDirId)
+ if (old != null) {
+ old
+ } else {
+ val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
+ newDir.mkdir()
+ subDirs(dirId)(subDirId) = newDir
+ newDir
+ }
+ }
+ }
+
+ new File(subDir, filename)
+ }
+
+ private def createLocalDirs(): Array[File] = {
+ logDebug("Creating local directories at root dirs '" + rootDirs + "'")
+ val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
+ rootDirs.split(",").map { rootDir =>
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirId: String = null
+ var tries = 0
+ val rand = new Random()
+ while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
+ tries += 1
+ try {
+ localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
+ localDir = new File(rootDir, "spark-local-" + localDirId)
+ if (!localDir.exists) {
+ foundLocalDir = localDir.mkdirs()
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
+ " attempts to create local dir in " + rootDir)
+ System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
+ }
+ logInfo("Created local directory at " + localDir)
+ localDir
+ }
+ }
+
+ private def cleanup(cleanupTime: Long) {
+ blockToFileSegmentMap.clearOldValues(cleanupTime)
+ }
+
+ private def addShutdownHook() {
+ localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
+ override def run() {
+ logDebug("Shutdown hook called")
+ localDirs.foreach { localDir =>
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ } catch {
+ case t: Throwable =>
+ logError("Exception while deleting local spark dir: " + localDir, t)
+ }
+ }
+
+ if (shuffleSender != null) {
+ shuffleSender.stop()
+ }
+ }
+ })
+ }
+
+ private[storage] def startShuffleBlockSender(port: Int): Int = {
+ shuffleSender = new ShuffleSender(port, this)
+ logInfo("Created ShuffleSender binding to port : " + shuffleSender.port)
+ shuffleSender.port
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index b7ca61e938..a3c496f9e0 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -17,120 +17,25 @@
package org.apache.spark.storage
-import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
+import java.io.{FileOutputStream, RandomAccessFile}
import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
-import java.util.{Random, Date}
-import java.text.SimpleDateFormat
import scala.collection.mutable.ArrayBuffer
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
-import org.apache.spark.executor.ExecutorExitCode
-import org.apache.spark.serializer.{Serializer, SerializationStream}
import org.apache.spark.Logging
-import org.apache.spark.network.netty.ShuffleSender
-import org.apache.spark.network.netty.PathResolver
+import org.apache.spark.serializer.Serializer
import org.apache.spark.util.Utils
/**
* Stores BlockManager blocks on disk.
*/
-private class DiskStore(blockManager: BlockManager, rootDirs: String)
+private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager)
extends BlockStore(blockManager) with Logging {
- class DiskBlockObjectWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
- extends BlockObjectWriter(blockId) {
-
- private val f: File = createFile(blockId /*, allowAppendExisting */)
-
- // The file channel, used for repositioning / truncating the file.
- private var channel: FileChannel = null
- private var bs: OutputStream = null
- private var objOut: SerializationStream = null
- private var lastValidPosition = 0L
- private var initialized = false
-
- override def open(): DiskBlockObjectWriter = {
- val fos = new FileOutputStream(f, true)
- channel = fos.getChannel()
- bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize))
- objOut = serializer.newInstance().serializeStream(bs)
- initialized = true
- this
- }
-
- override def close() {
- if (initialized) {
- objOut.close()
- channel = null
- bs = null
- objOut = null
- }
- // Invoke the close callback handler.
- super.close()
- }
-
- override def isOpen: Boolean = objOut != null
-
- // Flush the partial writes, and set valid length to be the length of the entire file.
- // Return the number of bytes written for this commit.
- override def commit(): Long = {
- if (initialized) {
- // NOTE: Flush the serializer first and then the compressed/buffered output stream
- objOut.flush()
- bs.flush()
- val prevPos = lastValidPosition
- lastValidPosition = channel.position()
- lastValidPosition - prevPos
- } else {
- // lastValidPosition is zero if stream is uninitialized
- lastValidPosition
- }
- }
-
- override def revertPartialWrites() {
- if (initialized) {
- // Discard current writes. We do this by flushing the outstanding writes and
- // truncate the file to the last valid position.
- objOut.flush()
- bs.flush()
- channel.truncate(lastValidPosition)
- }
- }
-
- override def write(value: Any) {
- if (!initialized) {
- open()
- }
- objOut.writeObject(value)
- }
-
- override def size(): Long = lastValidPosition
- }
-
- private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
-
- private var shuffleSender : ShuffleSender = null
- // Create one local directory for each path mentioned in spark.local.dir; then, inside this
- // directory, create multiple subdirectories that we will hash files into, in order to avoid
- // having really large inodes at the top level.
- private val localDirs: Array[File] = createLocalDirs()
- private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
-
- addShutdownHook()
-
- def getBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
- : BlockObjectWriter = {
- new DiskBlockObjectWriter(blockId, serializer, bufferSize)
- }
-
override def getSize(blockId: BlockId): Long = {
- getFile(blockId).length()
+ diskManager.getBlockLocation(blockId).length
}
override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
@@ -139,27 +44,15 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
val bytes = _bytes.duplicate()
logDebug("Attempting to put block " + blockId)
val startTime = System.currentTimeMillis
- val file = createFile(blockId)
- val channel = new RandomAccessFile(file, "rw").getChannel()
+ val file = diskManager.createBlockFile(blockId, allowAppending = false)
+ val channel = new FileOutputStream(file).getChannel()
while (bytes.remaining > 0) {
channel.write(bytes)
}
channel.close()
val finishTime = System.currentTimeMillis
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
- }
-
- private def getFileBytes(file: File): ByteBuffer = {
- val length = file.length()
- val channel = new RandomAccessFile(file, "r").getChannel()
- val buffer = try {
- channel.map(MapMode.READ_ONLY, 0, length)
- } finally {
- channel.close()
- }
-
- buffer
+ file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
}
override def putValues(
@@ -171,21 +64,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
logDebug("Attempting to write values for block " + blockId)
val startTime = System.currentTimeMillis
- val file = createFile(blockId)
- val fileOut = blockManager.wrapForCompression(blockId,
- new FastBufferedOutputStream(new FileOutputStream(file)))
- val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut)
- objOut.writeAll(values.iterator)
- objOut.close()
- val length = file.length()
+ val file = diskManager.createBlockFile(blockId, allowAppending = false)
+ val outputStream = new FileOutputStream(file)
+ blockManager.dataSerializeStream(blockId, outputStream, values.iterator)
+ val length = file.length
val timeTaken = System.currentTimeMillis - startTime
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.bytesToString(length), timeTaken))
+ file.getName, Utils.bytesToString(length), timeTaken))
if (returnValues) {
// Return a byte buffer for the contents of the file
- val buffer = getFileBytes(file)
+ val buffer = getBytes(blockId).get
PutResult(length, Right(buffer))
} else {
PutResult(length, null)
@@ -193,13 +83,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
- val file = getFile(blockId)
- val bytes = getFileBytes(file)
- Some(bytes)
+ val segment = diskManager.getBlockLocation(blockId)
+ val channel = new RandomAccessFile(segment.file, "r").getChannel()
+ val buffer = try {
+ channel.map(MapMode.READ_ONLY, segment.offset, segment.length)
+ } finally {
+ channel.close()
+ }
+ Some(buffer)
}
override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
- getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
+ getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer))
}
/**
@@ -211,118 +106,20 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
override def remove(blockId: BlockId): Boolean = {
- val file = getFile(blockId)
- if (file.exists()) {
+ val fileSegment = diskManager.getBlockLocation(blockId)
+ val file = fileSegment.file
+ if (file.exists() && file.length() == fileSegment.length) {
file.delete()
} else {
+ if (fileSegment.length < file.length()) {
+ logWarning("Could not delete block associated with only a part of a file: " + blockId)
+ }
false
}
}
override def contains(blockId: BlockId): Boolean = {
- getFile(blockId).exists()
- }
-
- private def createFile(blockId: BlockId, allowAppendExisting: Boolean = false): File = {
- val file = getFile(blockId)
- if (!allowAppendExisting && file.exists()) {
- // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
- // was rescheduled on the same machine as the old task.
- logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
- file.delete()
- }
- file
- }
-
- private def getFile(blockId: BlockId): File = {
- logDebug("Getting file for block " + blockId)
-
- // Figure out which local directory it hashes to, and which subdirectory in that
- val hash = Utils.nonNegativeHash(blockId)
- val dirId = hash % localDirs.length
- val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
-
- // Create the subdirectory if it doesn't already exist
- var subDir = subDirs(dirId)(subDirId)
- if (subDir == null) {
- subDir = subDirs(dirId).synchronized {
- val old = subDirs(dirId)(subDirId)
- if (old != null) {
- old
- } else {
- val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
- newDir.mkdir()
- subDirs(dirId)(subDirId) = newDir
- newDir
- }
- }
- }
-
- new File(subDir, blockId.name)
- }
-
- private def createLocalDirs(): Array[File] = {
- logDebug("Creating local directories at root dirs '" + rootDirs + "'")
- val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
- rootDirs.split(",").map { rootDir =>
- var foundLocalDir = false
- var localDir: File = null
- var localDirId: String = null
- var tries = 0
- val rand = new Random()
- while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
- tries += 1
- try {
- localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
- localDir = new File(rootDir, "spark-local-" + localDirId)
- if (!localDir.exists) {
- foundLocalDir = localDir.mkdirs()
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
- }
- }
- if (!foundLocalDir) {
- logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
- " attempts to create local dir in " + rootDir)
- System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
- }
- logInfo("Created local directory at " + localDir)
- localDir
- }
- }
-
- private def addShutdownHook() {
- localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
- Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
- override def run() {
- logDebug("Shutdown hook called")
- localDirs.foreach { localDir =>
- try {
- if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
- } catch {
- case t: Throwable =>
- logError("Exception while deleting local spark dir: " + localDir, t)
- }
- }
- if (shuffleSender != null) {
- shuffleSender.stop()
- }
- }
- })
- }
-
- private[storage] def startShuffleBlockSender(port: Int): Int = {
- val pResolver = new PathResolver {
- override def getAbsolutePath(blockIdString: String): String = {
- val blockId = BlockId(blockIdString)
- if (!blockId.isShuffle) null
- else DiskStore.this.getFile(blockId).getAbsolutePath
- }
- }
- shuffleSender = new ShuffleSender(port, pResolver)
- logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port)
- shuffleSender.port
+ val file = diskManager.getBlockLocation(blockId).file
+ file.exists()
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
new file mode 100644
index 0000000000..555486830a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.File
+
+/**
+ * References a particular segment of a file (potentially the entire file),
+ * based off an offset and a length.
+ */
+private[spark] class FileSegment(val file: File, val offset: Long, val length : Long) {
+ override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length)
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index f39fcd87fb..066e45a12b 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -17,12 +17,13 @@
package org.apache.spark.storage
-import org.apache.spark.serializer.Serializer
+import java.util.concurrent.ConcurrentLinkedQueue
+import java.util.concurrent.atomic.AtomicInteger
+import org.apache.spark.serializer.Serializer
private[spark]
-class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter])
-
+class ShuffleWriterGroup(val id: Int, val fileId: Int, val writers: Array[BlockObjectWriter])
private[spark]
trait ShuffleBlocks {
@@ -30,24 +31,66 @@ trait ShuffleBlocks {
def releaseWriters(group: ShuffleWriterGroup)
}
+/**
+ * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one writer
+ * per reducer.
+ *
+ * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle
+ * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer
+ * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle files,
+ * it releases them for another task.
+ * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple:
+ * - shuffleId: The unique id given to the entire shuffle stage.
+ * - bucketId: The id of the output partition (i.e., reducer id)
+ * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a
+ * time owns a particular fileId, and this id is returned to a pool when the task finishes.
+ */
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) {
+ // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
+ // TODO: Remove this once the shuffle file consolidation feature is stable.
+ val consolidateShuffleFiles =
+ System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean
+
+ var nextFileId = new AtomicInteger(0)
+ val unusedFileIds = new ConcurrentLinkedQueue[java.lang.Integer]()
- def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = {
+ def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) = {
new ShuffleBlocks {
// Get a group of writers for a map task.
override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+ val fileId = getUnusedFileId()
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
- blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
+ if (consolidateShuffleFiles) {
+ val filename = physicalFileName(shuffleId, bucketId, fileId)
+ blockManager.getDiskWriter(blockId, filename, serializer, bufferSize)
+ } else {
+ blockManager.getDiskWriter(blockId, blockId.name, serializer, bufferSize)
+ }
}
- new ShuffleWriterGroup(mapId, writers)
+ new ShuffleWriterGroup(mapId, fileId, writers)
}
- override def releaseWriters(group: ShuffleWriterGroup) = {
- // Nothing really to release here.
+ override def releaseWriters(group: ShuffleWriterGroup) {
+ recycleFileId(group.fileId)
}
}
}
+
+ private def getUnusedFileId(): Int = {
+ val fileId = unusedFileIds.poll()
+ if (fileId == null) nextFileId.getAndIncrement() else fileId
+ }
+
+ private def recycleFileId(fileId: Int) {
+ if (consolidateShuffleFiles) {
+ unusedFileIds.add(fileId)
+ }
+ }
+
+ private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
+ "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
new file mode 100644
index 0000000000..7dcadc3805
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
@@ -0,0 +1,86 @@
+package org.apache.spark.storage
+
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{CountDownLatch, Executors}
+
+import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.SparkContext
+import org.apache.spark.util.Utils
+
+/**
+ * Utility for micro-benchmarking shuffle write performance.
+ *
+ * Writes simulated shuffle output from several threads and records the observed throughput.
+ */
+object StoragePerfTester {
+ def main(args: Array[String]) = {
+ /** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */
+ val dataSizeMb = Utils.memoryStringToMb(sys.env.getOrElse("OUTPUT_DATA", "1g"))
+
+ /** Number of map tasks. All tasks execute concurrently. */
+ val numMaps = sys.env.get("NUM_MAPS").map(_.toInt).getOrElse(8)
+
+ /** Number of reduce splits for each map task. */
+ val numOutputSplits = sys.env.get("NUM_REDUCERS").map(_.toInt).getOrElse(500)
+
+ val recordLength = 1000 // ~1KB records
+ val totalRecords = dataSizeMb * 1000
+ val recordsPerMap = totalRecords / numMaps
+
+ val writeData = "1" * recordLength
+ val executor = Executors.newFixedThreadPool(numMaps)
+
+ System.setProperty("spark.shuffle.compress", "false")
+ System.setProperty("spark.shuffle.sync", "true")
+
+ // This is only used to instantiate a BlockManager. All thread scheduling is done manually.
+ val sc = new SparkContext("local[4]", "Write Tester")
+ val blockManager = sc.env.blockManager
+
+ def writeOutputBytes(mapId: Int, total: AtomicLong) = {
+ val shuffle = blockManager.shuffleBlockManager.forShuffle(1, numOutputSplits,
+ new KryoSerializer())
+ val buckets = shuffle.acquireWriters(mapId)
+ for (i <- 1 to recordsPerMap) {
+ buckets.writers(i % numOutputSplits).write(writeData)
+ }
+ buckets.writers.map {w =>
+ w.commit()
+ total.addAndGet(w.fileSegment().length)
+ w.close()
+ }
+
+ shuffle.releaseWriters(buckets)
+ }
+
+ val start = System.currentTimeMillis()
+ val latch = new CountDownLatch(numMaps)
+ val totalBytes = new AtomicLong()
+ for (task <- 1 to numMaps) {
+ executor.submit(new Runnable() {
+ override def run() = {
+ try {
+ writeOutputBytes(task, totalBytes)
+ latch.countDown()
+ } catch {
+ case e: Exception =>
+ println("Exception in child thread: " + e + " " + e.getMessage)
+ System.exit(1)
+ }
+ }
+ })
+ }
+ latch.await()
+ val end = System.currentTimeMillis()
+ val time = (end - start) / 1000.0
+ val bytesPerSecond = totalBytes.get() / time
+ val bytesPerFile = (totalBytes.get() / (numOutputSplits * numMaps.toDouble)).toLong
+
+ System.err.println("files_total\t\t%s".format(numMaps * numOutputSplits))
+ System.err.println("bytes_per_file\t\t%s".format(Utils.bytesToString(bytesPerFile)))
+ System.err.println("agg_throughput\t\t%s/s".format(Utils.bytesToString(bytesPerSecond.toLong)))
+
+ executor.shutdown()
+ sc.stop()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
index b39c0e9769..ca5a28625b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala
@@ -38,7 +38,7 @@ private[spark] class IndexPage(parent: JobProgressUI) {
val now = System.currentTimeMillis()
var activeTime = 0L
- for (tasks <- listener.stageToTasksActive.values; t <- tasks) {
+ for (tasks <- listener.stageIdToTasksActive.values; t <- tasks) {
activeTime += t.timeRunning(now)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index eb3b4e8522..6b854740d6 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -36,52 +36,52 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt
val DEFAULT_POOL_NAME = "default"
- val stageToPool = new HashMap[Stage, String]()
- val stageToDescription = new HashMap[Stage, String]()
- val poolToActiveStages = new HashMap[String, HashSet[Stage]]()
+ val stageIdToPool = new HashMap[Int, String]()
+ val stageIdToDescription = new HashMap[Int, String]()
+ val poolToActiveStages = new HashMap[String, HashSet[StageInfo]]()
- val activeStages = HashSet[Stage]()
- val completedStages = ListBuffer[Stage]()
- val failedStages = ListBuffer[Stage]()
+ val activeStages = HashSet[StageInfo]()
+ val completedStages = ListBuffer[StageInfo]()
+ val failedStages = ListBuffer[StageInfo]()
// Total metrics reflect metrics only for completed tasks
var totalTime = 0L
var totalShuffleRead = 0L
var totalShuffleWrite = 0L
- val stageToTime = HashMap[Int, Long]()
- val stageToShuffleRead = HashMap[Int, Long]()
- val stageToShuffleWrite = HashMap[Int, Long]()
- val stageToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
- val stageToTasksComplete = HashMap[Int, Int]()
- val stageToTasksFailed = HashMap[Int, Int]()
- val stageToTaskInfos =
+ val stageIdToTime = HashMap[Int, Long]()
+ val stageIdToShuffleRead = HashMap[Int, Long]()
+ val stageIdToShuffleWrite = HashMap[Int, Long]()
+ val stageIdToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
+ val stageIdToTasksComplete = HashMap[Int, Int]()
+ val stageIdToTasksFailed = HashMap[Int, Int]()
+ val stageIdToTaskInfos =
HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]()
override def onJobStart(jobStart: SparkListenerJobStart) {}
override def onStageCompleted(stageCompleted: StageCompleted) = synchronized {
- val stage = stageCompleted.stageInfo.stage
- poolToActiveStages(stageToPool(stage)) -= stage
+ val stage = stageCompleted.stage
+ poolToActiveStages(stageIdToPool(stage.stageId)) -= stage
activeStages -= stage
completedStages += stage
trimIfNecessary(completedStages)
}
/** If stages is too large, remove and garbage collect old stages */
- def trimIfNecessary(stages: ListBuffer[Stage]) = synchronized {
+ def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
if (stages.size > RETAINED_STAGES) {
val toRemove = RETAINED_STAGES / 10
stages.takeRight(toRemove).foreach( s => {
- stageToTaskInfos.remove(s.id)
- stageToTime.remove(s.id)
- stageToShuffleRead.remove(s.id)
- stageToShuffleWrite.remove(s.id)
- stageToTasksActive.remove(s.id)
- stageToTasksComplete.remove(s.id)
- stageToTasksFailed.remove(s.id)
- stageToPool.remove(s)
- if (stageToDescription.contains(s)) {stageToDescription.remove(s)}
+ stageIdToTaskInfos.remove(s.stageId)
+ stageIdToTime.remove(s.stageId)
+ stageIdToShuffleRead.remove(s.stageId)
+ stageIdToShuffleWrite.remove(s.stageId)
+ stageIdToTasksActive.remove(s.stageId)
+ stageIdToTasksComplete.remove(s.stageId)
+ stageIdToTasksFailed.remove(s.stageId)
+ stageIdToPool.remove(s.stageId)
+ if (stageIdToDescription.contains(s.stageId)) {stageIdToDescription.remove(s.stageId)}
})
stages.trimEnd(toRemove)
}
@@ -95,63 +95,69 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val poolName = Option(stageSubmitted.properties).map {
p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME)
}.getOrElse(DEFAULT_POOL_NAME)
- stageToPool(stage) = poolName
+ stageIdToPool(stage.stageId) = poolName
val description = Option(stageSubmitted.properties).flatMap {
p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION))
}
- description.map(d => stageToDescription(stage) = d)
+ description.map(d => stageIdToDescription(stage.stageId) = d)
- val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[Stage]())
+ val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[StageInfo]())
stages += stage
}
override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
val sid = taskStart.task.stageId
- val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
+ val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive += taskStart.taskInfo
- val taskList = stageToTaskInfos.getOrElse(
+ val taskList = stageIdToTaskInfos.getOrElse(
sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
taskList += ((taskStart.taskInfo, None, None))
- stageToTaskInfos(sid) = taskList
+ stageIdToTaskInfos(sid) = taskList
}
-
+
+ override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult)
+ = synchronized {
+ // Do nothing: because we don't do a deep copy of the TaskInfo, the TaskInfo in
+ // stageToTaskInfos already has the updated status.
+ }
+
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
val sid = taskEnd.task.stageId
- val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
+ val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive -= taskEnd.taskInfo
val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
taskEnd.reason match {
case e: ExceptionFailure =>
- stageToTasksFailed(sid) = stageToTasksFailed.getOrElse(sid, 0) + 1
+ stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1
(Some(e), e.metrics)
case _ =>
- stageToTasksComplete(sid) = stageToTasksComplete.getOrElse(sid, 0) + 1
+ stageIdToTasksComplete(sid) = stageIdToTasksComplete.getOrElse(sid, 0) + 1
(None, Option(taskEnd.taskMetrics))
}
- stageToTime.getOrElseUpdate(sid, 0L)
+ stageIdToTime.getOrElseUpdate(sid, 0L)
val time = metrics.map(m => m.executorRunTime).getOrElse(0)
- stageToTime(sid) += time
+ stageIdToTime(sid) += time
totalTime += time
- stageToShuffleRead.getOrElseUpdate(sid, 0L)
+ stageIdToShuffleRead.getOrElseUpdate(sid, 0L)
val shuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s =>
s.remoteBytesRead).getOrElse(0L)
- stageToShuffleRead(sid) += shuffleRead
+ stageIdToShuffleRead(sid) += shuffleRead
totalShuffleRead += shuffleRead
- stageToShuffleWrite.getOrElseUpdate(sid, 0L)
+ stageIdToShuffleWrite.getOrElseUpdate(sid, 0L)
val shuffleWrite = metrics.flatMap(m => m.shuffleWriteMetrics).map(s =>
s.shuffleBytesWritten).getOrElse(0L)
- stageToShuffleWrite(sid) += shuffleWrite
+ stageIdToShuffleWrite(sid) += shuffleWrite
totalShuffleWrite += shuffleWrite
- val taskList = stageToTaskInfos.getOrElse(
+ val taskList = stageIdToTaskInfos.getOrElse(
sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
taskList -= ((taskEnd.taskInfo, None, None))
taskList += ((taskEnd.taskInfo, metrics, failureInfo))
- stageToTaskInfos(sid) = taskList
+ stageIdToTaskInfos(sid) = taskList
}
override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized {
@@ -159,10 +165,15 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
case end: SparkListenerJobEnd =>
end.jobResult match {
case JobFailed(ex, Some(stage)) =>
- activeStages -= stage
- poolToActiveStages(stageToPool(stage)) -= stage
- failedStages += stage
- trimIfNecessary(failedStages)
+ /* If two jobs share a stage we could get this failure message twice. So we first
+ * check whether we've already retired this stage. */
+ val stageInfo = activeStages.filter(s => s.stageId == stage.id).headOption
+ stageInfo.foreach {s =>
+ activeStages -= s
+ poolToActiveStages(stageIdToPool(stage.id)) -= s
+ failedStages += s
+ trimIfNecessary(failedStages)
+ }
case _ =>
}
case _ =>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
index 06810d8dbc..cfeeccda41 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
@@ -21,13 +21,13 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.xml.Node
-import org.apache.spark.scheduler.{Schedulable, Stage}
+import org.apache.spark.scheduler.{Schedulable, StageInfo}
import org.apache.spark.ui.UIUtils
/** Table showing list of pools */
private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressListener) {
- var poolToActiveStages: HashMap[String, HashSet[Stage]] = listener.poolToActiveStages
+ var poolToActiveStages: HashMap[String, HashSet[StageInfo]] = listener.poolToActiveStages
def toNodeSeq(): Seq[Node] = {
listener.synchronized {
@@ -35,7 +35,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis
}
}
- private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[Stage]]) => Seq[Node],
+ private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[StageInfo]]) => Seq[Node],
rows: Seq[Schedulable]
): Seq[Node] = {
<table class="table table-bordered table-striped table-condensed sortable table-fixed">
@@ -53,7 +53,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis
</table>
}
- private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[Stage]])
+ private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[StageInfo]])
: Seq[Node] = {
val activeStages = poolToActiveStages.get(p.name) match {
case Some(stages) => stages.size
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 163a3746ea..35b5d5fd59 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -40,7 +40,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
val stageId = request.getParameter("id").toInt
val now = System.currentTimeMillis()
- if (!listener.stageToTaskInfos.contains(stageId)) {
+ if (!listener.stageIdToTaskInfos.contains(stageId)) {
val content =
<div>
<h4>Summary Metrics</h4> No tasks have started yet
@@ -49,23 +49,23 @@ private[spark] class StagePage(parent: JobProgressUI) {
return headerSparkPage(content, parent.sc, "Details for Stage %s".format(stageId), Stages)
}
- val tasks = listener.stageToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime)
+ val tasks = listener.stageIdToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime)
val numCompleted = tasks.count(_._1.finished)
- val shuffleReadBytes = listener.stageToShuffleRead.getOrElse(stageId, 0L)
+ val shuffleReadBytes = listener.stageIdToShuffleRead.getOrElse(stageId, 0L)
val hasShuffleRead = shuffleReadBytes > 0
- val shuffleWriteBytes = listener.stageToShuffleWrite.getOrElse(stageId, 0L)
+ val shuffleWriteBytes = listener.stageIdToShuffleWrite.getOrElse(stageId, 0L)
val hasShuffleWrite = shuffleWriteBytes > 0
var activeTime = 0L
- listener.stageToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
+ listener.stageIdToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
val summary =
<div>
<ul class="unstyled">
<li>
<strong>CPU time: </strong>
- {parent.formatDuration(listener.stageToTime.getOrElse(stageId, 0L) + activeTime)}
+ {parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)}
</li>
{if (hasShuffleRead)
<li>
@@ -83,10 +83,10 @@ private[spark] class StagePage(parent: JobProgressUI) {
</div>
val taskHeaders: Seq[String] =
- Seq("Task ID", "Status", "Locality Level", "Executor", "Launch Time", "Duration") ++
- Seq("GC Time") ++
+ Seq("Task Index", "Task ID", "Status", "Locality Level", "Executor", "Launch Time") ++
+ Seq("Duration", "GC Time") ++
{if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++
- {if (hasShuffleWrite) Seq("Shuffle Write") else Nil} ++
+ {if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++
Seq("Errors")
val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite), tasks)
@@ -153,6 +153,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L)
<tr>
+ <td>{info.index}</td>
<td>{info.taskId}</td>
<td>{info.status}</td>
<td>{info.taskLocality}</td>
@@ -169,6 +170,8 @@ private[spark] class StagePage(parent: JobProgressUI) {
Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")}</td>
}}
{if (shuffleWrite) {
+ <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
+ parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}</td>
<td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")}</td>
}}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 07db8622da..d7d0441c38 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -22,13 +22,13 @@ import java.util.Date
import scala.xml.Node
import scala.collection.mutable.HashSet
-import org.apache.spark.scheduler.{SchedulingMode, Stage, TaskInfo}
+import org.apache.spark.scheduler.{SchedulingMode, StageInfo, TaskInfo}
import org.apache.spark.ui.UIUtils
import org.apache.spark.util.Utils
/** Page showing list of all ongoing and recently finished stages */
-private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressUI) {
+private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgressUI) {
val listener = parent.listener
val dateFmt = parent.dateFmt
@@ -73,40 +73,40 @@ private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressU
}
- private def stageRow(s: Stage): Seq[Node] = {
+ private def stageRow(s: StageInfo): Seq[Node] = {
val submissionTime = s.submissionTime match {
case Some(t) => dateFmt.format(new Date(t))
case None => "Unknown"
}
- val shuffleRead = listener.stageToShuffleRead.getOrElse(s.id, 0L) match {
+ val shuffleRead = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L) match {
case 0 => ""
case b => Utils.bytesToString(b)
}
- val shuffleWrite = listener.stageToShuffleWrite.getOrElse(s.id, 0L) match {
+ val shuffleWrite = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L) match {
case 0 => ""
case b => Utils.bytesToString(b)
}
- val startedTasks = listener.stageToTasksActive.getOrElse(s.id, HashSet[TaskInfo]()).size
- val completedTasks = listener.stageToTasksComplete.getOrElse(s.id, 0)
- val failedTasks = listener.stageToTasksFailed.getOrElse(s.id, 0) match {
+ val startedTasks = listener.stageIdToTasksActive.getOrElse(s.stageId, HashSet[TaskInfo]()).size
+ val completedTasks = listener.stageIdToTasksComplete.getOrElse(s.stageId, 0)
+ val failedTasks = listener.stageIdToTasksFailed.getOrElse(s.stageId, 0) match {
case f if f > 0 => "(%s failed)".format(f)
case _ => ""
}
- val totalTasks = s.numPartitions
+ val totalTasks = s.numTasks
- val poolName = listener.stageToPool.get(s)
+ val poolName = listener.stageIdToPool.get(s.stageId)
val nameLink =
- <a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.id)}>{s.name}</a>
- val description = listener.stageToDescription.get(s)
+ <a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.stageId)}>{s.name}</a>
+ val description = listener.stageIdToDescription.get(s.stageId)
.map(d => <div><em>{d}</em></div><div>{nameLink}</div>).getOrElse(nameLink)
val finishTime = s.completionTime.getOrElse(System.currentTimeMillis())
val duration = s.submissionTime.map(t => finishTime - t)
<tr>
- <td>{s.id}</td>
+ <td>{s.stageId}</td>
{if (isFairScheduler) {
<td><a href={"%s/stages/pool?poolname=%s".format(UIUtils.prependBaseUri(),poolName.get)}>
{poolName.get}</a></td>}
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index 0ce1394c77..3f963727d9 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -56,9 +56,10 @@ class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, clea
}
object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext", "HttpBroadcast", "DagScheduler", "ResultTask",
- "ShuffleMapTask", "BlockManager", "BroadcastVars") {
+ "ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") {
- val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, SHUFFLE_MAP_TASK, BLOCK_MANAGER, BROADCAST_VARS = Value
+ val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
+ SHUFFLE_MAP_TASK, BLOCK_MANAGER, DISK_BLOCK_MANAGER, BROADCAST_VARS = Value
type MetadataCleanerType = Value
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index f384875cc9..0c5c12b7a8 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -37,6 +37,7 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream,
import org.apache.spark.deploy.SparkHadoopUtil
import java.nio.ByteBuffer
import org.apache.spark.{SparkEnv, SparkException, Logging}
+import java.util.ConcurrentModificationException
/**
@@ -280,9 +281,8 @@ private[spark] object Utils extends Logging {
}
case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
- val env = SparkEnv.get
val uri = new URI(url)
- val conf = env.hadoop.newConfiguration()
+ val conf = SparkHadoopUtil.get.newConfiguration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(tempFile)
@@ -447,14 +447,17 @@ private[spark] object Utils extends Logging {
hostPortParseResults.get(hostPort)
}
- private[spark] val daemonThreadFactory: ThreadFactory =
- new ThreadFactoryBuilder().setDaemon(true).build()
+ private val daemonThreadFactoryBuilder: ThreadFactoryBuilder =
+ new ThreadFactoryBuilder().setDaemon(true)
/**
- * Wrapper over newCachedThreadPool.
+ * Wrapper over newCachedThreadPool. Thread names are formatted as prefix-ID, where ID is a
+ * unique, sequentially assigned integer.
*/
- def newDaemonCachedThreadPool(): ThreadPoolExecutor =
- Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
+ def newDaemonCachedThreadPool(prefix: String): ThreadPoolExecutor = {
+ val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build()
+ Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor]
+ }
/**
* Return the string to tell how long has passed in seconds. The passing parameter should be in
@@ -465,10 +468,13 @@ private[spark] object Utils extends Logging {
}
/**
- * Wrapper over newFixedThreadPool.
+ * Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a
+ * unique, sequentially assigned integer.
*/
- def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor =
- Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
+ def newDaemonFixedThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = {
+ val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build()
+ Executors.newFixedThreadPool(nThreads, threadFactory).asInstanceOf[ThreadPoolExecutor]
+ }
private def listFilesSafely(file: File): Seq[File] = {
val files = file.listFiles()
@@ -813,4 +819,10 @@ private[spark] object Utils extends Logging {
// Nothing else to guard against ?
hashAbs
}
+
+ /** Returns a copy of the system properties that is thread-safe to iterator over. */
+ def getSystemProperties(): Map[String, String] = {
+ return System.getProperties().clone()
+ .asInstanceOf[java.util.Properties].toMap[String, String]
+ }
}