diff options
author | Reynold Xin <rxin@apache.org> | 2013-11-02 22:45:15 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2013-11-02 22:45:15 -0700 |
commit | da6bb0aedd5121d9e3b92031dcc0884a9682da05 (patch) | |
tree | 818928e20787661304af22153683ab987c0df491 /core/src/main/scala/org | |
parent | 3e7df8f6c6edec9e71c6e416de86aa1a2cefc176 (diff) | |
parent | 41ead7a74533ffdd208a4ba2f7cd38945b4343ec (diff) | |
download | spark-da6bb0aedd5121d9e3b92031dcc0884a9682da05.tar.gz spark-da6bb0aedd5121d9e3b92031dcc0884a9682da05.tar.bz2 spark-da6bb0aedd5121d9e3b92031dcc0884a9682da05.zip |
Merge branch 'master' into hash1
Diffstat (limited to 'core/src/main/scala/org')
77 files changed, 2001 insertions, 3303 deletions
diff --git a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala index f87460039b..0c47afae54 100644 --- a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala @@ -17,20 +17,29 @@ package org.apache.hadoop.mapred +private[apache] trait SparkHadoopMapRedUtil { def newJobContext(conf: JobConf, jobId: JobID): JobContext = { - val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", "org.apache.hadoop.mapred.JobContext"); - val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[org.apache.hadoop.mapreduce.JobID]) + val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", + "org.apache.hadoop.mapred.JobContext") + val ctor = klass.getDeclaredConstructor(classOf[JobConf], + classOf[org.apache.hadoop.mapreduce.JobID]) ctor.newInstance(conf, jobId).asInstanceOf[JobContext] } def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = { - val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", "org.apache.hadoop.mapred.TaskAttemptContext") + val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", + "org.apache.hadoop.mapred.TaskAttemptContext") val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID]) ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] } - def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = { + def newTaskAttemptID( + jtIdentifier: String, + jobId: Int, + isMap: Boolean, + taskId: Int, + attemptId: Int) = { new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId) } diff --git a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala index 93180307fa..32429f01ac 100644 --- a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala @@ -17,9 +17,10 @@ package org.apache.hadoop.mapreduce -import org.apache.hadoop.conf.Configuration import java.lang.{Integer => JInteger, Boolean => JBoolean} +import org.apache.hadoop.conf.Configuration +private[apache] trait SparkHadoopMapReduceUtil { def newJobContext(conf: Configuration, jobId: JobID): JobContext = { val klass = firstAvailableClass( @@ -37,23 +38,31 @@ trait SparkHadoopMapReduceUtil { ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] } - def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = { - val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID"); + def newTaskAttemptID( + jtIdentifier: String, + jobId: Int, + isMap: Boolean, + taskId: Int, + attemptId: Int) = { + val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID") try { - // first, attempt to use the old-style constructor that takes a boolean isMap (not available in YARN) + // First, attempt to use the old-style constructor that takes a boolean isMap + // (not available in YARN) val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean], - classOf[Int], classOf[Int]) - ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), new - JInteger(attemptId)).asInstanceOf[TaskAttemptID] + classOf[Int], classOf[Int]) + ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), + new JInteger(attemptId)).asInstanceOf[TaskAttemptID] } catch { case exc: NoSuchMethodException => { - // failed, look for the new ctor that takes a TaskType (not available in 1.x) - val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType").asInstanceOf[Class[Enum[_]]] - val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(taskTypeClass, if(isMap) "MAP" else "REDUCE") + // If that failed, look for the new constructor that takes a TaskType (not available in 1.x) + val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType") + .asInstanceOf[Class[Enum[_]]] + val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke( + taskTypeClass, if(isMap) "MAP" else "REDUCE") val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass, classOf[Int], classOf[Int]) - ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), new - JInteger(attemptId)).asInstanceOf[TaskAttemptID] + ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), + new JInteger(attemptId)).asInstanceOf[TaskAttemptID] } } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 1e3f1ebfaf..5e465fa22c 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -20,13 +20,11 @@ package org.apache.spark import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet import akka.actor._ import akka.dispatch._ import akka.pattern.ask -import akka.remote._ import akka.util.Duration @@ -40,11 +38,12 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage -private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging { +private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster) + extends Actor with Logging { def receive = { case GetMapOutputStatuses(shuffleId: Int, requester: String) => logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester) - sender ! tracker.getSerializedLocations(shuffleId) + sender ! tracker.getSerializedMapOutputStatuses(shuffleId) case StopMapOutputTracker => logInfo("MapOutputTrackerActor stopped!") @@ -60,22 +59,19 @@ private[spark] class MapOutputTracker extends Logging { // Set to the MapOutputTrackerActor living on the driver var trackerActor: ActorRef = _ - private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. - private var epoch: Long = 0 - private val epochLock = new java.lang.Object + protected var epoch: Long = 0 + protected val epochLock = new java.lang.Object - // Cache a serialized version of the output statuses for each shuffle to send them out faster - var cacheEpoch = epoch - private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] - - val metadataCleaner = new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup) + private val metadataCleaner = + new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup) // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. - def askTracker(message: Any): Any = { + private def askTracker(message: Any): Any = { try { val future = trackerActor.ask(message)(timeout) return Await.result(future, timeout) @@ -86,50 +82,12 @@ private[spark] class MapOutputTracker extends Logging { } // Send a one-way message to the trackerActor, to which we expect it to reply with true. - def communicate(message: Any) { + private def communicate(message: Any) { if (askTracker(message) != true) { throw new SparkException("Error reply received from MapOutputTracker") } } - def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) { - throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") - } - } - - def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { - var array = mapStatuses(shuffleId) - array.synchronized { - array(mapId) = status - } - } - - def registerMapOutputs( - shuffleId: Int, - statuses: Array[MapStatus], - changeEpoch: Boolean = false) { - mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses) - if (changeEpoch) { - incrementEpoch() - } - } - - def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - var arrayOpt = mapStatuses.get(shuffleId) - if (arrayOpt.isDefined && arrayOpt.get != null) { - var array = arrayOpt.get - array.synchronized { - if (array(mapId) != null && array(mapId).location == bmAddress) { - array(mapId) = null - } - } - incrementEpoch() - } else { - throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") - } - } - // Remembers which map output locations are currently being fetched on a worker private val fetching = new HashSet[Int] @@ -168,7 +126,7 @@ private[spark] class MapOutputTracker extends Logging { try { val fetchedBytes = askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]] - fetchedStatuses = deserializeStatuses(fetchedBytes) + fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) } finally { @@ -194,9 +152,8 @@ private[spark] class MapOutputTracker extends Logging { } } - private def cleanup(cleanupTime: Long) { + protected def cleanup(cleanupTime: Long) { mapStatuses.clearOldValues(cleanupTime) - cachedSerializedStatuses.clearOldValues(cleanupTime) } def stop() { @@ -206,15 +163,7 @@ private[spark] class MapOutputTracker extends Logging { trackerActor = null } - // Called on master to increment the epoch number - def incrementEpoch() { - epochLock.synchronized { - epoch += 1 - logDebug("Increasing epoch to " + epoch) - } - } - - // Called on master or workers to get current epoch number + // Called to get current epoch number def getEpoch: Long = { epochLock.synchronized { return epoch @@ -228,14 +177,62 @@ private[spark] class MapOutputTracker extends Logging { epochLock.synchronized { if (newEpoch > epoch) { logInfo("Updating epoch to " + newEpoch + " and clearing cache") - // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] - mapStatuses.clear() epoch = newEpoch + mapStatuses.clear() + } + } + } +} + +private[spark] class MapOutputTrackerMaster extends MapOutputTracker { + + // Cache a serialized version of the output statuses for each shuffle to send them out faster + private var cacheEpoch = epoch + private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] + + def registerShuffle(shuffleId: Int, numMaps: Int) { + if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) { + throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") + } + } + + def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { + val array = mapStatuses(shuffleId) + array.synchronized { + array(mapId) = status + } + } + + def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) { + mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses) + if (changeEpoch) { + incrementEpoch() + } + } + + def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { + val arrayOpt = mapStatuses.get(shuffleId) + if (arrayOpt.isDefined && arrayOpt.get != null) { + val array = arrayOpt.get + array.synchronized { + if (array(mapId) != null && array(mapId).location == bmAddress) { + array(mapId) = null + } } + incrementEpoch() + } else { + throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") } } - def getSerializedLocations(shuffleId: Int): Array[Byte] = { + def incrementEpoch() { + epochLock.synchronized { + epoch += 1 + logDebug("Increasing epoch to " + epoch) + } + } + + def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = { var statuses: Array[MapStatus] = null var epochGotten: Long = -1 epochLock.synchronized { @@ -253,7 +250,7 @@ private[spark] class MapOutputTracker extends Logging { } // If we got here, we failed to find the serialized locations in the cache, so we pulled // out a snapshot of the locations as "locs"; let's serialize and return that - val bytes = serializeStatuses(statuses) + val bytes = MapOutputTracker.serializeMapStatuses(statuses) logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) // Add them into the table only if the epoch hasn't changed while we were working epochLock.synchronized { @@ -261,13 +258,31 @@ private[spark] class MapOutputTracker extends Logging { cachedSerializedStatuses(shuffleId) = bytes } } - return bytes + bytes + } + + protected override def cleanup(cleanupTime: Long) { + super.cleanup(cleanupTime) + cachedSerializedStatuses.clearOldValues(cleanupTime) } + override def stop() { + super.stop() + cachedSerializedStatuses.clear() + } + + override def updateEpoch(newEpoch: Long) { + // This might be called on the MapOutputTrackerMaster if we're running in local mode. + } +} + +private[spark] object MapOutputTracker { + private val LOG_BASE = 1.1 + // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. - private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = { + def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = { val out = new ByteArrayOutputStream val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) // Since statuses can be modified in parallel, sync on it @@ -278,18 +293,11 @@ private[spark] class MapOutputTracker extends Logging { out.toByteArray } - // Opposite of serializeStatuses. - def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = { + // Opposite of serializeMapStatuses. + def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = { val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes))) - objIn.readObject(). - // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present - // comment this out - nulls could be due to missing location ? - asInstanceOf[Array[MapStatus]] // .filter( _ != null ) + objIn.readObject().asInstanceOf[Array[MapStatus]] } -} - -private[spark] object MapOutputTracker { - private val LOG_BASE = 1.1 // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If // any of the statuses is null (indicating a missing location due to a failed mapper), diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f674cb397f..880b49e8ef 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -24,7 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.Map import scala.collection.generic.Growable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap @@ -51,25 +51,19 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor import org.apache.mesos.MesosNativeLibrary -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.LocalSparkCluster +import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, - ClusterScheduler} -import org.apache.spark.scheduler.local.LocalScheduler +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, + SparkDeploySchedulerBackend, ClusterScheduler, SimrSchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import org.apache.spark.storage.{StorageUtils, BlockManagerSource} -import org.apache.spark.ui.SparkUI -import org.apache.spark.util._ -import org.apache.spark.scheduler.StageInfo -import org.apache.spark.storage.RDDInfo -import org.apache.spark.storage.StorageStatus -import scala.Some +import org.apache.spark.scheduler.local.LocalScheduler import org.apache.spark.scheduler.StageInfo -import org.apache.spark.storage.RDDInfo -import org.apache.spark.storage.StorageStatus +import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils} +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, + TimeStampedHashMap, Utils} /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -125,7 +119,7 @@ class SparkContext( private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]] private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup) - // Initalize the Spark UI + // Initialize the Spark UI private[spark] val ui = new SparkUI(this) ui.bind() @@ -161,8 +155,10 @@ class SparkContext( val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters val SPARK_REGEX = """spark://(.*)""".r - //Regular expression for connection to Mesos cluster - val MESOS_REGEX = """(mesos://.*)""".r + // Regular expression for connection to Mesos cluster + val MESOS_REGEX = """mesos://(.*)""".r + // Regular expression for connection to Simr cluster + val SIMR_REGEX = """simr://(.*)""".r master match { case "local" => @@ -181,6 +177,12 @@ class SparkContext( scheduler.initialize(backend) scheduler + case SIMR_REGEX(simrUrl) => + val scheduler = new ClusterScheduler(this) + val backend = new SimrSchedulerBackend(scheduler, this, simrUrl) + scheduler.initialize(backend) + scheduler + case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. val memoryPerSlaveInt = memoryPerSlave.toInt @@ -213,25 +215,24 @@ class SparkContext( throw new SparkException("YARN mode not available ?", th) } } - val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem) + val backend = new CoarseGrainedSchedulerBackend(scheduler, this.env.actorSystem) scheduler.initialize(backend) scheduler - case _ => - if (MESOS_REGEX.findFirstIn(master).isEmpty) { - logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master)) - } + case MESOS_REGEX(mesosUrl) => MesosNativeLibrary.load() val scheduler = new ClusterScheduler(this) val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean - val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos:// val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName) + new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName) } else { - new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName) + new MesosSchedulerBackend(scheduler, this, mesosUrl, appName) } scheduler.initialize(backend) scheduler + + case _ => + throw new SparkException("Could not parse Master URL: '" + master + "'") } } taskScheduler.start() @@ -244,7 +245,7 @@ class SparkContext( /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = { val env = SparkEnv.get - val conf = env.hadoop.newConfiguration() + val conf = SparkHadoopUtil.get.newConfiguration() // Explicitly check for S3 environment variables if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) { @@ -254,8 +255,10 @@ class SparkContext( conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) } // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" - for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) { - conf.set(key.substring("spark.hadoop.".length), System.getProperty(key)) + Utils.getSystemProperties.foreach { case (key, value) => + if (key.startsWith("spark.hadoop.")) { + conf.set(key.substring("spark.hadoop.".length), value) + } } val bufferSize = System.getProperty("spark.buffer.size", "65536") conf.set("io.file.buffer.size", bufferSize) @@ -288,15 +291,46 @@ class SparkContext( Option(localProperties.get).map(_.getProperty(key)).getOrElse(null) /** Set a human readable description of the current job. */ + @deprecated("use setJobGroup", "0.8.1") def setJobDescription(value: String) { - setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value) + setJobGroup("", value) + } + + /** + * Assigns a group id to all the jobs started by this thread until the group id is set to a + * different value or cleared. + * + * Often, a unit of execution in an application consists of multiple Spark actions or jobs. + * Application programmers can use this method to group all those jobs together and give a + * group description. Once set, the Spark web UI will associate such jobs with this group. + * + * The application can also use [[org.apache.spark.SparkContext.cancelJobGroup]] to cancel all + * running jobs in this group. For example, + * {{{ + * // In the main thread: + * sc.setJobGroup("some_job_to_cancel", "some job description") + * sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count() + * + * // In a separate thread: + * sc.cancelJobGroup("some_job_to_cancel") + * }}} + */ + def setJobGroup(groupId: String, description: String) { + setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description) + setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId) + } + + /** Clear the job group id and its description. */ + def clearJobGroup() { + setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null) + setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null) } // Post init taskScheduler.postStartHook() - val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this) - val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this) + private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this) + private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this) def initDriverMetrics() { SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource) @@ -347,7 +381,7 @@ class SparkContext( minSplits: Int = defaultMinSplits ): RDD[(K, V)] = { // Add necessary security credentials to the JobConf before broadcasting it. - SparkEnv.get.hadoop.addCredentials(conf) + SparkHadoopUtil.get.addCredentials(conf) new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) } @@ -557,7 +591,8 @@ class SparkContext( val uri = new URI(path) val key = uri.getScheme match { case null | "file" => env.httpFileServer.addFile(new File(uri.getPath)) - case _ => path + case "local" => "file:" + uri.getPath + case _ => path } addedFiles(key) = System.currentTimeMillis @@ -651,12 +686,11 @@ class SparkContext( /** * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. + * filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. */ def addJar(path: String) { if (path == null) { - logWarning("null specified as parameter to addJar", - new SparkException("null specified as parameter to addJar")) + logWarning("null specified as parameter to addJar") } else { var key = "" if (path.contains("\\")) { @@ -665,8 +699,9 @@ class SparkContext( } else { val uri = new URI(path) key = uri.getScheme match { + // A JAR file which exists only on the driver node case null | "file" => - if (env.hadoop.isYarnMode()) { + if (SparkHadoopUtil.get.isYarnMode()) { // In order for this to work on yarn the user must specify the --addjars option to // the client to upload the file into the distributed cache to make it show up in the // current working directory. @@ -682,6 +717,9 @@ class SparkContext( } else { env.httpFileServer.addJar(new File(uri.getPath)) } + // A JAR file which exists locally on every worker node + case "local" => + "file:" + uri.getPath case _ => path } @@ -867,13 +905,19 @@ class SparkContext( callSite, allowLocal = false, resultHandler, - null) + localProperties.get) new SimpleFutureAction(waiter, resultFunc) } /** - * Cancel all jobs that have been scheduled or are running. + * Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]] + * for more information. */ + def cancelJobGroup(groupId: String) { + dagScheduler.cancelJobGroup(groupId) + } + + /** Cancel all jobs that have been scheduled or are running. */ def cancelAllJobs() { dagScheduler.cancelAllJobs() } @@ -895,9 +939,8 @@ class SparkContext( * prevent accidental overriding of checkpoint files in the existing directory. */ def setCheckpointDir(dir: String, useExisting: Boolean = false) { - val env = SparkEnv.get val path = new Path(dir) - val fs = path.getFileSystem(env.hadoop.newConfiguration()) + val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration()) if (!useExisting) { if (fs.exists(path)) { throw new Exception("Checkpoint directory '" + path + "' already exists.") @@ -934,7 +977,10 @@ class SparkContext( * various Spark features. */ object SparkContext { - val SPARK_JOB_DESCRIPTION = "spark.job.description" + + private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description" + + private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 29968c273c..ff2df8fb6a 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -25,13 +25,13 @@ import akka.remote.RemoteActorRefProvider import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.{BlockManagerMasterActor, BlockManager, BlockManagerMaster} import org.apache.spark.network.ConnectionManager import org.apache.spark.serializer.{Serializer, SerializerManager} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.api.python.PythonWorkerFactory +import com.google.common.collect.MapMaker /** * Holds all the runtime environment objects for a running Spark instance (either master or worker), @@ -58,18 +58,9 @@ class SparkEnv ( private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() - val hadoop = { - val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) - if(yarnMode) { - try { - Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil] - } catch { - case th: Throwable => throw new SparkException("Unable to load YARN support", th) - } - } else { - new SparkHadoopUtil - } - } + // A general, soft-reference map for metadata needed during HadoopRDD split computation + // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). + private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]() def stop() { pythonWorkers.foreach { case(key, worker) => worker.stop() } @@ -187,10 +178,14 @@ object SparkEnv extends Logging { // Have to assign trackerActor after initialization as MapOutputTrackerActor // requires the MapOutputTracker itself - val mapOutputTracker = new MapOutputTracker() + val mapOutputTracker = if (isDriver) { + new MapOutputTrackerMaster() + } else { + new MapOutputTracker() + } mapOutputTracker.trackerActor = registerOrLookup( "MapOutputTracker", - new MapOutputTrackerActor(mapOutputTracker)) + new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])) val shuffleFetcher = instantiateClass[ShuffleFetcher]( "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher") diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 2bab9d6e3d..103a1c2051 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -17,14 +17,14 @@ package org.apache.hadoop.mapred -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.fs.Path - +import java.io.IOException import java.text.SimpleDateFormat import java.text.NumberFormat -import java.io.IOException import java.util.Date +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path + import org.apache.spark.Logging import org.apache.spark.SerializableWritable @@ -36,7 +36,11 @@ import org.apache.spark.SerializableWritable * Saves the RDD using a JobConf, which should contain an output key class, an output value class, * a filename to write to, etc, exactly like in a Hadoop MapReduce job. */ -class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable { +private[apache] +class SparkHadoopWriter(@transient jobConf: JobConf) + extends Logging + with SparkHadoopMapRedUtil + with Serializable { private val now = new Date() private val conf = new SerializableWritable(jobConf) @@ -83,13 +87,11 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH } getOutputCommitter().setupTask(getTaskContext()) - writer = getOutputFormat().getRecordWriter( - fs, conf.value, outputName, Reporter.NULL) + writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL) } def write(key: AnyRef, value: AnyRef) { - if (writer!=null) { - //println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")") + if (writer != null) { writer.write(key, value) } else { throw new IOException("Writer is null, open() has not been called") @@ -179,6 +181,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH } } +private[apache] object SparkHadoopWriter { def createJobID(time: Date, id: Int): JobID = { val formatter = new SimpleDateFormat("yyyyMMddHHmm") diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 51584d686d..cae983ed4c 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.executor.TaskMetrics class TaskContext( - private[spark] val stageId: Int, + val stageId: Int, val partitionId: Int, val attemptId: Long, val runningLocally: Boolean = false, diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 5fd1fab580..043cb183ba 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -48,6 +48,19 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav */ def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel)) + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * This method blocks until all blocks are deleted. + */ + def unpersist(): JavaDoubleRDD = fromRDD(srdd.unpersist()) + + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * + * @param blocking Whether to block until all blocks are deleted. + */ + def unpersist(blocking: Boolean): JavaDoubleRDD = fromRDD(srdd.unpersist(blocking)) + // first() has to be overriden here in order for its return type to be Double instead of Object. override def first(): Double = srdd.first() @@ -81,6 +94,17 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav fromRDD(srdd.coalesce(numPartitions, shuffle)) /** + * Return a new RDD that has exactly numPartitions partitions. + * + * Can increase or decrease the level of parallelism in this RDD. Internally, this uses + * a shuffle to redistribute data. + * + * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, + * which can avoid performing a shuffle. + */ + def repartition(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.repartition(numPartitions)) + + /** * Return an RDD with the elements from `this` that are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index a6518abf45..2142fd7327 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -65,6 +65,19 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif def persist(newLevel: StorageLevel): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.persist(newLevel)) + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * This method blocks until all blocks are deleted. + */ + def unpersist(): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist()) + + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * + * @param blocking Whether to block until all blocks are deleted. + */ + def unpersist(blocking: Boolean): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist(blocking)) + // Transformations (return a new RDD) /** @@ -95,6 +108,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif fromRDD(rdd.coalesce(numPartitions, shuffle)) /** + * Return a new RDD that has exactly numPartitions partitions. + * + * Can increase or decrease the level of parallelism in this RDD. Internally, this uses + * a shuffle to redistribute data. + * + * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, + * which can avoid performing a shuffle. + */ + def repartition(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.repartition(numPartitions)) + + /** * Return a sampled subset of this RDD. */ def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] = @@ -598,4 +622,15 @@ object JavaPairRDD { new JavaPairRDD[K, V](rdd) implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd + + + /** Convert a JavaRDD of key-value pairs to JavaPairRDD. */ + def fromJavaRDD[K, V](rdd: JavaRDD[(K, V)]): JavaPairRDD[K, V] = { + implicit val cmk: ClassManifest[K] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] + implicit val cmv: ClassManifest[V] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] + new JavaPairRDD[K, V](rdd.rdd) + } + } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index eec58abdd6..3b359a8fd6 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -41,9 +41,17 @@ JavaRDDLike[T, JavaRDD[T]] { /** * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * This method blocks until all blocks are deleted. */ def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist()) + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * + * @param blocking Whether to block until all blocks are deleted. + */ + def unpersist(blocking: Boolean): JavaRDD[T] = wrapRDD(rdd.unpersist(blocking)) + // Transformations (return a new RDD) /** @@ -74,6 +82,17 @@ JavaRDDLike[T, JavaRDD[T]] { rdd.coalesce(numPartitions, shuffle) /** + * Return a new RDD that has exactly numPartitions partitions. + * + * Can increase or decrease the level of parallelism in this RDD. Internally, this uses + * a shuffle to redistribute data. + * + * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, + * which can avoid performing a shuffle. + */ + def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions) + + /** * Return a sampled subset of this RDD. */ def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] = diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java index 4830067f7a..3e85052cd0 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java @@ -18,8 +18,6 @@ package org.apache.spark.api.java.function; -import scala.runtime.AbstractFunction1; - import java.io.Serializable; /** @@ -27,11 +25,7 @@ import java.io.Serializable; */ // DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is // overloaded for both FlatMapFunction and DoubleFlatMapFunction. -public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>> +public abstract class DoubleFlatMapFunction<T> extends WrappedFunction1<T, Iterable<Double>> implements Serializable { - - public abstract Iterable<Double> call(T t); - - @Override - public final Iterable<Double> apply(T t) { return call(t); } + // Intentionally left blank } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java index db34cd190a..5e9b8c48b8 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java @@ -18,8 +18,6 @@ package org.apache.spark.api.java.function; -import scala.runtime.AbstractFunction1; - import java.io.Serializable; /** @@ -29,6 +27,5 @@ import java.io.Serializable; // are overloaded for both Function and DoubleFunction. public abstract class DoubleFunction<T> extends WrappedFunction1<T, Double> implements Serializable { - - public abstract Double call(T t) throws Exception; + // Intentionally left blank } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala index 158539a846..2dfda8b09a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala +++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala @@ -21,8 +21,5 @@ package org.apache.spark.api.java.function * A function that returns zero or more output records from each input record. */ abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] { - @throws(classOf[Exception]) - def call(x: T) : java.lang.Iterable[R] - def elementType() : ClassManifest[R] = ClassManifest.Any.asInstanceOf[ClassManifest[R]] } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala index 5ef6a814f5..528e1c0a7c 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala +++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala @@ -21,8 +21,5 @@ package org.apache.spark.api.java.function * A function that takes two inputs and returns zero or more output records. */ abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] { - @throws(classOf[Exception]) - def call(a: A, b:B) : java.lang.Iterable[C] - def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]] } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function.java b/core/src/main/scala/org/apache/spark/api/java/function/Function.java index b9070cfd83..ce368ee01b 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/Function.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/Function.java @@ -19,7 +19,6 @@ package org.apache.spark.api.java.function; import scala.reflect.ClassManifest; import scala.reflect.ClassManifest$; -import scala.runtime.AbstractFunction1; import java.io.Serializable; @@ -30,8 +29,6 @@ import java.io.Serializable; * when mapping RDDs of other types. */ public abstract class Function<T, R> extends WrappedFunction1<T, R> implements Serializable { - public abstract R call(T t) throws Exception; - public ClassManifest<R> returnType() { return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class); } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java index d4c9154869..44ad559d48 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java @@ -19,7 +19,6 @@ package org.apache.spark.api.java.function; import scala.reflect.ClassManifest; import scala.reflect.ClassManifest$; -import scala.runtime.AbstractFunction2; import java.io.Serializable; @@ -29,8 +28,6 @@ import java.io.Serializable; public abstract class Function2<T1, T2, R> extends WrappedFunction2<T1, T2, R> implements Serializable { - public abstract R call(T1 t1, T2 t2) throws Exception; - public ClassManifest<R> returnType() { return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class); } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function3.java b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java new file mode 100644 index 0000000000..ac6178924a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import scala.reflect.ClassManifest; +import scala.reflect.ClassManifest$; +import scala.runtime.AbstractFunction2; + +import java.io.Serializable; + +/** + * A three-argument function that takes arguments of type T1, T2 and T3 and returns an R. + */ +public abstract class Function3<T1, T2, T3, R> extends WrappedFunction3<T1, T2, T3, R> + implements Serializable { + + public ClassManifest<R> returnType() { + return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class); + } +} + diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java index c0e5544b7d..6d76a8f970 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java @@ -20,7 +20,6 @@ package org.apache.spark.api.java.function; import scala.Tuple2; import scala.reflect.ClassManifest; import scala.reflect.ClassManifest$; -import scala.runtime.AbstractFunction1; import java.io.Serializable; @@ -34,8 +33,6 @@ public abstract class PairFlatMapFunction<T, K, V> extends WrappedFunction1<T, Iterable<Tuple2<K, V>>> implements Serializable { - public abstract Iterable<Tuple2<K, V>> call(T t) throws Exception; - public ClassManifest<K> keyType() { return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class); } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java index 40480fe8e8..ede7ceefb5 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java @@ -20,7 +20,6 @@ package org.apache.spark.api.java.function; import scala.Tuple2; import scala.reflect.ClassManifest; import scala.reflect.ClassManifest$; -import scala.runtime.AbstractFunction1; import java.io.Serializable; @@ -29,12 +28,9 @@ import java.io.Serializable; */ // PairFunction does not extend Function because some UDF functions, like map, // are overloaded for both Function and PairFunction. -public abstract class PairFunction<T, K, V> - extends WrappedFunction1<T, Tuple2<K, V>> +public abstract class PairFunction<T, K, V> extends WrappedFunction1<T, Tuple2<K, V>> implements Serializable { - public abstract Tuple2<K, V> call(T t) throws Exception; - public ClassManifest<K> keyType() { return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class); } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala new file mode 100644 index 0000000000..d314dbdf1d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function + +import scala.runtime.AbstractFunction3 + +/** + * Subclass of Function3 for ease of calling from Java. The main thing it does is re-expose the + * apply() method as call() and declare that it can throw Exception (since AbstractFunction3.apply + * isn't marked to allow that). + */ +private[spark] abstract class WrappedFunction3[T1, T2, T3, R] + extends AbstractFunction3[T1, T2, T3, R] { + @throws(classOf[Exception]) + def call(t1: T1, t2: T2, t3: T3): R + + final def apply(t1: T1, t2: T2, t3: T3): R = call(t1, t2, t3) +} + diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 1f8ad688a6..12b4d94a56 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -308,7 +308,7 @@ private class BytesToString extends org.apache.spark.api.java.function.Function[ * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it * collects a list of pickled strings that we pass to Python through a socket. */ -class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) +private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) extends AccumulatorParam[JList[Array[Byte]]] { Utils.checkHost(serverHost, "Expected hostname") diff --git a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala deleted file mode 100644 index b6c484bfe1..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala +++ /dev/null @@ -1,1058 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.io._ -import java.net._ -import java.util.{BitSet, Comparator, Timer, TimerTask, UUID} -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable.{ListBuffer, Map, Set} -import scala.math - -import org.apache.spark._ -import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} -import org.apache.spark.util.Utils - -private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) - extends Broadcast[T](id) - with Logging - with Serializable { - - def value = value_ - - def blockId = BroadcastBlockId(id) - - MultiTracker.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } - - @transient var arrayOfBlocks: Array[BroadcastBlock] = null - @transient var hasBlocksBitVector: BitSet = null - @transient var numCopiesSent: Array[Int] = null - @transient var totalBytes = -1 - @transient var totalBlocks = -1 - @transient var hasBlocks = new AtomicInteger(0) - - // Used ONLY by driver to track how many unique blocks have been sent out - @transient var sentBlocks = new AtomicInteger(0) - - @transient var listenPortLock = new Object - @transient var guidePortLock = new Object - @transient var totalBlocksLock = new Object - - @transient var listOfSources = ListBuffer[SourceInfo]() - - @transient var serveMR: ServeMultipleRequests = null - - // Used only in driver - @transient var guideMR: GuideMultipleRequests = null - - // Used only in Workers - @transient var ttGuide: TalkToGuide = null - - @transient var hostAddress = Utils.localIpAddress - @transient var listenPort = -1 - @transient var guidePort = -1 - - @transient var stopBroadcast = false - - // Must call this after all the variables have been created/initialized - if (!isLocal) { - sendBroadcast() - } - - def sendBroadcast() { - logInfo("Local host address: " + hostAddress) - - // Create a variableInfo object and store it in valueInfos - var variableInfo = MultiTracker.blockifyObject(value_) - - // Prepare the value being broadcasted - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks.set(variableInfo.totalBlocks) - - // Guide has all the blocks - hasBlocksBitVector = new BitSet(totalBlocks) - hasBlocksBitVector.set(0, totalBlocks) - - // Guide still hasn't sent any block - numCopiesSent = new Array[Int](totalBlocks) - - guideMR = new GuideMultipleRequests - guideMR.setDaemon(true) - guideMR.start() - logInfo("GuideMultipleRequests started...") - - // Must always come AFTER guideMR is created - while (guidePort == -1) { - guidePortLock.synchronized { guidePortLock.wait() } - } - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - // Must always come AFTER serveMR is created - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Must always come AFTER listenPort is created - val driverSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) - hasBlocksBitVector.synchronized { - driverSource.hasBlocksBitVector = hasBlocksBitVector - } - - // In the beginning, this is the only known source to Guide - listOfSources += driverSource - - // Register with the Tracker - MultiTracker.registerBroadcast(id, - SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) - } - - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - MultiTracker.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { - case Some(x) => - value_ = x.asInstanceOf[T] - - case None => - logInfo("Started reading broadcast variable " + id) - // Initializing everything because driver will only send null/0 values - // Only the 1st worker in a node can be here. Others will get from cache - initializeWorkerVariables() - - logInfo("Local host address: " + hostAddress) - - // Start local ServeMultipleRequests thread first - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast(id) - if (receptionSucceeded) { - value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } else { - logError("Reading broadcast variable " + id + " failed") - } - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - - // Initialize variables in the worker node. Driver sends everything as 0/null - private def initializeWorkerVariables() { - arrayOfBlocks = null - hasBlocksBitVector = null - numCopiesSent = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = new AtomicInteger(0) - - listenPortLock = new Object - totalBlocksLock = new Object - - serveMR = null - ttGuide = null - - hostAddress = Utils.localIpAddress - listenPort = -1 - - listOfSources = ListBuffer[SourceInfo]() - - stopBroadcast = false - } - - private def getLocalSourceInfo: SourceInfo = { - // Wait till hostName and listenPort are OK - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Wait till totalBlocks and totalBytes are OK - while (totalBlocks == -1) { - totalBlocksLock.synchronized { totalBlocksLock.wait() } - } - - var localSourceInfo = SourceInfo( - hostAddress, listenPort, totalBlocks, totalBytes) - - localSourceInfo.hasBlocks = hasBlocks.get - - hasBlocksBitVector.synchronized { - localSourceInfo.hasBlocksBitVector = hasBlocksBitVector - } - - return localSourceInfo - } - - // Add new SourceInfo to the listOfSources. Update if it exists already. - // Optimizing just by OR-ing the BitVectors was BAD for performance - private def addToListOfSources(newSourceInfo: SourceInfo) { - listOfSources.synchronized { - if (listOfSources.contains(newSourceInfo)) { - listOfSources = listOfSources - newSourceInfo - } - listOfSources += newSourceInfo - } - } - - private def addToListOfSources(newSourceInfos: ListBuffer[SourceInfo]) { - newSourceInfos.foreach { newSourceInfo => - addToListOfSources(newSourceInfo) - } - } - - class TalkToGuide(gInfo: SourceInfo) - extends Thread with Logging { - override def run() { - - // Keep exchaning information until all blocks have been received - while (hasBlocks.get < totalBlocks) { - talkOnce - Thread.sleep(MultiTracker.ranGen.nextInt( - MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + - MultiTracker.MinKnockInterval) - } - - // Talk one more time to let the Guide know of reception completion - talkOnce - } - - // Connect to Guide and send this worker's information - private def talkOnce { - var clientSocketToGuide: Socket = null - var oosGuide: ObjectOutputStream = null - var oisGuide: ObjectInputStream = null - - clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort) - oosGuide = new ObjectOutputStream(clientSocketToGuide.getOutputStream) - oosGuide.flush() - oisGuide = new ObjectInputStream(clientSocketToGuide.getInputStream) - - // Send local information - oosGuide.writeObject(getLocalSourceInfo) - oosGuide.flush() - - // Receive source information from Guide - var suitableSources = - oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]] - logDebug("Received suitableSources from Driver " + suitableSources) - - addToListOfSources(suitableSources) - - oisGuide.close() - oosGuide.close() - clientSocketToGuide.close() - } - } - - def receiveBroadcast(variableID: Long): Boolean = { - val gInfo = MultiTracker.getGuideInfo(variableID) - - if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { - return false - } - - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Setup initial states of variables - totalBlocks = gInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) - hasBlocksBitVector = new BitSet(totalBlocks) - numCopiesSent = new Array[Int](totalBlocks) - totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } - totalBytes = gInfo.totalBytes - - // Start ttGuide to periodically talk to the Guide - var ttGuide = new TalkToGuide(gInfo) - ttGuide.setDaemon(true) - ttGuide.start() - logInfo("TalkToGuide started...") - - // Start pController to run TalkToPeer threads - var pcController = new PeerChatterController - pcController.setDaemon(true) - pcController.start() - logInfo("PeerChatterController started...") - - // FIXME: Must fix this. This might never break if broadcast fails. - // We should be able to break and send false. Also need to kill threads - while (hasBlocks.get < totalBlocks) { - Thread.sleep(MultiTracker.MaxKnockInterval) - } - - return true - } - - class PeerChatterController - extends Thread with Logging { - private var peersNowTalking = ListBuffer[SourceInfo]() - // TODO: There is a possible bug with blocksInRequestBitVector when a - // certain bit is NOT unset upon failure resulting in an infinite loop. - private var blocksInRequestBitVector = new BitSet(totalBlocks) - - override def run() { - var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots) - - while (hasBlocks.get < totalBlocks) { - var numThreadsToCreate = 0 - listOfSources.synchronized { - numThreadsToCreate = math.min(listOfSources.size, MultiTracker.MaxChatSlots) - - threadPool.getActiveCount - } - - while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) { - var peerToTalkTo = pickPeerToTalkToRandom - - if (peerToTalkTo != null) - logDebug("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector) - else - logDebug("No peer chosen...") - - if (peerToTalkTo != null) { - threadPool.execute(new TalkToPeer(peerToTalkTo)) - - // Add to peersNowTalking. Remove in the thread. We have to do this - // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once - peersNowTalking.synchronized { peersNowTalking += peerToTalkTo } - } - - numThreadsToCreate = numThreadsToCreate - 1 - } - - // Sleep for a while before starting some more threads - Thread.sleep(MultiTracker.MinKnockInterval) - } - // Shutdown the thread pool - threadPool.shutdown() - } - - // Right now picking the one that has the most blocks this peer wants - // Also picking peer randomly if no one has anything interesting - private def pickPeerToTalkToRandom: SourceInfo = { - var curPeer: SourceInfo = null - var curMax = 0 - - logDebug("Picking peers to talk to...") - - // Find peers that are not connected right now - var peersNotInUse = ListBuffer[SourceInfo]() - listOfSources.synchronized { - peersNowTalking.synchronized { - peersNotInUse = listOfSources -- peersNowTalking - } - } - - // Select the peer that has the most blocks that this receiver does not - peersNotInUse.foreach { eachSource => - var tempHasBlocksBitVector: BitSet = null - hasBlocksBitVector.synchronized { - tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - tempHasBlocksBitVector.flip(0, tempHasBlocksBitVector.size) - tempHasBlocksBitVector.and(eachSource.hasBlocksBitVector) - - if (tempHasBlocksBitVector.cardinality > curMax) { - curPeer = eachSource - curMax = tempHasBlocksBitVector.cardinality - } - } - - // Always picking randomly - if (curPeer == null && peersNotInUse.size > 0) { - // Pick uniformly the i'th required peer - var i = MultiTracker.ranGen.nextInt(peersNotInUse.size) - - var peerIter = peersNotInUse.iterator - curPeer = peerIter.next - - while (i > 0) { - curPeer = peerIter.next - i = i - 1 - } - } - - return curPeer - } - - // Picking peer with the weight of rare blocks it has - private def pickPeerToTalkToRarestFirst: SourceInfo = { - // Find peers that are not connected right now - var peersNotInUse = ListBuffer[SourceInfo]() - listOfSources.synchronized { - peersNowTalking.synchronized { - peersNotInUse = listOfSources -- peersNowTalking - } - } - - // Count the number of copies of each block in the neighborhood - var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0) - - listOfSources.synchronized { - listOfSources.foreach { eachSource => - for (i <- 0 until totalBlocks) { - numCopiesPerBlock(i) += - ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 ) - } - } - } - - // A block is considered rare if there are at most 2 copies of that block - // This CONSTANT could be a function of the neighborhood size - var rareBlocksIndices = ListBuffer[Int]() - for (i <- 0 until totalBlocks) { - if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) { - rareBlocksIndices += i - } - } - - // Find peers with rare blocks - var peersWithRareBlocks = ListBuffer[(SourceInfo, Int)]() - var totalRareBlocks = 0 - - peersNotInUse.foreach { eachPeer => - var hasRareBlocks = 0 - rareBlocksIndices.foreach { rareBlock => - if (eachPeer.hasBlocksBitVector.get(rareBlock)) { - hasRareBlocks += 1 - } - } - - if (hasRareBlocks > 0) { - peersWithRareBlocks += ((eachPeer, hasRareBlocks)) - } - totalRareBlocks += hasRareBlocks - } - - // Select a peer from peersWithRareBlocks based on weight calculated from - // unique rare blocks - var selectedPeerToTalkTo: SourceInfo = null - - if (peersWithRareBlocks.size > 0) { - // Sort the peers based on how many rare blocks they have - peersWithRareBlocks.sortBy(_._2) - - var randomNumber = MultiTracker.ranGen.nextDouble - var tempSum = 0.0 - - var i = 0 - do { - tempSum += (1.0 * peersWithRareBlocks(i)._2 / totalRareBlocks) - if (tempSum >= randomNumber) { - selectedPeerToTalkTo = peersWithRareBlocks(i)._1 - } - i += 1 - } while (i < peersWithRareBlocks.size && selectedPeerToTalkTo == null) - } - - if (selectedPeerToTalkTo == null) { - selectedPeerToTalkTo = pickPeerToTalkToRandom - } - - return selectedPeerToTalkTo - } - - class TalkToPeer(peerToTalkTo: SourceInfo) - extends Thread with Logging { - private var peerSocketToSource: Socket = null - private var oosSource: ObjectOutputStream = null - private var oisSource: ObjectInputStream = null - - override def run() { - // TODO: There is a possible bug here regarding blocksInRequestBitVector - var blockToAskFor = -1 - - // Setup the timeout mechanism - var timeOutTask = new TimerTask { - override def run() { - cleanUpConnections() - } - } - - var timeOutTimer = new Timer - timeOutTimer.schedule(timeOutTask, MultiTracker.MaxKnockInterval) - - logInfo("TalkToPeer started... => " + peerToTalkTo) - - try { - // Connect to the source - peerSocketToSource = - new Socket(peerToTalkTo.hostAddress, peerToTalkTo.listenPort) - oosSource = - new ObjectOutputStream(peerSocketToSource.getOutputStream) - oosSource.flush() - oisSource = - new ObjectInputStream(peerSocketToSource.getInputStream) - - // Receive latest SourceInfo from peerToTalkTo - var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo] - // Update listOfSources - addToListOfSources(newPeerToTalkTo) - - // Turn the timer OFF, if the sender responds before timeout - timeOutTimer.cancel() - - // Send the latest SourceInfo - oosSource.writeObject(getLocalSourceInfo) - oosSource.flush() - - var keepReceiving = true - - while (hasBlocks.get < totalBlocks && keepReceiving) { - blockToAskFor = - pickBlockRandom(newPeerToTalkTo.hasBlocksBitVector) - - // No block to request - if (blockToAskFor < 0) { - // Nothing to receive from newPeerToTalkTo - keepReceiving = false - } else { - // Let other threads know that blockToAskFor is being requested - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor) - } - - // Start with sending the blockID - oosSource.writeObject(blockToAskFor) - oosSource.flush() - - // CHANGED: Driver might send some other block than the one - // requested to ensure fast spreading of all blocks. - val recvStartTime = System.currentTimeMillis - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - val receptionTime = (System.currentTimeMillis - recvStartTime) - - logDebug("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.") - - if (!hasBlocksBitVector.get(bcBlock.blockID)) { - arrayOfBlocks(bcBlock.blockID) = bcBlock - - // Update the hasBlocksBitVector first - hasBlocksBitVector.synchronized { - hasBlocksBitVector.set(bcBlock.blockID) - hasBlocks.getAndIncrement - } - - // Some block(may NOT be blockToAskFor) has arrived. - // In any case, blockToAskFor is not in request any more - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor, false) - } - - // Reset blockToAskFor to -1. Else it will be considered missing - blockToAskFor = -1 - } - - // Send the latest SourceInfo - oosSource.writeObject(getLocalSourceInfo) - oosSource.flush() - } - } - } catch { - // EOFException is expected to happen because sender can break - // connection due to timeout - case eofe: java.io.EOFException => { } - case e: Exception => { - logError("TalktoPeer had a " + e) - // FIXME: Remove 'newPeerToTalkTo' from listOfSources - // We probably should have the following in some form, but not - // really here. This exception can happen if the sender just breaks connection - // listOfSources.synchronized { - // logInfo("Exception in TalkToPeer. Removing source: " + peerToTalkTo) - // listOfSources = listOfSources - peerToTalkTo - // } - } - } finally { - // blockToAskFor != -1 => there was an exception - if (blockToAskFor != -1) { - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor, false) - } - } - - cleanUpConnections() - } - } - - // Right now it picks a block uniformly that this peer does not have - private def pickBlockRandom(txHasBlocksBitVector: BitSet): Int = { - var needBlocksBitVector: BitSet = null - - // Blocks already present - hasBlocksBitVector.synchronized { - needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - - // Include blocks already in transmission ONLY IF - // MultiTracker.EndGameFraction has NOT been achieved - if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { - blocksInRequestBitVector.synchronized { - needBlocksBitVector.or(blocksInRequestBitVector) - } - } - - // Find blocks that are neither here nor in transit - needBlocksBitVector.flip(0, needBlocksBitVector.size) - - // Blocks that should/can be requested - needBlocksBitVector.and(txHasBlocksBitVector) - - if (needBlocksBitVector.cardinality == 0) { - return -1 - } else { - // Pick uniformly the i'th required block - var i = MultiTracker.ranGen.nextInt(needBlocksBitVector.cardinality) - var pickedBlockIndex = needBlocksBitVector.nextSetBit(0) - - while (i > 0) { - pickedBlockIndex = - needBlocksBitVector.nextSetBit(pickedBlockIndex + 1) - i -= 1 - } - - return pickedBlockIndex - } - } - - // Pick the block that seems to be the rarest across sources - private def pickBlockRarestFirst(txHasBlocksBitVector: BitSet): Int = { - var needBlocksBitVector: BitSet = null - - // Blocks already present - hasBlocksBitVector.synchronized { - needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - - // Include blocks already in transmission ONLY IF - // MultiTracker.EndGameFraction has NOT been achieved - if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { - blocksInRequestBitVector.synchronized { - needBlocksBitVector.or(blocksInRequestBitVector) - } - } - - // Find blocks that are neither here nor in transit - needBlocksBitVector.flip(0, needBlocksBitVector.size) - - // Blocks that should/can be requested - needBlocksBitVector.and(txHasBlocksBitVector) - - if (needBlocksBitVector.cardinality == 0) { - return -1 - } else { - // Count the number of copies for each block across all sources - var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0) - - listOfSources.synchronized { - listOfSources.foreach { eachSource => - for (i <- 0 until totalBlocks) { - numCopiesPerBlock(i) += - ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 ) - } - } - } - - // Find the minimum - var minVal = Integer.MAX_VALUE - for (i <- 0 until totalBlocks) { - if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) < minVal) { - minVal = numCopiesPerBlock(i) - } - } - - // Find the blocks with the least copies that this peer does not have - var minBlocksIndices = ListBuffer[Int]() - for (i <- 0 until totalBlocks) { - if (needBlocksBitVector.get(i) && numCopiesPerBlock(i) == minVal) { - minBlocksIndices += i - } - } - - // Now select a random index from minBlocksIndices - if (minBlocksIndices.size == 0) { - return -1 - } else { - // Pick uniformly the i'th index - var i = MultiTracker.ranGen.nextInt(minBlocksIndices.size) - return minBlocksIndices(i) - } - } - } - - private def cleanUpConnections() { - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (peerSocketToSource != null) { - peerSocketToSource.close() - } - - // Delete from peersNowTalking - peersNowTalking.synchronized { peersNowTalking -= peerToTalkTo } - } - } - } - - class GuideMultipleRequests - extends Thread with Logging { - // Keep track of sources that have completed reception - private var setOfCompletedSources = Set[SourceInfo]() - - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) - guidePort = serverSocket.getLocalPort - logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - - guidePortLock.synchronized { guidePortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - // Stop broadcast if at least one worker has connected and - // everyone connected so far are done. Comparing with - // listOfSources.size - 1, because it includes the Guide itself - listOfSources.synchronized { - setOfCompletedSources.synchronized { - if (listOfSources.size > 1 && - setOfCompletedSources.size == listOfSources.size - 1) { - stopBroadcast = true - logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.") - } - } - } - } - } - if (clientSocket != null) { - logDebug("Guide: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new GuideSingleRequest(clientSocket)) - } catch { - // In failure, close the socket here; else, thread will close it - case ioe: IOException => { - clientSocket.close() - } - } - } - } - - // Shutdown the thread pool - threadPool.shutdown() - - logInfo("Sending stopBroadcast notifications...") - sendStopBroadcastNotifications - - MultiTracker.unregisterBroadcast(id) - } finally { - if (serverSocket != null) { - logInfo("GuideMultipleRequests now stopping...") - serverSocket.close() - } - } - } - - private def sendStopBroadcastNotifications() { - listOfSources.synchronized { - listOfSources.foreach { sourceInfo => - - var guideSocketToSource: Socket = null - var gosSource: ObjectOutputStream = null - var gisSource: ObjectInputStream = null - - try { - // Connect to the source - guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) - gosSource.flush() - gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) - - // Throw away whatever comes in - gisSource.readObject.asInstanceOf[SourceInfo] - - // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast - gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast)) - gosSource.flush() - } catch { - case e: Exception => { - logError("sendStopBroadcastNotifications had a " + e) - } - } finally { - if (gisSource != null) { - gisSource.close() - } - if (gosSource != null) { - gosSource.close() - } - if (guideSocketToSource != null) { - guideSocketToSource.close() - } - } - } - } - } - - class GuideSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var sourceInfo: SourceInfo = null - private var selectedSources: ListBuffer[SourceInfo] = null - - override def run() { - try { - logInfo("new GuideSingleRequest is running") - // Connecting worker is sending in its information - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Select a suitable source and send it back to the worker - selectedSources = selectSuitableSources(sourceInfo) - logDebug("Sending selectedSources:" + selectedSources) - oos.writeObject(selectedSources) - oos.flush() - - // Add this source to the listOfSources - addToListOfSources(sourceInfo) - } catch { - case e: Exception => { - // Assuming exception caused by receiver failure: remove - if (listOfSources != null) { - listOfSources.synchronized { listOfSources -= sourceInfo } - } - } - } finally { - logInfo("GuideSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - // Randomly select some sources to send back - private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = { - var selectedSources = ListBuffer[SourceInfo]() - - // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true' - // then add skipSourceInfo to setOfCompletedSources. Return blank. - if (skipSourceInfo.hasBlocks == totalBlocks) { - setOfCompletedSources.synchronized { setOfCompletedSources += skipSourceInfo } - return selectedSources - } - - listOfSources.synchronized { - if (listOfSources.size <= MultiTracker.MaxPeersInGuideResponse) { - selectedSources = listOfSources.clone - } else { - var picksLeft = MultiTracker.MaxPeersInGuideResponse - var alreadyPicked = new BitSet(listOfSources.size) - - while (picksLeft > 0) { - var i = -1 - - do { - i = MultiTracker.ranGen.nextInt(listOfSources.size) - } while (alreadyPicked.get(i)) - - var peerIter = listOfSources.iterator - var curPeer = peerIter.next - - // Set the BitSet before i is decremented - alreadyPicked.set(i) - - while (i > 0) { - curPeer = peerIter.next - i = i - 1 - } - - selectedSources += curPeer - - picksLeft = picksLeft - 1 - } - } - } - - // Remove the receiving source (if present) - selectedSources = selectedSources - skipSourceInfo - - return selectedSources - } - } - } - - class ServeMultipleRequests - extends Thread with Logging { - // Server at most MultiTracker.MaxChatSlots peers - var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots) - - override def run() { - var serverSocket = new ServerSocket(0) - listenPort = serverSocket.getLocalPort - - logInfo("ServeMultipleRequests started with " + serverSocket) - - listenPortLock.synchronized { listenPortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { } - } - if (clientSocket != null) { - logDebug("Serve: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new ServeSingleRequest(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ServeMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ServeSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - logInfo("new ServeSingleRequest is running") - - override def run() { - try { - // Send latest local SourceInfo to the receiver - // In the case of receiver timeout and connection close, this will - // throw a java.net.SocketException: Broken pipe - oos.writeObject(getLocalSourceInfo) - oos.flush() - - // Receive latest SourceInfo from the receiver - var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) { - stopBroadcast = true - } else { - addToListOfSources(rxSourceInfo) - } - - val startTime = System.currentTimeMillis - var curTime = startTime - var keepSending = true - var numBlocksToSend = MultiTracker.MaxChatBlocks - - while (!stopBroadcast && keepSending && numBlocksToSend > 0) { - // Receive which block to send - var blockToSend = ois.readObject.asInstanceOf[Int] - - // If it is driver AND at least one copy of each block has not been - // sent out already, MODIFY blockToSend - if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) { - blockToSend = sentBlocks.getAndIncrement - } - - // Send the block - sendBlock(blockToSend) - rxSourceInfo.hasBlocksBitVector.set(blockToSend) - - numBlocksToSend -= 1 - - // Receive latest SourceInfo from the receiver - rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] - logDebug("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector) - addToListOfSources(rxSourceInfo) - - curTime = System.currentTimeMillis - // Revoke sending only if there is anyone waiting in the queue - if (curTime - startTime >= MultiTracker.MaxChatTime && - threadPool.getQueue.size > 0) { - keepSending = false - } - } - } catch { - case e: Exception => logError("ServeSingleRequest had a " + e) - } finally { - logInfo("ServeSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - private def sendBlock(blockToSend: Int) { - try { - oos.writeObject(arrayOfBlocks(blockToSend)) - oos.flush() - } catch { - case e: Exception => logError("sendBlock had a " + e) - } - logDebug("Sent block: " + blockToSend + " to " + clientSocket) - } - } - } -} - -private[spark] class BitTorrentBroadcastFactory -extends BroadcastFactory { - def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new BitTorrentBroadcast[T](value_, isLocal, id) - - def stop() { MultiTracker.stop() } -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala b/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala deleted file mode 100644 index 21ec94659e..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala +++ /dev/null @@ -1,410 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.io._ -import java.net._ -import java.util.Random - -import scala.collection.mutable.Map - -import org.apache.spark._ -import org.apache.spark.util.Utils - -private object MultiTracker -extends Logging { - - // Tracker Messages - val REGISTER_BROADCAST_TRACKER = 0 - val UNREGISTER_BROADCAST_TRACKER = 1 - val FIND_BROADCAST_TRACKER = 2 - - // Map to keep track of guides of ongoing broadcasts - var valueToGuideMap = Map[Long, SourceInfo]() - - // Random number generator - var ranGen = new Random - - private var initialized = false - private var _isDriver = false - - private var stopBroadcast = false - - private var trackMV: TrackMultipleValues = null - - def initialize(__isDriver: Boolean) { - synchronized { - if (!initialized) { - _isDriver = __isDriver - - if (isDriver) { - trackMV = new TrackMultipleValues - trackMV.setDaemon(true) - trackMV.start() - - // Set DriverHostAddress to the driver's IP address for the slaves to read - System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress) - } - - initialized = true - } - } - } - - def stop() { - stopBroadcast = true - } - - // Load common parameters - private var DriverHostAddress_ = System.getProperty( - "spark.MultiTracker.DriverHostAddress", "") - private var DriverTrackerPort_ = System.getProperty( - "spark.broadcast.driverTrackerPort", "11111").toInt - private var BlockSize_ = System.getProperty( - "spark.broadcast.blockSize", "4096").toInt * 1024 - private var MaxRetryCount_ = System.getProperty( - "spark.broadcast.maxRetryCount", "2").toInt - - private var TrackerSocketTimeout_ = System.getProperty( - "spark.broadcast.trackerSocketTimeout", "50000").toInt - private var ServerSocketTimeout_ = System.getProperty( - "spark.broadcast.serverSocketTimeout", "10000").toInt - - private var MinKnockInterval_ = System.getProperty( - "spark.broadcast.minKnockInterval", "500").toInt - private var MaxKnockInterval_ = System.getProperty( - "spark.broadcast.maxKnockInterval", "999").toInt - - // Load TreeBroadcast config params - private var MaxDegree_ = System.getProperty( - "spark.broadcast.maxDegree", "2").toInt - - // Load BitTorrentBroadcast config params - private var MaxPeersInGuideResponse_ = System.getProperty( - "spark.broadcast.maxPeersInGuideResponse", "4").toInt - - private var MaxChatSlots_ = System.getProperty( - "spark.broadcast.maxChatSlots", "4").toInt - private var MaxChatTime_ = System.getProperty( - "spark.broadcast.maxChatTime", "500").toInt - private var MaxChatBlocks_ = System.getProperty( - "spark.broadcast.maxChatBlocks", "1024").toInt - - private var EndGameFraction_ = System.getProperty( - "spark.broadcast.endGameFraction", "0.95").toDouble - - def isDriver = _isDriver - - // Common config params - def DriverHostAddress = DriverHostAddress_ - def DriverTrackerPort = DriverTrackerPort_ - def BlockSize = BlockSize_ - def MaxRetryCount = MaxRetryCount_ - - def TrackerSocketTimeout = TrackerSocketTimeout_ - def ServerSocketTimeout = ServerSocketTimeout_ - - def MinKnockInterval = MinKnockInterval_ - def MaxKnockInterval = MaxKnockInterval_ - - // TreeBroadcast configs - def MaxDegree = MaxDegree_ - - // BitTorrentBroadcast configs - def MaxPeersInGuideResponse = MaxPeersInGuideResponse_ - - def MaxChatSlots = MaxChatSlots_ - def MaxChatTime = MaxChatTime_ - def MaxChatBlocks = MaxChatBlocks_ - - def EndGameFraction = EndGameFraction_ - - class TrackMultipleValues - extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(DriverTrackerPort) - logInfo("TrackMultipleValues started at " + serverSocket) - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(TrackerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - if (stopBroadcast) { - logInfo("Stopping TrackMultipleValues...") - } - } - } - - if (clientSocket != null) { - try { - threadPool.execute(new Thread { - override def run() { - val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - val ois = new ObjectInputStream(clientSocket.getInputStream) - - try { - // First, read message type - val messageType = ois.readObject.asInstanceOf[Int] - - if (messageType == REGISTER_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - // Receive hostAddress and listenPort - val gInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Add to the map - valueToGuideMap.synchronized { - valueToGuideMap += (id -> gInfo) - } - - logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == UNREGISTER_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - - // Remove from the map - valueToGuideMap.synchronized { - valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault) - } - - logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == FIND_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - - var gInfo = - if (valueToGuideMap.contains(id)) valueToGuideMap(id) - else SourceInfo("", SourceInfo.TxNotStartedRetry) - - logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort) - - // Send reply back - oos.writeObject(gInfo) - oos.flush() - } else { - throw new SparkException("Undefined messageType at TrackMultipleValues") - } - } catch { - case e: Exception => { - logError("TrackMultipleValues had a " + e) - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - }) - } catch { - // In failure, close socket here; else, client thread will close - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - serverSocket.close() - } - // Shutdown the thread pool - threadPool.shutdown() - } - } - - def getGuideInfo(variableLong: Long): SourceInfo = { - var clientSocketToTracker: Socket = null - var oosTracker: ObjectOutputStream = null - var oisTracker: ObjectInputStream = null - - var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry) - - var retriesLeft = MultiTracker.MaxRetryCount - do { - try { - // Connect to the tracker to find out GuideInfo - clientSocketToTracker = - new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort) - oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - // Send messageType/intention - oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER) - oosTracker.flush() - - // Send Long and receive GuideInfo - oosTracker.writeObject(variableLong) - oosTracker.flush() - gInfo = oisTracker.readObject.asInstanceOf[SourceInfo] - } catch { - case e: Exception => logError("getGuideInfo had a " + e) - } finally { - if (oisTracker != null) { - oisTracker.close() - } - if (oosTracker != null) { - oosTracker.close() - } - if (clientSocketToTracker != null) { - clientSocketToTracker.close() - } - } - - Thread.sleep(MultiTracker.ranGen.nextInt( - MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + - MultiTracker.MinKnockInterval) - - retriesLeft -= 1 - } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry) - - logDebug("Got this guidePort from Tracker: " + gInfo.listenPort) - return gInfo - } - - def registerBroadcast(id: Long, gInfo: SourceInfo) { - val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(REGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send Long of this broadcast - oosST.writeObject(id) - oosST.flush() - - // Send this tracker's information - oosST.writeObject(gInfo) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - - def unregisterBroadcast(id: Long) { - val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(UNREGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send Long of this broadcast - oosST.writeObject(id) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - - // Helper method to convert an object to Array[BroadcastBlock] - def blockifyObject[IN](obj: IN): VariableInfo = { - val baos = new ByteArrayOutputStream - val oos = new ObjectOutputStream(baos) - oos.writeObject(obj) - oos.close() - baos.close() - val byteArray = baos.toByteArray - val bais = new ByteArrayInputStream(byteArray) - - var blockNum = (byteArray.length / BlockSize) - if (byteArray.length % BlockSize != 0) - blockNum += 1 - - var retVal = new Array[BroadcastBlock](blockNum) - var blockID = 0 - - for (i <- 0 until (byteArray.length, BlockSize)) { - val thisBlockSize = math.min(BlockSize, byteArray.length - i) - var tempByteArray = new Array[Byte](thisBlockSize) - val hasRead = bais.read(tempByteArray, 0, thisBlockSize) - - retVal(blockID) = new BroadcastBlock(blockID, tempByteArray) - blockID += 1 - } - bais.close() - - var variableInfo = VariableInfo(retVal, blockNum, byteArray.length) - variableInfo.hasBlocks = blockNum - - return variableInfo - } - - // Helper method to convert Array[BroadcastBlock] to object - def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock], - totalBytes: Int, - totalBlocks: Int): OUT = { - - var retByteArray = new Array[Byte](totalBytes) - for (i <- 0 until totalBlocks) { - System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, - i * BlockSize, arrayOfBlocks(i).byteArray.length) - } - byteArrayToObject(retByteArray) - } - - private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = { - val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){ - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader) - } - val retVal = in.readObject.asInstanceOf[OUT] - in.close() - return retVal - } -} - -private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte]) -extends Serializable - -private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock], - totalBlocks: Int, - totalBytes: Int) -extends Serializable { - @transient var hasBlocks = 0 -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala b/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala deleted file mode 100644 index baa1fd6da4..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.util.BitSet - -import org.apache.spark._ - -/** - * Used to keep and pass around information of peers involved in a broadcast - */ -private[spark] case class SourceInfo (hostAddress: String, - listenPort: Int, - totalBlocks: Int = SourceInfo.UnusedParam, - totalBytes: Int = SourceInfo.UnusedParam) -extends Comparable[SourceInfo] with Logging { - - var currentLeechers = 0 - var receptionFailed = false - - var hasBlocks = 0 - var hasBlocksBitVector: BitSet = new BitSet (totalBlocks) - - // Ascending sort based on leecher count - def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers) -} - -/** - * Helper Object of SourceInfo for its constants - */ -private[spark] object SourceInfo { - // Broadcast has not started yet! Should never happen. - val TxNotStartedRetry = -1 - // Broadcast has already finished. Try default mechanism. - val TxOverGoToDefault = -3 - // Other constants - val StopBroadcast = -2 - val UnusedParam = 0 -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala new file mode 100644 index 0000000000..073a0a5029 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.io._ + +import scala.math +import scala.util.Random + +import org.apache.spark._ +import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} +import org.apache.spark.util.Utils + + +private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +extends Broadcast[T](id) with Logging with Serializable { + + def value = value_ + + def broadcastId = BroadcastBlockId(id) + + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + } + + @transient var arrayOfBlocks: Array[TorrentBlock] = null + @transient var totalBlocks = -1 + @transient var totalBytes = -1 + @transient var hasBlocks = 0 + + if (!isLocal) { + sendBroadcast() + } + + def sendBroadcast() { + var tInfo = TorrentBroadcast.blockifyObject(value_) + + totalBlocks = tInfo.totalBlocks + totalBytes = tInfo.totalBytes + hasBlocks = tInfo.totalBlocks + + // Store meta-info + val metaId = BroadcastHelperBlockId(broadcastId, "meta") + val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle( + metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true) + } + + // Store individual pieces + for (i <- 0 until totalBlocks) { + val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle( + pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true) + } + } + } + + // Called by JVM when deserializing an object + private def readObject(in: ObjectInputStream) { + in.defaultReadObject() + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(broadcastId) match { + case Some(x) => + value_ = x.asInstanceOf[T] + + case None => + val start = System.nanoTime + logInfo("Started reading broadcast variable " + id) + + // Initialize @transient variables that will receive garbage values from the master. + resetWorkerVariables() + + if (receiveBroadcast(id)) { + value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) + + // Store the merged copy in cache so that the next worker doesn't need to rebuild it. + // This creates a tradeoff between memory usage and latency. + // Storing copy doubles the memory footprint; not storing doubles deserialization cost. + SparkEnv.get.blockManager.putSingle( + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + + // Remove arrayOfBlocks from memory once value_ is on local cache + resetWorkerVariables() + } else { + logError("Reading broadcast variable " + id + " failed") + } + + val time = (System.nanoTime - start) / 1e9 + logInfo("Reading broadcast variable " + id + " took " + time + " s") + } + } + } + + private def resetWorkerVariables() { + arrayOfBlocks = null + totalBytes = -1 + totalBlocks = -1 + hasBlocks = 0 + } + + def receiveBroadcast(variableID: Long): Boolean = { + // Receive meta-info + val metaId = BroadcastHelperBlockId(broadcastId, "meta") + var attemptId = 10 + while (attemptId > 0 && totalBlocks == -1) { + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(metaId) match { + case Some(x) => + val tInfo = x.asInstanceOf[TorrentInfo] + totalBlocks = tInfo.totalBlocks + totalBytes = tInfo.totalBytes + arrayOfBlocks = new Array[TorrentBlock](totalBlocks) + hasBlocks = 0 + + case None => + Thread.sleep(500) + } + } + attemptId -= 1 + } + if (totalBlocks == -1) { + return false + } + + // Receive actual blocks + val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) + for (pid <- recvOrder) { + val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(pieceId) match { + case Some(x) => + arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] + hasBlocks += 1 + SparkEnv.get.blockManager.putSingle( + pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true) + + case None => + throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) + } + } + } + + (hasBlocks == totalBlocks) + } + +} + +private object TorrentBroadcast +extends Logging { + + private var initialized = false + + def initialize(_isDriver: Boolean) { + synchronized { + if (!initialized) { + initialized = true + } + } + } + + def stop() { + initialized = false + } + + val BLOCK_SIZE = System.getProperty("spark.broadcast.blockSize", "4096").toInt * 1024 + + def blockifyObject[T](obj: T): TorrentInfo = { + val byteArray = Utils.serialize[T](obj) + val bais = new ByteArrayInputStream(byteArray) + + var blockNum = (byteArray.length / BLOCK_SIZE) + if (byteArray.length % BLOCK_SIZE != 0) + blockNum += 1 + + var retVal = new Array[TorrentBlock](blockNum) + var blockID = 0 + + for (i <- 0 until (byteArray.length, BLOCK_SIZE)) { + val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i) + var tempByteArray = new Array[Byte](thisBlockSize) + val hasRead = bais.read(tempByteArray, 0, thisBlockSize) + + retVal(blockID) = new TorrentBlock(blockID, tempByteArray) + blockID += 1 + } + bais.close() + + var tInfo = TorrentInfo(retVal, blockNum, byteArray.length) + tInfo.hasBlocks = blockNum + + return tInfo + } + + def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock], + totalBytes: Int, + totalBlocks: Int): T = { + var retByteArray = new Array[Byte](totalBytes) + for (i <- 0 until totalBlocks) { + System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, + i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length) + } + Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader) + } + +} + +private[spark] case class TorrentBlock( + blockID: Int, + byteArray: Array[Byte]) + extends Serializable + +private[spark] case class TorrentInfo( + @transient arrayOfBlocks : Array[TorrentBlock], + totalBlocks: Int, + totalBytes: Int) + extends Serializable { + + @transient var hasBlocks = 0 +} + +private[spark] class TorrentBroadcastFactory + extends BroadcastFactory { + + def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new TorrentBroadcast[T](value_, isLocal, id) + + def stop() { TorrentBroadcast.stop() } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala deleted file mode 100644 index e6674d49a7..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala +++ /dev/null @@ -1,601 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.io._ -import java.net._ - -import scala.collection.mutable.{ListBuffer, Set} - -import org.apache.spark._ -import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} -import org.apache.spark.util.Utils - -private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) -extends Broadcast[T](id) with Logging with Serializable { - - def value = value_ - - def blockId = BroadcastBlockId(id) - - MultiTracker.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } - - @transient var arrayOfBlocks: Array[BroadcastBlock] = null - @transient var totalBytes = -1 - @transient var totalBlocks = -1 - @transient var hasBlocks = 0 - - @transient var listenPortLock = new Object - @transient var guidePortLock = new Object - @transient var totalBlocksLock = new Object - @transient var hasBlocksLock = new Object - - @transient var listOfSources = ListBuffer[SourceInfo]() - - @transient var serveMR: ServeMultipleRequests = null - @transient var guideMR: GuideMultipleRequests = null - - @transient var hostAddress = Utils.localIpAddress - @transient var listenPort = -1 - @transient var guidePort = -1 - - @transient var stopBroadcast = false - - // Must call this after all the variables have been created/initialized - if (!isLocal) { - sendBroadcast() - } - - def sendBroadcast() { - logInfo("Local host address: " + hostAddress) - - // Create a variableInfo object and store it in valueInfos - var variableInfo = MultiTracker.blockifyObject(value_) - - // Prepare the value being broadcasted - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks = variableInfo.totalBlocks - - guideMR = new GuideMultipleRequests - guideMR.setDaemon(true) - guideMR.start() - logInfo("GuideMultipleRequests started...") - - // Must always come AFTER guideMR is created - while (guidePort == -1) { - guidePortLock.synchronized { guidePortLock.wait() } - } - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - // Must always come AFTER serveMR is created - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Must always come AFTER listenPort is created - val masterSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) - listOfSources += masterSource - - // Register with the Tracker - MultiTracker.registerBroadcast(id, - SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) - } - - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - MultiTracker.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { - case Some(x) => - value_ = x.asInstanceOf[T] - - case None => - logInfo("Started reading broadcast variable " + id) - // Initializing everything because Driver will only send null/0 values - // Only the 1st worker in a node can be here. Others will get from cache - initializeWorkerVariables() - - logInfo("Local host address: " + hostAddress) - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast(id) - if (receptionSucceeded) { - value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } else { - logError("Reading broadcast variable " + id + " failed") - } - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - - private def initializeWorkerVariables() { - arrayOfBlocks = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = 0 - - listenPortLock = new Object - totalBlocksLock = new Object - hasBlocksLock = new Object - - serveMR = null - - hostAddress = Utils.localIpAddress - listenPort = -1 - - stopBroadcast = false - } - - def receiveBroadcast(variableID: Long): Boolean = { - val gInfo = MultiTracker.getGuideInfo(variableID) - - if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { - return false - } - - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - var clientSocketToDriver: Socket = null - var oosDriver: ObjectOutputStream = null - var oisDriver: ObjectInputStream = null - - // Connect and receive broadcast from the specified source, retrying the - // specified number of times in case of failures - var retriesLeft = MultiTracker.MaxRetryCount - do { - // Connect to Driver and send this worker's Information - clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort) - oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream) - oosDriver.flush() - oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream) - - logDebug("Connected to Driver's guiding object") - - // Send local source information - oosDriver.writeObject(SourceInfo(hostAddress, listenPort)) - oosDriver.flush() - - // Receive source information from Driver - var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo] - totalBlocks = sourceInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) - totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } - totalBytes = sourceInfo.totalBytes - - logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort) - - val start = System.nanoTime - val receptionSucceeded = receiveSingleTransmission(sourceInfo) - val time = (System.nanoTime - start) / 1e9 - - // Updating some statistics in sourceInfo. Driver will be using them later - if (!receptionSucceeded) { - sourceInfo.receptionFailed = true - } - - // Send back statistics to the Driver - oosDriver.writeObject(sourceInfo) - - if (oisDriver != null) { - oisDriver.close() - } - if (oosDriver != null) { - oosDriver.close() - } - if (clientSocketToDriver != null) { - clientSocketToDriver.close() - } - - retriesLeft -= 1 - } while (retriesLeft > 0 && hasBlocks < totalBlocks) - - return (hasBlocks == totalBlocks) - } - - /** - * Tries to receive broadcast from the source and returns Boolean status. - * This might be called multiple times to retry a defined number of times. - */ - private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = { - var clientSocketToSource: Socket = null - var oosSource: ObjectOutputStream = null - var oisSource: ObjectInputStream = null - - var receptionSucceeded = false - try { - // Connect to the source to get the object itself - clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream) - oosSource.flush() - oisSource = new ObjectInputStream(clientSocketToSource.getInputStream) - - logDebug("Inside receiveSingleTransmission") - logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) - - // Send the range - oosSource.writeObject((hasBlocks, totalBlocks)) - oosSource.flush() - - for (i <- hasBlocks until totalBlocks) { - val recvStartTime = System.currentTimeMillis - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - val receptionTime = (System.currentTimeMillis - recvStartTime) - - logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.") - - arrayOfBlocks(hasBlocks) = bcBlock - hasBlocks += 1 - - // Set to true if at least one block is received - receptionSucceeded = true - hasBlocksLock.synchronized { hasBlocksLock.notifyAll() } - } - } catch { - case e: Exception => logError("receiveSingleTransmission had a " + e) - } finally { - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (clientSocketToSource != null) { - clientSocketToSource.close() - } - } - - return receptionSucceeded - } - - class GuideMultipleRequests - extends Thread with Logging { - // Keep track of sources that have completed reception - private var setOfCompletedSources = Set[SourceInfo]() - - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) - guidePort = serverSocket.getLocalPort - logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - - guidePortLock.synchronized { guidePortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - // Stop broadcast if at least one worker has connected and - // everyone connected so far are done. Comparing with - // listOfSources.size - 1, because it includes the Guide itself - listOfSources.synchronized { - setOfCompletedSources.synchronized { - if (listOfSources.size > 1 && - setOfCompletedSources.size == listOfSources.size - 1) { - stopBroadcast = true - logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.") - } - } - } - } - } - if (clientSocket != null) { - logDebug("Guide: Accepted new client connection: " + clientSocket) - try { - threadPool.execute(new GuideSingleRequest(clientSocket)) - } catch { - // In failure, close() the socket here; else, the thread will close() it - case ioe: IOException => clientSocket.close() - } - } - } - - logInfo("Sending stopBroadcast notifications...") - sendStopBroadcastNotifications - - MultiTracker.unregisterBroadcast(id) - } finally { - if (serverSocket != null) { - logInfo("GuideMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - private def sendStopBroadcastNotifications() { - listOfSources.synchronized { - var listIter = listOfSources.iterator - while (listIter.hasNext) { - var sourceInfo = listIter.next - - var guideSocketToSource: Socket = null - var gosSource: ObjectOutputStream = null - var gisSource: ObjectInputStream = null - - try { - // Connect to the source - guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) - gosSource.flush() - gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) - - // Send stopBroadcast signal - gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast)) - gosSource.flush() - } catch { - case e: Exception => { - logError("sendStopBroadcastNotifications had a " + e) - } - } finally { - if (gisSource != null) { - gisSource.close() - } - if (gosSource != null) { - gosSource.close() - } - if (guideSocketToSource != null) { - guideSocketToSource.close() - } - } - } - } - } - - class GuideSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var selectedSourceInfo: SourceInfo = null - private var thisWorkerInfo:SourceInfo = null - - override def run() { - try { - logInfo("new GuideSingleRequest is running") - // Connecting worker is sending in its hostAddress and listenPort it will - // be listening to. Other fields are invalid (SourceInfo.UnusedParam) - var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - listOfSources.synchronized { - // Select a suitable source and send it back to the worker - selectedSourceInfo = selectSuitableSource(sourceInfo) - logDebug("Sending selectedSourceInfo: " + selectedSourceInfo) - oos.writeObject(selectedSourceInfo) - oos.flush() - - // Add this new (if it can finish) source to the list of sources - thisWorkerInfo = SourceInfo(sourceInfo.hostAddress, - sourceInfo.listenPort, totalBlocks, totalBytes) - logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo) - listOfSources += thisWorkerInfo - } - - // Wait till the whole transfer is done. Then receive and update source - // statistics in listOfSources - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - listOfSources.synchronized { - // This should work since SourceInfo is a case class - assert(listOfSources.contains(selectedSourceInfo)) - - // Remove first - // (Currently removing a source based on just one failure notification!) - listOfSources = listOfSources - selectedSourceInfo - - // Update sourceInfo and put it back in, IF reception succeeded - if (!sourceInfo.receptionFailed) { - // Add thisWorkerInfo to sources that have completed reception - setOfCompletedSources.synchronized { - setOfCompletedSources += thisWorkerInfo - } - - // Update leecher count and put it back in - selectedSourceInfo.currentLeechers -= 1 - listOfSources += selectedSourceInfo - } - } - } catch { - case e: Exception => { - // Remove failed worker from listOfSources and update leecherCount of - // corresponding source worker - listOfSources.synchronized { - if (selectedSourceInfo != null) { - // Remove first - listOfSources = listOfSources - selectedSourceInfo - // Update leecher count and put it back in - selectedSourceInfo.currentLeechers -= 1 - listOfSources += selectedSourceInfo - } - - // Remove thisWorkerInfo - if (listOfSources != null) { - listOfSources = listOfSources - thisWorkerInfo - } - } - } - } finally { - logInfo("GuideSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - // Assuming the caller to have a synchronized block on listOfSources - // Select one with the most leechers. This will level-wise fill the tree - private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { - var maxLeechers = -1 - var selectedSource: SourceInfo = null - - listOfSources.foreach { source => - if ((source.hostAddress != skipSourceInfo.hostAddress || - source.listenPort != skipSourceInfo.listenPort) && - source.currentLeechers < MultiTracker.MaxDegree && - source.currentLeechers > maxLeechers) { - selectedSource = source - maxLeechers = source.currentLeechers - } - } - - // Update leecher count - selectedSource.currentLeechers += 1 - return selectedSource - } - } - } - - class ServeMultipleRequests - extends Thread with Logging { - - var threadPool = Utils.newDaemonCachedThreadPool() - - override def run() { - var serverSocket = new ServerSocket(0) - listenPort = serverSocket.getLocalPort - - logInfo("ServeMultipleRequests started with " + serverSocket) - - listenPortLock.synchronized { listenPortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { } - } - - if (clientSocket != null) { - logDebug("Serve: Accepted new client connection: " + clientSocket) - try { - threadPool.execute(new ServeSingleRequest(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ServeMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ServeSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var sendFrom = 0 - private var sendUntil = totalBlocks - - override def run() { - try { - logInfo("new ServeSingleRequest is running") - - // Receive range to send - var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)] - sendFrom = rangeToSend._1 - sendUntil = rangeToSend._2 - - // If not a valid range, stop broadcast - if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) { - stopBroadcast = true - } else { - sendObject - } - } catch { - case e: Exception => logError("ServeSingleRequest had a " + e) - } finally { - logInfo("ServeSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - private def sendObject() { - // Wait till receiving the SourceInfo from Driver - while (totalBlocks == -1) { - totalBlocksLock.synchronized { totalBlocksLock.wait() } - } - - for (i <- sendFrom until sendUntil) { - while (i == hasBlocks) { - hasBlocksLock.synchronized { hasBlocksLock.wait() } - } - try { - oos.writeObject(arrayOfBlocks(i)) - oos.flush() - } catch { - case e: Exception => logError("sendObject had a " + e) - } - logDebug("Sent block: " + i + " to " + clientSocket) - } - } - } - } -} - -private[spark] class TreeBroadcastFactory -extends BroadcastFactory { - def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new TreeBroadcast[T](value_, isLocal, id) - - def stop() { MultiTracker.stop() } -} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 993ba6bd3d..6bc846aa92 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -17,28 +17,47 @@ package org.apache.spark.deploy -import com.google.common.collect.MapMaker - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.JobConf +import org.apache.spark.SparkException /** - * Contains util methods to interact with Hadoop from spark. + * Contains util methods to interact with Hadoop from Spark. */ +private[spark] class SparkHadoopUtil { - // A general, soft-reference map for metadata needed during HadoopRDD split computation - // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). - private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]() - // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop - // subsystems + /** + * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop + * subsystems. + */ def newConfiguration(): Configuration = new Configuration() - // Add any user credentials to the job conf which are necessary for running on a secure Hadoop - // cluster + /** + * Add any user credentials to the job conf which are necessary for running on a secure Hadoop + * cluster. + */ def addCredentials(conf: JobConf) {} def isYarnMode(): Boolean = { false } - +} + +object SparkHadoopUtil { + private val hadoop = { + val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) + if (yarnMode) { + try { + Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil] + } catch { + case th: Throwable => throw new SparkException("Unable to load YARN support", th) + } + } else { + new SparkHadoopUtil + } + } + + def get: SparkHadoopUtil = { + hadoop + } } diff --git a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index db0bea0472..80ff4c59cb 100644 --- a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -24,11 +24,11 @@ import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClie import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._ +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{Utils, AkkaUtils} -private[spark] class StandaloneExecutorBackend( +private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, executorId: String, hostPort: String, @@ -80,6 +80,11 @@ private[spark] class StandaloneExecutorBackend( case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => logError("Driver terminated or disconnected! Shutting down.") System.exit(1) + + case StopExecutor => + logInfo("Driver commanded a shutdown") + context.stop(self) + context.system.shutdown() } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { @@ -87,7 +92,7 @@ private[spark] class StandaloneExecutorBackend( } } -private[spark] object StandaloneExecutorBackend { +private[spark] object CoarseGrainedExecutorBackend { def run(driverUrl: String, executorId: String, hostname: String, cores: Int) { // Debug code Utils.checkHost(hostname) @@ -99,7 +104,7 @@ private[spark] object StandaloneExecutorBackend { val sparkHostPort = hostname + ":" + boundPort System.setProperty("spark.hostPort", sparkHostPort) val actor = actorSystem.actorOf( - Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)), + Props(new CoarseGrainedExecutorBackend(driverUrl, executorId, sparkHostPort, cores)), name = "Executor") actorSystem.awaitTermination() } @@ -107,7 +112,9 @@ private[spark] object StandaloneExecutorBackend { def main(args: Array[String]) { if (args.length < 4) { //the reason we allow the last frameworkId argument is to make it easy to kill rogue executors - System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores> [<appid>]") + System.err.println( + "Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> <cores> " + + "[<appid>]") System.exit(1) } run(args(0), args(1), args(2), args(3).toInt) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 20323ea038..b773346df3 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -74,30 +74,33 @@ private[spark] class Executor( private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) Thread.currentThread.setContextClassLoader(replClassLoader) - // Make any thread terminations due to uncaught exceptions kill the entire - // executor process to avoid surprising stalls. - Thread.setDefaultUncaughtExceptionHandler( - new Thread.UncaughtExceptionHandler { - override def uncaughtException(thread: Thread, exception: Throwable) { - try { - logError("Uncaught exception in thread " + thread, exception) - - // We may have been called from a shutdown hook. If so, we must not call System.exit(). - // (If we do, we will deadlock.) - if (!Utils.inShutdown()) { - if (exception.isInstanceOf[OutOfMemoryError]) { - System.exit(ExecutorExitCode.OOM) - } else { - System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION) + if (!isLocal) { + // Setup an uncaught exception handler for non-local mode. + // Make any thread terminations due to uncaught exceptions kill the entire + // executor process to avoid surprising stalls. + Thread.setDefaultUncaughtExceptionHandler( + new Thread.UncaughtExceptionHandler { + override def uncaughtException(thread: Thread, exception: Throwable) { + try { + logError("Uncaught exception in thread " + thread, exception) + + // We may have been called from a shutdown hook. If so, we must not call System.exit(). + // (If we do, we will deadlock.) + if (!Utils.inShutdown()) { + if (exception.isInstanceOf[OutOfMemoryError]) { + System.exit(ExecutorExitCode.OOM) + } else { + System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION) + } } + } catch { + case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM) + case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE) } - } catch { - case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM) - case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE) } } - } - ) + ) + } val executorSource = new ExecutorSource(this, executorId) @@ -121,8 +124,7 @@ private[spark] class Executor( } // Start worker thread pool - val threadPool = new ThreadPoolExecutor( - 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable], Utils.daemonThreadFactory) + val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") // Maintains the list of running tasks. private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index f311141148..0b4892f98f 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -102,4 +102,9 @@ class ShuffleWriteMetrics extends Serializable { * Number of bytes written for a shuffle */ var shuffleBytesWritten: Long = _ + + /** + * Time spent blocking on writes to disk or buffer cache, in nanoseconds. + */ + var shuffleWriteTime: Long = _ } diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index e15a839c4e..9c2fee4023 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -79,7 +79,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] private val registerRequests = new SynchronizedQueue[SendingConnection] - implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool()) + implicit val futureExecContext = ExecutionContext.fromExecutor( + Utils.newDaemonCachedThreadPool("Connection manager future execution context")) private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala index 1586dff254..546d921067 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala @@ -21,7 +21,7 @@ import java.io.File import org.apache.spark.Logging import org.apache.spark.util.Utils -import org.apache.spark.storage.BlockId +import org.apache.spark.storage.{BlockId, FileSegment} private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging { @@ -54,8 +54,7 @@ private[spark] object ShuffleSender { val localDirs = args.drop(2).map(new File(_)) val pResovler = new PathResolver { - override def getAbsolutePath(blockIdString: String): String = { - val blockId = BlockId(blockIdString) + override def getBlockLocation(blockId: BlockId): FileSegment = { if (!blockId.isShuffle) { throw new Exception("Block " + blockId + " is not a shuffle block") } @@ -65,7 +64,7 @@ private[spark] object ShuffleSender { val subDirId = (hash / localDirs.length) % subDirsPerLocalDir val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) val file = new File(subDir, blockId.name) - return file.getAbsolutePath + return new FileSegment(file, 0, file.length()) } } val sender = new ShuffleSender(port, pResovler) diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index f132e2b735..70a5a8caff 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -15,6 +15,8 @@ * limitations under the License. */ +package org.apache + /** * Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to * Spark, while [[org.apache.spark.rdd.RDD]] is the data type representing a distributed collection, diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index ccaaecb85b..d3033ea4a6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.rdd import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.{NullWritable, BytesWritable} @@ -83,7 +84,7 @@ private[spark] object CheckpointRDD extends Logging { def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { val env = SparkEnv.get val outputDir = new Path(path) - val fs = outputDir.getFileSystem(env.hadoop.newConfiguration()) + val fs = outputDir.getFileSystem(SparkHadoopUtil.get.newConfiguration()) val finalOutputName = splitIdToFile(ctx.partitionId) val finalOutputPath = new Path(outputDir, finalOutputName) @@ -122,7 +123,7 @@ private[spark] object CheckpointRDD extends Logging { def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = { val env = SparkEnv.get - val fs = path.getFileSystem(env.hadoop.newConfiguration()) + val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration()) val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt val fileInputStream = fs.open(path, bufferSize) val serializer = env.serializer.newInstance() @@ -145,7 +146,7 @@ private[spark] object CheckpointRDD extends Logging { val sc = new SparkContext(cluster, "CheckpointRDD Test") val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) val path = new Path(hdfsPath, "temp") - val fs = path.getFileSystem(env.hadoop.newConfiguration()) + val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration()) sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index fad042c7ae..32901a508f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.util.ReflectionUtils import org.apache.spark._ import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.NextIterator import org.apache.hadoop.conf.{Configuration, Configurable} @@ -198,10 +199,10 @@ private[spark] object HadoopRDD { * The three methods below are helpers for accessing the local map, a property of the SparkEnv of * the local process. */ - def getCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.get(key) + def getCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.get(key) - def containsCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.containsKey(key) + def containsCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.containsKey(key) def putCachedMetadata(key: String, value: Any) = - SparkEnv.get.hadoop.hadoopJobMetadata.put(key, value) + SparkEnv.get.hadoopJobMetadata.put(key, value) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 0355618e43..6e88be6f6a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -266,6 +266,19 @@ abstract class RDD[T: ClassManifest]( def distinct(): RDD[T] = distinct(partitions.size) /** + * Return a new RDD that has exactly numPartitions partitions. + * + * Can increase or decrease the level of parallelism in this RDD. Internally, this uses + * a shuffle to redistribute data. + * + * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, + * which can avoid performing a shuffle. + */ + def repartition(numPartitions: Int): RDD[T] = { + coalesce(numPartitions, true) + } + + /** * Return a new RDD that is reduced into `numPartitions` partitions. * * This results in a narrow dependency, e.g. if you go from 1000 partitions diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 7fb614402b..4cef0825dd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -52,23 +52,29 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH private[spark] class DAGScheduler( taskSched: TaskScheduler, - mapOutputTracker: MapOutputTracker, + mapOutputTracker: MapOutputTrackerMaster, blockManagerMaster: BlockManagerMaster, env: SparkEnv) - extends TaskSchedulerListener with Logging { + extends Logging { def this(taskSched: TaskScheduler) { - this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get) + this(taskSched, SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], + SparkEnv.get.blockManager.master, SparkEnv.get) } - taskSched.setListener(this) + taskSched.setDAGScheduler(this) // Called by TaskScheduler to report task's starting. - override def taskStarted(task: Task[_], taskInfo: TaskInfo) { + def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventQueue.put(BeginEvent(task, taskInfo)) } + // Called to report that a task has completed and results are being fetched remotely. + def taskGettingResult(task: Task[_], taskInfo: TaskInfo) { + eventQueue.put(GettingResultEvent(task, taskInfo)) + } + // Called by TaskScheduler to report task completions or failures. - override def taskEnded( + def taskEnded( task: Task[_], reason: TaskEndReason, result: Any, @@ -79,18 +85,18 @@ class DAGScheduler( } // Called by TaskScheduler when an executor fails. - override def executorLost(execId: String) { + def executorLost(execId: String) { eventQueue.put(ExecutorLost(execId)) } // Called by TaskScheduler when a host is added - override def executorGained(execId: String, host: String) { + def executorGained(execId: String, host: String) { eventQueue.put(ExecutorGained(execId, host)) } // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or // cancellation of the job itself. - override def taskSetFailed(taskSet: TaskSet, reason: String) { + def taskSetFailed(taskSet: TaskSet, reason: String) { eventQueue.put(TaskSetFailed(taskSet, reason)) } @@ -182,7 +188,7 @@ class DAGScheduler( shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => - val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId) + val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId) shuffleToMapStage(shuffleDep.shuffleId) = stage stage } @@ -195,6 +201,7 @@ class DAGScheduler( */ private def newStage( rdd: RDD[_], + numTasks: Int, shuffleDep: Option[ShuffleDependency[_,_]], jobId: Int, callSite: Option[String] = None) @@ -207,9 +214,10 @@ class DAGScheduler( mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size) } val id = nextStageId.getAndIncrement() - val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) + val stage = + new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) stageIdToStage(id) = stage - stageToInfos(stage) = StageInfo(stage) + stageToInfos(stage) = new StageInfo(stage) stage } @@ -277,11 +285,6 @@ class DAGScheduler( resultHandler: (Int, U) => Unit, properties: Properties = null): JobWaiter[U] = { - val jobId = nextJobId.getAndIncrement() - if (partitions.size == 0) { - return new JobWaiter[U](this, jobId, 0, resultHandler) - } - // Check to make sure we are not launching a task on a partition that does not exist. val maxPartitions = rdd.partitions.length partitions.find(p => p >= maxPartitions).foreach { p => @@ -290,6 +293,11 @@ class DAGScheduler( "Total number of partitions: " + maxPartitions) } + val jobId = nextJobId.getAndIncrement() + if (partitions.size == 0) { + return new JobWaiter[U](this, jobId, 0, resultHandler) + } + assert(partitions.size > 0) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) @@ -342,6 +350,11 @@ class DAGScheduler( eventQueue.put(JobCancelled(jobId)) } + def cancelJobGroup(groupId: String) { + logInfo("Asked to cancel job group " + groupId) + eventQueue.put(JobGroupCancelled(groupId)) + } + /** * Cancel all jobs that are running or waiting in the queue. */ @@ -356,7 +369,7 @@ class DAGScheduler( private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { event match { case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => - val finalStage = newStage(rdd, None, jobId, Some(callSite)) + val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite)) val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length + @@ -381,6 +394,17 @@ class DAGScheduler( taskSched.cancelTasks(stage.id) } + case JobGroupCancelled(groupId) => + // Cancel all jobs belonging to this job group. + // First finds all active jobs with this group id, and then kill stages for them. + val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + .map(_.jobId) + if (!jobIds.isEmpty) { + running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage => + taskSched.cancelTasks(stage.id) + } + } + case AllJobsCancelled => // Cancel all running jobs. running.foreach { stage => @@ -396,6 +420,9 @@ class DAGScheduler( case begin: BeginEvent => listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo)) + case gettingResult: GettingResultEvent => + listenerBus.post(SparkListenerTaskGettingResult(gettingResult.task, gettingResult.taskInfo)) + case completion: CompletionEvent => listenerBus.post(SparkListenerTaskEnd( completion.task, completion.reason, completion.taskInfo, completion.taskMetrics)) @@ -568,7 +595,7 @@ class DAGScheduler( // must be run listener before possible NotSerializableException // should be "StageSubmitted" first and then "JobEnded" - listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties)) + listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties)) if (tasks.size > 0) { // Preemptively serialize a task to make sure it can be serialized. We are catching this @@ -589,9 +616,7 @@ class DAGScheduler( logDebug("New pending tasks: " + myPending) taskSched.submitTasks( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) - if (!stage.submissionTime.isDefined) { - stage.submissionTime = Some(System.currentTimeMillis()) - } + stageToInfos(stage).submissionTime = Some(System.currentTimeMillis()) } else { logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) @@ -613,12 +638,12 @@ class DAGScheduler( val stage = stageIdToStage(task.stageId) def markStageAsFinished(stage: Stage) = { - val serviceTime = stage.submissionTime match { + val serviceTime = stageToInfos(stage).submissionTime match { case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0) - case _ => "Unkown" + case _ => "Unknown" } logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stage.completionTime = Some(System.currentTimeMillis) + stageToInfos(stage).completionTime = Some(System.currentTimeMillis()) listenerBus.post(StageCompleted(stageToInfos(stage))) running -= stage } @@ -788,7 +813,7 @@ class DAGScheduler( */ private def abortStage(failedStage: Stage, reason: String) { val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq - failedStage.completionTime = Some(System.currentTimeMillis()) + stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis()) for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) val error = new SparkException("Job aborted: " + reason) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index ee89bfb38d..708d221d60 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -46,11 +46,16 @@ private[scheduler] case class JobSubmitted( private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent +private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent + private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent private[scheduler] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent +private[scheduler] +case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent + private[scheduler] case class CompletionEvent( task: Task[_], reason: TaskEndReason, diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index 370ccd183c..1791ee660d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.deploy.SparkHadoopUtil import scala.collection.immutable.Set import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.hadoop.security.UserGroupInformation @@ -87,9 +88,8 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl // This method does not expect failures, since validate has already passed ... private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = { - val env = SparkEnv.get val conf = new JobConf(configuration) - env.hadoop.addCredentials(conf) + SparkHadoopUtil.get.addCredentials(conf) FileInputFormat.setInputPaths(conf, path) val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] = @@ -108,9 +108,8 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl // This method does not expect failures, since validate has already passed ... private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = { - val env = SparkEnv.get val jobConf = new JobConf(configuration) - env.hadoop.addCredentials(jobConf) + SparkHadoopUtil.get.addCredentials(jobConf) FileInputFormat.setInputPaths(jobConf, path) val instance: org.apache.hadoop.mapred.InputFormat[_, _] = diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 3628b1b078..12b0d74fb5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -24,56 +24,54 @@ import java.text.SimpleDateFormat import java.util.{Date, Properties}
import java.util.concurrent.LinkedBlockingQueue
-import scala.collection.mutable.{Map, HashMap, ListBuffer}
-import scala.io.Source
+import scala.collection.mutable.{HashMap, ListBuffer}
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.executor.TaskMetrics
-// Used to record runtime information for each job, including RDD graph
-// tasks' start/stop shuffle information and information from outside
-
+/**
+ * A logger class to record runtime information for jobs in Spark. This class outputs one log file
+ * per Spark job with information such as RDD graph, tasks start/stop, shuffle information.
+ *
+ * @param logDirName The base directory for the log files.
+ */
class JobLogger(val logDirName: String) extends SparkListener with Logging {
- private val logDir =
- if (System.getenv("SPARK_LOG_DIR") != null)
- System.getenv("SPARK_LOG_DIR")
- else
- "/tmp/spark"
+
+ private val logDir = Option(System.getenv("SPARK_LOG_DIR")).getOrElse("/tmp/spark")
+
private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
private val stageIDToJobID = new HashMap[Int, Int]
private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
-
+
createLogDir()
def this() = this(String.valueOf(System.currentTimeMillis()))
-
- def getLogDir = logDir
- def getJobIDtoPrintWriter = jobIDToPrintWriter
- def getStageIDToJobID = stageIDToJobID
- def getJobIDToStages = jobIDToStages
- def getEventQueue = eventQueue
-
+
+ // The following 5 functions are used only in testing.
+ private[scheduler] def getLogDir = logDir
+ private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
+ private[scheduler] def getStageIDToJobID = stageIDToJobID
+ private[scheduler] def getJobIDToStages = jobIDToStages
+ private[scheduler] def getEventQueue = eventQueue
+
// Create a folder for log files, the folder's name is the creation time of the jobLogger
protected def createLogDir() {
val dir = new File(logDir + "/" + logDirName + "/")
- if (dir.exists()) {
- return
- }
- if (dir.mkdirs() == false) {
- logError("create log directory error:" + logDir + "/" + logDirName + "/")
+ if (!dir.exists() && !dir.mkdirs()) {
+ logError("Error creating log directory: " + logDir + "/" + logDirName + "/")
}
}
// Create a log file for one job, the file name is the jobID
protected def createLogWriter(jobID: Int) {
- try{
+ try {
val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
jobIDToPrintWriter += (jobID -> fileWriter)
- } catch {
- case e: FileNotFoundException => e.printStackTrace()
- }
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
}
// Close log file, and clean the stage relationship in stageIDToJobID
@@ -118,10 +116,9 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
var rddList = new ListBuffer[RDD[_]]
rddList += rdd
- rdd.dependencies.foreach{ dep => dep match {
- case shufDep: ShuffleDependency[_,_] =>
- case _ => rddList ++= getRddsInStage(dep.rdd)
- }
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
}
rddList
}
@@ -161,29 +158,27 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
jobLogInfo(jobID, indentString(indent) + rddInfo, false)
- rdd.dependencies.foreach{ dep => dep match {
- case shufDep: ShuffleDependency[_,_] =>
- val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
- jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
- case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
- }
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
}
}
protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
- var stageInfo: String = ""
- if (stage.isShuffleMap) {
- stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
- stage.shuffleDep.get.shuffleId
- }else{
- stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ val stageInfo = if (stage.isShuffleMap) {
+ "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
+ } else {
+ "STAGE_ID=" + stage.id + " RESULT_STAGE"
}
if (stage.jobId == jobID) {
jobLogInfo(jobID, indentString(indent) + stageInfo, false)
recordRddInStageGraph(jobID, stage.rdd, indent)
stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
- } else
+ } else {
jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
+ }
}
// Record task metrics into job log files
@@ -193,39 +188,32 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
" EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
- val readMetrics =
- taskMetrics.shuffleReadMetrics match {
- case Some(metrics) =>
- " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
- " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
- " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
- " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
- " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
- " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
- " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
- case None => ""
- }
- val writeMetrics =
- taskMetrics.shuffleWriteMetrics match {
- case Some(metrics) =>
- " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
- case None => ""
- }
+ val readMetrics = taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics = taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
}
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
- stageLogInfo(
- stageSubmitted.stage.id,
- "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
- stageSubmitted.stage.id, stageSubmitted.taskSize))
+ stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
+ stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
}
override def onStageCompleted(stageCompleted: StageCompleted) {
- stageLogInfo(
- stageCompleted.stageInfo.stage.id,
- "STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id))
-
+ stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
+ stageCompleted.stage.stageId))
}
override def onTaskStart(taskStart: SparkListenerTaskStart) { }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 802791797a..24d97da6eb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -164,17 +164,19 @@ private[spark] class ShuffleMapTask( // Commit the writes. Get the size of each bucket block (total block size). var totalBytes = 0L + var totalTime = 0L val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter => writer.commit() - writer.close() - val size = writer.size() + val size = writer.fileSegment().length totalBytes += size + totalTime += writer.timeWriting() MapOutputTracker.compressSize(size) } // Update shuffle metrics. val shuffleMetrics = new ShuffleWriteMetrics shuffleMetrics.shuffleBytesWritten = totalBytes + shuffleMetrics.shuffleWriteTime = totalTime metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) new MapStatus(blockManager.blockManagerId, compressedSizes) @@ -188,6 +190,7 @@ private[spark] class ShuffleMapTask( } finally { // Release the writers back to the shuffle block manager. if (shuffle != null && buckets != null) { + buckets.writers.foreach(_.close()) shuffle.releaseWriters(buckets) } // Execute the callbacks on task completion. diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 466baf9913..a35081f7b1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -24,13 +24,16 @@ import org.apache.spark.executor.TaskMetrics sealed trait SparkListenerEvents -case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int, properties: Properties) +case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties) extends SparkListenerEvents -case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents +case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents +case class SparkListenerTaskGettingResult( + task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents + case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo, taskMetrics: TaskMetrics) extends SparkListenerEvents @@ -57,6 +60,12 @@ trait SparkListener { def onTaskStart(taskStart: SparkListenerTaskStart) { } /** + * Called when a task begins remotely fetching its result (will not be called for tasks that do + * not need to fetch the result remotely). + */ + def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { } + + /** * Called when a task ends */ def onTaskEnd(taskEnd: SparkListenerTaskEnd) { } @@ -80,7 +89,7 @@ class StatsReportListener extends SparkListener with Logging { override def onStageCompleted(stageCompleted: StageCompleted) { import org.apache.spark.scheduler.StatsReportListener._ implicit val sc = stageCompleted - this.logInfo("Finished stage: " + stageCompleted.stageInfo) + this.logInfo("Finished stage: " + stageCompleted.stage) showMillisDistribution("task runtime:", (info, _) => Some(info.duration)) //shuffle write @@ -93,7 +102,7 @@ class StatsReportListener extends SparkListener with Logging { //runtime breakdown - val runtimePcts = stageCompleted.stageInfo.taskInfos.map{ + val runtimePcts = stageCompleted.stage.taskInfos.map{ case (info, metrics) => RuntimePercentage(info.duration, metrics) } showDistribution("executor (non-fetch) time pct: ", Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%") @@ -111,7 +120,7 @@ object StatsReportListener extends Logging { val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%" def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = { - Distribution(stage.stageInfo.taskInfos.flatMap{ + Distribution(stage.stage.taskInfos.flatMap { case ((info,metric)) => getMetric(info, metric)}) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 4d3e4a17ba..d5824e7954 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -49,6 +49,8 @@ private[spark] class SparkListenerBus() extends Logging { sparkListeners.foreach(_.onJobEnd(jobEnd))
case taskStart: SparkListenerTaskStart =>
sparkListeners.foreach(_.onTaskStart(taskStart))
+ case taskGettingResult: SparkListenerTaskGettingResult =>
+ sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
case taskEnd: SparkListenerTaskEnd =>
sparkListeners.foreach(_.onTaskEnd(taskEnd))
case _ =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index aa293dc6b3..7cb3fe46e5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManagerId private[spark] class Stage( val id: Int, val rdd: RDD[_], + val numTasks: Int, val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage val parents: List[Stage], val jobId: Int, @@ -49,11 +50,6 @@ private[spark] class Stage( val numPartitions = rdd.partitions.size val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) var numAvailableOutputs = 0 - - /** When first task was submitted to scheduler. */ - var submissionTime: Option[Long] = None - var completionTime: Option[Long] = None - private var nextAttemptId = 0 def isAvailable: Boolean = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index b6f11969e5..93599dfdc8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -21,9 +21,16 @@ import scala.collection._ import org.apache.spark.executor.TaskMetrics -case class StageInfo( - val stage: Stage, +class StageInfo( + stage: Stage, val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]() ) { - override def toString = stage.rdd.toString + val stageId = stage.id + /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ + var submissionTime: Option[Long] = None + var completionTime: Option[Long] = None + val rddName = stage.rdd.name + val name = stage.name + val numPartitions = stage.numPartitions + val numTasks = stage.numTasks } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 1fe0d0e4e2..69b42e86ea 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -45,7 +45,7 @@ import org.apache.spark.util.ByteBufferInputStream */ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { - def run(attemptId: Long): T = { + final def run(attemptId: Long): T = { context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) if (_killed) { kill() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 7c2a422aff..4bae26f3a6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -31,9 +31,25 @@ class TaskInfo( val host: String, val taskLocality: TaskLocality.TaskLocality) { + /** + * The time when the task started remotely getting the result. Will not be set if the + * task result was sent immediately when the task finished (as opposed to sending an + * IndirectTaskResult and later fetching the result from the block manager). + */ + var gettingResultTime: Long = 0 + + /** + * The time when the task has completed successfully (including the time to remotely fetch + * results, if necessary). + */ var finishTime: Long = 0 + var failed = false + def markGettingResult(time: Long = System.currentTimeMillis) { + gettingResultTime = time + } + def markSuccessful(time: Long = System.currentTimeMillis) { finishTime = time } @@ -43,6 +59,8 @@ class TaskInfo( failed = true } + def gettingResult: Boolean = gettingResultTime != 0 + def finished: Boolean = finishTime != 0 def successful: Boolean = finished && !failed @@ -52,6 +70,8 @@ class TaskInfo( def status: String = { if (running) "RUNNING" + else if (gettingResult) + "GET RESULT" else if (failed) "FAILED" else if (successful) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 6a51efe8d6..10e0478108 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -24,8 +24,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode * Each TaskScheduler schedulers task for a single SparkContext. * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage, * and are responsible for sending the tasks to the cluster, running them, retrying if there - * are failures, and mitigating stragglers. They return events to the DAGScheduler through - * the TaskSchedulerListener interface. + * are failures, and mitigating stragglers. They return events to the DAGScheduler. */ private[spark] trait TaskScheduler { @@ -48,8 +47,8 @@ private[spark] trait TaskScheduler { // Cancel a stage. def cancelTasks(stageId: Int) - // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called. - def setListener(listener: TaskSchedulerListener): Unit + // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. + def setDAGScheduler(dagScheduler: DAGScheduler): Unit // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. def defaultParallelism(): Int diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala deleted file mode 100644 index 593fa9fb93..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler - -import scala.collection.mutable.Map - -import org.apache.spark.TaskEndReason -import org.apache.spark.executor.TaskMetrics - -/** - * Interface for getting events back from the TaskScheduler. - */ -private[spark] trait TaskSchedulerListener { - // A task has started. - def taskStarted(task: Task[_], taskInfo: TaskInfo) - - // A task has finished or failed. - def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit - - // A node was added to the cluster. - def executorGained(execId: String, host: String): Unit - - // A node was lost from the cluster. - def executorLost(execId: String): Unit - - // The TaskScheduler wants to abort an entire task set. - def taskSetFailed(taskSet: TaskSet, reason: String): Unit -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index 7a72ff0474..85033958ef 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -79,7 +79,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) private val executorIdToHost = new HashMap[String, String] // Listener object to pass upcalls into - var listener: TaskSchedulerListener = null + var dagScheduler: DAGScheduler = null var backend: SchedulerBackend = null @@ -94,8 +94,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // This is a var so that we can reset it for testing purposes. private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) - override def setListener(listener: TaskSchedulerListener) { - this.listener = listener + override def setDAGScheduler(dagScheduler: DAGScheduler) { + this.dagScheduler = dagScheduler } def initialize(context: SchedulerBackend) { @@ -297,7 +297,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } // Update the DAGScheduler without holding a lock on this, since that can deadlock if (failedExecutor != None) { - listener.executorLost(failedExecutor.get) + dagScheduler.executorLost(failedExecutor.get) backend.reviveOffers() } if (taskFailed) { @@ -306,6 +306,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } + def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) { + taskSetManager.handleTaskGettingResult(tid) + } + def handleSuccessfulTask( taskSetManager: ClusterTaskSetManager, tid: Long, @@ -397,9 +401,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext) logError("Lost an executor " + executorId + " (already removed): " + reason) } } - // Call listener.executorLost without holding the lock on this to prevent deadlock + // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock if (failedExecutor != None) { - listener.executorLost(failedExecutor.get) + dagScheduler.executorLost(failedExecutor.get) backend.reviveOffers() } } @@ -418,7 +422,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } def executorGained(execId: String, host: String) { - listener.executorGained(execId, host) + dagScheduler.executorGained(execId, host) } def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala index 7bd3499300..ee47aaffca 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -415,11 +415,17 @@ private[spark] class ClusterTaskSetManager( } private def taskStarted(task: Task[_], info: TaskInfo) { - sched.listener.taskStarted(task, info) + sched.dagScheduler.taskStarted(task, info) + } + + def handleTaskGettingResult(tid: Long) = { + val info = taskInfos(tid) + info.markGettingResult() + sched.dagScheduler.taskGettingResult(tasks(info.index), info) } /** - * Marks the task as successful and notifies the listener that a task has ended. + * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = { val info = taskInfos(tid) @@ -429,7 +435,7 @@ private[spark] class ClusterTaskSetManager( if (!successful(index)) { logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format( tid, info.duration, info.host, tasksSuccessful, numTasks)) - sched.listener.taskEnded( + sched.dagScheduler.taskEnded( tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) // Mark successful and stop if all the tasks have succeeded. @@ -445,7 +451,8 @@ private[spark] class ClusterTaskSetManager( } /** - * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the listener. + * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the + * DAG Scheduler. */ def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) { val info = taskInfos(tid) @@ -463,7 +470,7 @@ private[spark] class ClusterTaskSetManager( reason.foreach { case fetchFailed: FetchFailed => logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress) - sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) + sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null) successful(index) = true tasksSuccessful += 1 sched.taskSetFinished(this) @@ -472,11 +479,11 @@ private[spark] class ClusterTaskSetManager( case TaskKilled => logWarning("Task %d was killed.".format(tid)) - sched.listener.taskEnded(tasks(index), reason.get, null, null, info, null) + sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null) return case ef: ExceptionFailure => - sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) + sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) val key = ef.description val now = clock.getTime() val (printFull, dupCount) = { @@ -504,7 +511,7 @@ private[spark] class ClusterTaskSetManager( case TaskResultLost => logWarning("Lost result for TID %s on host %s".format(tid, info.host)) - sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null) + sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null) case _ => {} } @@ -533,7 +540,7 @@ private[spark] class ClusterTaskSetManager( failed = true causeOfFailure = message // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.listener.taskSetFailed(taskSet, message) + sched.dagScheduler.taskSetFailed(taskSet, message) removeAllRunningTasks() sched.taskSetFinished(this) } @@ -606,7 +613,7 @@ private[spark] class ClusterTaskSetManager( addPendingTask(index) // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our // stage finishes when a total of tasks.size tasks finish. - sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null) + sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 12b2fd01c0..53316dae2a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -24,28 +24,28 @@ import org.apache.spark.scheduler.TaskDescription import org.apache.spark.util.{Utils, SerializableBuffer} -private[spark] sealed trait StandaloneClusterMessage extends Serializable +private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable -private[spark] object StandaloneClusterMessages { +private[spark] object CoarseGrainedClusterMessages { // Driver to executors - case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage + case class LaunchTask(task: TaskDescription) extends CoarseGrainedClusterMessage - case class KillTask(taskId: Long, executor: String) extends StandaloneClusterMessage + case class KillTask(taskId: Long, executor: String) extends CoarseGrainedClusterMessage case class RegisteredExecutor(sparkProperties: Seq[(String, String)]) - extends StandaloneClusterMessage + extends CoarseGrainedClusterMessage - case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage + case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage // Executors to driver case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) - extends StandaloneClusterMessage { + extends CoarseGrainedClusterMessage { Utils.checkHostPort(hostPort, "Expected host port") } case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, - data: SerializableBuffer) extends StandaloneClusterMessage + data: SerializableBuffer) extends CoarseGrainedClusterMessage object StatusUpdate { /** Alternate factory method that takes a ByteBuffer directly for the data field */ @@ -56,10 +56,14 @@ private[spark] object StandaloneClusterMessages { } // Internal messages in driver - case object ReviveOffers extends StandaloneClusterMessage + case object ReviveOffers extends CoarseGrainedClusterMessage - case object StopDriver extends StandaloneClusterMessage + case object StopDriver extends CoarseGrainedClusterMessage - case class RemoveExecutor(executorId: String, reason: String) extends StandaloneClusterMessage + case object StopExecutor extends CoarseGrainedClusterMessage + + case object StopExecutors extends CoarseGrainedClusterMessage + + case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 08ee2182a2..70f3f88401 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -30,16 +30,19 @@ import akka.util.duration._ import org.apache.spark.{SparkException, Logging, TaskState} import org.apache.spark.scheduler.TaskDescription -import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._ +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.Utils /** - * A standalone scheduler backend, which waits for standalone executors to connect to it through - * Akka. These may be executed in a variety of ways, such as Mesos tasks for the coarse-grained - * Mesos mode or standalone processes for Spark's standalone deploy mode (spark.deploy.*). + * A scheduler backend that waits for coarse grained executors to connect to it through Akka. + * This backend holds onto each executor for the duration of the Spark job rather than relinquishing + * executors whenever a task is done and asking the scheduler to launch a new executor for + * each new task. Executors may be launched in a variety of ways, such as Mesos tasks for the + * coarse-grained Mesos mode or standalone processes for Spark's standalone deploy mode + * (spark.deploy.*). */ private[spark] -class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) extends SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed @@ -98,6 +101,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor sender ! true context.stop(self) + case StopExecutors => + logInfo("Asking each executor to shut down") + for (executor <- executorActor.values) { + executor ! StopExecutor + } + sender ! true + case RemoveExecutor(executorId, reason) => removeExecutor(executorId, reason) sender ! true @@ -162,16 +172,29 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } } driverActor = actorSystem.actorOf( - Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME) + Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME) } private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + def stopExecutors() { + try { + if (driverActor != null) { + logInfo("Shutting down all executors") + val future = driverActor.ask(StopExecutors)(timeout) + Await.ready(future, timeout) + } + } catch { + case e: Exception => + throw new SparkException("Error asking standalone scheduler to shut down executors", e) + } + } + override def stop() { try { if (driverActor != null) { val future = driverActor.ask(StopDriver)(timeout) - Await.result(future, timeout) + Await.ready(future, timeout) } } catch { case e: Exception => @@ -194,7 +217,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor def removeExecutor(executorId: String, reason: String) { try { val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) - Await.result(future, timeout) + Await.ready(future, timeout) } catch { case e: Exception => throw new SparkException("Error notifying standalone scheduler's driver actor", e) @@ -202,6 +225,6 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } } -private[spark] object StandaloneSchedulerBackend { - val ACTOR_NAME = "StandaloneScheduler" +private[spark] object CoarseGrainedSchedulerBackend { + val ACTOR_NAME = "CoarseGrainedScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala new file mode 100644 index 0000000000..d78bdbaa7a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.spark.{Logging, SparkContext} + +private[spark] class SimrSchedulerBackend( + scheduler: ClusterScheduler, + sc: SparkContext, + driverFilePath: String) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + with Logging { + + val tmpPath = new Path(driverFilePath + "_tmp") + val filePath = new Path(driverFilePath) + + val maxCores = System.getProperty("spark.simr.executor.cores", "1").toInt + + override def start() { + super.start() + + val driverUrl = "akka://spark@%s:%s/user/%s".format( + System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), + CoarseGrainedSchedulerBackend.ACTOR_NAME) + + val conf = new Configuration() + val fs = FileSystem.get(conf) + + logInfo("Writing to HDFS file: " + driverFilePath) + logInfo("Writing Akka address: " + driverUrl) + + // Create temporary file to prevent race condition where executors get empty driverUrl file + val temp = fs.create(tmpPath, true) + temp.writeUTF(driverUrl) + temp.writeInt(maxCores) + temp.close() + + // "Atomic" rename + fs.rename(tmpPath, filePath) + } + + override def stop() { + val conf = new Configuration() + val fs = FileSystem.get(conf) + fs.delete(new Path(driverFilePath), false) + super.stopExecutors() + super.stop() + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index cb88159b8d..cefa970bb9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -28,7 +28,7 @@ private[spark] class SparkDeploySchedulerBackend( sc: SparkContext, masters: Array[String], appName: String) - extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) with ClientListener with Logging { @@ -44,10 +44,10 @@ private[spark] class SparkDeploySchedulerBackend( // The endpoint for executors to talk to us val driverUrl = "akka://spark@%s:%s/user/%s".format( System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), - StandaloneSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ACTOR_NAME) val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command( - "org.apache.spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) + "org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse(null) val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome, "http://" + sc.ui.appUIAddress) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala index feec8ecfe4..2064d97b49 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala @@ -24,33 +24,16 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult} import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.util.Utils /** * Runs a thread pool that deserializes and remotely fetches (if necessary) task results. */ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler) extends Logging { - private val MIN_THREADS = System.getProperty("spark.resultGetter.minThreads", "4").toInt - private val MAX_THREADS = System.getProperty("spark.resultGetter.maxThreads", "4").toInt - private val getTaskResultExecutor = new ThreadPoolExecutor( - MIN_THREADS, - MAX_THREADS, - 0L, - TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable], - new ResultResolverThreadFactory) - - class ResultResolverThreadFactory extends ThreadFactory { - private var counter = 0 - private var PREFIX = "Result resolver thread" - - override def newThread(r: Runnable): Thread = { - val thread = new Thread(r, "%s-%s".format(PREFIX, counter)) - counter += 1 - thread.setDaemon(true) - return thread - } - } + private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt + private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( + THREADS, "Result resolver thread") protected val serializer = new ThreadLocal[SerializerInstance] { override def initialValue(): SerializerInstance = { @@ -67,6 +50,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche case directResult: DirectTaskResult[_] => directResult case IndirectTaskResult(blockId) => logDebug("Fetching indirect task result for TID %s".format(tid)) + scheduler.handleTaskGettingResult(taskSetManager, tid) val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId) if (!serializedTaskResult.isDefined) { /* We won't be able to get the task result if the machine that ran the task failed diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 8f2eef9a53..300fe693f1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -30,13 +30,14 @@ import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.spark.{SparkException, Logging, SparkContext, TaskState} -import org.apache.spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend} +import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend} /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds * onto each Mesos node for the duration of the Spark job instead of relinquishing cores whenever * a task is done. It launches Spark tasks within the coarse-grained Mesos tasks using the - * StandaloneBackend mechanism. This class is useful for lower and more predictable latency. + * CoarseGrainedSchedulerBackend mechanism. This class is useful for lower and more predictable + * latency. * * Unfortunately this has a bit of duplication from MesosSchedulerBackend, but it seems hard to * remove this. @@ -46,7 +47,7 @@ private[spark] class CoarseMesosSchedulerBackend( sc: SparkContext, master: String, appName: String) - extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) with MScheduler with Logging { @@ -122,20 +123,20 @@ private[spark] class CoarseMesosSchedulerBackend( val driverUrl = "akka://spark@%s:%s/user/%s".format( System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), - StandaloneSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ACTOR_NAME) val uri = System.getProperty("spark.executor.uri") if (uri == null) { val runScript = new File(sparkHome, "spark-class").getCanonicalPath command.setValue( - "\"%s\" org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format( + "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format( runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.split('/').last.split('.').head command.setValue( - "cd %s*; ./spark-class org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format( - basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) + "cd %s*; ./spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d" + .format(basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } return command.build() diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala index b445260d1b..2699f0b33e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -81,7 +81,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: val env = SparkEnv.get val attemptId = new AtomicInteger - var listener: TaskSchedulerListener = null + var dagScheduler: DAGScheduler = null // Application dependencies (added through SparkContext) that we've fetched so far on this node. // Each map holds the master's timestamp for the version of that file or JAR we got. @@ -114,8 +114,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test") } - override def setListener(listener: TaskSchedulerListener) { - this.listener = listener + override def setDAGScheduler(dagScheduler: DAGScheduler) { + this.dagScheduler = dagScheduler } override def submitTasks(taskSet: TaskSet) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala index f72e77d40f..53bf78267e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala @@ -133,7 +133,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas } def taskStarted(task: Task[_], info: TaskInfo) { - sched.listener.taskStarted(task, info) + sched.dagScheduler.taskStarted(task, info) } def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { @@ -148,7 +148,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas } } result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics) + sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info, + result.metrics) numFinished += 1 decreaseRunningTasks(1) finished(index) = true @@ -165,7 +166,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas decreaseRunningTasks(1) val reason: ExceptionFailure = ser.deserialize[ExceptionFailure]( serializedData, getClass.getClassLoader) - sched.listener.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null)) + sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null)) if (!finished(index)) { copiesRunning(index) -= 1 numFailures(index) += 1 @@ -174,9 +175,9 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas reason.className, reason.description, locs.mkString("\n"))) if (numFailures(index) > MAX_TASK_FAILURES) { val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format( - taskSet.id, index, 4, reason.description) + taskSet.id, index, MAX_TASK_FAILURES, reason.description) decreaseRunningTasks(runningTasks) - sched.listener.taskSetFailed(taskSet, errorMessage) + sched.dagScheduler.taskSetFailed(taskSet, errorMessage) // need to delete failed Taskset from schedule queue sched.taskSetFinished(this) } @@ -184,7 +185,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas } override def error(message: String) { - sched.listener.taskSetFailed(taskSet, message) + sched.dagScheduler.taskSetFailed(taskSet, message) sched.taskSetFinished(this) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index c7efc67a4a..7156d855d8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -32,7 +32,7 @@ private[spark] sealed abstract class BlockId { def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None def isRDD = isInstanceOf[RDDBlockId] def isShuffle = isInstanceOf[ShuffleBlockId] - def isBroadcast = isInstanceOf[BroadcastBlockId] + def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId] override def toString = name override def hashCode = name.hashCode @@ -55,6 +55,10 @@ private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId { def name = "broadcast_" + broadcastId } +private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId { + def name = broadcastId.name + "_" + hType +} + private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId { def name = "taskresult_" + taskId } @@ -72,6 +76,7 @@ private[spark] object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r val BROADCAST = "broadcast_([0-9]+)".r + val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r val TEST = "test_(.*)".r @@ -84,6 +89,8 @@ private[spark] object BlockId { ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) case BROADCAST(broadcastId) => BroadcastBlockId(broadcastId.toLong) + case BROADCAST_HELPER(broadcastId, hType) => + BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType) case TASKRESULT(taskId) => TaskResultBlockId(taskId.toLong) case STREAM(streamId, uniqueId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala new file mode 100644 index 0000000000..dbe0bda615 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.util.concurrent.ConcurrentHashMap + +private[storage] trait BlockInfo { + def level: StorageLevel + def tellMaster: Boolean + // To save space, 'pending' and 'failed' are encoded as special sizes: + @volatile var size: Long = BlockInfo.BLOCK_PENDING + private def pending: Boolean = size == BlockInfo.BLOCK_PENDING + private def failed: Boolean = size == BlockInfo.BLOCK_FAILED + private def initThread: Thread = BlockInfo.blockInfoInitThreads.get(this) + + setInitThread() + + private def setInitThread() { + // Set current thread as init thread - waitForReady will not block this thread + // (in case there is non trivial initialization which ends up calling waitForReady as part of + // initialization itself) + BlockInfo.blockInfoInitThreads.put(this, Thread.currentThread()) + } + + /** + * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing). + * Return true if the block is available, false otherwise. + */ + def waitForReady(): Boolean = { + if (pending && initThread != Thread.currentThread()) { + synchronized { + while (pending) this.wait() + } + } + !failed + } + + /** Mark this BlockInfo as ready (i.e. block is finished writing) */ + def markReady(sizeInBytes: Long) { + require (sizeInBytes >= 0, "sizeInBytes was negative: " + sizeInBytes) + assert (pending) + size = sizeInBytes + BlockInfo.blockInfoInitThreads.remove(this) + synchronized { + this.notifyAll() + } + } + + /** Mark this BlockInfo as ready but failed */ + def markFailure() { + assert (pending) + size = BlockInfo.BLOCK_FAILED + BlockInfo.blockInfoInitThreads.remove(this) + synchronized { + this.notifyAll() + } + } +} + +private object BlockInfo { + // initThread is logically a BlockInfo field, but we store it here because + // it's only needed while this block is in the 'pending' state and we want + // to minimize BlockInfo's memory footprint. + private val blockInfoInitThreads = new ConcurrentHashMap[BlockInfo, Thread] + + private val BLOCK_PENDING: Long = -1L + private val BLOCK_FAILED: Long = -2L +} + +// All shuffle blocks have the same `level` and `tellMaster` properties, +// so we can save space by not storing them in each instance: +private[storage] class ShuffleBlockInfo extends BlockInfo { + // These need to be defined using 'def' instead of 'val' in order for + // the compiler to eliminate the fields: + def level: StorageLevel = StorageLevel.DISK_ONLY + def tellMaster: Boolean = false +} + +private[storage] class BlockInfoImpl(val level: StorageLevel, val tellMaster: Boolean) + extends BlockInfo { + // Intentionally left blank +}
\ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 801f88a3db..76d537f8e8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -20,14 +20,15 @@ package org.apache.spark.storage import java.io.{InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} -import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet} +import scala.collection.mutable.{HashMap, ArrayBuffer} +import scala.util.Random import akka.actor.{ActorSystem, Cancellable, Props} import akka.dispatch.{Await, Future} import akka.util.Duration import akka.util.duration._ -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream +import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec @@ -45,74 +46,20 @@ private[spark] class BlockManager( maxMemory: Long) extends Logging { - private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { - @volatile var pending: Boolean = true - @volatile var size: Long = -1L - @volatile var initThread: Thread = null - @volatile var failed = false - - setInitThread() - - private def setInitThread() { - // Set current thread as init thread - waitForReady will not block this thread - // (in case there is non trivial initialization which ends up calling waitForReady as part of - // initialization itself) - this.initThread = Thread.currentThread() - } - - /** - * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing). - * Return true if the block is available, false otherwise. - */ - def waitForReady(): Boolean = { - if (initThread != Thread.currentThread() && pending) { - synchronized { - while (pending) this.wait() - } - } - !failed - } - - /** Mark this BlockInfo as ready (i.e. block is finished writing) */ - def markReady(sizeInBytes: Long) { - assert (pending) - size = sizeInBytes - initThread = null - failed = false - initThread = null - pending = false - synchronized { - this.notifyAll() - } - } - - /** Mark this BlockInfo as ready but failed */ - def markFailure() { - assert (pending) - size = 0 - initThread = null - failed = true - initThread = null - pending = false - synchronized { - this.notifyAll() - } - } - } - val shuffleBlockManager = new ShuffleBlockManager(this) + val diskBlockManager = new DiskBlockManager( + System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) - private[storage] val diskStore: DiskStore = - new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) + private[storage] val diskStore = new DiskStore(this, diskBlockManager) // If we use Netty for shuffle, start a new Netty-based shuffle sender service. private val nettyPort: Int = { val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt - if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0 + if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0 } val connectionManager = new ConnectionManager(0) @@ -269,7 +216,7 @@ private[spark] class BlockManager( } /** - * Actually send a UpdateBlockInfo message. Returns the mater's response, + * Actually send a UpdateBlockInfo message. Returns the master's response, * which will be true if the block was successfully recorded and false if * the slave needs to re-register. */ @@ -320,89 +267,14 @@ private[spark] class BlockManager( */ def getLocal(blockId: BlockId): Option[Iterator[Any]] = { logDebug("Getting local block " + blockId) - val info = blockInfo.get(blockId).orNull - if (info != null) { - info.synchronized { - - // In the another thread is writing the block, wait for it to become ready. - if (!info.waitForReady()) { - // If we get here, the block write failed. - logWarning("Block " + blockId + " was marked as failure.") - return None - } - - val level = info.level - logDebug("Level for block " + blockId + " is " + level) - - // Look for the block in memory - if (level.useMemory) { - logDebug("Getting block " + blockId + " from memory") - memoryStore.getValues(blockId) match { - case Some(iterator) => - return Some(iterator) - case None => - logDebug("Block " + blockId + " not found in memory") - } - } - - // Look for block on disk, potentially loading it back into memory if required - if (level.useDisk) { - logDebug("Getting block " + blockId + " from disk") - if (level.useMemory && level.deserialized) { - diskStore.getValues(blockId) match { - case Some(iterator) => - // Put the block back in memory before returning it - // TODO: Consider creating a putValues that also takes in a iterator ? - val elements = new ArrayBuffer[Any] - elements ++= iterator - memoryStore.putValues(blockId, elements, level, true).data match { - case Left(iterator2) => - return Some(iterator2) - case _ => - throw new Exception("Memory store did not return back an iterator") - } - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } else if (level.useMemory && !level.deserialized) { - // Read it as a byte buffer into memory first, then return it - diskStore.getBytes(blockId) match { - case Some(bytes) => - // Put a copy of the block back in memory before returning it. Note that we can't - // put the ByteBuffer returned by the disk store as that's a memory-mapped file. - // The use of rewind assumes this. - assert (0 == bytes.position()) - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) - memoryStore.putBytes(blockId, copyForMemory, level) - bytes.rewind() - return Some(dataDeserialize(blockId, bytes)) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } else { - diskStore.getValues(blockId) match { - case Some(iterator) => - return Some(iterator) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } - } - } - } else { - logDebug("Block " + blockId + " not registered locally") - } - return None + doGetLocal(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]] } /** * Get block from the local block manager as serialized bytes. */ def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = { - // TODO: This whole thing is very similar to getLocal; we need to refactor it somehow logDebug("Getting local block " + blockId + " as bytes") - // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work if (blockId.isShuffle) { @@ -413,12 +285,15 @@ private[spark] class BlockManager( throw new Exception("Block " + blockId + " not found on disk, though it should be") } } + doGetLocal(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]] + } + private def doGetLocal(blockId: BlockId, asValues: Boolean): Option[Any] = { val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { - // In the another thread is writing the block, wait for it to become ready. + // If another thread is writing the block, wait for it to become ready. if (!info.waitForReady()) { // If we get here, the block write failed. logWarning("Block " + blockId + " was marked as failure.") @@ -431,62 +306,104 @@ private[spark] class BlockManager( // Look for the block in memory if (level.useMemory) { logDebug("Getting block " + blockId + " from memory") - memoryStore.getBytes(blockId) match { - case Some(bytes) => - return Some(bytes) + val result = if (asValues) { + memoryStore.getValues(blockId) + } else { + memoryStore.getBytes(blockId) + } + result match { + case Some(values) => + return Some(values) case None => logDebug("Block " + blockId + " not found in memory") } } - // Look for block on disk + // Look for block on disk, potentially storing it back into memory if required: if (level.useDisk) { - // Read it as a byte buffer into memory first, then return it - diskStore.getBytes(blockId) match { - case Some(bytes) => - assert (0 == bytes.position()) - if (level.useMemory) { - if (level.deserialized) { - memoryStore.putBytes(blockId, bytes, level) - } else { - // The memory store will hang onto the ByteBuffer, so give it a copy instead of - // the memory-mapped file buffer we got from the disk store - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) - memoryStore.putBytes(blockId, copyForMemory, level) - } - } - bytes.rewind() - return Some(bytes) + logDebug("Getting block " + blockId + " from disk") + val bytes: ByteBuffer = diskStore.getBytes(blockId) match { + case Some(bytes) => bytes case None => throw new Exception("Block " + blockId + " not found on disk, though it should be") } + assert (0 == bytes.position()) + + if (!level.useMemory) { + // If the block shouldn't be stored in memory, we can just return it: + if (asValues) { + return Some(dataDeserialize(blockId, bytes)) + } else { + return Some(bytes) + } + } else { + // Otherwise, we also have to store something in the memory store: + if (!level.deserialized || !asValues) { + // We'll store the bytes in memory if the block's storage level includes + // "memory serialized", or if it should be cached as objects in memory + // but we only requested its serialized bytes: + val copyForMemory = ByteBuffer.allocate(bytes.limit) + copyForMemory.put(bytes) + memoryStore.putBytes(blockId, copyForMemory, level) + bytes.rewind() + } + if (!asValues) { + return Some(bytes) + } else { + val values = dataDeserialize(blockId, bytes) + if (level.deserialized) { + // Cache the values before returning them: + // TODO: Consider creating a putValues that also takes in a iterator? + val valuesBuffer = new ArrayBuffer[Any] + valuesBuffer ++= values + memoryStore.putValues(blockId, valuesBuffer, level, true).data match { + case Left(values2) => + return Some(values2) + case _ => + throw new Exception("Memory store did not return back an iterator") + } + } else { + return Some(values) + } + } + } } } } else { logDebug("Block " + blockId + " not registered locally") } - return None + None } /** * Get block from remote block managers. */ def getRemote(blockId: BlockId): Option[Iterator[Any]] = { - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } logDebug("Getting remote block " + blockId) - // Get locations of block - val locations = master.getLocations(blockId) + doGetRemote(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]] + } + + /** + * Get block from remote block managers as serialized bytes. + */ + def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = { + logDebug("Getting remote block " + blockId + " as bytes") + doGetRemote(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]] + } - // Get block from remote locations + private def doGetRemote(blockId: BlockId, asValues: Boolean): Option[Any] = { + require(blockId != null, "BlockId is null") + val locations = Random.shuffle(master.getLocations(blockId)) for (loc <- locations) { logDebug("Getting remote block " + blockId + " from " + loc) val data = BlockManagerWorker.syncGetBlock( GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) if (data != null) { - return Some(dataDeserialize(blockId, data)) + if (asValues) { + return Some(dataDeserialize(blockId, data)) + } else { + return Some(data) + } } logDebug("The value of block " + blockId + " is null") } @@ -495,31 +412,6 @@ private[spark] class BlockManager( } /** - * Get block from remote block managers as serialized bytes. - */ - def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = { - // TODO: As with getLocalBytes, this is very similar to getRemote and perhaps should be - // refactored. - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - logDebug("Getting remote block " + blockId + " as bytes") - - val locations = master.getLocations(blockId) - for (loc <- locations) { - logDebug("Getting remote block " + blockId + " from " + loc) - val data = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) - if (data != null) { - return Some(data) - } - logDebug("The value of block " + blockId + " is null") - } - logDebug("Block " + blockId + " not found") - return None - } - - /** * Get a block from the block manager (either local or remote). */ def get(blockId: BlockId): Option[Iterator[Any]] = { @@ -566,16 +458,22 @@ private[spark] class BlockManager( /** * A short circuited method to get a block writer that can write data directly to disk. + * The Block will be appended to the File specified by filename. * This is currently used for writing shuffle files out. Callers should handle error * cases. */ - def getDiskBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int) + def getDiskWriter(blockId: BlockId, filename: String, serializer: Serializer, bufferSize: Int) : BlockObjectWriter = { - val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize) + val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) + val file = diskBlockManager.createBlockFile(blockId, filename, allowAppending = true) + val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream) writer.registerCloseEventHandler(() => { - val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false) + if (shuffleBlockManager.consolidateShuffleFiles) { + diskBlockManager.mapBlockToFileSegment(blockId, writer.fileSegment()) + } + val myInfo = new ShuffleBlockInfo() blockInfo.put(blockId, myInfo) - myInfo.markReady(writer.size()) + myInfo.markReady(writer.fileSegment().length) }) writer } @@ -584,23 +482,30 @@ private[spark] class BlockManager( * Put a new block of values to the block manager. Returns its (estimated) size in bytes. */ def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, - tellMaster: Boolean = true) : Long = { + tellMaster: Boolean = true) : Long = { + require(values != null, "Values is null") + doPut(blockId, Left(values), level, tellMaster) + } - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - if (values == null) { - throw new IllegalArgumentException("Values is null") - } - if (level == null || !level.isValid) { - throw new IllegalArgumentException("Storage level is null or invalid") - } + /** + * Put a new block of serialized bytes to the block manager. + */ + def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel, + tellMaster: Boolean = true) { + require(bytes != null, "Bytes is null") + doPut(blockId, Right(bytes), level, tellMaster) + } + + private def doPut(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer], + level: StorageLevel, tellMaster: Boolean = true): Long = { + require(blockId != null, "BlockId is null") + require(level != null && level.isValid, "StorageLevel is null or invalid") // Remember the block's storage level so that we can correctly drop it to disk if it needs // to be dropped right after it got put into memory. Note, however, that other threads will // not be able to get() this block until we call markReady on its BlockInfo. val myInfo = { - val tinfo = new BlockInfo(level, tellMaster) + val tinfo = new BlockInfoImpl(level, tellMaster) // Do atomically ! val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) @@ -610,7 +515,8 @@ private[spark] class BlockManager( return oldBlockOpt.get.size } - // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? + // TODO: So the block info exists - but previous attempt to load it (?) failed. + // What do we do now ? Retry on it ? oldBlockOpt.get } else { tinfo @@ -619,10 +525,10 @@ private[spark] class BlockManager( val startTimeMs = System.currentTimeMillis - // If we need to replicate the data, we'll want access to the values, but because our - // put will read the whole iterator, there will be no values left. For the case where - // the put serializes data, we'll remember the bytes, above; but for the case where it - // doesn't, such as deserialized storage, let's rely on the put returning an Iterator. + // If we're storing values and we need to replicate the data, we'll want access to the values, + // but because our put will read the whole iterator, there will be no values left. For the + // case where the put serializes data, we'll remember the bytes, above; but for the case where + // it doesn't, such as deserialized storage, let's rely on the put returning an Iterator. var valuesAfterPut: Iterator[Any] = null // Ditto for the bytes after the put @@ -631,30 +537,51 @@ private[spark] class BlockManager( // Size of the block in bytes (to return to caller) var size = 0L + // If we're storing bytes, then initiate the replication before storing them locally. + // This is faster as data is already serialized and ready to send. + val replicationFuture = if (data.isRight && level.replication > 1) { + val bufferView = data.right.get.duplicate() // Doesn't copy the bytes, just creates a wrapper + Future { + replicate(blockId, bufferView, level) + } + } else { + null + } + myInfo.synchronized { logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") var marked = false try { - if (level.useMemory) { - // Save it just to memory first, even if it also has useDisk set to true; we will later - // drop it to disk if the memory store can't hold it. - val res = memoryStore.putValues(blockId, values, level, true) - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case Left(newIterator) => valuesAfterPut = newIterator + data match { + case Left(values) => { + if (level.useMemory) { + // Save it just to memory first, even if it also has useDisk set to true; we will + // drop it to disk later if the memory store can't hold it. + val res = memoryStore.putValues(blockId, values, level, true) + size = res.size + res.data match { + case Right(newBytes) => bytesAfterPut = newBytes + case Left(newIterator) => valuesAfterPut = newIterator + } + } else { + // Save directly to disk. + // Don't get back the bytes unless we replicate them. + val askForBytes = level.replication > 1 + val res = diskStore.putValues(blockId, values, level, askForBytes) + size = res.size + res.data match { + case Right(newBytes) => bytesAfterPut = newBytes + case _ => + } + } } - } else { - // Save directly to disk. - // Don't get back the bytes unless we replicate them. - val askForBytes = level.replication > 1 - val res = diskStore.putValues(blockId, values, level, askForBytes) - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case _ => + case Right(bytes) => { + bytes.rewind() + // Store it only in memory at first, even if useDisk is also set to true + (if (level.useMemory) memoryStore else diskStore).putBytes(blockId, bytes, level) + size = bytes.limit } } @@ -679,125 +606,39 @@ private[spark] class BlockManager( } logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) - // Replicate block if required + // Either we're storing bytes and we asynchronously started replication, or we're storing + // values and need to serialize and replicate them now: if (level.replication > 1) { - val remoteStartTime = System.currentTimeMillis - // Serialize the block if not already done - if (bytesAfterPut == null) { - if (valuesAfterPut == null) { - throw new SparkException( - "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") - } - bytesAfterPut = dataSerialize(blockId, valuesAfterPut) - } - replicate(blockId, bytesAfterPut, level) - logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime)) - } - BlockManager.dispose(bytesAfterPut) - - return size - } - - - /** - * Put a new block of serialized bytes to the block manager. - */ - def putBytes( - blockId: BlockId, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) { - - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - if (bytes == null) { - throw new IllegalArgumentException("Bytes is null") - } - if (level == null || !level.isValid) { - throw new IllegalArgumentException("Storage level is null or invalid") - } - - // Remember the block's storage level so that we can correctly drop it to disk if it needs - // to be dropped right after it got put into memory. Note, however, that other threads will - // not be able to get() this block until we call markReady on its BlockInfo. - val myInfo = { - val tinfo = new BlockInfo(level, tellMaster) - // Do atomically ! - val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) - - if (oldBlockOpt.isDefined) { - if (oldBlockOpt.get.waitForReady()) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") - return - } - - // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? - oldBlockOpt.get - } else { - tinfo - } - } - - val startTimeMs = System.currentTimeMillis - - // Initiate the replication before storing it locally. This is faster as - // data is already serialized and ready for sending - val replicationFuture = if (level.replication > 1) { - val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper - Future { - replicate(blockId, bufferView, level) - } - } else { - null - } - - myInfo.synchronized { - logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) - + " to get into synchronized block") - - var marked = false - try { - if (level.useMemory) { - // Store it only in memory at first, even if useDisk is also set to true - bytes.rewind() - memoryStore.putBytes(blockId, bytes, level) - } else { - bytes.rewind() - diskStore.putBytes(blockId, bytes, level) - } - - // assert (0 == bytes.position(), "" + bytes) - - // Now that the block is in either the memory or disk store, let other threads read it, - // and tell the master about it. - marked = true - myInfo.markReady(bytes.limit) - if (tellMaster) { - reportBlockStatus(blockId, myInfo) - } - } finally { - // If we failed at putting the block to memory/disk, notify other possible readers - // that it has failed, and then remove it from the block info map. - if (! marked) { - // Note that the remove must happen before markFailure otherwise another thread - // could've inserted a new BlockInfo before we remove it. - blockInfo.remove(blockId) - myInfo.markFailure() - logWarning("Putting block " + blockId + " failed") + data match { + case Right(bytes) => Await.ready(replicationFuture, Duration.Inf) + case Left(values) => { + val remoteStartTime = System.currentTimeMillis + // Serialize the block if not already done + if (bytesAfterPut == null) { + if (valuesAfterPut == null) { + throw new SparkException( + "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") + } + bytesAfterPut = dataSerialize(blockId, valuesAfterPut) + } + replicate(blockId, bytesAfterPut, level) + logDebug("Put block " + blockId + " remotely took " + + Utils.getUsedTimeMs(remoteStartTime)) } } } - // If replication had started, then wait for it to finish - if (level.replication > 1) { - Await.ready(replicationFuture, Duration.Inf) - } + BlockManager.dispose(bytesAfterPut) if (level.replication > 1) { - logDebug("PutBytes for block " + blockId + " with replication took " + + logDebug("Put for block " + blockId + " with replication took " + Utils.getUsedTimeMs(startTimeMs)) } else { - logDebug("PutBytes for block " + blockId + " without replication took " + + logDebug("Put for block " + blockId + " without replication took " + Utils.getUsedTimeMs(startTimeMs)) } + + size } /** @@ -922,34 +763,20 @@ private[spark] class BlockManager( private def dropOldNonBroadcastBlocks(cleanupTime: Long) { logInfo("Dropping non broadcast blocks older than " + cleanupTime) - val iterator = blockInfo.internalMap.entrySet().iterator() - while (iterator.hasNext) { - val entry = iterator.next() - val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) - if (time < cleanupTime && !id.isBroadcast) { - info.synchronized { - val level = info.level - if (level.useMemory) { - memoryStore.remove(id) - } - if (level.useDisk) { - diskStore.remove(id) - } - iterator.remove() - logInfo("Dropped block " + id) - } - reportBlockStatus(id, info) - } - } + dropOldBlocks(cleanupTime, !_.isBroadcast) } private def dropOldBroadcastBlocks(cleanupTime: Long) { logInfo("Dropping broadcast blocks older than " + cleanupTime) + dropOldBlocks(cleanupTime, _.isBroadcast) + } + + private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) { val iterator = blockInfo.internalMap.entrySet().iterator() while (iterator.hasNext) { val entry = iterator.next() val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) - if (time < cleanupTime && id.isBroadcast) { + if (time < cleanupTime && shouldDrop(id)) { info.synchronized { val level = info.level if (level.useMemory) { @@ -987,13 +814,24 @@ private[spark] class BlockManager( if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s } + /** Serializes into a stream. */ + def dataSerializeStream( + blockId: BlockId, + outputStream: OutputStream, + values: Iterator[Any], + serializer: Serializer = defaultSerializer) { + val byteStream = new FastBufferedOutputStream(outputStream) + val ser = serializer.newInstance() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() + } + + /** Serializes into a byte buffer. */ def dataSerialize( blockId: BlockId, values: Iterator[Any], serializer: Serializer = defaultSerializer): ByteBuffer = { val byteStream = new FastByteArrayOutputStream(4096) - val ser = serializer.newInstance() - ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() + dataSerializeStream(blockId, byteStream, values, serializer) byteStream.trim() ByteBuffer.wrap(byteStream.array) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 633230c0a8..f8cf14b503 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -227,9 +227,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { - if (id.executorId == "<driver>" && !isLocal) { - // Got a register message from the master node; don't register it - } else if (!blockManagerInfo.contains(id)) { + if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { case Some(manager) => // A block manager of the same executor already exists. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 951503019f..3a65e55733 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -26,6 +26,7 @@ import org.apache.spark.storage.BlockManagerMessages._ * An actor to take commands from the master to execute options. For example, * this is used to remove blocks from the slave's BlockManager. */ +private[storage] class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor { override def receive = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 2a67800c45..32d2dd0694 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -17,6 +17,13 @@ package org.apache.spark.storage +import java.io.{FileOutputStream, File, OutputStream} +import java.nio.channels.FileChannel + +import it.unimi.dsi.fastutil.io.FastBufferedOutputStream + +import org.apache.spark.Logging +import org.apache.spark.serializer.{SerializationStream, Serializer} /** * An interface for writing JVM objects to some underlying storage. This interface allows @@ -59,7 +66,129 @@ abstract class BlockObjectWriter(val blockId: BlockId) { def write(value: Any) /** - * Size of the valid writes, in bytes. + * Returns the file segment of committed data that this Writer has written. + */ + def fileSegment(): FileSegment + + /** + * Cumulative time spent performing blocking writes, in ns. */ - def size(): Long + def timeWriting(): Long +} + +/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */ +class DiskBlockObjectWriter( + blockId: BlockId, + file: File, + serializer: Serializer, + bufferSize: Int, + compressStream: OutputStream => OutputStream) + extends BlockObjectWriter(blockId) + with Logging +{ + + /** Intercepts write calls and tracks total time spent writing. Not thread safe. */ + private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream { + def timeWriting = _timeWriting + private var _timeWriting = 0L + + private def callWithTiming(f: => Unit) = { + val start = System.nanoTime() + f + _timeWriting += (System.nanoTime() - start) + } + + def write(i: Int): Unit = callWithTiming(out.write(i)) + override def write(b: Array[Byte]) = callWithTiming(out.write(b)) + override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len)) + } + + private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean + + /** The file channel, used for repositioning / truncating the file. */ + private var channel: FileChannel = null + private var bs: OutputStream = null + private var fos: FileOutputStream = null + private var ts: TimeTrackingOutputStream = null + private var objOut: SerializationStream = null + private var initialPosition = 0L + private var lastValidPosition = 0L + private var initialized = false + private var _timeWriting = 0L + + override def open(): BlockObjectWriter = { + fos = new FileOutputStream(file, true) + ts = new TimeTrackingOutputStream(fos) + channel = fos.getChannel() + initialPosition = channel.position + lastValidPosition = initialPosition + bs = compressStream(new FastBufferedOutputStream(ts, bufferSize)) + objOut = serializer.newInstance().serializeStream(bs) + initialized = true + this + } + + override def close() { + if (initialized) { + if (syncWrites) { + // Force outstanding writes to disk and track how long it takes + objOut.flush() + val start = System.nanoTime() + fos.getFD.sync() + _timeWriting += System.nanoTime() - start + } + objOut.close() + + _timeWriting += ts.timeWriting + + channel = null + bs = null + fos = null + ts = null + objOut = null + } + // Invoke the close callback handler. + super.close() + } + + override def isOpen: Boolean = objOut != null + + override def commit(): Long = { + if (initialized) { + // NOTE: Flush the serializer first and then the compressed/buffered output stream + objOut.flush() + bs.flush() + val prevPos = lastValidPosition + lastValidPosition = channel.position() + lastValidPosition - prevPos + } else { + // lastValidPosition is zero if stream is uninitialized + lastValidPosition + } + } + + override def revertPartialWrites() { + if (initialized) { + // Discard current writes. We do this by flushing the outstanding writes and + // truncate the file to the last valid position. + objOut.flush() + bs.flush() + channel.truncate(lastValidPosition) + } + } + + override def write(value: Any) { + if (!initialized) { + open() + } + objOut.writeObject(value) + } + + override def fileSegment(): FileSegment = { + val bytesWritten = lastValidPosition - initialPosition + new FileSegment(file, initialPosition, bytesWritten) + } + + // Only valid if called after close() + override def timeWriting() = _timeWriting } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala new file mode 100644 index 0000000000..bcb58ad946 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.File +import java.text.SimpleDateFormat +import java.util.{Date, Random} +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.Logging +import org.apache.spark.executor.ExecutorExitCode +import org.apache.spark.network.netty.{PathResolver, ShuffleSender} +import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} + +/** + * Creates and maintains the logical mapping between logical blocks and physical on-disk + * locations. By default, one block is mapped to one file with a name given by its BlockId. + * However, it is also possible to have a block map to only a segment of a file, by calling + * mapBlockToFileSegment(). + * + * @param rootDirs The directories to use for storing block files. Data will be hashed among these. + */ +private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver with Logging { + + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 + private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt + + // Create one local directory for each path mentioned in spark.local.dir; then, inside this + // directory, create multiple subdirectories that we will hash files into, in order to avoid + // having really large inodes at the top level. + private val localDirs: Array[File] = createLocalDirs() + private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) + private var shuffleSender : ShuffleSender = null + + // Stores only Blocks which have been specifically mapped to segments of files + // (rather than the default, which maps a Block to a whole file). + // This keeps our bookkeeping down, since the file system itself tracks the standalone Blocks. + private val blockToFileSegmentMap = new TimeStampedHashMap[BlockId, FileSegment] + + val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DISK_BLOCK_MANAGER, this.cleanup) + + addShutdownHook() + + /** + * Creates a logical mapping from the given BlockId to a segment of a file. + * This will cause any accesses of the logical BlockId to be directed to the specified + * physical location. + */ + def mapBlockToFileSegment(blockId: BlockId, fileSegment: FileSegment) { + blockToFileSegmentMap.put(blockId, fileSegment) + } + + /** + * Returns the phyiscal file segment in which the given BlockId is located. + * If the BlockId has been mapped to a specific FileSegment, that will be returned. + * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly. + */ + def getBlockLocation(blockId: BlockId): FileSegment = { + if (blockToFileSegmentMap.internalMap.containsKey(blockId)) { + blockToFileSegmentMap.get(blockId).get + } else { + val file = getFile(blockId.name) + new FileSegment(file, 0, file.length()) + } + } + + /** + * Simply returns a File to place the given Block into. This does not physically create the file. + * If filename is given, that file will be used. Otherwise, we will use the BlockId to get + * a unique filename. + */ + def createBlockFile(blockId: BlockId, filename: String = "", allowAppending: Boolean): File = { + val actualFilename = if (filename == "") blockId.name else filename + val file = getFile(actualFilename) + if (!allowAppending && file.exists()) { + throw new IllegalStateException( + "Attempted to create file that already exists: " + actualFilename) + } + file + } + + private def getFile(filename: String): File = { + // Figure out which local directory it hashes to, and which subdirectory in that + val hash = Utils.nonNegativeHash(filename) + val dirId = hash % localDirs.length + val subDirId = (hash / localDirs.length) % subDirsPerLocalDir + + // Create the subdirectory if it doesn't already exist + var subDir = subDirs(dirId)(subDirId) + if (subDir == null) { + subDir = subDirs(dirId).synchronized { + val old = subDirs(dirId)(subDirId) + if (old != null) { + old + } else { + val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) + newDir.mkdir() + subDirs(dirId)(subDirId) = newDir + newDir + } + } + } + + new File(subDir, filename) + } + + private def createLocalDirs(): Array[File] = { + logDebug("Creating local directories at root dirs '" + rootDirs + "'") + val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + rootDirs.split(",").map { rootDir => + var foundLocalDir = false + var localDir: File = null + var localDirId: String = null + var tries = 0 + val rand = new Random() + while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) { + tries += 1 + try { + localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) + localDir = new File(rootDir, "spark-local-" + localDirId) + if (!localDir.exists) { + foundLocalDir = localDir.mkdirs() + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e) + } + } + if (!foundLocalDir) { + logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + + " attempts to create local dir in " + rootDir) + System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) + } + logInfo("Created local directory at " + localDir) + localDir + } + } + + private def cleanup(cleanupTime: Long) { + blockToFileSegmentMap.clearOldValues(cleanupTime) + } + + private def addShutdownHook() { + localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir)) + Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { + override def run() { + logDebug("Shutdown hook called") + localDirs.foreach { localDir => + try { + if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + } catch { + case t: Throwable => + logError("Exception while deleting local spark dir: " + localDir, t) + } + } + + if (shuffleSender != null) { + shuffleSender.stop() + } + } + }) + } + + private[storage] def startShuffleBlockSender(port: Int): Int = { + shuffleSender = new ShuffleSender(port, this) + logInfo("Created ShuffleSender binding to port : " + shuffleSender.port) + shuffleSender.port + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index b7ca61e938..a3c496f9e0 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -17,120 +17,25 @@ package org.apache.spark.storage -import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile} +import java.io.{FileOutputStream, RandomAccessFile} import java.nio.ByteBuffer -import java.nio.channels.FileChannel import java.nio.channels.FileChannel.MapMode -import java.util.{Random, Date} -import java.text.SimpleDateFormat import scala.collection.mutable.ArrayBuffer -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream - -import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.serializer.{Serializer, SerializationStream} import org.apache.spark.Logging -import org.apache.spark.network.netty.ShuffleSender -import org.apache.spark.network.netty.PathResolver +import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils /** * Stores BlockManager blocks on disk. */ -private class DiskStore(blockManager: BlockManager, rootDirs: String) +private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager) extends BlockStore(blockManager) with Logging { - class DiskBlockObjectWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int) - extends BlockObjectWriter(blockId) { - - private val f: File = createFile(blockId /*, allowAppendExisting */) - - // The file channel, used for repositioning / truncating the file. - private var channel: FileChannel = null - private var bs: OutputStream = null - private var objOut: SerializationStream = null - private var lastValidPosition = 0L - private var initialized = false - - override def open(): DiskBlockObjectWriter = { - val fos = new FileOutputStream(f, true) - channel = fos.getChannel() - bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize)) - objOut = serializer.newInstance().serializeStream(bs) - initialized = true - this - } - - override def close() { - if (initialized) { - objOut.close() - channel = null - bs = null - objOut = null - } - // Invoke the close callback handler. - super.close() - } - - override def isOpen: Boolean = objOut != null - - // Flush the partial writes, and set valid length to be the length of the entire file. - // Return the number of bytes written for this commit. - override def commit(): Long = { - if (initialized) { - // NOTE: Flush the serializer first and then the compressed/buffered output stream - objOut.flush() - bs.flush() - val prevPos = lastValidPosition - lastValidPosition = channel.position() - lastValidPosition - prevPos - } else { - // lastValidPosition is zero if stream is uninitialized - lastValidPosition - } - } - - override def revertPartialWrites() { - if (initialized) { - // Discard current writes. We do this by flushing the outstanding writes and - // truncate the file to the last valid position. - objOut.flush() - bs.flush() - channel.truncate(lastValidPosition) - } - } - - override def write(value: Any) { - if (!initialized) { - open() - } - objOut.writeObject(value) - } - - override def size(): Long = lastValidPosition - } - - private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt - - private var shuffleSender : ShuffleSender = null - // Create one local directory for each path mentioned in spark.local.dir; then, inside this - // directory, create multiple subdirectories that we will hash files into, in order to avoid - // having really large inodes at the top level. - private val localDirs: Array[File] = createLocalDirs() - private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) - - addShutdownHook() - - def getBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int) - : BlockObjectWriter = { - new DiskBlockObjectWriter(blockId, serializer, bufferSize) - } - override def getSize(blockId: BlockId): Long = { - getFile(blockId).length() + diskManager.getBlockLocation(blockId).length } override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) { @@ -139,27 +44,15 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) val bytes = _bytes.duplicate() logDebug("Attempting to put block " + blockId) val startTime = System.currentTimeMillis - val file = createFile(blockId) - val channel = new RandomAccessFile(file, "rw").getChannel() + val file = diskManager.createBlockFile(blockId, allowAppending = false) + val channel = new FileOutputStream(file).getChannel() while (bytes.remaining > 0) { channel.write(bytes) } channel.close() val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file on disk in %d ms".format( - blockId, Utils.bytesToString(bytes.limit), (finishTime - startTime))) - } - - private def getFileBytes(file: File): ByteBuffer = { - val length = file.length() - val channel = new RandomAccessFile(file, "r").getChannel() - val buffer = try { - channel.map(MapMode.READ_ONLY, 0, length) - } finally { - channel.close() - } - - buffer + file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime))) } override def putValues( @@ -171,21 +64,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) logDebug("Attempting to write values for block " + blockId) val startTime = System.currentTimeMillis - val file = createFile(blockId) - val fileOut = blockManager.wrapForCompression(blockId, - new FastBufferedOutputStream(new FileOutputStream(file))) - val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut) - objOut.writeAll(values.iterator) - objOut.close() - val length = file.length() + val file = diskManager.createBlockFile(blockId, allowAppending = false) + val outputStream = new FileOutputStream(file) + blockManager.dataSerializeStream(blockId, outputStream, values.iterator) + val length = file.length val timeTaken = System.currentTimeMillis - startTime logDebug("Block %s stored as %s file on disk in %d ms".format( - blockId, Utils.bytesToString(length), timeTaken)) + file.getName, Utils.bytesToString(length), timeTaken)) if (returnValues) { // Return a byte buffer for the contents of the file - val buffer = getFileBytes(file) + val buffer = getBytes(blockId).get PutResult(length, Right(buffer)) } else { PutResult(length, null) @@ -193,13 +83,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } override def getBytes(blockId: BlockId): Option[ByteBuffer] = { - val file = getFile(blockId) - val bytes = getFileBytes(file) - Some(bytes) + val segment = diskManager.getBlockLocation(blockId) + val channel = new RandomAccessFile(segment.file, "r").getChannel() + val buffer = try { + channel.map(MapMode.READ_ONLY, segment.offset, segment.length) + } finally { + channel.close() + } + Some(buffer) } override def getValues(blockId: BlockId): Option[Iterator[Any]] = { - getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes)) + getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) } /** @@ -211,118 +106,20 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) } override def remove(blockId: BlockId): Boolean = { - val file = getFile(blockId) - if (file.exists()) { + val fileSegment = diskManager.getBlockLocation(blockId) + val file = fileSegment.file + if (file.exists() && file.length() == fileSegment.length) { file.delete() } else { + if (fileSegment.length < file.length()) { + logWarning("Could not delete block associated with only a part of a file: " + blockId) + } false } } override def contains(blockId: BlockId): Boolean = { - getFile(blockId).exists() - } - - private def createFile(blockId: BlockId, allowAppendExisting: Boolean = false): File = { - val file = getFile(blockId) - if (!allowAppendExisting && file.exists()) { - // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task - // was rescheduled on the same machine as the old task. - logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") - file.delete() - } - file - } - - private def getFile(blockId: BlockId): File = { - logDebug("Getting file for block " + blockId) - - // Figure out which local directory it hashes to, and which subdirectory in that - val hash = Utils.nonNegativeHash(blockId) - val dirId = hash % localDirs.length - val subDirId = (hash / localDirs.length) % subDirsPerLocalDir - - // Create the subdirectory if it doesn't already exist - var subDir = subDirs(dirId)(subDirId) - if (subDir == null) { - subDir = subDirs(dirId).synchronized { - val old = subDirs(dirId)(subDirId) - if (old != null) { - old - } else { - val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) - newDir.mkdir() - subDirs(dirId)(subDirId) = newDir - newDir - } - } - } - - new File(subDir, blockId.name) - } - - private def createLocalDirs(): Array[File] = { - logDebug("Creating local directories at root dirs '" + rootDirs + "'") - val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - rootDirs.split(",").map { rootDir => - var foundLocalDir = false - var localDir: File = null - var localDirId: String = null - var tries = 0 - val rand = new Random() - while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) { - tries += 1 - try { - localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - localDir = new File(rootDir, "spark-local-" + localDirId) - if (!localDir.exists) { - foundLocalDir = localDir.mkdirs() - } - } catch { - case e: Exception => - logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e) - } - } - if (!foundLocalDir) { - logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + - " attempts to create local dir in " + rootDir) - System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) - } - logInfo("Created local directory at " + localDir) - localDir - } - } - - private def addShutdownHook() { - localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir)) - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { - override def run() { - logDebug("Shutdown hook called") - localDirs.foreach { localDir => - try { - if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) - } catch { - case t: Throwable => - logError("Exception while deleting local spark dir: " + localDir, t) - } - } - if (shuffleSender != null) { - shuffleSender.stop() - } - } - }) - } - - private[storage] def startShuffleBlockSender(port: Int): Int = { - val pResolver = new PathResolver { - override def getAbsolutePath(blockIdString: String): String = { - val blockId = BlockId(blockIdString) - if (!blockId.isShuffle) null - else DiskStore.this.getFile(blockId).getAbsolutePath - } - } - shuffleSender = new ShuffleSender(port, pResolver) - logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port) - shuffleSender.port + val file = diskManager.getBlockLocation(blockId).file + file.exists() } } diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala new file mode 100644 index 0000000000..555486830a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.File + +/** + * References a particular segment of a file (potentially the entire file), + * based off an offset and a length. + */ +private[spark] class FileSegment(val file: File, val offset: Long, val length : Long) { + override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) +} diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index f39fcd87fb..066e45a12b 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -17,12 +17,13 @@ package org.apache.spark.storage -import org.apache.spark.serializer.Serializer +import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicInteger +import org.apache.spark.serializer.Serializer private[spark] -class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter]) - +class ShuffleWriterGroup(val id: Int, val fileId: Int, val writers: Array[BlockObjectWriter]) private[spark] trait ShuffleBlocks { @@ -30,24 +31,66 @@ trait ShuffleBlocks { def releaseWriters(group: ShuffleWriterGroup) } +/** + * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one writer + * per reducer. + * + * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle + * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer + * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle files, + * it releases them for another task. + * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple: + * - shuffleId: The unique id given to the entire shuffle stage. + * - bucketId: The id of the output partition (i.e., reducer id) + * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a + * time owns a particular fileId, and this id is returned to a pool when the task finishes. + */ private[spark] class ShuffleBlockManager(blockManager: BlockManager) { + // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. + // TODO: Remove this once the shuffle file consolidation feature is stable. + val consolidateShuffleFiles = + System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean + + var nextFileId = new AtomicInteger(0) + val unusedFileIds = new ConcurrentLinkedQueue[java.lang.Integer]() - def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = { + def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) = { new ShuffleBlocks { // Get a group of writers for a map task. override def acquireWriters(mapId: Int): ShuffleWriterGroup = { val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 + val fileId = getUnusedFileId() val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) + if (consolidateShuffleFiles) { + val filename = physicalFileName(shuffleId, bucketId, fileId) + blockManager.getDiskWriter(blockId, filename, serializer, bufferSize) + } else { + blockManager.getDiskWriter(blockId, blockId.name, serializer, bufferSize) + } } - new ShuffleWriterGroup(mapId, writers) + new ShuffleWriterGroup(mapId, fileId, writers) } - override def releaseWriters(group: ShuffleWriterGroup) = { - // Nothing really to release here. + override def releaseWriters(group: ShuffleWriterGroup) { + recycleFileId(group.fileId) } } } + + private def getUnusedFileId(): Int = { + val fileId = unusedFileIds.poll() + if (fileId == null) nextFileId.getAndIncrement() else fileId + } + + private def recycleFileId(fileId: Int) { + if (consolidateShuffleFiles) { + unusedFileIds.add(fileId) + } + } + + private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { + "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) + } } diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala new file mode 100644 index 0000000000..7dcadc3805 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala @@ -0,0 +1,86 @@ +package org.apache.spark.storage + +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{CountDownLatch, Executors} + +import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.SparkContext +import org.apache.spark.util.Utils + +/** + * Utility for micro-benchmarking shuffle write performance. + * + * Writes simulated shuffle output from several threads and records the observed throughput. + */ +object StoragePerfTester { + def main(args: Array[String]) = { + /** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */ + val dataSizeMb = Utils.memoryStringToMb(sys.env.getOrElse("OUTPUT_DATA", "1g")) + + /** Number of map tasks. All tasks execute concurrently. */ + val numMaps = sys.env.get("NUM_MAPS").map(_.toInt).getOrElse(8) + + /** Number of reduce splits for each map task. */ + val numOutputSplits = sys.env.get("NUM_REDUCERS").map(_.toInt).getOrElse(500) + + val recordLength = 1000 // ~1KB records + val totalRecords = dataSizeMb * 1000 + val recordsPerMap = totalRecords / numMaps + + val writeData = "1" * recordLength + val executor = Executors.newFixedThreadPool(numMaps) + + System.setProperty("spark.shuffle.compress", "false") + System.setProperty("spark.shuffle.sync", "true") + + // This is only used to instantiate a BlockManager. All thread scheduling is done manually. + val sc = new SparkContext("local[4]", "Write Tester") + val blockManager = sc.env.blockManager + + def writeOutputBytes(mapId: Int, total: AtomicLong) = { + val shuffle = blockManager.shuffleBlockManager.forShuffle(1, numOutputSplits, + new KryoSerializer()) + val buckets = shuffle.acquireWriters(mapId) + for (i <- 1 to recordsPerMap) { + buckets.writers(i % numOutputSplits).write(writeData) + } + buckets.writers.map {w => + w.commit() + total.addAndGet(w.fileSegment().length) + w.close() + } + + shuffle.releaseWriters(buckets) + } + + val start = System.currentTimeMillis() + val latch = new CountDownLatch(numMaps) + val totalBytes = new AtomicLong() + for (task <- 1 to numMaps) { + executor.submit(new Runnable() { + override def run() = { + try { + writeOutputBytes(task, totalBytes) + latch.countDown() + } catch { + case e: Exception => + println("Exception in child thread: " + e + " " + e.getMessage) + System.exit(1) + } + } + }) + } + latch.await() + val end = System.currentTimeMillis() + val time = (end - start) / 1000.0 + val bytesPerSecond = totalBytes.get() / time + val bytesPerFile = (totalBytes.get() / (numOutputSplits * numMaps.toDouble)).toLong + + System.err.println("files_total\t\t%s".format(numMaps * numOutputSplits)) + System.err.println("bytes_per_file\t\t%s".format(Utils.bytesToString(bytesPerFile))) + System.err.println("agg_throughput\t\t%s/s".format(Utils.bytesToString(bytesPerSecond.toLong))) + + executor.shutdown() + sc.stop() + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala index b39c0e9769..ca5a28625b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala @@ -38,7 +38,7 @@ private[spark] class IndexPage(parent: JobProgressUI) { val now = System.currentTimeMillis() var activeTime = 0L - for (tasks <- listener.stageToTasksActive.values; t <- tasks) { + for (tasks <- listener.stageIdToTasksActive.values; t <- tasks) { activeTime += t.timeRunning(now) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index eb3b4e8522..6b854740d6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -36,52 +36,52 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt val DEFAULT_POOL_NAME = "default" - val stageToPool = new HashMap[Stage, String]() - val stageToDescription = new HashMap[Stage, String]() - val poolToActiveStages = new HashMap[String, HashSet[Stage]]() + val stageIdToPool = new HashMap[Int, String]() + val stageIdToDescription = new HashMap[Int, String]() + val poolToActiveStages = new HashMap[String, HashSet[StageInfo]]() - val activeStages = HashSet[Stage]() - val completedStages = ListBuffer[Stage]() - val failedStages = ListBuffer[Stage]() + val activeStages = HashSet[StageInfo]() + val completedStages = ListBuffer[StageInfo]() + val failedStages = ListBuffer[StageInfo]() // Total metrics reflect metrics only for completed tasks var totalTime = 0L var totalShuffleRead = 0L var totalShuffleWrite = 0L - val stageToTime = HashMap[Int, Long]() - val stageToShuffleRead = HashMap[Int, Long]() - val stageToShuffleWrite = HashMap[Int, Long]() - val stageToTasksActive = HashMap[Int, HashSet[TaskInfo]]() - val stageToTasksComplete = HashMap[Int, Int]() - val stageToTasksFailed = HashMap[Int, Int]() - val stageToTaskInfos = + val stageIdToTime = HashMap[Int, Long]() + val stageIdToShuffleRead = HashMap[Int, Long]() + val stageIdToShuffleWrite = HashMap[Int, Long]() + val stageIdToTasksActive = HashMap[Int, HashSet[TaskInfo]]() + val stageIdToTasksComplete = HashMap[Int, Int]() + val stageIdToTasksFailed = HashMap[Int, Int]() + val stageIdToTaskInfos = HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]() override def onJobStart(jobStart: SparkListenerJobStart) {} override def onStageCompleted(stageCompleted: StageCompleted) = synchronized { - val stage = stageCompleted.stageInfo.stage - poolToActiveStages(stageToPool(stage)) -= stage + val stage = stageCompleted.stage + poolToActiveStages(stageIdToPool(stage.stageId)) -= stage activeStages -= stage completedStages += stage trimIfNecessary(completedStages) } /** If stages is too large, remove and garbage collect old stages */ - def trimIfNecessary(stages: ListBuffer[Stage]) = synchronized { + def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized { if (stages.size > RETAINED_STAGES) { val toRemove = RETAINED_STAGES / 10 stages.takeRight(toRemove).foreach( s => { - stageToTaskInfos.remove(s.id) - stageToTime.remove(s.id) - stageToShuffleRead.remove(s.id) - stageToShuffleWrite.remove(s.id) - stageToTasksActive.remove(s.id) - stageToTasksComplete.remove(s.id) - stageToTasksFailed.remove(s.id) - stageToPool.remove(s) - if (stageToDescription.contains(s)) {stageToDescription.remove(s)} + stageIdToTaskInfos.remove(s.stageId) + stageIdToTime.remove(s.stageId) + stageIdToShuffleRead.remove(s.stageId) + stageIdToShuffleWrite.remove(s.stageId) + stageIdToTasksActive.remove(s.stageId) + stageIdToTasksComplete.remove(s.stageId) + stageIdToTasksFailed.remove(s.stageId) + stageIdToPool.remove(s.stageId) + if (stageIdToDescription.contains(s.stageId)) {stageIdToDescription.remove(s.stageId)} }) stages.trimEnd(toRemove) } @@ -95,63 +95,69 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList val poolName = Option(stageSubmitted.properties).map { p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME) }.getOrElse(DEFAULT_POOL_NAME) - stageToPool(stage) = poolName + stageIdToPool(stage.stageId) = poolName val description = Option(stageSubmitted.properties).flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) } - description.map(d => stageToDescription(stage) = d) + description.map(d => stageIdToDescription(stage.stageId) = d) - val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[Stage]()) + val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[StageInfo]()) stages += stage } override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { val sid = taskStart.task.stageId - val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) + val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) tasksActive += taskStart.taskInfo - val taskList = stageToTaskInfos.getOrElse( + val taskList = stageIdToTaskInfos.getOrElse( sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]()) taskList += ((taskStart.taskInfo, None, None)) - stageToTaskInfos(sid) = taskList + stageIdToTaskInfos(sid) = taskList } - + + override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) + = synchronized { + // Do nothing: because we don't do a deep copy of the TaskInfo, the TaskInfo in + // stageToTaskInfos already has the updated status. + } + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { val sid = taskEnd.task.stageId - val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) + val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) tasksActive -= taskEnd.taskInfo val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = taskEnd.reason match { case e: ExceptionFailure => - stageToTasksFailed(sid) = stageToTasksFailed.getOrElse(sid, 0) + 1 + stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1 (Some(e), e.metrics) case _ => - stageToTasksComplete(sid) = stageToTasksComplete.getOrElse(sid, 0) + 1 + stageIdToTasksComplete(sid) = stageIdToTasksComplete.getOrElse(sid, 0) + 1 (None, Option(taskEnd.taskMetrics)) } - stageToTime.getOrElseUpdate(sid, 0L) + stageIdToTime.getOrElseUpdate(sid, 0L) val time = metrics.map(m => m.executorRunTime).getOrElse(0) - stageToTime(sid) += time + stageIdToTime(sid) += time totalTime += time - stageToShuffleRead.getOrElseUpdate(sid, 0L) + stageIdToShuffleRead.getOrElseUpdate(sid, 0L) val shuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s => s.remoteBytesRead).getOrElse(0L) - stageToShuffleRead(sid) += shuffleRead + stageIdToShuffleRead(sid) += shuffleRead totalShuffleRead += shuffleRead - stageToShuffleWrite.getOrElseUpdate(sid, 0L) + stageIdToShuffleWrite.getOrElseUpdate(sid, 0L) val shuffleWrite = metrics.flatMap(m => m.shuffleWriteMetrics).map(s => s.shuffleBytesWritten).getOrElse(0L) - stageToShuffleWrite(sid) += shuffleWrite + stageIdToShuffleWrite(sid) += shuffleWrite totalShuffleWrite += shuffleWrite - val taskList = stageToTaskInfos.getOrElse( + val taskList = stageIdToTaskInfos.getOrElse( sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]()) taskList -= ((taskEnd.taskInfo, None, None)) taskList += ((taskEnd.taskInfo, metrics, failureInfo)) - stageToTaskInfos(sid) = taskList + stageIdToTaskInfos(sid) = taskList } override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { @@ -159,10 +165,15 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList case end: SparkListenerJobEnd => end.jobResult match { case JobFailed(ex, Some(stage)) => - activeStages -= stage - poolToActiveStages(stageToPool(stage)) -= stage - failedStages += stage - trimIfNecessary(failedStages) + /* If two jobs share a stage we could get this failure message twice. So we first + * check whether we've already retired this stage. */ + val stageInfo = activeStages.filter(s => s.stageId == stage.id).headOption + stageInfo.foreach {s => + activeStages -= s + poolToActiveStages(stageIdToPool(stage.id)) -= s + failedStages += s + trimIfNecessary(failedStages) + } case _ => } case _ => diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index 06810d8dbc..cfeeccda41 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -21,13 +21,13 @@ import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet import scala.xml.Node -import org.apache.spark.scheduler.{Schedulable, Stage} +import org.apache.spark.scheduler.{Schedulable, StageInfo} import org.apache.spark.ui.UIUtils /** Table showing list of pools */ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressListener) { - var poolToActiveStages: HashMap[String, HashSet[Stage]] = listener.poolToActiveStages + var poolToActiveStages: HashMap[String, HashSet[StageInfo]] = listener.poolToActiveStages def toNodeSeq(): Seq[Node] = { listener.synchronized { @@ -35,7 +35,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis } } - private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[Stage]]) => Seq[Node], + private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[StageInfo]]) => Seq[Node], rows: Seq[Schedulable] ): Seq[Node] = { <table class="table table-bordered table-striped table-condensed sortable table-fixed"> @@ -53,7 +53,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis </table> } - private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[Stage]]) + private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[StageInfo]]) : Seq[Node] = { val activeStages = poolToActiveStages.get(p.name) match { case Some(stages) => stages.size diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 163a3746ea..35b5d5fd59 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -40,7 +40,7 @@ private[spark] class StagePage(parent: JobProgressUI) { val stageId = request.getParameter("id").toInt val now = System.currentTimeMillis() - if (!listener.stageToTaskInfos.contains(stageId)) { + if (!listener.stageIdToTaskInfos.contains(stageId)) { val content = <div> <h4>Summary Metrics</h4> No tasks have started yet @@ -49,23 +49,23 @@ private[spark] class StagePage(parent: JobProgressUI) { return headerSparkPage(content, parent.sc, "Details for Stage %s".format(stageId), Stages) } - val tasks = listener.stageToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime) + val tasks = listener.stageIdToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime) val numCompleted = tasks.count(_._1.finished) - val shuffleReadBytes = listener.stageToShuffleRead.getOrElse(stageId, 0L) + val shuffleReadBytes = listener.stageIdToShuffleRead.getOrElse(stageId, 0L) val hasShuffleRead = shuffleReadBytes > 0 - val shuffleWriteBytes = listener.stageToShuffleWrite.getOrElse(stageId, 0L) + val shuffleWriteBytes = listener.stageIdToShuffleWrite.getOrElse(stageId, 0L) val hasShuffleWrite = shuffleWriteBytes > 0 var activeTime = 0L - listener.stageToTasksActive(stageId).foreach(activeTime += _.timeRunning(now)) + listener.stageIdToTasksActive(stageId).foreach(activeTime += _.timeRunning(now)) val summary = <div> <ul class="unstyled"> <li> <strong>CPU time: </strong> - {parent.formatDuration(listener.stageToTime.getOrElse(stageId, 0L) + activeTime)} + {parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)} </li> {if (hasShuffleRead) <li> @@ -83,10 +83,10 @@ private[spark] class StagePage(parent: JobProgressUI) { </div> val taskHeaders: Seq[String] = - Seq("Task ID", "Status", "Locality Level", "Executor", "Launch Time", "Duration") ++ - Seq("GC Time") ++ + Seq("Task Index", "Task ID", "Status", "Locality Level", "Executor", "Launch Time") ++ + Seq("Duration", "GC Time") ++ {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ - {if (hasShuffleWrite) Seq("Shuffle Write") else Nil} ++ + {if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++ Seq("Errors") val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite), tasks) @@ -153,6 +153,7 @@ private[spark] class StagePage(parent: JobProgressUI) { val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L) <tr> + <td>{info.index}</td> <td>{info.taskId}</td> <td>{info.status}</td> <td>{info.taskLocality}</td> @@ -169,6 +170,8 @@ private[spark] class StagePage(parent: JobProgressUI) { Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")}</td> }} {if (shuffleWrite) { + <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => + parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}</td> <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")}</td> }} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 07db8622da..d7d0441c38 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -22,13 +22,13 @@ import java.util.Date import scala.xml.Node import scala.collection.mutable.HashSet -import org.apache.spark.scheduler.{SchedulingMode, Stage, TaskInfo} +import org.apache.spark.scheduler.{SchedulingMode, StageInfo, TaskInfo} import org.apache.spark.ui.UIUtils import org.apache.spark.util.Utils /** Page showing list of all ongoing and recently finished stages */ -private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressUI) { +private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgressUI) { val listener = parent.listener val dateFmt = parent.dateFmt @@ -73,40 +73,40 @@ private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressU } - private def stageRow(s: Stage): Seq[Node] = { + private def stageRow(s: StageInfo): Seq[Node] = { val submissionTime = s.submissionTime match { case Some(t) => dateFmt.format(new Date(t)) case None => "Unknown" } - val shuffleRead = listener.stageToShuffleRead.getOrElse(s.id, 0L) match { + val shuffleRead = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L) match { case 0 => "" case b => Utils.bytesToString(b) } - val shuffleWrite = listener.stageToShuffleWrite.getOrElse(s.id, 0L) match { + val shuffleWrite = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L) match { case 0 => "" case b => Utils.bytesToString(b) } - val startedTasks = listener.stageToTasksActive.getOrElse(s.id, HashSet[TaskInfo]()).size - val completedTasks = listener.stageToTasksComplete.getOrElse(s.id, 0) - val failedTasks = listener.stageToTasksFailed.getOrElse(s.id, 0) match { + val startedTasks = listener.stageIdToTasksActive.getOrElse(s.stageId, HashSet[TaskInfo]()).size + val completedTasks = listener.stageIdToTasksComplete.getOrElse(s.stageId, 0) + val failedTasks = listener.stageIdToTasksFailed.getOrElse(s.stageId, 0) match { case f if f > 0 => "(%s failed)".format(f) case _ => "" } - val totalTasks = s.numPartitions + val totalTasks = s.numTasks - val poolName = listener.stageToPool.get(s) + val poolName = listener.stageIdToPool.get(s.stageId) val nameLink = - <a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.id)}>{s.name}</a> - val description = listener.stageToDescription.get(s) + <a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.stageId)}>{s.name}</a> + val description = listener.stageIdToDescription.get(s.stageId) .map(d => <div><em>{d}</em></div><div>{nameLink}</div>).getOrElse(nameLink) val finishTime = s.completionTime.getOrElse(System.currentTimeMillis()) val duration = s.submissionTime.map(t => finishTime - t) <tr> - <td>{s.id}</td> + <td>{s.stageId}</td> {if (isFairScheduler) { <td><a href={"%s/stages/pool?poolname=%s".format(UIUtils.prependBaseUri(),poolName.get)}> {poolName.get}</a></td>} diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 0ce1394c77..3f963727d9 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -56,9 +56,10 @@ class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, clea } object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext", "HttpBroadcast", "DagScheduler", "ResultTask", - "ShuffleMapTask", "BlockManager", "BroadcastVars") { + "ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") { - val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, SHUFFLE_MAP_TASK, BLOCK_MANAGER, BROADCAST_VARS = Value + val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, + SHUFFLE_MAP_TASK, BLOCK_MANAGER, DISK_BLOCK_MANAGER, BROADCAST_VARS = Value type MetadataCleanerType = Value diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index f384875cc9..0c5c12b7a8 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -37,6 +37,7 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream, import org.apache.spark.deploy.SparkHadoopUtil import java.nio.ByteBuffer import org.apache.spark.{SparkEnv, SparkException, Logging} +import java.util.ConcurrentModificationException /** @@ -280,9 +281,8 @@ private[spark] object Utils extends Logging { } case _ => // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others - val env = SparkEnv.get val uri = new URI(url) - val conf = env.hadoop.newConfiguration() + val conf = SparkHadoopUtil.get.newConfiguration() val fs = FileSystem.get(uri, conf) val in = fs.open(new Path(uri)) val out = new FileOutputStream(tempFile) @@ -447,14 +447,17 @@ private[spark] object Utils extends Logging { hostPortParseResults.get(hostPort) } - private[spark] val daemonThreadFactory: ThreadFactory = - new ThreadFactoryBuilder().setDaemon(true).build() + private val daemonThreadFactoryBuilder: ThreadFactoryBuilder = + new ThreadFactoryBuilder().setDaemon(true) /** - * Wrapper over newCachedThreadPool. + * Wrapper over newCachedThreadPool. Thread names are formatted as prefix-ID, where ID is a + * unique, sequentially assigned integer. */ - def newDaemonCachedThreadPool(): ThreadPoolExecutor = - Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] + def newDaemonCachedThreadPool(prefix: String): ThreadPoolExecutor = { + val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build() + Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] + } /** * Return the string to tell how long has passed in seconds. The passing parameter should be in @@ -465,10 +468,13 @@ private[spark] object Utils extends Logging { } /** - * Wrapper over newFixedThreadPool. + * Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a + * unique, sequentially assigned integer. */ - def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = - Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] + def newDaemonFixedThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = { + val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build() + Executors.newFixedThreadPool(nThreads, threadFactory).asInstanceOf[ThreadPoolExecutor] + } private def listFilesSafely(file: File): Seq[File] = { val files = file.listFiles() @@ -813,4 +819,10 @@ private[spark] object Utils extends Logging { // Nothing else to guard against ? hashAbs } + + /** Returns a copy of the system properties that is thread-safe to iterator over. */ + def getSystemProperties(): Map[String, String] = { + return System.getProperties().clone() + .asInstanceOf[java.util.Properties].toMap[String, String] + } } |