diff options
Diffstat (limited to 'core/src/main/scala/org')
150 files changed, 6611 insertions, 4541 deletions
diff --git a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala index f87460039b..0c47afae54 100644 --- a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala @@ -17,20 +17,29 @@ package org.apache.hadoop.mapred +private[apache] trait SparkHadoopMapRedUtil { def newJobContext(conf: JobConf, jobId: JobID): JobContext = { - val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", "org.apache.hadoop.mapred.JobContext"); - val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[org.apache.hadoop.mapreduce.JobID]) + val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", + "org.apache.hadoop.mapred.JobContext") + val ctor = klass.getDeclaredConstructor(classOf[JobConf], + classOf[org.apache.hadoop.mapreduce.JobID]) ctor.newInstance(conf, jobId).asInstanceOf[JobContext] } def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = { - val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", "org.apache.hadoop.mapred.TaskAttemptContext") + val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", + "org.apache.hadoop.mapred.TaskAttemptContext") val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID]) ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] } - def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = { + def newTaskAttemptID( + jtIdentifier: String, + jobId: Int, + isMap: Boolean, + taskId: Int, + attemptId: Int) = { new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId) } diff --git a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala index 93180307fa..32429f01ac 100644 --- a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala @@ -17,9 +17,10 @@ package org.apache.hadoop.mapreduce -import org.apache.hadoop.conf.Configuration import java.lang.{Integer => JInteger, Boolean => JBoolean} +import org.apache.hadoop.conf.Configuration +private[apache] trait SparkHadoopMapReduceUtil { def newJobContext(conf: Configuration, jobId: JobID): JobContext = { val klass = firstAvailableClass( @@ -37,23 +38,31 @@ trait SparkHadoopMapReduceUtil { ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] } - def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = { - val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID"); + def newTaskAttemptID( + jtIdentifier: String, + jobId: Int, + isMap: Boolean, + taskId: Int, + attemptId: Int) = { + val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID") try { - // first, attempt to use the old-style constructor that takes a boolean isMap (not available in YARN) + // First, attempt to use the old-style constructor that takes a boolean isMap + // (not available in YARN) val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean], - classOf[Int], classOf[Int]) - ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), new - JInteger(attemptId)).asInstanceOf[TaskAttemptID] + classOf[Int], classOf[Int]) + ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), + new JInteger(attemptId)).asInstanceOf[TaskAttemptID] } catch { case exc: NoSuchMethodException => { - // failed, look for the new ctor that takes a TaskType (not available in 1.x) - val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType").asInstanceOf[Class[Enum[_]]] - val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(taskTypeClass, if(isMap) "MAP" else "REDUCE") + // If that failed, look for the new constructor that takes a TaskType (not available in 1.x) + val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType") + .asInstanceOf[Class[Enum[_]]] + val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke( + taskTypeClass, if(isMap) "MAP" else "REDUCE") val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass, classOf[Int], classOf[Int]) - ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), new - JInteger(attemptId)).asInstanceOf[TaskAttemptID] + ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), + new JInteger(attemptId)).asInstanceOf[TaskAttemptID] } } } diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 3ef402926e..1a2ec55876 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -17,43 +17,42 @@ package org.apache.spark -import java.util.{HashMap => JHashMap} +import org.apache.spark.util.AppendOnlyMap -import scala.collection.JavaConversions._ - -/** A set of functions used to aggregate data. - * - * @param createCombiner function to create the initial value of the aggregation. - * @param mergeValue function to merge a new value into the aggregation result. - * @param mergeCombiners function to merge outputs from multiple mergeValue function. - */ +/** + * A set of functions used to aggregate data. + * + * @param createCombiner function to create the initial value of the aggregation. + * @param mergeValue function to merge a new value into the aggregation result. + * @param mergeCombiners function to merge outputs from multiple mergeValue function. + */ case class Aggregator[K, V, C] ( createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = { - val combiners = new JHashMap[K, C] - for (kv <- iter) { - val oldC = combiners.get(kv._1) - if (oldC == null) { - combiners.put(kv._1, createCombiner(kv._2)) - } else { - combiners.put(kv._1, mergeValue(oldC, kv._2)) - } + val combiners = new AppendOnlyMap[K, C] + var kv: Product2[K, V] = null + val update = (hadValue: Boolean, oldValue: C) => { + if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) + } + while (iter.hasNext) { + kv = iter.next() + combiners.changeValue(kv._1, update) } combiners.iterator } def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = { - val combiners = new JHashMap[K, C] - iter.foreach { case(k, c) => - val oldC = combiners.get(k) - if (oldC == null) { - combiners.put(k, c) - } else { - combiners.put(k, mergeCombiners(oldC, c)) - } + val combiners = new AppendOnlyMap[K, C] + var kc: (K, C) = null + val update = (hadValue: Boolean, oldValue: C) => { + if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 + } + while (iter.hasNext) { + kc = iter.next() + combiners.changeValue(kc._1, update) } combiners.iterator } diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala index 908ff56a6b..d9ed572da6 100644 --- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala @@ -22,13 +22,17 @@ import scala.collection.mutable.HashMap import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics} import org.apache.spark.serializer.Serializer -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util.CompletionIterator private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { - override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) + override def fetch[T]( + shuffleId: Int, + reduceId: Int, + context: TaskContext, + serializer: Serializer) : Iterator[T] = { @@ -45,12 +49,12 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) } - val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map { + val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map { case (address, splits) => - (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2))) + (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) } - def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[T] = { + def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = { val blockId = blockPair._1 val blockOption = blockPair._2 blockOption match { @@ -58,9 +62,8 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin block.asInstanceOf[Iterator[T]] } case None => { - val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r blockId match { - case regex(shufId, mapId, _) => + case ShuffleBlockId(shufId, mapId, _) => val address = statuses(mapId.toInt)._1 throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null) case _ => @@ -74,7 +77,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer) val itr = blockFetcherItr.flatMap(unpackBlock) - CompletionIterator[T, Iterator[T]](itr, { + val completionIter = CompletionIterator[T, Iterator[T]](itr, { val shuffleMetrics = new ShuffleReadMetrics shuffleMetrics.shuffleFinishTime = System.currentTimeMillis shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime @@ -83,7 +86,9 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks - metrics.shuffleReadMetrics = Some(shuffleMetrics) + context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics) }) + + new InterruptibleIterator[T](context, completionIter) } } diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 4cf7eb96da..519ecde50a 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -18,7 +18,7 @@ package org.apache.spark import scala.collection.mutable.{ArrayBuffer, HashSet} -import org.apache.spark.storage.{BlockManager, StorageLevel} +import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, RDDBlockId} import org.apache.spark.rdd.RDD @@ -28,17 +28,17 @@ import org.apache.spark.rdd.RDD private[spark] class CacheManager(blockManager: BlockManager) extends Logging { /** Keys of RDD splits that are being computed/loaded. */ - private val loading = new HashSet[String]() + private val loading = new HashSet[RDDBlockId]() /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */ def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel) : Iterator[T] = { - val key = "rdd_%d_%d".format(rdd.id, split.index) + val key = RDDBlockId(rdd.id, split.index) logDebug("Looking for partition " + key) blockManager.get(key) match { case Some(values) => // Partition is already materialized, so just return its values - return values.asInstanceOf[Iterator[T]] + return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) case None => // Mark the split as loading (unless someone else marks it first) @@ -56,7 +56,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { // downside of the current code is that threads wait serially if this does happen. blockManager.get(key) match { case Some(values) => - return values.asInstanceOf[Iterator[T]] + return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) case None => logInfo("Whoever was loading %s failed; we'll try it ourselves".format(key)) loading.add(key) @@ -73,7 +73,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { if (context.runningLocally) { return computedValues } val elements = new ArrayBuffer[Any] elements ++= computedValues - blockManager.put(key, elements, storageLevel, true) + blockManager.put(key, elements, storageLevel, tellMaster = true) return elements.iterator.asInstanceOf[Iterator[T]] } finally { loading.synchronized { diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala new file mode 100644 index 0000000000..1ad9240cfa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.concurrent._ +import scala.concurrent.duration.Duration +import scala.util.Try + +import org.apache.spark.scheduler.{JobSucceeded, JobWaiter} +import org.apache.spark.scheduler.JobFailed +import org.apache.spark.rdd.RDD + + +/** + * A future for the result of an action. This is an extension of the Scala Future interface to + * support cancellation. + */ +trait FutureAction[T] extends Future[T] { + // Note that we redefine methods of the Future trait here explicitly so we can specify a different + // documentation (with reference to the word "action"). + + /** + * Cancels the execution of this action. + */ + def cancel() + + /** + * Blocks until this action completes. + * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf + * for unbounded waiting, or a finite positive duration + * @return this FutureAction + */ + override def ready(atMost: Duration)(implicit permit: CanAwait): FutureAction.this.type + + /** + * Awaits and returns the result (of type T) of this action. + * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf + * for unbounded waiting, or a finite positive duration + * @throws Exception exception during action execution + * @return the result value if the action is completed within the specific maximum wait time + */ + @throws(classOf[Exception]) + override def result(atMost: Duration)(implicit permit: CanAwait): T + + /** + * When this action is completed, either through an exception, or a value, applies the provided + * function. + */ + def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) + + /** + * Returns whether the action has already been completed with a value or an exception. + */ + override def isCompleted: Boolean + + /** + * The value of this Future. + * + * If the future is not completed the returned value will be None. If the future is completed + * the value will be Some(Success(t)) if it contains a valid result, or Some(Failure(error)) if + * it contains an exception. + */ + override def value: Option[Try[T]] + + /** + * Blocks and returns the result of this job. + */ + @throws(classOf[Exception]) + def get(): T = Await.result(this, Duration.Inf) +} + + +/** + * The future holding the result of an action that triggers a single job. Examples include + * count, collect, reduce. + */ +class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) + extends FutureAction[T] { + + override def cancel() { + jobWaiter.cancel() + } + + override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = { + if (!atMost.isFinite()) { + awaitResult() + } else { + val finishTime = System.currentTimeMillis() + atMost.toMillis + while (!isCompleted) { + val time = System.currentTimeMillis() + if (time >= finishTime) { + throw new TimeoutException + } else { + jobWaiter.wait(finishTime - time) + } + } + } + this + } + + @throws(classOf[Exception]) + override def result(atMost: Duration)(implicit permit: CanAwait): T = { + ready(atMost)(permit) + awaitResult() match { + case scala.util.Success(res) => res + case scala.util.Failure(e) => throw e + } + } + + override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) { + executor.execute(new Runnable { + override def run() { + func(awaitResult()) + } + }) + } + + override def isCompleted: Boolean = jobWaiter.jobFinished + + override def value: Option[Try[T]] = { + if (jobWaiter.jobFinished) { + Some(awaitResult()) + } else { + None + } + } + + private def awaitResult(): Try[T] = { + jobWaiter.awaitResult() match { + case JobSucceeded => scala.util.Success(resultFunc) + case JobFailed(e: Exception, _) => scala.util.Failure(e) + } + } +} + + +/** + * A FutureAction for actions that could trigger multiple Spark jobs. Examples include take, + * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the + * action thread if it is being blocked by a job. + */ +class ComplexFutureAction[T] extends FutureAction[T] { + + // Pointer to the thread that is executing the action. It is set when the action is run. + @volatile private var thread: Thread = _ + + // A flag indicating whether the future has been cancelled. This is used in case the future + // is cancelled before the action was even run (and thus we have no thread to interrupt). + @volatile private var _cancelled: Boolean = false + + // A promise used to signal the future. + private val p = promise[T]() + + override def cancel(): Unit = this.synchronized { + _cancelled = true + if (thread != null) { + thread.interrupt() + } + } + + /** + * Executes some action enclosed in the closure. To properly enable cancellation, the closure + * should use runJob implementation in this promise. See takeAsync for example. + */ + def run(func: => T)(implicit executor: ExecutionContext): this.type = { + scala.concurrent.future { + thread = Thread.currentThread + try { + p.success(func) + } catch { + case e: Exception => p.failure(e) + } finally { + thread = null + } + } + this + } + + /** + * Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext + * to enable cancellation. + */ + def runJob[T, U, R]( + rdd: RDD[T], + processPartition: Iterator[T] => U, + partitions: Seq[Int], + resultHandler: (Int, U) => Unit, + resultFunc: => R) { + // If the action hasn't been cancelled yet, submit the job. The check and the submitJob + // command need to be in an atomic block. + val job = this.synchronized { + if (!cancelled) { + rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc) + } else { + throw new SparkException("Action has been cancelled") + } + } + + // Wait for the job to complete. If the action is cancelled (with an interrupt), + // cancel the job and stop the execution. This is not in a synchronized block because + // Await.ready eventually waits on the monitor in FutureJob.jobWaiter. + try { + Await.ready(job, Duration.Inf) + } catch { + case e: InterruptedException => + job.cancel() + throw new SparkException("Action has been cancelled") + } + } + + /** + * Returns whether the promise has been cancelled. + */ + def cancelled: Boolean = _cancelled + + @throws(classOf[InterruptedException]) + @throws(classOf[scala.concurrent.TimeoutException]) + override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = { + p.future.ready(atMost)(permit) + this + } + + @throws(classOf[Exception]) + override def result(atMost: Duration)(implicit permit: CanAwait): T = { + p.future.result(atMost)(permit) + } + + override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit = { + p.future.onComplete(func)(executor) + } + + override def isCompleted: Boolean = p.isCompleted + + override def value: Option[Try[T]] = p.future.value +} diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala new file mode 100644 index 0000000000..56e0b8d2c0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * An iterator that wraps around an existing iterator to provide task killing functionality. + * It works by checking the interrupted flag in TaskContext. + */ +class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T]) + extends Iterator[T] { + + def hasNext: Boolean = !context.interrupted && delegate.hasNext + + def next(): T = delegate.next() +} diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 1afb1870f1..035942ad39 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -20,7 +20,6 @@ package org.apache.spark import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet import akka.actor._ @@ -34,7 +33,7 @@ import scala.concurrent.duration._ import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashMap} +import org.apache.spark.util.{MetadataCleanerType, Utils, MetadataCleaner, TimeStampedHashMap} private[spark] sealed trait MapOutputTrackerMessage @@ -42,11 +41,12 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage -private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging { +private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster) + extends Actor with Logging { def receive = { case GetMapOutputStatuses(shuffleId: Int, requester: String) => logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester) - sender ! tracker.getSerializedLocations(shuffleId) + sender ! tracker.getSerializedMapOutputStatuses(shuffleId) case StopMapOutputTracker => logInfo("MapOutputTrackerActor stopped!") @@ -62,22 +62,19 @@ private[spark] class MapOutputTracker extends Logging { // Set to the MapOutputTrackerActor living on the driver var trackerActor: ActorRef = _ - private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. - private var epoch: Long = 0 - private val epochLock = new java.lang.Object + protected var epoch: Long = 0 + protected val epochLock = new java.lang.Object - // Cache a serialized version of the output statuses for each shuffle to send them out faster - var cacheEpoch = epoch - private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] - - val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup) + private val metadataCleaner = + new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup) // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. - def askTracker(message: Any): Any = { + private def askTracker(message: Any): Any = { try { val future = trackerActor.ask(message)(timeout) return Await.result(future, timeout) @@ -88,50 +85,12 @@ private[spark] class MapOutputTracker extends Logging { } // Send a one-way message to the trackerActor, to which we expect it to reply with true. - def communicate(message: Any) { + private def communicate(message: Any) { if (askTracker(message) != true) { throw new SparkException("Error reply received from MapOutputTracker") } } - def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) { - throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") - } - } - - def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { - var array = mapStatuses(shuffleId) - array.synchronized { - array(mapId) = status - } - } - - def registerMapOutputs( - shuffleId: Int, - statuses: Array[MapStatus], - changeEpoch: Boolean = false) { - mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses) - if (changeEpoch) { - incrementEpoch() - } - } - - def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - var arrayOpt = mapStatuses.get(shuffleId) - if (arrayOpt.isDefined && arrayOpt.get != null) { - var array = arrayOpt.get - array.synchronized { - if (array(mapId) != null && array(mapId).location == bmAddress) { - array(mapId) = null - } - } - incrementEpoch() - } else { - throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") - } - } - // Remembers which map output locations are currently being fetched on a worker private val fetching = new HashSet[Int] @@ -170,7 +129,7 @@ private[spark] class MapOutputTracker extends Logging { try { val fetchedBytes = askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]] - fetchedStatuses = deserializeStatuses(fetchedBytes) + fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) } finally { @@ -196,9 +155,8 @@ private[spark] class MapOutputTracker extends Logging { } } - private def cleanup(cleanupTime: Long) { + protected def cleanup(cleanupTime: Long) { mapStatuses.clearOldValues(cleanupTime) - cachedSerializedStatuses.clearOldValues(cleanupTime) } def stop() { @@ -208,15 +166,7 @@ private[spark] class MapOutputTracker extends Logging { trackerActor = null } - // Called on master to increment the epoch number - def incrementEpoch() { - epochLock.synchronized { - epoch += 1 - logDebug("Increasing epoch to " + epoch) - } - } - - // Called on master or workers to get current epoch number + // Called to get current epoch number def getEpoch: Long = { epochLock.synchronized { return epoch @@ -230,14 +180,62 @@ private[spark] class MapOutputTracker extends Logging { epochLock.synchronized { if (newEpoch > epoch) { logInfo("Updating epoch to " + newEpoch + " and clearing cache") - // mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] - mapStatuses.clear() epoch = newEpoch + mapStatuses.clear() + } + } + } +} + +private[spark] class MapOutputTrackerMaster extends MapOutputTracker { + + // Cache a serialized version of the output statuses for each shuffle to send them out faster + private var cacheEpoch = epoch + private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] + + def registerShuffle(shuffleId: Int, numMaps: Int) { + if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) { + throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") + } + } + + def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { + val array = mapStatuses(shuffleId) + array.synchronized { + array(mapId) = status + } + } + + def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) { + mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses) + if (changeEpoch) { + incrementEpoch() + } + } + + def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { + val arrayOpt = mapStatuses.get(shuffleId) + if (arrayOpt.isDefined && arrayOpt.get != null) { + val array = arrayOpt.get + array.synchronized { + if (array(mapId) != null && array(mapId).location == bmAddress) { + array(mapId) = null + } } + incrementEpoch() + } else { + throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") } } - def getSerializedLocations(shuffleId: Int): Array[Byte] = { + def incrementEpoch() { + epochLock.synchronized { + epoch += 1 + logDebug("Increasing epoch to " + epoch) + } + } + + def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = { var statuses: Array[MapStatus] = null var epochGotten: Long = -1 epochLock.synchronized { @@ -255,7 +253,7 @@ private[spark] class MapOutputTracker extends Logging { } // If we got here, we failed to find the serialized locations in the cache, so we pulled // out a snapshot of the locations as "locs"; let's serialize and return that - val bytes = serializeStatuses(statuses) + val bytes = MapOutputTracker.serializeMapStatuses(statuses) logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) // Add them into the table only if the epoch hasn't changed while we were working epochLock.synchronized { @@ -263,13 +261,31 @@ private[spark] class MapOutputTracker extends Logging { cachedSerializedStatuses(shuffleId) = bytes } } - return bytes + bytes + } + + protected override def cleanup(cleanupTime: Long) { + super.cleanup(cleanupTime) + cachedSerializedStatuses.clearOldValues(cleanupTime) } + override def stop() { + super.stop() + cachedSerializedStatuses.clear() + } + + override def updateEpoch(newEpoch: Long) { + // This might be called on the MapOutputTrackerMaster if we're running in local mode. + } +} + +private[spark] object MapOutputTracker { + private val LOG_BASE = 1.1 + // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. - private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = { + def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = { val out = new ByteArrayOutputStream val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) // Since statuses can be modified in parallel, sync on it @@ -280,18 +296,11 @@ private[spark] class MapOutputTracker extends Logging { out.toByteArray } - // Opposite of serializeStatuses. - def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = { + // Opposite of serializeMapStatuses. + def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = { val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes))) - objIn.readObject(). - // // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present - // comment this out - nulls could be due to missing location ? - asInstanceOf[Array[MapStatus]] // .filter( _ != null ) + objIn.readObject().asInstanceOf[Array[MapStatus]] } -} - -private[spark] object MapOutputTracker { - private val LOG_BASE = 1.1 // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If // any of the statuses is null (indicating a missing location due to a failed mapper), diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala index 307c383a89..a85aa50a9b 100644 --- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala @@ -27,7 +27,10 @@ private[spark] abstract class ShuffleFetcher { * Fetch the shuffle outputs for a given ShuffleDependency. * @return An iterator over the elements of the fetched shuffle outputs. */ - def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, + def fetch[T]( + shuffleId: Int, + reduceId: Int, + context: TaskContext, serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T] /** Stop the fetcher */ diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1b003cc685..cc44a4c7dd 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -24,7 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.Map import scala.collection.generic.Growable -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.reflect.{ ClassTag, classTag} @@ -53,21 +53,19 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor import org.apache.mesos.MesosNativeLibrary -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.LocalSparkCluster +import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, - ClusterScheduler} -import org.apache.spark.scheduler.local.LocalScheduler +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, + SparkDeploySchedulerBackend, ClusterScheduler, SimrSchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import org.apache.spark.storage.{StorageUtils, BlockManagerSource} -import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ClosureCleaner, Utils, MetadataCleaner, TimeStampedHashMap} +import org.apache.spark.scheduler.local.LocalScheduler import org.apache.spark.scheduler.StageInfo -import org.apache.spark.storage.RDDInfo -import org.apache.spark.storage.StorageStatus +import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils} +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, + TimeStampedHashMap, Utils} /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -121,9 +119,9 @@ class SparkContext( // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]] - private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup) + private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup) - // Initalize the Spark UI + // Initialize the Spark UI private[spark] val ui = new SparkUI(this) ui.bind() @@ -149,6 +147,14 @@ class SparkContext( executorEnvs ++= environment } + // Set SPARK_USER for user who is running SparkContext. + val sparkUser = Option { + Option(System.getProperty("user.name")).getOrElse(System.getenv("SPARK_USER")) + }.getOrElse { + SparkContext.SPARK_UNKNOWN_USER + } + executorEnvs("SPARK_USER") = sparkUser + // Create and start the scheduler private[spark] var taskScheduler: TaskScheduler = { // Regular expression used for local[N] master format @@ -158,9 +164,11 @@ class SparkContext( // Regular expression for simulating a Spark cluster of [N, cores, memory] locally val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters - val SPARK_REGEX = """(spark://.*)""".r - //Regular expression for connection to Mesos cluster - val MESOS_REGEX = """(mesos://.*)""".r + val SPARK_REGEX = """spark://(.*)""".r + // Regular expression for connection to Mesos cluster + val MESOS_REGEX = """mesos://(.*)""".r + // Regular expression for connection to Simr cluster + val SIMR_REGEX = """simr://(.*)""".r master match { case "local" => @@ -174,7 +182,14 @@ class SparkContext( case SPARK_REGEX(sparkUrl) => val scheduler = new ClusterScheduler(this) - val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) + val masterUrls = sparkUrl.split(",").map("spark://" + _) + val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName) + scheduler.initialize(backend) + scheduler + + case SIMR_REGEX(simrUrl) => + val scheduler = new ClusterScheduler(this) + val backend = new SimrSchedulerBackend(scheduler, this, simrUrl) scheduler.initialize(backend) scheduler @@ -190,8 +205,8 @@ class SparkContext( val scheduler = new ClusterScheduler(this) val localCluster = new LocalSparkCluster( numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) - val sparkUrl = localCluster.start() - val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) + val masterUrls = localCluster.start() + val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName) scheduler.initialize(backend) backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { localCluster.stop() @@ -210,25 +225,24 @@ class SparkContext( throw new SparkException("YARN mode not available ?", th) } } - val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem) + val backend = new CoarseGrainedSchedulerBackend(scheduler, this.env.actorSystem) scheduler.initialize(backend) scheduler - case _ => - if (MESOS_REGEX.findFirstIn(master).isEmpty) { - logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master)) - } + case MESOS_REGEX(mesosUrl) => MesosNativeLibrary.load() val scheduler = new ClusterScheduler(this) val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean - val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos:// val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName) + new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName) } else { - new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName) + new MesosSchedulerBackend(scheduler, this, mesosUrl, appName) } scheduler.initialize(backend) scheduler + + case _ => + throw new SparkException("Could not parse Master URL: '" + master + "'") } } taskScheduler.start() @@ -241,7 +255,7 @@ class SparkContext( /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = { val env = SparkEnv.get - val conf = env.hadoop.newConfiguration() + val conf = SparkHadoopUtil.get.newConfiguration() // Explicitly check for S3 environment variables if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) { @@ -251,8 +265,10 @@ class SparkContext( conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) } // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" - for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) { - conf.set(key.substring("spark.hadoop.".length), System.getProperty(key)) + Utils.getSystemProperties.foreach { case (key, value) => + if (key.startsWith("spark.hadoop.")) { + conf.set(key.substring("spark.hadoop.".length), value) + } } val bufferSize = System.getProperty("spark.buffer.size", "65536") conf.set("io.file.buffer.size", bufferSize) @@ -285,15 +301,46 @@ class SparkContext( Option(localProperties.get).map(_.getProperty(key)).getOrElse(null) /** Set a human readable description of the current job. */ + @deprecated("use setJobGroup", "0.8.1") def setJobDescription(value: String) { - setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value) + setJobGroup("", value) + } + + /** + * Assigns a group id to all the jobs started by this thread until the group id is set to a + * different value or cleared. + * + * Often, a unit of execution in an application consists of multiple Spark actions or jobs. + * Application programmers can use this method to group all those jobs together and give a + * group description. Once set, the Spark web UI will associate such jobs with this group. + * + * The application can also use [[org.apache.spark.SparkContext.cancelJobGroup]] to cancel all + * running jobs in this group. For example, + * {{{ + * // In the main thread: + * sc.setJobGroup("some_job_to_cancel", "some job description") + * sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count() + * + * // In a separate thread: + * sc.cancelJobGroup("some_job_to_cancel") + * }}} + */ + def setJobGroup(groupId: String, description: String) { + setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description) + setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId) + } + + /** Clear the job group id and its description. */ + def clearJobGroup() { + setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null) + setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null) } // Post init taskScheduler.postStartHook() - val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this) - val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this) + private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this) + private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this) def initDriverMetrics() { SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource) @@ -332,7 +379,7 @@ class SparkContext( } /** - * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf giving its InputFormat and any + * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and any * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, * etc). */ @@ -344,7 +391,7 @@ class SparkContext( minSplits: Int = defaultMinSplits ): RDD[(K, V)] = { // Add necessary security credentials to the JobConf before broadcasting it. - SparkEnv.get.hadoop.addCredentials(conf) + SparkHadoopUtil.get.addCredentials(conf) new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits) } @@ -358,24 +405,15 @@ class SparkContext( ): RDD[(K, V)] = { // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration)) - hadoopFile(path, confBroadcast, inputFormatClass, keyClass, valueClass, minSplits) - } - - /** - * Get an RDD for a Hadoop file with an arbitray InputFormat. Accept a Hadoop Configuration - * that has already been broadcast, assuming that it's safe to use it to construct a - * HadoopFileRDD (i.e., except for file 'path', all other configuration properties can be resued). - */ - def hadoopFile[K, V]( - path: String, - confBroadcast: Broadcast[SerializableWritable[Configuration]], - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int - ): RDD[(K, V)] = { - new HadoopFileRDD( - this, path, confBroadcast, inputFormatClass, keyClass, valueClass, minSplits) + val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path) + new HadoopRDD( + this, + confBroadcast, + Some(setInputPathsFunc), + inputFormatClass, + keyClass, + valueClass, + minSplits) } /** @@ -563,7 +601,8 @@ class SparkContext( val uri = new URI(path) val key = uri.getScheme match { case null | "file" => env.httpFileServer.addFile(new File(uri.getPath)) - case _ => path + case "local" => "file:" + uri.getPath + case _ => path } addedFiles(key) = System.currentTimeMillis @@ -657,12 +696,11 @@ class SparkContext( /** * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. + * filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. */ def addJar(path: String) { if (path == null) { - logWarning("null specified as parameter to addJar", - new SparkException("null specified as parameter to addJar")) + logWarning("null specified as parameter to addJar") } else { var key = "" if (path.contains("\\")) { @@ -671,12 +709,27 @@ class SparkContext( } else { val uri = new URI(path) key = uri.getScheme match { + // A JAR file which exists only on the driver node case null | "file" => - if (env.hadoop.isYarnMode()) { - logWarning("local jar specified as parameter to addJar under Yarn mode") - return + if (SparkHadoopUtil.get.isYarnMode()) { + // In order for this to work on yarn the user must specify the --addjars option to + // the client to upload the file into the distributed cache to make it show up in the + // current working directory. + val fileName = new Path(uri.getPath).getName() + try { + env.httpFileServer.addJar(new File(fileName)) + } catch { + case e: Exception => { + logError("Error adding jar (" + e + "), was the --addJars option used?") + throw e + } + } + } else { + env.httpFileServer.addJar(new File(uri.getPath)) } - env.httpFileServer.addJar(new File(uri.getPath)) + // A JAR file which exists locally on every worker node + case "local" => + "file:" + uri.getPath case _ => path } @@ -750,13 +803,13 @@ class SparkContext( allowLocal: Boolean, resultHandler: (Int, U) => Unit) { val callSite = Utils.formatSparkCallSite + val cleanedFunc = clean(func) logInfo("Starting job: " + callSite) val start = System.nanoTime - val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, - localProperties.get) + dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, + resultHandler, localProperties.get) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() - result } /** @@ -843,6 +896,42 @@ class SparkContext( } /** + * Submit a job for execution and return a FutureJob holding the result. + */ + def submitJob[T, U, R]( + rdd: RDD[T], + processPartition: Iterator[T] => U, + partitions: Seq[Int], + resultHandler: (Int, U) => Unit, + resultFunc: => R): SimpleFutureAction[R] = + { + val cleanF = clean(processPartition) + val callSite = Utils.formatSparkCallSite + val waiter = dagScheduler.submitJob( + rdd, + (context: TaskContext, iter: Iterator[T]) => cleanF(iter), + partitions, + callSite, + allowLocal = false, + resultHandler, + localProperties.get) + new SimpleFutureAction(waiter, resultFunc) + } + + /** + * Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]] + * for more information. + */ + def cancelJobGroup(groupId: String) { + dagScheduler.cancelJobGroup(groupId) + } + + /** Cancel all jobs that have been scheduled or are running. */ + def cancelAllJobs() { + dagScheduler.cancelAllJobs() + } + + /** * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) */ @@ -859,9 +948,8 @@ class SparkContext( * prevent accidental overriding of checkpoint files in the existing directory. */ def setCheckpointDir(dir: String, useExisting: Boolean = false) { - val env = SparkEnv.get val path = new Path(dir) - val fs = path.getFileSystem(env.hadoop.newConfiguration()) + val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration()) if (!useExisting) { if (fs.exists(path)) { throw new Exception("Checkpoint directory '" + path + "' already exists.") @@ -898,7 +986,12 @@ class SparkContext( * various Spark features. */ object SparkContext { - val SPARK_JOB_DESCRIPTION = "spark.job.description" + + private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description" + + private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" + + private[spark] val SPARK_UNKNOWN_USER = "<unknown>" implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 @@ -925,6 +1018,8 @@ object SparkContext { implicit def rddToPairRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) = new PairRDDFunctions(rdd) + implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd) + implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( rdd: RDD[(K, V)]) = new SequenceFileRDDFunctions(rdd) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index a267407c67..84750e2e85 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -25,13 +25,13 @@ import akka.remote.RemoteActorRefProvider import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.{BlockManagerMasterActor, BlockManager, BlockManagerMaster} import org.apache.spark.network.ConnectionManager import org.apache.spark.serializer.{Serializer, SerializerManager} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.api.python.PythonWorkerFactory +import com.google.common.collect.MapMaker /** * Holds all the runtime environment objects for a running Spark instance (either master or worker), @@ -58,18 +58,9 @@ class SparkEnv ( private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() - val hadoop = { - val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) - if(yarnMode) { - try { - Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil] - } catch { - case th: Throwable => throw new SparkException("Unable to load YARN support", th) - } - } else { - new SparkHadoopUtil - } - } + // A general, soft-reference map for metadata needed during HadoopRDD split computation + // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). + private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]() def stop() { pythonWorkers.foreach { case(key, worker) => worker.stop() } @@ -188,10 +179,14 @@ object SparkEnv extends Logging { // Have to assign trackerActor after initialization as MapOutputTrackerActor // requires the MapOutputTracker itself - val mapOutputTracker = new MapOutputTracker() + val mapOutputTracker = if (isDriver) { + new MapOutputTrackerMaster() + } else { + new MapOutputTracker() + } mapOutputTracker.trackerActor = registerOrLookup( "MapOutputTracker", - new MapOutputTrackerActor(mapOutputTracker)) + new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])) val shuffleFetcher = instantiateClass[ShuffleFetcher]( "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher") diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 2bab9d6e3d..103a1c2051 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -17,14 +17,14 @@ package org.apache.hadoop.mapred -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.fs.Path - +import java.io.IOException import java.text.SimpleDateFormat import java.text.NumberFormat -import java.io.IOException import java.util.Date +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path + import org.apache.spark.Logging import org.apache.spark.SerializableWritable @@ -36,7 +36,11 @@ import org.apache.spark.SerializableWritable * Saves the RDD using a JobConf, which should contain an output key class, an output value class, * a filename to write to, etc, exactly like in a Hadoop MapReduce job. */ -class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable { +private[apache] +class SparkHadoopWriter(@transient jobConf: JobConf) + extends Logging + with SparkHadoopMapRedUtil + with Serializable { private val now = new Date() private val conf = new SerializableWritable(jobConf) @@ -83,13 +87,11 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH } getOutputCommitter().setupTask(getTaskContext()) - writer = getOutputFormat().getRecordWriter( - fs, conf.value, outputName, Reporter.NULL) + writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL) } def write(key: AnyRef, value: AnyRef) { - if (writer!=null) { - //println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")") + if (writer != null) { writer.write(key, value) } else { throw new IOException("Writer is null, open() has not been called") @@ -179,6 +181,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH } } +private[apache] object SparkHadoopWriter { def createJobID(time: Date, id: Int): JobID = { val formatter = new SimpleDateFormat("yyyyMMddHHmm") diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index c2c358c7ad..cae983ed4c 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -17,21 +17,30 @@ package org.apache.spark -import executor.TaskMetrics import scala.collection.mutable.ArrayBuffer +import org.apache.spark.executor.TaskMetrics + class TaskContext( val stageId: Int, - val splitId: Int, + val partitionId: Int, val attemptId: Long, val runningLocally: Boolean = false, - val taskMetrics: TaskMetrics = TaskMetrics.empty() + @volatile var interrupted: Boolean = false, + private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty() ) extends Serializable { - @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit] + @deprecated("use partitionId", "0.8.1") + def splitId = partitionId + + // List of callback functions to execute when the task completes. + @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit] - // Add a callback function to be executed on task completion. An example use - // is for HadoopRDD to register a callback to close the input stream. + /** + * Add a callback function to be executed on task completion. An example use + * is for HadoopRDD to register a callback to close the input stream. + * @param f Callback function. + */ def addOnCompleteCallback(f: () => Unit) { onCompleteCallbacks += f } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 8466c2a004..c1e5e04b31 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -52,4 +52,6 @@ private[spark] case class ExceptionFailure( */ private[spark] case object TaskResultLost extends TaskEndReason +private[spark] case object TaskKilled extends TaskEndReason + private[spark] case class OtherFailure(message: String) extends TaskEndReason diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index f0a1960a1b..e5e20dbb66 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -51,6 +51,19 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav */ def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel)) + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * This method blocks until all blocks are deleted. + */ + def unpersist(): JavaDoubleRDD = fromRDD(srdd.unpersist()) + + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * + * @param blocking Whether to block until all blocks are deleted. + */ + def unpersist(blocking: Boolean): JavaDoubleRDD = fromRDD(srdd.unpersist(blocking)) + // first() has to be overriden here in order for its return type to be Double instead of Object. override def first(): Double = srdd.first() @@ -84,6 +97,17 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav fromRDD(srdd.coalesce(numPartitions, shuffle)) /** + * Return a new RDD that has exactly numPartitions partitions. + * + * Can increase or decrease the level of parallelism in this RDD. Internally, this uses + * a shuffle to redistribute data. + * + * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, + * which can avoid performing a shuffle. + */ + def repartition(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.repartition(numPartitions)) + + /** * Return an RDD with the elements from `this` that are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 899e17d4fa..eeea0eddb1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -66,6 +66,19 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K def persist(newLevel: StorageLevel): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.persist(newLevel)) + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * This method blocks until all blocks are deleted. + */ + def unpersist(): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist()) + + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * + * @param blocking Whether to block until all blocks are deleted. + */ + def unpersist(blocking: Boolean): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist(blocking)) + // Transformations (return a new RDD) /** @@ -96,6 +109,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K fromRDD(rdd.coalesce(numPartitions, shuffle)) /** + * Return a new RDD that has exactly numPartitions partitions. + * + * Can increase or decrease the level of parallelism in this RDD. Internally, this uses + * a shuffle to redistribute data. + * + * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, + * which can avoid performing a shuffle. + */ + def repartition(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.repartition(numPartitions)) + + /** * Return a sampled subset of this RDD. */ def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] = @@ -599,4 +623,15 @@ object JavaPairRDD { new JavaPairRDD[K, V](rdd) implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd + + + /** Convert a JavaRDD of key-value pairs to JavaPairRDD. */ + def fromJavaRDD[K, V](rdd: JavaRDD[(K, V)]): JavaPairRDD[K, V] = { + implicit val cmk: ClassTag[K] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] + implicit val cmv: ClassTag[V] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] + new JavaPairRDD[K, V](rdd.rdd) + } + } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 9968bc8e5f..c47657f512 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -43,9 +43,17 @@ JavaRDDLike[T, JavaRDD[T]] { /** * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * This method blocks until all blocks are deleted. */ def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist()) + /** + * Mark the RDD as non-persistent, and remove all blocks for it from memory and disk. + * + * @param blocking Whether to block until all blocks are deleted. + */ + def unpersist(blocking: Boolean): JavaRDD[T] = wrapRDD(rdd.unpersist(blocking)) + // Transformations (return a new RDD) /** @@ -76,6 +84,17 @@ JavaRDDLike[T, JavaRDD[T]] { rdd.coalesce(numPartitions, shuffle) /** + * Return a new RDD that has exactly numPartitions partitions. + * + * Can increase or decrease the level of parallelism in this RDD. Internally, this uses + * a shuffle to redistribute data. + * + * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, + * which can avoid performing a shuffle. + */ + def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions) + + /** * Return a sampled subset of this RDD. */ def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] = diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java index 4830067f7a..3e85052cd0 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFlatMapFunction.java @@ -18,8 +18,6 @@ package org.apache.spark.api.java.function; -import scala.runtime.AbstractFunction1; - import java.io.Serializable; /** @@ -27,11 +25,7 @@ import java.io.Serializable; */ // DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is // overloaded for both FlatMapFunction and DoubleFlatMapFunction. -public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>> +public abstract class DoubleFlatMapFunction<T> extends WrappedFunction1<T, Iterable<Double>> implements Serializable { - - public abstract Iterable<Double> call(T t); - - @Override - public final Iterable<Double> apply(T t) { return call(t); } + // Intentionally left blank } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java index ed92d31af5..5e9b8c48b8 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/DoubleFunction.java @@ -27,6 +27,5 @@ import java.io.Serializable; // are overloaded for both Function and DoubleFunction. public abstract class DoubleFunction<T> extends WrappedFunction1<T, Double> implements Serializable { - - public abstract Double call(T t) throws Exception; + // Intentionally left blank } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala index b7c0d78e33..ed8fea97fc 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala +++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala @@ -23,8 +23,5 @@ import scala.reflect.ClassTag * A function that returns zero or more output records from each input record. */ abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] { - @throws(classOf[Exception]) - def call(x: T) : java.lang.Iterable[R] - def elementType() : ClassTag[R] = ClassTag.Any.asInstanceOf[ClassTag[R]] } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala index 7a505df4be..aae1349c5e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala +++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala @@ -23,8 +23,5 @@ import scala.reflect.ClassTag * A function that takes two inputs and returns zero or more output records. */ abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] { - @throws(classOf[Exception]) - def call(a: A, b:B) : java.lang.Iterable[C] - def elementType() : ClassTag[C] = ClassTag.Any.asInstanceOf[ClassTag[C]] } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function.java b/core/src/main/scala/org/apache/spark/api/java/function/Function.java index e97116986f..49e661a376 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/Function.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/Function.java @@ -32,7 +32,7 @@ public abstract class Function<T, R> extends WrappedFunction1<T, R> implements S public abstract R call(T t) throws Exception; public ClassTag<R> returnType() { - return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class); + return ClassTag$.MODULE$.apply(Object.class); } } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function3.java b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java new file mode 100644 index 0000000000..fb1deceab5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; +import scala.runtime.AbstractFunction2; + +import java.io.Serializable; + +/** + * A three-argument function that takes arguments of type T1, T2 and T3 and returns an R. + */ +public abstract class Function3<T1, T2, T3, R> extends WrappedFunction3<T1, T2, T3, R> + implements Serializable { + + public ClassTag<R> returnType() { + return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class); + } +} + diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java index fbd0cdabe0..ca485b3cc2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java @@ -33,8 +33,6 @@ public abstract class PairFlatMapFunction<T, K, V> extends WrappedFunction1<T, Iterable<Tuple2<K, V>>> implements Serializable { - public abstract Iterable<Tuple2<K, V>> call(T t) throws Exception; - public ClassTag<K> keyType() { return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class); } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java index f09559627d..cbe2306026 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java @@ -28,12 +28,9 @@ import java.io.Serializable; */ // PairFunction does not extend Function because some UDF functions, like map, // are overloaded for both Function and PairFunction. -public abstract class PairFunction<T, K, V> - extends WrappedFunction1<T, Tuple2<K, V>> +public abstract class PairFunction<T, K, V> extends WrappedFunction1<T, Tuple2<K, V>> implements Serializable { - public abstract Tuple2<K, V> call(T t) throws Exception; - public ClassTag<K> keyType() { return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class); } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala new file mode 100644 index 0000000000..d314dbdf1d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/function/WrappedFunction3.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function + +import scala.runtime.AbstractFunction3 + +/** + * Subclass of Function3 for ease of calling from Java. The main thing it does is re-expose the + * apply() method as call() and declare that it can throw Exception (since AbstractFunction3.apply + * isn't marked to allow that). + */ +private[spark] abstract class WrappedFunction3[T1, T2, T3, R] + extends AbstractFunction3[T1, T2, T3, R] { + @throws(classOf[Exception]) + def call(t1: T1, t2: T2, t3: T3): R + + final def apply(t1: T1, t2: T2, t3: T3): R = call(t1, t2, t3) +} + diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 4d887cf195..53b53df9ac 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -308,7 +308,7 @@ private class BytesToString extends org.apache.spark.api.java.function.Function[ * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it * collects a list of pickled strings that we pass to Python through a socket. */ -class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) +private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) extends AccumulatorParam[JList[Array[Byte]]] { Utils.checkHost(serverHost, "Expected hostname") diff --git a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala deleted file mode 100644 index 93e7815ab5..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala +++ /dev/null @@ -1,1058 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.io._ -import java.net._ -import java.util.{BitSet, Comparator, Timer, TimerTask, UUID} -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable.{ListBuffer, Map, Set} -import scala.math - -import org.apache.spark._ -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils - -private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) - extends Broadcast[T](id) - with Logging - with Serializable { - - def value = value_ - - def blockId: String = "broadcast_" + id - - MultiTracker.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } - - @transient var arrayOfBlocks: Array[BroadcastBlock] = null - @transient var hasBlocksBitVector: BitSet = null - @transient var numCopiesSent: Array[Int] = null - @transient var totalBytes = -1 - @transient var totalBlocks = -1 - @transient var hasBlocks = new AtomicInteger(0) - - // Used ONLY by driver to track how many unique blocks have been sent out - @transient var sentBlocks = new AtomicInteger(0) - - @transient var listenPortLock = new Object - @transient var guidePortLock = new Object - @transient var totalBlocksLock = new Object - - @transient var listOfSources = ListBuffer[SourceInfo]() - - @transient var serveMR: ServeMultipleRequests = null - - // Used only in driver - @transient var guideMR: GuideMultipleRequests = null - - // Used only in Workers - @transient var ttGuide: TalkToGuide = null - - @transient var hostAddress = Utils.localIpAddress - @transient var listenPort = -1 - @transient var guidePort = -1 - - @transient var stopBroadcast = false - - // Must call this after all the variables have been created/initialized - if (!isLocal) { - sendBroadcast() - } - - def sendBroadcast() { - logInfo("Local host address: " + hostAddress) - - // Create a variableInfo object and store it in valueInfos - var variableInfo = MultiTracker.blockifyObject(value_) - - // Prepare the value being broadcasted - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks.set(variableInfo.totalBlocks) - - // Guide has all the blocks - hasBlocksBitVector = new BitSet(totalBlocks) - hasBlocksBitVector.set(0, totalBlocks) - - // Guide still hasn't sent any block - numCopiesSent = new Array[Int](totalBlocks) - - guideMR = new GuideMultipleRequests - guideMR.setDaemon(true) - guideMR.start() - logInfo("GuideMultipleRequests started...") - - // Must always come AFTER guideMR is created - while (guidePort == -1) { - guidePortLock.synchronized { guidePortLock.wait() } - } - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - // Must always come AFTER serveMR is created - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Must always come AFTER listenPort is created - val driverSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) - hasBlocksBitVector.synchronized { - driverSource.hasBlocksBitVector = hasBlocksBitVector - } - - // In the beginning, this is the only known source to Guide - listOfSources += driverSource - - // Register with the Tracker - MultiTracker.registerBroadcast(id, - SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) - } - - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - MultiTracker.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { - case Some(x) => - value_ = x.asInstanceOf[T] - - case None => - logInfo("Started reading broadcast variable " + id) - // Initializing everything because driver will only send null/0 values - // Only the 1st worker in a node can be here. Others will get from cache - initializeWorkerVariables() - - logInfo("Local host address: " + hostAddress) - - // Start local ServeMultipleRequests thread first - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast(id) - if (receptionSucceeded) { - value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } else { - logError("Reading broadcast variable " + id + " failed") - } - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - - // Initialize variables in the worker node. Driver sends everything as 0/null - private def initializeWorkerVariables() { - arrayOfBlocks = null - hasBlocksBitVector = null - numCopiesSent = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = new AtomicInteger(0) - - listenPortLock = new Object - totalBlocksLock = new Object - - serveMR = null - ttGuide = null - - hostAddress = Utils.localIpAddress - listenPort = -1 - - listOfSources = ListBuffer[SourceInfo]() - - stopBroadcast = false - } - - private def getLocalSourceInfo: SourceInfo = { - // Wait till hostName and listenPort are OK - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Wait till totalBlocks and totalBytes are OK - while (totalBlocks == -1) { - totalBlocksLock.synchronized { totalBlocksLock.wait() } - } - - var localSourceInfo = SourceInfo( - hostAddress, listenPort, totalBlocks, totalBytes) - - localSourceInfo.hasBlocks = hasBlocks.get - - hasBlocksBitVector.synchronized { - localSourceInfo.hasBlocksBitVector = hasBlocksBitVector - } - - return localSourceInfo - } - - // Add new SourceInfo to the listOfSources. Update if it exists already. - // Optimizing just by OR-ing the BitVectors was BAD for performance - private def addToListOfSources(newSourceInfo: SourceInfo) { - listOfSources.synchronized { - if (listOfSources.contains(newSourceInfo)) { - listOfSources = listOfSources - newSourceInfo - } - listOfSources += newSourceInfo - } - } - - private def addToListOfSources(newSourceInfos: ListBuffer[SourceInfo]) { - newSourceInfos.foreach { newSourceInfo => - addToListOfSources(newSourceInfo) - } - } - - class TalkToGuide(gInfo: SourceInfo) - extends Thread with Logging { - override def run() { - - // Keep exchaning information until all blocks have been received - while (hasBlocks.get < totalBlocks) { - talkOnce - Thread.sleep(MultiTracker.ranGen.nextInt( - MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + - MultiTracker.MinKnockInterval) - } - - // Talk one more time to let the Guide know of reception completion - talkOnce - } - - // Connect to Guide and send this worker's information - private def talkOnce { - var clientSocketToGuide: Socket = null - var oosGuide: ObjectOutputStream = null - var oisGuide: ObjectInputStream = null - - clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort) - oosGuide = new ObjectOutputStream(clientSocketToGuide.getOutputStream) - oosGuide.flush() - oisGuide = new ObjectInputStream(clientSocketToGuide.getInputStream) - - // Send local information - oosGuide.writeObject(getLocalSourceInfo) - oosGuide.flush() - - // Receive source information from Guide - var suitableSources = - oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]] - logDebug("Received suitableSources from Driver " + suitableSources) - - addToListOfSources(suitableSources) - - oisGuide.close() - oosGuide.close() - clientSocketToGuide.close() - } - } - - def receiveBroadcast(variableID: Long): Boolean = { - val gInfo = MultiTracker.getGuideInfo(variableID) - - if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { - return false - } - - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Setup initial states of variables - totalBlocks = gInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) - hasBlocksBitVector = new BitSet(totalBlocks) - numCopiesSent = new Array[Int](totalBlocks) - totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } - totalBytes = gInfo.totalBytes - - // Start ttGuide to periodically talk to the Guide - var ttGuide = new TalkToGuide(gInfo) - ttGuide.setDaemon(true) - ttGuide.start() - logInfo("TalkToGuide started...") - - // Start pController to run TalkToPeer threads - var pcController = new PeerChatterController - pcController.setDaemon(true) - pcController.start() - logInfo("PeerChatterController started...") - - // FIXME: Must fix this. This might never break if broadcast fails. - // We should be able to break and send false. Also need to kill threads - while (hasBlocks.get < totalBlocks) { - Thread.sleep(MultiTracker.MaxKnockInterval) - } - - return true - } - - class PeerChatterController - extends Thread with Logging { - private var peersNowTalking = ListBuffer[SourceInfo]() - // TODO: There is a possible bug with blocksInRequestBitVector when a - // certain bit is NOT unset upon failure resulting in an infinite loop. - private var blocksInRequestBitVector = new BitSet(totalBlocks) - - override def run() { - var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots) - - while (hasBlocks.get < totalBlocks) { - var numThreadsToCreate = 0 - listOfSources.synchronized { - numThreadsToCreate = math.min(listOfSources.size, MultiTracker.MaxChatSlots) - - threadPool.getActiveCount - } - - while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) { - var peerToTalkTo = pickPeerToTalkToRandom - - if (peerToTalkTo != null) - logDebug("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector) - else - logDebug("No peer chosen...") - - if (peerToTalkTo != null) { - threadPool.execute(new TalkToPeer(peerToTalkTo)) - - // Add to peersNowTalking. Remove in the thread. We have to do this - // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once - peersNowTalking.synchronized { peersNowTalking += peerToTalkTo } - } - - numThreadsToCreate = numThreadsToCreate - 1 - } - - // Sleep for a while before starting some more threads - Thread.sleep(MultiTracker.MinKnockInterval) - } - // Shutdown the thread pool - threadPool.shutdown() - } - - // Right now picking the one that has the most blocks this peer wants - // Also picking peer randomly if no one has anything interesting - private def pickPeerToTalkToRandom: SourceInfo = { - var curPeer: SourceInfo = null - var curMax = 0 - - logDebug("Picking peers to talk to...") - - // Find peers that are not connected right now - var peersNotInUse = ListBuffer[SourceInfo]() - listOfSources.synchronized { - peersNowTalking.synchronized { - peersNotInUse = listOfSources -- peersNowTalking - } - } - - // Select the peer that has the most blocks that this receiver does not - peersNotInUse.foreach { eachSource => - var tempHasBlocksBitVector: BitSet = null - hasBlocksBitVector.synchronized { - tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - tempHasBlocksBitVector.flip(0, tempHasBlocksBitVector.size) - tempHasBlocksBitVector.and(eachSource.hasBlocksBitVector) - - if (tempHasBlocksBitVector.cardinality > curMax) { - curPeer = eachSource - curMax = tempHasBlocksBitVector.cardinality - } - } - - // Always picking randomly - if (curPeer == null && peersNotInUse.size > 0) { - // Pick uniformly the i'th required peer - var i = MultiTracker.ranGen.nextInt(peersNotInUse.size) - - var peerIter = peersNotInUse.iterator - curPeer = peerIter.next - - while (i > 0) { - curPeer = peerIter.next - i = i - 1 - } - } - - return curPeer - } - - // Picking peer with the weight of rare blocks it has - private def pickPeerToTalkToRarestFirst: SourceInfo = { - // Find peers that are not connected right now - var peersNotInUse = ListBuffer[SourceInfo]() - listOfSources.synchronized { - peersNowTalking.synchronized { - peersNotInUse = listOfSources -- peersNowTalking - } - } - - // Count the number of copies of each block in the neighborhood - var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0) - - listOfSources.synchronized { - listOfSources.foreach { eachSource => - for (i <- 0 until totalBlocks) { - numCopiesPerBlock(i) += - ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 ) - } - } - } - - // A block is considered rare if there are at most 2 copies of that block - // This CONSTANT could be a function of the neighborhood size - var rareBlocksIndices = ListBuffer[Int]() - for (i <- 0 until totalBlocks) { - if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) { - rareBlocksIndices += i - } - } - - // Find peers with rare blocks - var peersWithRareBlocks = ListBuffer[(SourceInfo, Int)]() - var totalRareBlocks = 0 - - peersNotInUse.foreach { eachPeer => - var hasRareBlocks = 0 - rareBlocksIndices.foreach { rareBlock => - if (eachPeer.hasBlocksBitVector.get(rareBlock)) { - hasRareBlocks += 1 - } - } - - if (hasRareBlocks > 0) { - peersWithRareBlocks += ((eachPeer, hasRareBlocks)) - } - totalRareBlocks += hasRareBlocks - } - - // Select a peer from peersWithRareBlocks based on weight calculated from - // unique rare blocks - var selectedPeerToTalkTo: SourceInfo = null - - if (peersWithRareBlocks.size > 0) { - // Sort the peers based on how many rare blocks they have - peersWithRareBlocks.sortBy(_._2) - - var randomNumber = MultiTracker.ranGen.nextDouble - var tempSum = 0.0 - - var i = 0 - do { - tempSum += (1.0 * peersWithRareBlocks(i)._2 / totalRareBlocks) - if (tempSum >= randomNumber) { - selectedPeerToTalkTo = peersWithRareBlocks(i)._1 - } - i += 1 - } while (i < peersWithRareBlocks.size && selectedPeerToTalkTo == null) - } - - if (selectedPeerToTalkTo == null) { - selectedPeerToTalkTo = pickPeerToTalkToRandom - } - - return selectedPeerToTalkTo - } - - class TalkToPeer(peerToTalkTo: SourceInfo) - extends Thread with Logging { - private var peerSocketToSource: Socket = null - private var oosSource: ObjectOutputStream = null - private var oisSource: ObjectInputStream = null - - override def run() { - // TODO: There is a possible bug here regarding blocksInRequestBitVector - var blockToAskFor = -1 - - // Setup the timeout mechanism - var timeOutTask = new TimerTask { - override def run() { - cleanUpConnections() - } - } - - var timeOutTimer = new Timer - timeOutTimer.schedule(timeOutTask, MultiTracker.MaxKnockInterval) - - logInfo("TalkToPeer started... => " + peerToTalkTo) - - try { - // Connect to the source - peerSocketToSource = - new Socket(peerToTalkTo.hostAddress, peerToTalkTo.listenPort) - oosSource = - new ObjectOutputStream(peerSocketToSource.getOutputStream) - oosSource.flush() - oisSource = - new ObjectInputStream(peerSocketToSource.getInputStream) - - // Receive latest SourceInfo from peerToTalkTo - var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo] - // Update listOfSources - addToListOfSources(newPeerToTalkTo) - - // Turn the timer OFF, if the sender responds before timeout - timeOutTimer.cancel() - - // Send the latest SourceInfo - oosSource.writeObject(getLocalSourceInfo) - oosSource.flush() - - var keepReceiving = true - - while (hasBlocks.get < totalBlocks && keepReceiving) { - blockToAskFor = - pickBlockRandom(newPeerToTalkTo.hasBlocksBitVector) - - // No block to request - if (blockToAskFor < 0) { - // Nothing to receive from newPeerToTalkTo - keepReceiving = false - } else { - // Let other threads know that blockToAskFor is being requested - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor) - } - - // Start with sending the blockID - oosSource.writeObject(blockToAskFor) - oosSource.flush() - - // CHANGED: Driver might send some other block than the one - // requested to ensure fast spreading of all blocks. - val recvStartTime = System.currentTimeMillis - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - val receptionTime = (System.currentTimeMillis - recvStartTime) - - logDebug("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.") - - if (!hasBlocksBitVector.get(bcBlock.blockID)) { - arrayOfBlocks(bcBlock.blockID) = bcBlock - - // Update the hasBlocksBitVector first - hasBlocksBitVector.synchronized { - hasBlocksBitVector.set(bcBlock.blockID) - hasBlocks.getAndIncrement - } - - // Some block(may NOT be blockToAskFor) has arrived. - // In any case, blockToAskFor is not in request any more - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor, false) - } - - // Reset blockToAskFor to -1. Else it will be considered missing - blockToAskFor = -1 - } - - // Send the latest SourceInfo - oosSource.writeObject(getLocalSourceInfo) - oosSource.flush() - } - } - } catch { - // EOFException is expected to happen because sender can break - // connection due to timeout - case eofe: java.io.EOFException => { } - case e: Exception => { - logError("TalktoPeer had a " + e) - // FIXME: Remove 'newPeerToTalkTo' from listOfSources - // We probably should have the following in some form, but not - // really here. This exception can happen if the sender just breaks connection - // listOfSources.synchronized { - // logInfo("Exception in TalkToPeer. Removing source: " + peerToTalkTo) - // listOfSources = listOfSources - peerToTalkTo - // } - } - } finally { - // blockToAskFor != -1 => there was an exception - if (blockToAskFor != -1) { - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor, false) - } - } - - cleanUpConnections() - } - } - - // Right now it picks a block uniformly that this peer does not have - private def pickBlockRandom(txHasBlocksBitVector: BitSet): Int = { - var needBlocksBitVector: BitSet = null - - // Blocks already present - hasBlocksBitVector.synchronized { - needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - - // Include blocks already in transmission ONLY IF - // MultiTracker.EndGameFraction has NOT been achieved - if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { - blocksInRequestBitVector.synchronized { - needBlocksBitVector.or(blocksInRequestBitVector) - } - } - - // Find blocks that are neither here nor in transit - needBlocksBitVector.flip(0, needBlocksBitVector.size) - - // Blocks that should/can be requested - needBlocksBitVector.and(txHasBlocksBitVector) - - if (needBlocksBitVector.cardinality == 0) { - return -1 - } else { - // Pick uniformly the i'th required block - var i = MultiTracker.ranGen.nextInt(needBlocksBitVector.cardinality) - var pickedBlockIndex = needBlocksBitVector.nextSetBit(0) - - while (i > 0) { - pickedBlockIndex = - needBlocksBitVector.nextSetBit(pickedBlockIndex + 1) - i -= 1 - } - - return pickedBlockIndex - } - } - - // Pick the block that seems to be the rarest across sources - private def pickBlockRarestFirst(txHasBlocksBitVector: BitSet): Int = { - var needBlocksBitVector: BitSet = null - - // Blocks already present - hasBlocksBitVector.synchronized { - needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - - // Include blocks already in transmission ONLY IF - // MultiTracker.EndGameFraction has NOT been achieved - if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { - blocksInRequestBitVector.synchronized { - needBlocksBitVector.or(blocksInRequestBitVector) - } - } - - // Find blocks that are neither here nor in transit - needBlocksBitVector.flip(0, needBlocksBitVector.size) - - // Blocks that should/can be requested - needBlocksBitVector.and(txHasBlocksBitVector) - - if (needBlocksBitVector.cardinality == 0) { - return -1 - } else { - // Count the number of copies for each block across all sources - var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0) - - listOfSources.synchronized { - listOfSources.foreach { eachSource => - for (i <- 0 until totalBlocks) { - numCopiesPerBlock(i) += - ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 ) - } - } - } - - // Find the minimum - var minVal = Integer.MAX_VALUE - for (i <- 0 until totalBlocks) { - if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) < minVal) { - minVal = numCopiesPerBlock(i) - } - } - - // Find the blocks with the least copies that this peer does not have - var minBlocksIndices = ListBuffer[Int]() - for (i <- 0 until totalBlocks) { - if (needBlocksBitVector.get(i) && numCopiesPerBlock(i) == minVal) { - minBlocksIndices += i - } - } - - // Now select a random index from minBlocksIndices - if (minBlocksIndices.size == 0) { - return -1 - } else { - // Pick uniformly the i'th index - var i = MultiTracker.ranGen.nextInt(minBlocksIndices.size) - return minBlocksIndices(i) - } - } - } - - private def cleanUpConnections() { - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (peerSocketToSource != null) { - peerSocketToSource.close() - } - - // Delete from peersNowTalking - peersNowTalking.synchronized { peersNowTalking -= peerToTalkTo } - } - } - } - - class GuideMultipleRequests - extends Thread with Logging { - // Keep track of sources that have completed reception - private var setOfCompletedSources = Set[SourceInfo]() - - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) - guidePort = serverSocket.getLocalPort - logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - - guidePortLock.synchronized { guidePortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - // Stop broadcast if at least one worker has connected and - // everyone connected so far are done. Comparing with - // listOfSources.size - 1, because it includes the Guide itself - listOfSources.synchronized { - setOfCompletedSources.synchronized { - if (listOfSources.size > 1 && - setOfCompletedSources.size == listOfSources.size - 1) { - stopBroadcast = true - logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.") - } - } - } - } - } - if (clientSocket != null) { - logDebug("Guide: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new GuideSingleRequest(clientSocket)) - } catch { - // In failure, close the socket here; else, thread will close it - case ioe: IOException => { - clientSocket.close() - } - } - } - } - - // Shutdown the thread pool - threadPool.shutdown() - - logInfo("Sending stopBroadcast notifications...") - sendStopBroadcastNotifications - - MultiTracker.unregisterBroadcast(id) - } finally { - if (serverSocket != null) { - logInfo("GuideMultipleRequests now stopping...") - serverSocket.close() - } - } - } - - private def sendStopBroadcastNotifications() { - listOfSources.synchronized { - listOfSources.foreach { sourceInfo => - - var guideSocketToSource: Socket = null - var gosSource: ObjectOutputStream = null - var gisSource: ObjectInputStream = null - - try { - // Connect to the source - guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) - gosSource.flush() - gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) - - // Throw away whatever comes in - gisSource.readObject.asInstanceOf[SourceInfo] - - // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast - gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast)) - gosSource.flush() - } catch { - case e: Exception => { - logError("sendStopBroadcastNotifications had a " + e) - } - } finally { - if (gisSource != null) { - gisSource.close() - } - if (gosSource != null) { - gosSource.close() - } - if (guideSocketToSource != null) { - guideSocketToSource.close() - } - } - } - } - } - - class GuideSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var sourceInfo: SourceInfo = null - private var selectedSources: ListBuffer[SourceInfo] = null - - override def run() { - try { - logInfo("new GuideSingleRequest is running") - // Connecting worker is sending in its information - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Select a suitable source and send it back to the worker - selectedSources = selectSuitableSources(sourceInfo) - logDebug("Sending selectedSources:" + selectedSources) - oos.writeObject(selectedSources) - oos.flush() - - // Add this source to the listOfSources - addToListOfSources(sourceInfo) - } catch { - case e: Exception => { - // Assuming exception caused by receiver failure: remove - if (listOfSources != null) { - listOfSources.synchronized { listOfSources -= sourceInfo } - } - } - } finally { - logInfo("GuideSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - // Randomly select some sources to send back - private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = { - var selectedSources = ListBuffer[SourceInfo]() - - // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true' - // then add skipSourceInfo to setOfCompletedSources. Return blank. - if (skipSourceInfo.hasBlocks == totalBlocks) { - setOfCompletedSources.synchronized { setOfCompletedSources += skipSourceInfo } - return selectedSources - } - - listOfSources.synchronized { - if (listOfSources.size <= MultiTracker.MaxPeersInGuideResponse) { - selectedSources = listOfSources.clone - } else { - var picksLeft = MultiTracker.MaxPeersInGuideResponse - var alreadyPicked = new BitSet(listOfSources.size) - - while (picksLeft > 0) { - var i = -1 - - do { - i = MultiTracker.ranGen.nextInt(listOfSources.size) - } while (alreadyPicked.get(i)) - - var peerIter = listOfSources.iterator - var curPeer = peerIter.next - - // Set the BitSet before i is decremented - alreadyPicked.set(i) - - while (i > 0) { - curPeer = peerIter.next - i = i - 1 - } - - selectedSources += curPeer - - picksLeft = picksLeft - 1 - } - } - } - - // Remove the receiving source (if present) - selectedSources = selectedSources - skipSourceInfo - - return selectedSources - } - } - } - - class ServeMultipleRequests - extends Thread with Logging { - // Server at most MultiTracker.MaxChatSlots peers - var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots) - - override def run() { - var serverSocket = new ServerSocket(0) - listenPort = serverSocket.getLocalPort - - logInfo("ServeMultipleRequests started with " + serverSocket) - - listenPortLock.synchronized { listenPortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { } - } - if (clientSocket != null) { - logDebug("Serve: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new ServeSingleRequest(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ServeMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ServeSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - logInfo("new ServeSingleRequest is running") - - override def run() { - try { - // Send latest local SourceInfo to the receiver - // In the case of receiver timeout and connection close, this will - // throw a java.net.SocketException: Broken pipe - oos.writeObject(getLocalSourceInfo) - oos.flush() - - // Receive latest SourceInfo from the receiver - var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) { - stopBroadcast = true - } else { - addToListOfSources(rxSourceInfo) - } - - val startTime = System.currentTimeMillis - var curTime = startTime - var keepSending = true - var numBlocksToSend = MultiTracker.MaxChatBlocks - - while (!stopBroadcast && keepSending && numBlocksToSend > 0) { - // Receive which block to send - var blockToSend = ois.readObject.asInstanceOf[Int] - - // If it is driver AND at least one copy of each block has not been - // sent out already, MODIFY blockToSend - if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) { - blockToSend = sentBlocks.getAndIncrement - } - - // Send the block - sendBlock(blockToSend) - rxSourceInfo.hasBlocksBitVector.set(blockToSend) - - numBlocksToSend -= 1 - - // Receive latest SourceInfo from the receiver - rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] - logDebug("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector) - addToListOfSources(rxSourceInfo) - - curTime = System.currentTimeMillis - // Revoke sending only if there is anyone waiting in the queue - if (curTime - startTime >= MultiTracker.MaxChatTime && - threadPool.getQueue.size > 0) { - keepSending = false - } - } - } catch { - case e: Exception => logError("ServeSingleRequest had a " + e) - } finally { - logInfo("ServeSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - private def sendBlock(blockToSend: Int) { - try { - oos.writeObject(arrayOfBlocks(blockToSend)) - oos.flush() - } catch { - case e: Exception => logError("sendBlock had a " + e) - } - logDebug("Sent block: " + blockToSend + " to " + clientSocket) - } - } - } -} - -private[spark] class BitTorrentBroadcastFactory -extends BroadcastFactory { - def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new BitTorrentBroadcast[T](value_, isLocal, id) - - def stop() { MultiTracker.stop() } -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 9db26ae6de..609464e38d 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -25,16 +25,15 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream import org.apache.spark.{HttpServer, Logging, SparkEnv} import org.apache.spark.io.CompressionCodec -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashSet} - +import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} +import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { def value = value_ - def blockId: String = "broadcast_" + id + def blockId = BroadcastBlockId(id) HttpBroadcast.synchronized { SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) @@ -82,7 +81,7 @@ private object HttpBroadcast extends Logging { private var server: HttpServer = null private val files = new TimeStampedHashSet[String] - private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup) + private val cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup) private lazy val compressionCodec = CompressionCodec.createCodec() @@ -121,7 +120,7 @@ private object HttpBroadcast extends Logging { } def write(id: Long, value: Any) { - val file = new File(broadcastDir, "broadcast-" + id) + val file = new File(broadcastDir, BroadcastBlockId(id).name) val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(new FileOutputStream(file)) @@ -137,7 +136,7 @@ private object HttpBroadcast extends Logging { } def read[T](id: Long): T = { - val url = serverUri + "/broadcast-" + id + val url = serverUri + "/" + BroadcastBlockId(id).name val in = { if (compress) { compressionCodec.compressedInputStream(new URL(url).openStream()) diff --git a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala b/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala deleted file mode 100644 index 21ec94659e..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala +++ /dev/null @@ -1,410 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.io._ -import java.net._ -import java.util.Random - -import scala.collection.mutable.Map - -import org.apache.spark._ -import org.apache.spark.util.Utils - -private object MultiTracker -extends Logging { - - // Tracker Messages - val REGISTER_BROADCAST_TRACKER = 0 - val UNREGISTER_BROADCAST_TRACKER = 1 - val FIND_BROADCAST_TRACKER = 2 - - // Map to keep track of guides of ongoing broadcasts - var valueToGuideMap = Map[Long, SourceInfo]() - - // Random number generator - var ranGen = new Random - - private var initialized = false - private var _isDriver = false - - private var stopBroadcast = false - - private var trackMV: TrackMultipleValues = null - - def initialize(__isDriver: Boolean) { - synchronized { - if (!initialized) { - _isDriver = __isDriver - - if (isDriver) { - trackMV = new TrackMultipleValues - trackMV.setDaemon(true) - trackMV.start() - - // Set DriverHostAddress to the driver's IP address for the slaves to read - System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress) - } - - initialized = true - } - } - } - - def stop() { - stopBroadcast = true - } - - // Load common parameters - private var DriverHostAddress_ = System.getProperty( - "spark.MultiTracker.DriverHostAddress", "") - private var DriverTrackerPort_ = System.getProperty( - "spark.broadcast.driverTrackerPort", "11111").toInt - private var BlockSize_ = System.getProperty( - "spark.broadcast.blockSize", "4096").toInt * 1024 - private var MaxRetryCount_ = System.getProperty( - "spark.broadcast.maxRetryCount", "2").toInt - - private var TrackerSocketTimeout_ = System.getProperty( - "spark.broadcast.trackerSocketTimeout", "50000").toInt - private var ServerSocketTimeout_ = System.getProperty( - "spark.broadcast.serverSocketTimeout", "10000").toInt - - private var MinKnockInterval_ = System.getProperty( - "spark.broadcast.minKnockInterval", "500").toInt - private var MaxKnockInterval_ = System.getProperty( - "spark.broadcast.maxKnockInterval", "999").toInt - - // Load TreeBroadcast config params - private var MaxDegree_ = System.getProperty( - "spark.broadcast.maxDegree", "2").toInt - - // Load BitTorrentBroadcast config params - private var MaxPeersInGuideResponse_ = System.getProperty( - "spark.broadcast.maxPeersInGuideResponse", "4").toInt - - private var MaxChatSlots_ = System.getProperty( - "spark.broadcast.maxChatSlots", "4").toInt - private var MaxChatTime_ = System.getProperty( - "spark.broadcast.maxChatTime", "500").toInt - private var MaxChatBlocks_ = System.getProperty( - "spark.broadcast.maxChatBlocks", "1024").toInt - - private var EndGameFraction_ = System.getProperty( - "spark.broadcast.endGameFraction", "0.95").toDouble - - def isDriver = _isDriver - - // Common config params - def DriverHostAddress = DriverHostAddress_ - def DriverTrackerPort = DriverTrackerPort_ - def BlockSize = BlockSize_ - def MaxRetryCount = MaxRetryCount_ - - def TrackerSocketTimeout = TrackerSocketTimeout_ - def ServerSocketTimeout = ServerSocketTimeout_ - - def MinKnockInterval = MinKnockInterval_ - def MaxKnockInterval = MaxKnockInterval_ - - // TreeBroadcast configs - def MaxDegree = MaxDegree_ - - // BitTorrentBroadcast configs - def MaxPeersInGuideResponse = MaxPeersInGuideResponse_ - - def MaxChatSlots = MaxChatSlots_ - def MaxChatTime = MaxChatTime_ - def MaxChatBlocks = MaxChatBlocks_ - - def EndGameFraction = EndGameFraction_ - - class TrackMultipleValues - extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(DriverTrackerPort) - logInfo("TrackMultipleValues started at " + serverSocket) - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(TrackerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - if (stopBroadcast) { - logInfo("Stopping TrackMultipleValues...") - } - } - } - - if (clientSocket != null) { - try { - threadPool.execute(new Thread { - override def run() { - val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - val ois = new ObjectInputStream(clientSocket.getInputStream) - - try { - // First, read message type - val messageType = ois.readObject.asInstanceOf[Int] - - if (messageType == REGISTER_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - // Receive hostAddress and listenPort - val gInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Add to the map - valueToGuideMap.synchronized { - valueToGuideMap += (id -> gInfo) - } - - logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == UNREGISTER_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - - // Remove from the map - valueToGuideMap.synchronized { - valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault) - } - - logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == FIND_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - - var gInfo = - if (valueToGuideMap.contains(id)) valueToGuideMap(id) - else SourceInfo("", SourceInfo.TxNotStartedRetry) - - logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort) - - // Send reply back - oos.writeObject(gInfo) - oos.flush() - } else { - throw new SparkException("Undefined messageType at TrackMultipleValues") - } - } catch { - case e: Exception => { - logError("TrackMultipleValues had a " + e) - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - }) - } catch { - // In failure, close socket here; else, client thread will close - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - serverSocket.close() - } - // Shutdown the thread pool - threadPool.shutdown() - } - } - - def getGuideInfo(variableLong: Long): SourceInfo = { - var clientSocketToTracker: Socket = null - var oosTracker: ObjectOutputStream = null - var oisTracker: ObjectInputStream = null - - var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry) - - var retriesLeft = MultiTracker.MaxRetryCount - do { - try { - // Connect to the tracker to find out GuideInfo - clientSocketToTracker = - new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort) - oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - // Send messageType/intention - oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER) - oosTracker.flush() - - // Send Long and receive GuideInfo - oosTracker.writeObject(variableLong) - oosTracker.flush() - gInfo = oisTracker.readObject.asInstanceOf[SourceInfo] - } catch { - case e: Exception => logError("getGuideInfo had a " + e) - } finally { - if (oisTracker != null) { - oisTracker.close() - } - if (oosTracker != null) { - oosTracker.close() - } - if (clientSocketToTracker != null) { - clientSocketToTracker.close() - } - } - - Thread.sleep(MultiTracker.ranGen.nextInt( - MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + - MultiTracker.MinKnockInterval) - - retriesLeft -= 1 - } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry) - - logDebug("Got this guidePort from Tracker: " + gInfo.listenPort) - return gInfo - } - - def registerBroadcast(id: Long, gInfo: SourceInfo) { - val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(REGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send Long of this broadcast - oosST.writeObject(id) - oosST.flush() - - // Send this tracker's information - oosST.writeObject(gInfo) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - - def unregisterBroadcast(id: Long) { - val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(UNREGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send Long of this broadcast - oosST.writeObject(id) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - - // Helper method to convert an object to Array[BroadcastBlock] - def blockifyObject[IN](obj: IN): VariableInfo = { - val baos = new ByteArrayOutputStream - val oos = new ObjectOutputStream(baos) - oos.writeObject(obj) - oos.close() - baos.close() - val byteArray = baos.toByteArray - val bais = new ByteArrayInputStream(byteArray) - - var blockNum = (byteArray.length / BlockSize) - if (byteArray.length % BlockSize != 0) - blockNum += 1 - - var retVal = new Array[BroadcastBlock](blockNum) - var blockID = 0 - - for (i <- 0 until (byteArray.length, BlockSize)) { - val thisBlockSize = math.min(BlockSize, byteArray.length - i) - var tempByteArray = new Array[Byte](thisBlockSize) - val hasRead = bais.read(tempByteArray, 0, thisBlockSize) - - retVal(blockID) = new BroadcastBlock(blockID, tempByteArray) - blockID += 1 - } - bais.close() - - var variableInfo = VariableInfo(retVal, blockNum, byteArray.length) - variableInfo.hasBlocks = blockNum - - return variableInfo - } - - // Helper method to convert Array[BroadcastBlock] to object - def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock], - totalBytes: Int, - totalBlocks: Int): OUT = { - - var retByteArray = new Array[Byte](totalBytes) - for (i <- 0 until totalBlocks) { - System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, - i * BlockSize, arrayOfBlocks(i).byteArray.length) - } - byteArrayToObject(retByteArray) - } - - private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = { - val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){ - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader) - } - val retVal = in.readObject.asInstanceOf[OUT] - in.close() - return retVal - } -} - -private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte]) -extends Serializable - -private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock], - totalBlocks: Int, - totalBytes: Int) -extends Serializable { - @transient var hasBlocks = 0 -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala b/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala deleted file mode 100644 index baa1fd6da4..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.util.BitSet - -import org.apache.spark._ - -/** - * Used to keep and pass around information of peers involved in a broadcast - */ -private[spark] case class SourceInfo (hostAddress: String, - listenPort: Int, - totalBlocks: Int = SourceInfo.UnusedParam, - totalBytes: Int = SourceInfo.UnusedParam) -extends Comparable[SourceInfo] with Logging { - - var currentLeechers = 0 - var receptionFailed = false - - var hasBlocks = 0 - var hasBlocksBitVector: BitSet = new BitSet (totalBlocks) - - // Ascending sort based on leecher count - def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers) -} - -/** - * Helper Object of SourceInfo for its constants - */ -private[spark] object SourceInfo { - // Broadcast has not started yet! Should never happen. - val TxNotStartedRetry = -1 - // Broadcast has already finished. Try default mechanism. - val TxOverGoToDefault = -3 - // Other constants - val StopBroadcast = -2 - val UnusedParam = 0 -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala new file mode 100644 index 0000000000..073a0a5029 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.io._ + +import scala.math +import scala.util.Random + +import org.apache.spark._ +import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} +import org.apache.spark.util.Utils + + +private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +extends Broadcast[T](id) with Logging with Serializable { + + def value = value_ + + def broadcastId = BroadcastBlockId(id) + + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + } + + @transient var arrayOfBlocks: Array[TorrentBlock] = null + @transient var totalBlocks = -1 + @transient var totalBytes = -1 + @transient var hasBlocks = 0 + + if (!isLocal) { + sendBroadcast() + } + + def sendBroadcast() { + var tInfo = TorrentBroadcast.blockifyObject(value_) + + totalBlocks = tInfo.totalBlocks + totalBytes = tInfo.totalBytes + hasBlocks = tInfo.totalBlocks + + // Store meta-info + val metaId = BroadcastHelperBlockId(broadcastId, "meta") + val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle( + metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true) + } + + // Store individual pieces + for (i <- 0 until totalBlocks) { + val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle( + pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true) + } + } + } + + // Called by JVM when deserializing an object + private def readObject(in: ObjectInputStream) { + in.defaultReadObject() + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(broadcastId) match { + case Some(x) => + value_ = x.asInstanceOf[T] + + case None => + val start = System.nanoTime + logInfo("Started reading broadcast variable " + id) + + // Initialize @transient variables that will receive garbage values from the master. + resetWorkerVariables() + + if (receiveBroadcast(id)) { + value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) + + // Store the merged copy in cache so that the next worker doesn't need to rebuild it. + // This creates a tradeoff between memory usage and latency. + // Storing copy doubles the memory footprint; not storing doubles deserialization cost. + SparkEnv.get.blockManager.putSingle( + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + + // Remove arrayOfBlocks from memory once value_ is on local cache + resetWorkerVariables() + } else { + logError("Reading broadcast variable " + id + " failed") + } + + val time = (System.nanoTime - start) / 1e9 + logInfo("Reading broadcast variable " + id + " took " + time + " s") + } + } + } + + private def resetWorkerVariables() { + arrayOfBlocks = null + totalBytes = -1 + totalBlocks = -1 + hasBlocks = 0 + } + + def receiveBroadcast(variableID: Long): Boolean = { + // Receive meta-info + val metaId = BroadcastHelperBlockId(broadcastId, "meta") + var attemptId = 10 + while (attemptId > 0 && totalBlocks == -1) { + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(metaId) match { + case Some(x) => + val tInfo = x.asInstanceOf[TorrentInfo] + totalBlocks = tInfo.totalBlocks + totalBytes = tInfo.totalBytes + arrayOfBlocks = new Array[TorrentBlock](totalBlocks) + hasBlocks = 0 + + case None => + Thread.sleep(500) + } + } + attemptId -= 1 + } + if (totalBlocks == -1) { + return false + } + + // Receive actual blocks + val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) + for (pid <- recvOrder) { + val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(pieceId) match { + case Some(x) => + arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] + hasBlocks += 1 + SparkEnv.get.blockManager.putSingle( + pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true) + + case None => + throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) + } + } + } + + (hasBlocks == totalBlocks) + } + +} + +private object TorrentBroadcast +extends Logging { + + private var initialized = false + + def initialize(_isDriver: Boolean) { + synchronized { + if (!initialized) { + initialized = true + } + } + } + + def stop() { + initialized = false + } + + val BLOCK_SIZE = System.getProperty("spark.broadcast.blockSize", "4096").toInt * 1024 + + def blockifyObject[T](obj: T): TorrentInfo = { + val byteArray = Utils.serialize[T](obj) + val bais = new ByteArrayInputStream(byteArray) + + var blockNum = (byteArray.length / BLOCK_SIZE) + if (byteArray.length % BLOCK_SIZE != 0) + blockNum += 1 + + var retVal = new Array[TorrentBlock](blockNum) + var blockID = 0 + + for (i <- 0 until (byteArray.length, BLOCK_SIZE)) { + val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i) + var tempByteArray = new Array[Byte](thisBlockSize) + val hasRead = bais.read(tempByteArray, 0, thisBlockSize) + + retVal(blockID) = new TorrentBlock(blockID, tempByteArray) + blockID += 1 + } + bais.close() + + var tInfo = TorrentInfo(retVal, blockNum, byteArray.length) + tInfo.hasBlocks = blockNum + + return tInfo + } + + def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock], + totalBytes: Int, + totalBlocks: Int): T = { + var retByteArray = new Array[Byte](totalBytes) + for (i <- 0 until totalBlocks) { + System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, + i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length) + } + Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader) + } + +} + +private[spark] case class TorrentBlock( + blockID: Int, + byteArray: Array[Byte]) + extends Serializable + +private[spark] case class TorrentInfo( + @transient arrayOfBlocks : Array[TorrentBlock], + totalBlocks: Int, + totalBytes: Int) + extends Serializable { + + @transient var hasBlocks = 0 +} + +private[spark] class TorrentBroadcastFactory + extends BroadcastFactory { + + def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new TorrentBroadcast[T](value_, isLocal, id) + + def stop() { TorrentBroadcast.stop() } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala deleted file mode 100644 index 80c97ca073..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala +++ /dev/null @@ -1,603 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.io._ -import java.net._ -import java.util.{Comparator, Random, UUID} - -import scala.collection.mutable.{ListBuffer, Map, Set} -import scala.math - -import org.apache.spark._ -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils - -private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) -extends Broadcast[T](id) with Logging with Serializable { - - def value = value_ - - def blockId = "broadcast_" + id - - MultiTracker.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } - - @transient var arrayOfBlocks: Array[BroadcastBlock] = null - @transient var totalBytes = -1 - @transient var totalBlocks = -1 - @transient var hasBlocks = 0 - - @transient var listenPortLock = new Object - @transient var guidePortLock = new Object - @transient var totalBlocksLock = new Object - @transient var hasBlocksLock = new Object - - @transient var listOfSources = ListBuffer[SourceInfo]() - - @transient var serveMR: ServeMultipleRequests = null - @transient var guideMR: GuideMultipleRequests = null - - @transient var hostAddress = Utils.localIpAddress - @transient var listenPort = -1 - @transient var guidePort = -1 - - @transient var stopBroadcast = false - - // Must call this after all the variables have been created/initialized - if (!isLocal) { - sendBroadcast() - } - - def sendBroadcast() { - logInfo("Local host address: " + hostAddress) - - // Create a variableInfo object and store it in valueInfos - var variableInfo = MultiTracker.blockifyObject(value_) - - // Prepare the value being broadcasted - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks = variableInfo.totalBlocks - - guideMR = new GuideMultipleRequests - guideMR.setDaemon(true) - guideMR.start() - logInfo("GuideMultipleRequests started...") - - // Must always come AFTER guideMR is created - while (guidePort == -1) { - guidePortLock.synchronized { guidePortLock.wait() } - } - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - // Must always come AFTER serveMR is created - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Must always come AFTER listenPort is created - val masterSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) - listOfSources += masterSource - - // Register with the Tracker - MultiTracker.registerBroadcast(id, - SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) - } - - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - MultiTracker.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { - case Some(x) => - value_ = x.asInstanceOf[T] - - case None => - logInfo("Started reading broadcast variable " + id) - // Initializing everything because Driver will only send null/0 values - // Only the 1st worker in a node can be here. Others will get from cache - initializeWorkerVariables() - - logInfo("Local host address: " + hostAddress) - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast(id) - if (receptionSucceeded) { - value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } else { - logError("Reading broadcast variable " + id + " failed") - } - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - - private def initializeWorkerVariables() { - arrayOfBlocks = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = 0 - - listenPortLock = new Object - totalBlocksLock = new Object - hasBlocksLock = new Object - - serveMR = null - - hostAddress = Utils.localIpAddress - listenPort = -1 - - stopBroadcast = false - } - - def receiveBroadcast(variableID: Long): Boolean = { - val gInfo = MultiTracker.getGuideInfo(variableID) - - if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { - return false - } - - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - var clientSocketToDriver: Socket = null - var oosDriver: ObjectOutputStream = null - var oisDriver: ObjectInputStream = null - - // Connect and receive broadcast from the specified source, retrying the - // specified number of times in case of failures - var retriesLeft = MultiTracker.MaxRetryCount - do { - // Connect to Driver and send this worker's Information - clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort) - oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream) - oosDriver.flush() - oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream) - - logDebug("Connected to Driver's guiding object") - - // Send local source information - oosDriver.writeObject(SourceInfo(hostAddress, listenPort)) - oosDriver.flush() - - // Receive source information from Driver - var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo] - totalBlocks = sourceInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) - totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } - totalBytes = sourceInfo.totalBytes - - logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort) - - val start = System.nanoTime - val receptionSucceeded = receiveSingleTransmission(sourceInfo) - val time = (System.nanoTime - start) / 1e9 - - // Updating some statistics in sourceInfo. Driver will be using them later - if (!receptionSucceeded) { - sourceInfo.receptionFailed = true - } - - // Send back statistics to the Driver - oosDriver.writeObject(sourceInfo) - - if (oisDriver != null) { - oisDriver.close() - } - if (oosDriver != null) { - oosDriver.close() - } - if (clientSocketToDriver != null) { - clientSocketToDriver.close() - } - - retriesLeft -= 1 - } while (retriesLeft > 0 && hasBlocks < totalBlocks) - - return (hasBlocks == totalBlocks) - } - - /** - * Tries to receive broadcast from the source and returns Boolean status. - * This might be called multiple times to retry a defined number of times. - */ - private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = { - var clientSocketToSource: Socket = null - var oosSource: ObjectOutputStream = null - var oisSource: ObjectInputStream = null - - var receptionSucceeded = false - try { - // Connect to the source to get the object itself - clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream) - oosSource.flush() - oisSource = new ObjectInputStream(clientSocketToSource.getInputStream) - - logDebug("Inside receiveSingleTransmission") - logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) - - // Send the range - oosSource.writeObject((hasBlocks, totalBlocks)) - oosSource.flush() - - for (i <- hasBlocks until totalBlocks) { - val recvStartTime = System.currentTimeMillis - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - val receptionTime = (System.currentTimeMillis - recvStartTime) - - logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.") - - arrayOfBlocks(hasBlocks) = bcBlock - hasBlocks += 1 - - // Set to true if at least one block is received - receptionSucceeded = true - hasBlocksLock.synchronized { hasBlocksLock.notifyAll() } - } - } catch { - case e: Exception => logError("receiveSingleTransmission had a " + e) - } finally { - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (clientSocketToSource != null) { - clientSocketToSource.close() - } - } - - return receptionSucceeded - } - - class GuideMultipleRequests - extends Thread with Logging { - // Keep track of sources that have completed reception - private var setOfCompletedSources = Set[SourceInfo]() - - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) - guidePort = serverSocket.getLocalPort - logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - - guidePortLock.synchronized { guidePortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - // Stop broadcast if at least one worker has connected and - // everyone connected so far are done. Comparing with - // listOfSources.size - 1, because it includes the Guide itself - listOfSources.synchronized { - setOfCompletedSources.synchronized { - if (listOfSources.size > 1 && - setOfCompletedSources.size == listOfSources.size - 1) { - stopBroadcast = true - logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.") - } - } - } - } - } - if (clientSocket != null) { - logDebug("Guide: Accepted new client connection: " + clientSocket) - try { - threadPool.execute(new GuideSingleRequest(clientSocket)) - } catch { - // In failure, close() the socket here; else, the thread will close() it - case ioe: IOException => clientSocket.close() - } - } - } - - logInfo("Sending stopBroadcast notifications...") - sendStopBroadcastNotifications - - MultiTracker.unregisterBroadcast(id) - } finally { - if (serverSocket != null) { - logInfo("GuideMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - private def sendStopBroadcastNotifications() { - listOfSources.synchronized { - var listIter = listOfSources.iterator - while (listIter.hasNext) { - var sourceInfo = listIter.next - - var guideSocketToSource: Socket = null - var gosSource: ObjectOutputStream = null - var gisSource: ObjectInputStream = null - - try { - // Connect to the source - guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) - gosSource.flush() - gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) - - // Send stopBroadcast signal - gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast)) - gosSource.flush() - } catch { - case e: Exception => { - logError("sendStopBroadcastNotifications had a " + e) - } - } finally { - if (gisSource != null) { - gisSource.close() - } - if (gosSource != null) { - gosSource.close() - } - if (guideSocketToSource != null) { - guideSocketToSource.close() - } - } - } - } - } - - class GuideSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var selectedSourceInfo: SourceInfo = null - private var thisWorkerInfo:SourceInfo = null - - override def run() { - try { - logInfo("new GuideSingleRequest is running") - // Connecting worker is sending in its hostAddress and listenPort it will - // be listening to. Other fields are invalid (SourceInfo.UnusedParam) - var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - listOfSources.synchronized { - // Select a suitable source and send it back to the worker - selectedSourceInfo = selectSuitableSource(sourceInfo) - logDebug("Sending selectedSourceInfo: " + selectedSourceInfo) - oos.writeObject(selectedSourceInfo) - oos.flush() - - // Add this new (if it can finish) source to the list of sources - thisWorkerInfo = SourceInfo(sourceInfo.hostAddress, - sourceInfo.listenPort, totalBlocks, totalBytes) - logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo) - listOfSources += thisWorkerInfo - } - - // Wait till the whole transfer is done. Then receive and update source - // statistics in listOfSources - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - listOfSources.synchronized { - // This should work since SourceInfo is a case class - assert(listOfSources.contains(selectedSourceInfo)) - - // Remove first - // (Currently removing a source based on just one failure notification!) - listOfSources = listOfSources - selectedSourceInfo - - // Update sourceInfo and put it back in, IF reception succeeded - if (!sourceInfo.receptionFailed) { - // Add thisWorkerInfo to sources that have completed reception - setOfCompletedSources.synchronized { - setOfCompletedSources += thisWorkerInfo - } - - // Update leecher count and put it back in - selectedSourceInfo.currentLeechers -= 1 - listOfSources += selectedSourceInfo - } - } - } catch { - case e: Exception => { - // Remove failed worker from listOfSources and update leecherCount of - // corresponding source worker - listOfSources.synchronized { - if (selectedSourceInfo != null) { - // Remove first - listOfSources = listOfSources - selectedSourceInfo - // Update leecher count and put it back in - selectedSourceInfo.currentLeechers -= 1 - listOfSources += selectedSourceInfo - } - - // Remove thisWorkerInfo - if (listOfSources != null) { - listOfSources = listOfSources - thisWorkerInfo - } - } - } - } finally { - logInfo("GuideSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - // Assuming the caller to have a synchronized block on listOfSources - // Select one with the most leechers. This will level-wise fill the tree - private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { - var maxLeechers = -1 - var selectedSource: SourceInfo = null - - listOfSources.foreach { source => - if ((source.hostAddress != skipSourceInfo.hostAddress || - source.listenPort != skipSourceInfo.listenPort) && - source.currentLeechers < MultiTracker.MaxDegree && - source.currentLeechers > maxLeechers) { - selectedSource = source - maxLeechers = source.currentLeechers - } - } - - // Update leecher count - selectedSource.currentLeechers += 1 - return selectedSource - } - } - } - - class ServeMultipleRequests - extends Thread with Logging { - - var threadPool = Utils.newDaemonCachedThreadPool() - - override def run() { - var serverSocket = new ServerSocket(0) - listenPort = serverSocket.getLocalPort - - logInfo("ServeMultipleRequests started with " + serverSocket) - - listenPortLock.synchronized { listenPortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { } - } - - if (clientSocket != null) { - logDebug("Serve: Accepted new client connection: " + clientSocket) - try { - threadPool.execute(new ServeSingleRequest(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ServeMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ServeSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var sendFrom = 0 - private var sendUntil = totalBlocks - - override def run() { - try { - logInfo("new ServeSingleRequest is running") - - // Receive range to send - var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)] - sendFrom = rangeToSend._1 - sendUntil = rangeToSend._2 - - // If not a valid range, stop broadcast - if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) { - stopBroadcast = true - } else { - sendObject - } - } catch { - case e: Exception => logError("ServeSingleRequest had a " + e) - } finally { - logInfo("ServeSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - private def sendObject() { - // Wait till receiving the SourceInfo from Driver - while (totalBlocks == -1) { - totalBlocksLock.synchronized { totalBlocksLock.wait() } - } - - for (i <- sendFrom until sendUntil) { - while (i == hasBlocks) { - hasBlocksLock.synchronized { hasBlocksLock.wait() } - } - try { - oos.writeObject(arrayOfBlocks(i)) - oos.flush() - } catch { - case e: Exception => logError("sendObject had a " + e) - } - logDebug("Sent block: " + i + " to " + clientSocket) - } - } - } - } -} - -private[spark] class TreeBroadcastFactory -extends BroadcastFactory { - def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new TreeBroadcast[T](value_, isLocal, id) - - def stop() { MultiTracker.stop() } -} diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 1cfff5e565..275331724a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -21,12 +21,14 @@ import scala.collection.immutable.List import org.apache.spark.deploy.ExecutorState.ExecutorState import org.apache.spark.deploy.master.{WorkerInfo, ApplicationInfo} +import org.apache.spark.deploy.master.RecoveryState.MasterState import org.apache.spark.deploy.worker.ExecutorRunner import org.apache.spark.util.Utils private[deploy] sealed trait DeployMessage extends Serializable +/** Contains messages sent between Scheduler actor nodes. */ private[deploy] object DeployMessages { // Worker to Master @@ -52,17 +54,20 @@ private[deploy] object DeployMessages { exitStatus: Option[Int]) extends DeployMessage + case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription]) + case class Heartbeat(workerId: String) extends DeployMessage // Master to Worker - case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage + case class RegisteredWorker(masterUrl: String, masterWebUiUrl: String) extends DeployMessage case class RegisterWorkerFailed(message: String) extends DeployMessage - case class KillExecutor(appId: String, execId: Int) extends DeployMessage + case class KillExecutor(masterUrl: String, appId: String, execId: Int) extends DeployMessage case class LaunchExecutor( + masterUrl: String, appId: String, execId: Int, appDesc: ApplicationDescription, @@ -76,9 +81,11 @@ private[deploy] object DeployMessages { case class RegisterApplication(appDescription: ApplicationDescription) extends DeployMessage + case class MasterChangeAcknowledged(appId: String) + // Master to Client - case class RegisteredApplication(appId: String) extends DeployMessage + case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage // TODO(matei): replace hostPort with host case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { @@ -94,6 +101,10 @@ private[deploy] object DeployMessages { case object StopClient + // Master to Worker & Client + + case class MasterChanged(masterUrl: String, masterWebUiUrl: String) + // MasterWebUI To Master case object RequestMasterState @@ -101,7 +112,8 @@ private[deploy] object DeployMessages { // Master to MasterWebUI case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo], - activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) { + activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo], + status: MasterState) { Utils.checkHost(host, "Required hostname") assert (port > 0) @@ -123,12 +135,7 @@ private[deploy] object DeployMessages { assert (port > 0) } - // Actor System to Master - - case object CheckForWorkerTimeOut - - case object RequestWebUIPort - - case class WebUIPortResponse(webUIBoundPort: Int) + // Actor System to Worker + case object SendHeartbeat } diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala new file mode 100644 index 0000000000..2abf0b69dd --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorDescription.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +/** + * Used to send state on-the-wire about Executors from Worker to Master. + * This state is sufficient for the Master to reconstruct its internal data structures during + * failover. + */ +private[spark] class ExecutorDescription( + val appId: String, + val execId: Int, + val cores: Int, + val state: ExecutorState.Value) + extends Serializable { + + override def toString: String = + "ExecutorState(appId=%s, execId=%d, cores=%d, state=%s)".format(appId, execId, cores, state) +} diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala new file mode 100644 index 0000000000..668032a3a2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -0,0 +1,420 @@ +/* + * + * * Licensed to the Apache Software Foundation (ASF) under one or more + * * contributor license agreements. See the NOTICE file distributed with + * * this work for additional information regarding copyright ownership. + * * The ASF licenses this file to You under the Apache License, Version 2.0 + * * (the "License"); you may not use this file except in compliance with + * * the License. You may obtain a copy of the License at + * * + * * http://www.apache.org/licenses/LICENSE-2.0 + * * + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, + * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * * See the License for the specific language governing permissions and + * * limitations under the License. + * + */ + +package org.apache.spark.deploy + +import java.io._ +import java.net.URL +import java.util.concurrent.TimeoutException + +import scala.concurrent.{Await, future, promise} +import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext.Implicits.global +import scala.collection.mutable.ListBuffer +import scala.sys.process._ + +import net.liftweb.json.JsonParser + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.deploy.master.RecoveryState + +/** + * This suite tests the fault tolerance of the Spark standalone scheduler, mainly the Master. + * In order to mimic a real distributed cluster more closely, Docker is used. + * Execute using + * ./spark-class org.apache.spark.deploy.FaultToleranceTest + * + * Make sure that that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS: + * - spark.deploy.recoveryMode=ZOOKEEPER + * - spark.deploy.zookeeper.url=172.17.42.1:2181 + * Note that 172.17.42.1 is the default docker ip for the host and 2181 is the default ZK port. + * + * Unfortunately, due to the Docker dependency this suite cannot be run automatically without a + * working installation of Docker. In addition to having Docker, the following are assumed: + * - Docker can run without sudo (see http://docs.docker.io/en/latest/use/basics/) + * - The docker images tagged spark-test-master and spark-test-worker are built from the + * docker/ directory. Run 'docker/spark-test/build' to generate these. + */ +private[spark] object FaultToleranceTest extends App with Logging { + val masters = ListBuffer[TestMasterInfo]() + val workers = ListBuffer[TestWorkerInfo]() + var sc: SparkContext = _ + + var numPassed = 0 + var numFailed = 0 + + val sparkHome = System.getenv("SPARK_HOME") + assertTrue(sparkHome != null, "Run with a valid SPARK_HOME") + + val containerSparkHome = "/opt/spark" + val dockerMountDir = "%s:%s".format(sparkHome, containerSparkHome) + + System.setProperty("spark.driver.host", "172.17.42.1") // default docker host ip + + def afterEach() { + if (sc != null) { + sc.stop() + sc = null + } + terminateCluster() + } + + test("sanity-basic") { + addMasters(1) + addWorkers(1) + createClient() + assertValidClusterState() + } + + test("sanity-many-masters") { + addMasters(3) + addWorkers(3) + createClient() + assertValidClusterState() + } + + test("single-master-halt") { + addMasters(3) + addWorkers(2) + createClient() + assertValidClusterState() + + killLeader() + delay(30 seconds) + assertValidClusterState() + createClient() + assertValidClusterState() + } + + test("single-master-restart") { + addMasters(1) + addWorkers(2) + createClient() + assertValidClusterState() + + killLeader() + addMasters(1) + delay(30 seconds) + assertValidClusterState() + + killLeader() + addMasters(1) + delay(30 seconds) + assertValidClusterState() + } + + test("cluster-failure") { + addMasters(2) + addWorkers(2) + createClient() + assertValidClusterState() + + terminateCluster() + addMasters(2) + addWorkers(2) + assertValidClusterState() + } + + test("all-but-standby-failure") { + addMasters(2) + addWorkers(2) + createClient() + assertValidClusterState() + + killLeader() + workers.foreach(_.kill()) + workers.clear() + delay(30 seconds) + addWorkers(2) + assertValidClusterState() + } + + test("rolling-outage") { + addMasters(1) + delay() + addMasters(1) + delay() + addMasters(1) + addWorkers(2) + createClient() + assertValidClusterState() + assertTrue(getLeader == masters.head) + + (1 to 3).foreach { _ => + killLeader() + delay(30 seconds) + assertValidClusterState() + assertTrue(getLeader == masters.head) + addMasters(1) + } + } + + def test(name: String)(fn: => Unit) { + try { + fn + numPassed += 1 + logInfo("Passed: " + name) + } catch { + case e: Exception => + numFailed += 1 + logError("FAILED: " + name, e) + } + afterEach() + } + + def addMasters(num: Int) { + (1 to num).foreach { _ => masters += SparkDocker.startMaster(dockerMountDir) } + } + + def addWorkers(num: Int) { + val masterUrls = getMasterUrls(masters) + (1 to num).foreach { _ => workers += SparkDocker.startWorker(dockerMountDir, masterUrls) } + } + + /** Creates a SparkContext, which constructs a Client to interact with our cluster. */ + def createClient() = { + if (sc != null) { sc.stop() } + // Counter-hack: Because of a hack in SparkEnv#createFromSystemProperties() that changes this + // property, we need to reset it. + System.setProperty("spark.driver.port", "0") + sc = new SparkContext(getMasterUrls(masters), "fault-tolerance", containerSparkHome) + } + + def getMasterUrls(masters: Seq[TestMasterInfo]): String = { + "spark://" + masters.map(master => master.ip + ":7077").mkString(",") + } + + def getLeader: TestMasterInfo = { + val leaders = masters.filter(_.state == RecoveryState.ALIVE) + assertTrue(leaders.size == 1) + leaders(0) + } + + def killLeader(): Unit = { + masters.foreach(_.readState()) + val leader = getLeader + masters -= leader + leader.kill() + } + + def delay(secs: Duration = 5.seconds) = Thread.sleep(secs.toMillis) + + def terminateCluster() { + masters.foreach(_.kill()) + workers.foreach(_.kill()) + masters.clear() + workers.clear() + } + + /** This includes Client retry logic, so it may take a while if the cluster is recovering. */ + def assertUsable() = { + val f = future { + try { + val res = sc.parallelize(0 until 10).collect() + assertTrue(res.toList == (0 until 10)) + true + } catch { + case e: Exception => + logError("assertUsable() had exception", e) + e.printStackTrace() + false + } + } + + // Avoid waiting indefinitely (e.g., we could register but get no executors). + assertTrue(Await.result(f, 120 seconds)) + } + + /** + * Asserts that the cluster is usable and that the expected masters and workers + * are all alive in a proper configuration (e.g., only one leader). + */ + def assertValidClusterState() = { + assertUsable() + var numAlive = 0 + var numStandby = 0 + var numLiveApps = 0 + var liveWorkerIPs: Seq[String] = List() + + def stateValid(): Boolean = { + (workers.map(_.ip) -- liveWorkerIPs).isEmpty && + numAlive == 1 && numStandby == masters.size - 1 && numLiveApps >= 1 + } + + val f = future { + try { + while (!stateValid()) { + Thread.sleep(1000) + + numAlive = 0 + numStandby = 0 + numLiveApps = 0 + + masters.foreach(_.readState()) + + for (master <- masters) { + master.state match { + case RecoveryState.ALIVE => + numAlive += 1 + liveWorkerIPs = master.liveWorkerIPs + case RecoveryState.STANDBY => + numStandby += 1 + case _ => // ignore + } + + numLiveApps += master.numLiveApps + } + } + true + } catch { + case e: Exception => + logError("assertValidClusterState() had exception", e) + false + } + } + + try { + assertTrue(Await.result(f, 120 seconds)) + } catch { + case e: TimeoutException => + logError("Master states: " + masters.map(_.state)) + logError("Num apps: " + numLiveApps) + logError("IPs expected: " + workers.map(_.ip) + " / found: " + liveWorkerIPs) + throw new RuntimeException("Failed to get into acceptable cluster state after 2 min.", e) + } + } + + def assertTrue(bool: Boolean, message: String = "") { + if (!bool) { + throw new IllegalStateException("Assertion failed: " + message) + } + } + + logInfo("Ran %s tests, %s passed and %s failed".format(numPassed+numFailed, numPassed, numFailed)) +} + +private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val logFile: File) + extends Logging { + + implicit val formats = net.liftweb.json.DefaultFormats + var state: RecoveryState.Value = _ + var liveWorkerIPs: List[String] = _ + var numLiveApps = 0 + + logDebug("Created master: " + this) + + def readState() { + try { + val masterStream = new InputStreamReader(new URL("http://%s:8080/json".format(ip)).openStream) + val json = JsonParser.parse(masterStream, closeAutomatically = true) + + val workers = json \ "workers" + val liveWorkers = workers.children.filter(w => (w \ "state").extract[String] == "ALIVE") + liveWorkerIPs = liveWorkers.map(w => (w \ "host").extract[String]) + + numLiveApps = (json \ "activeapps").children.size + + val status = json \\ "status" + val stateString = status.extract[String] + state = RecoveryState.values.filter(state => state.toString == stateString).head + } catch { + case e: Exception => + // ignore, no state update + logWarning("Exception", e) + } + } + + def kill() { Docker.kill(dockerId) } + + override def toString: String = + "[ip=%s, id=%s, logFile=%s, state=%s]". + format(ip, dockerId.id, logFile.getAbsolutePath, state) +} + +private[spark] class TestWorkerInfo(val ip: String, val dockerId: DockerId, val logFile: File) + extends Logging { + + implicit val formats = net.liftweb.json.DefaultFormats + + logDebug("Created worker: " + this) + + def kill() { Docker.kill(dockerId) } + + override def toString: String = + "[ip=%s, id=%s, logFile=%s]".format(ip, dockerId, logFile.getAbsolutePath) +} + +private[spark] object SparkDocker { + def startMaster(mountDir: String): TestMasterInfo = { + val cmd = Docker.makeRunCmd("spark-test-master", mountDir = mountDir) + val (ip, id, outFile) = startNode(cmd) + new TestMasterInfo(ip, id, outFile) + } + + def startWorker(mountDir: String, masters: String): TestWorkerInfo = { + val cmd = Docker.makeRunCmd("spark-test-worker", args = masters, mountDir = mountDir) + val (ip, id, outFile) = startNode(cmd) + new TestWorkerInfo(ip, id, outFile) + } + + private def startNode(dockerCmd: ProcessBuilder) : (String, DockerId, File) = { + val ipPromise = promise[String]() + val outFile = File.createTempFile("fault-tolerance-test", "") + outFile.deleteOnExit() + val outStream: FileWriter = new FileWriter(outFile) + def findIpAndLog(line: String): Unit = { + if (line.startsWith("CONTAINER_IP=")) { + val ip = line.split("=")(1) + ipPromise.success(ip) + } + + outStream.write(line + "\n") + outStream.flush() + } + + dockerCmd.run(ProcessLogger(findIpAndLog _)) + val ip = Await.result(ipPromise.future, 30 seconds) + val dockerId = Docker.getLastProcessId + (ip, dockerId, outFile) + } +} + +private[spark] class DockerId(val id: String) { + override def toString = id +} + +private[spark] object Docker extends Logging { + def makeRunCmd(imageTag: String, args: String = "", mountDir: String = ""): ProcessBuilder = { + val mountCmd = if (mountDir != "") { " -v " + mountDir } else "" + + val cmd = "docker run %s %s %s".format(mountCmd, imageTag, args) + logDebug("Run command: " + cmd) + cmd + } + + def kill(dockerId: DockerId) : Unit = { + "docker kill %s".format(dockerId.id).! + } + + def getLastProcessId: DockerId = { + var id: String = null + "docker ps -l -q".!(ProcessLogger(line => id = line)) + new DockerId(id) + } +}
\ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala index 04d01c169d..e607b8c6f4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/deploy/JsonProtocol.scala @@ -72,7 +72,8 @@ private[spark] object JsonProtocol { ("memory" -> obj.workers.map(_.memory).sum) ~ ("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~ ("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~ - ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) + ("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~ + ("status" -> obj.status.toString) } def writeWorkerState(obj: WorkerStateResponse) = { diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 6a7d5a85ba..94cf4ff88b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -39,22 +39,23 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I private val masterActorSystems = ArrayBuffer[ActorSystem]() private val workerActorSystems = ArrayBuffer[ActorSystem]() - def start(): String = { + def start(): Array[String] = { logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") /* Start the Master */ val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0) masterActorSystems += masterSystem val masterUrl = "spark://" + localHostname + ":" + masterPort + val masters = Array(masterUrl) /* Start the Workers */ for (workerNum <- 1 to numWorkers) { val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, - memoryPerWorker, masterUrl, null, Some(workerNum)) + memoryPerWorker, masters, null, Some(workerNum)) workerActorSystems += workerSystem } - return masterUrl + return masters } def stop() { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 993ba6bd3d..c29a30184a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -17,28 +17,59 @@ package org.apache.spark.deploy -import com.google.common.collect.MapMaker +import java.security.PrivilegedExceptionAction import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.security.UserGroupInformation +import org.apache.spark.SparkException /** - * Contains util methods to interact with Hadoop from spark. + * Contains util methods to interact with Hadoop from Spark. */ +private[spark] class SparkHadoopUtil { - // A general, soft-reference map for metadata needed during HadoopRDD split computation - // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). - private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]() + val conf = newConfiguration() + UserGroupInformation.setConfiguration(conf) - // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop - // subsystems + def runAsUser(user: String)(func: () => Unit) { + val ugi = UserGroupInformation.createRemoteUser(user) + ugi.doAs(new PrivilegedExceptionAction[Unit] { + def run: Unit = func() + }) + } + + /** + * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop + * subsystems. + */ def newConfiguration(): Configuration = new Configuration() - // Add any user credentials to the job conf which are necessary for running on a secure Hadoop - // cluster + /** + * Add any user credentials to the job conf which are necessary for running on a secure Hadoop + * cluster. + */ def addCredentials(conf: JobConf) {} def isYarnMode(): Boolean = { false } +} + +object SparkHadoopUtil { + private val hadoop = { + val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) + if (yarnMode) { + try { + Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil] + } catch { + case th: Throwable => throw new SparkException("Unable to load YARN support", th) + } + } else { + new SparkHadoopUtil + } + } + def get: SparkHadoopUtil = { + hadoop + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala index 164386782c..be8693ec54 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala @@ -21,6 +21,7 @@ import java.util.concurrent.TimeoutException import scala.concurrent.duration._ import scala.concurrent.Await +import scala.concurrent.ExecutionContext.Implicits.global import akka.actor._ import akka.actor.Terminated @@ -37,41 +38,81 @@ import org.apache.spark.deploy.master.Master /** * The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description, * and a listener for cluster events, and calls back the listener when various events occur. + * + * @param masterUrls Each url should look like spark://host:port. */ private[spark] class Client( actorSystem: ActorSystem, - masterUrl: String, + masterUrls: Array[String], appDescription: ApplicationDescription, listener: ClientListener) extends Logging { + val REGISTRATION_TIMEOUT = 20.seconds + val REGISTRATION_RETRIES = 3 + var actor: ActorRef = null var appId: String = null + var registered = false + var activeMasterUrl: String = null class ClientActor extends Actor with Logging { var master: ActorRef = null var masterAddress: Address = null var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times + var alreadyDead = false // To avoid calling listener.dead() multiple times override def preStart() { - logInfo("Connecting to master " + masterUrl) try { - master = context.actorFor(Master.toAkkaUrl(masterUrl)) - masterAddress = master.path.address - master ! RegisterApplication(appDescription) - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - context.watch(master) // Doesn't work with remote actors, but useful for testing + registerWithMaster() } catch { case e: Exception => - logError("Failed to connect to master", e) + logWarning("Failed to connect to master", e) markDisconnected() context.stop(self) } } + def tryRegisterAllMasters() { + for (masterUrl <- masterUrls) { + logInfo("Connecting to master " + masterUrl + "...") + val actor = context.actorFor(Master.toAkkaUrl(masterUrl)) + actor ! RegisterApplication(appDescription) + } + } + + def registerWithMaster() { + tryRegisterAllMasters() + + var retries = 0 + lazy val retryTimer: Cancellable = + context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) { + retries += 1 + if (registered) { + retryTimer.cancel() + } else if (retries >= REGISTRATION_RETRIES) { + logError("All masters are unresponsive! Giving up.") + markDead() + } else { + tryRegisterAllMasters() + } + } + retryTimer // start timer + } + + def changeMaster(url: String) { + activeMasterUrl = url + master = context.actorFor(Master.toAkkaUrl(url)) + masterAddress = master.path.address + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + context.watch(master) // Doesn't work with remote actors, but useful for testing + } + override def receive = { - case RegisteredApplication(appId_) => + case RegisteredApplication(appId_, masterUrl) => appId = appId_ + registered = true + changeMaster(masterUrl) listener.connected(appId) case ApplicationRemoved(message) => @@ -92,23 +133,27 @@ private[spark] class Client( listener.executorRemoved(fullId, message.getOrElse(""), exitStatus) } + case MasterChanged(masterUrl, masterWebUiUrl) => + logInfo("Master has changed, new master is at " + masterUrl) + context.unwatch(master) + changeMaster(masterUrl) + alreadyDisconnected = false + sender ! MasterChangeAcknowledged(appId) + case Terminated(actor_) if actor_ == master => - logError("Connection to master failed; stopping client") + logWarning("Connection to master failed; waiting for master to reconnect...") markDisconnected() - context.stop(self) case DisassociatedEvent(_, address, _) if address == masterAddress => logError("Connection to master failed; stopping client") markDisconnected() - context.stop(self) case AssociationErrorEvent(_, _, address, _) if address == masterAddress => logError("Connection to master failed; stopping client") markDisconnected() - context.stop(self) case StopClient => - markDisconnected() + markDead() sender ! true context.stop(self) } @@ -122,6 +167,13 @@ private[spark] class Client( alreadyDisconnected = true } } + + def markDead() { + if (!alreadyDead) { + listener.dead() + alreadyDead = true + } + } } def start() { diff --git a/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala index 4605368c11..be7a11bd15 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/ClientListener.scala @@ -27,8 +27,12 @@ package org.apache.spark.deploy.client private[spark] trait ClientListener { def connected(appId: String): Unit + /** Disconnection may be a temporary state, as we fail over to a new Master. */ def disconnected(): Unit + /** Dead means that we couldn't find any Masters to connect to, and have given up. */ + def dead(): Unit + def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index d5e9a0e095..5b62d3ba6c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -33,6 +33,11 @@ private[spark] object TestClient { System.exit(0) } + def dead() { + logInfo("Could not connect to master") + System.exit(0) + } + def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {} def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {} @@ -44,7 +49,7 @@ private[spark] object TestClient { val desc = new ApplicationDescription( "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored") val listener = new TestListener - val client = new Client(actorSystem, url, desc, listener) + val client = new Client(actorSystem, Array(url), desc, listener) client.start() actorSystem.awaitTermination() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index bd5327627a..5150b7c7de 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -29,23 +29,46 @@ private[spark] class ApplicationInfo( val submitDate: Date, val driver: ActorRef, val appUiUrl: String) -{ - var state = ApplicationState.WAITING - var executors = new mutable.HashMap[Int, ExecutorInfo] - var coresGranted = 0 - var endTime = -1L - val appSource = new ApplicationSource(this) - - private var nextExecutorId = 0 - - def newExecutorId(): Int = { - val id = nextExecutorId - nextExecutorId += 1 - id + extends Serializable { + + @transient var state: ApplicationState.Value = _ + @transient var executors: mutable.HashMap[Int, ExecutorInfo] = _ + @transient var coresGranted: Int = _ + @transient var endTime: Long = _ + @transient var appSource: ApplicationSource = _ + + @transient private var nextExecutorId: Int = _ + + init() + + private def readObject(in: java.io.ObjectInputStream) : Unit = { + in.defaultReadObject() + init() + } + + private def init() { + state = ApplicationState.WAITING + executors = new mutable.HashMap[Int, ExecutorInfo] + coresGranted = 0 + endTime = -1L + appSource = new ApplicationSource(this) + nextExecutorId = 0 + } + + private def newExecutorId(useID: Option[Int] = None): Int = { + useID match { + case Some(id) => + nextExecutorId = math.max(nextExecutorId, id + 1) + id + case None => + val id = nextExecutorId + nextExecutorId += 1 + id + } } - def addExecutor(worker: WorkerInfo, cores: Int): ExecutorInfo = { - val exec = new ExecutorInfo(newExecutorId(), this, worker, cores, desc.memoryPerSlave) + def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorInfo = { + val exec = new ExecutorInfo(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave) executors(exec.id) = exec coresGranted += cores exec diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala index 39ef090ddf..a74d7be4c9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala @@ -22,7 +22,7 @@ private[spark] object ApplicationState type ApplicationState = Value - val WAITING, RUNNING, FINISHED, FAILED = Value + val WAITING, RUNNING, FINISHED, FAILED, UNKNOWN = Value val MAX_NUM_RETRY = 10 } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala index cf384a985e..76db61dd61 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.master -import org.apache.spark.deploy.ExecutorState +import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} private[spark] class ExecutorInfo( val id: Int, @@ -28,5 +28,10 @@ private[spark] class ExecutorInfo( var state = ExecutorState.LAUNCHING + /** Copy all state (non-val) variables from the given on-the-wire ExecutorDescription. */ + def copyState(execDesc: ExecutorDescription) { + state = execDesc.state + } + def fullId: String = application.id + "/" + id } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala new file mode 100644 index 0000000000..043945a211 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import java.io._ + +import scala.Serializable + +import akka.serialization.Serialization +import org.apache.spark.Logging + +/** + * Stores data in a single on-disk directory with one file per application and worker. + * Files are deleted when applications and workers are removed. + * + * @param dir Directory to store files. Created if non-existent (but not recursively). + * @param serialization Used to serialize our objects. + */ +private[spark] class FileSystemPersistenceEngine( + val dir: String, + val serialization: Serialization) + extends PersistenceEngine with Logging { + + new File(dir).mkdir() + + override def addApplication(app: ApplicationInfo) { + val appFile = new File(dir + File.separator + "app_" + app.id) + serializeIntoFile(appFile, app) + } + + override def removeApplication(app: ApplicationInfo) { + new File(dir + File.separator + "app_" + app.id).delete() + } + + override def addWorker(worker: WorkerInfo) { + val workerFile = new File(dir + File.separator + "worker_" + worker.id) + serializeIntoFile(workerFile, worker) + } + + override def removeWorker(worker: WorkerInfo) { + new File(dir + File.separator + "worker_" + worker.id).delete() + } + + override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = { + val sortedFiles = new File(dir).listFiles().sortBy(_.getName) + val appFiles = sortedFiles.filter(_.getName.startsWith("app_")) + val apps = appFiles.map(deserializeFromFile[ApplicationInfo]) + val workerFiles = sortedFiles.filter(_.getName.startsWith("worker_")) + val workers = workerFiles.map(deserializeFromFile[WorkerInfo]) + (apps, workers) + } + + private def serializeIntoFile(file: File, value: AnyRef) { + val created = file.createNewFile() + if (!created) { throw new IllegalStateException("Could not create file: " + file) } + + val serializer = serialization.findSerializerFor(value) + val serialized = serializer.toBinary(value) + + val out = new FileOutputStream(file) + out.write(serialized) + out.close() + } + + def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = { + val fileData = new Array[Byte](file.length().asInstanceOf[Int]) + val dis = new DataInputStream(new FileInputStream(file)) + dis.readFully(fileData) + dis.close() + + val clazz = m.runtimeClass.asInstanceOf[Class[T]] + val serializer = serialization.serializerFor(clazz) + serializer.fromBinary(fileData).asInstanceOf[T] + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala new file mode 100644 index 0000000000..f25a1ad3bf --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import akka.actor.{Actor, ActorRef} + +import org.apache.spark.deploy.master.MasterMessages.ElectedLeader + +/** + * A LeaderElectionAgent keeps track of whether the current Master is the leader, meaning it + * is the only Master serving requests. + * In addition to the API provided, the LeaderElectionAgent will use of the following messages + * to inform the Master of leader changes: + * [[org.apache.spark.deploy.master.MasterMessages.ElectedLeader ElectedLeader]] + * [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]] + */ +private[spark] trait LeaderElectionAgent extends Actor { + val masterActor: ActorRef +} + +/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */ +private[spark] class MonarchyLeaderAgent(val masterActor: ActorRef) extends LeaderElectionAgent { + override def preStart() { + masterActor ! ElectedLeader + } + + override def receive = { + case _ => + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index cb0fe6a850..26f980760d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -23,23 +23,25 @@ import java.text.SimpleDateFormat import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.concurrent.Await import scala.concurrent.duration._ +import scala.concurrent.duration.{ Duration, FiniteDuration } +import scala.concurrent.ExecutionContext.Implicits.global import akka.actor._ import akka.pattern.ask import akka.remote._ +import akka.util.Timeout import org.apache.spark.{Logging, SparkException} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{Utils, AkkaUtils} -import akka.util.Timeout import org.apache.spark.deploy.DeployMessages.RegisterWorkerFailed import org.apache.spark.deploy.DeployMessages.KillExecutor import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged import scala.Some -import org.apache.spark.deploy.DeployMessages.WebUIPortResponse import org.apache.spark.deploy.DeployMessages.LaunchExecutor import org.apache.spark.deploy.DeployMessages.RegisteredApplication import org.apache.spark.deploy.DeployMessages.RegisterWorker @@ -51,6 +53,8 @@ import org.apache.spark.deploy.DeployMessages.ApplicationRemoved import org.apache.spark.deploy.DeployMessages.Heartbeat import org.apache.spark.deploy.DeployMessages.RegisteredWorker import akka.actor.Terminated +import akka.serialization.SerializationExtension +import java.util.concurrent.TimeUnit private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { @@ -58,7 +62,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000 val RETAINED_APPLICATIONS = System.getProperty("spark.deploy.retainedApplications", "200").toInt val REAPER_ITERATIONS = System.getProperty("spark.dead.worker.persistence", "15").toInt - + val RECOVERY_DIR = System.getProperty("spark.deploy.recoveryDirectory", "") + val RECOVERY_MODE = System.getProperty("spark.deploy.recoveryMode", "NONE") + var nextAppNumber = 0 val workers = new HashSet[WorkerInfo] val idToWorker = new HashMap[String, WorkerInfo] @@ -88,52 +94,115 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act if (envVar != null) envVar else host } + val masterUrl = "spark://" + host + ":" + port + var masterWebUiUrl: String = _ + + var state = RecoveryState.STANDBY + + var persistenceEngine: PersistenceEngine = _ + + var leaderElectionAgent: ActorRef = _ + // As a temporary workaround before better ways of configuring memory, we allow users to set // a flag that will perform round-robin scheduling across the nodes (spreading out each app // among all the nodes) instead of trying to consolidate each app onto a small # of nodes. val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean override def preStart() { - logInfo("Starting Spark master at spark://" + host + ":" + port) + logInfo("Starting Spark master at " + masterUrl) // Listen for remote client disconnection events, since they don't go through Akka's watch() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) webUi.start() - import context.dispatcher + masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort.get context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut) masterMetricsSystem.registerSource(masterSource) masterMetricsSystem.start() applicationMetricsSystem.start() + + persistenceEngine = RECOVERY_MODE match { + case "ZOOKEEPER" => + logInfo("Persisting recovery state to ZooKeeper") + new ZooKeeperPersistenceEngine(SerializationExtension(context.system)) + case "FILESYSTEM" => + logInfo("Persisting recovery state to directory: " + RECOVERY_DIR) + new FileSystemPersistenceEngine(RECOVERY_DIR, SerializationExtension(context.system)) + case _ => + new BlackHolePersistenceEngine() + } + + leaderElectionAgent = RECOVERY_MODE match { + case "ZOOKEEPER" => + context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl)) + case _ => + context.actorOf(Props(classOf[MonarchyLeaderAgent], self)) + } + } + + override def preRestart(reason: Throwable, message: Option[Any]) { + super.preRestart(reason, message) // calls postStop()! + logError("Master actor restarted due to exception", reason) } override def postStop() { webUi.stop() masterMetricsSystem.stop() applicationMetricsSystem.stop() + persistenceEngine.close() + context.stop(leaderElectionAgent) } override def receive = { - case RegisterWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress) => { + case ElectedLeader => { + val (storedApps, storedWorkers) = persistenceEngine.readPersistedData() + state = if (storedApps.isEmpty && storedWorkers.isEmpty) + RecoveryState.ALIVE + else + RecoveryState.RECOVERING + + logInfo("I have been elected leader! New state: " + state) + + if (state == RecoveryState.RECOVERING) { + beginRecovery(storedApps, storedWorkers) + context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis) { completeRecovery() } + } + } + + case RevokedLeadership => { + logError("Leadership has been revoked -- master shutting down.") + System.exit(0) + } + + case RegisterWorker(id, host, workerPort, cores, memory, webUiPort, publicAddress) => { logInfo("Registering worker %s:%d with %d cores, %s RAM".format( host, workerPort, cores, Utils.megabytesToString(memory))) - if (idToWorker.contains(id)) { + if (state == RecoveryState.STANDBY) { + // ignore, don't send response + } else if (idToWorker.contains(id)) { sender ! RegisterWorkerFailed("Duplicate worker ID") } else { - addWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress) + val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress) + registerWorker(worker) context.watch(sender) // This doesn't work with remote actors but helps for testing - sender ! RegisteredWorker("http://" + masterPublicAddress + ":" + webUi.boundPort.get) + persistenceEngine.addWorker(worker) + sender ! RegisteredWorker(masterUrl, masterWebUiUrl) schedule() } } case RegisterApplication(description) => { - logInfo("Registering app " + description.name) - val app = addApplication(description, sender) - logInfo("Registered app " + description.name + " with ID " + app.id) - waitingApps += app - context.watch(sender) // This doesn't work with remote actors but helps for testing - sender ! RegisteredApplication(app.id) - schedule() + if (state == RecoveryState.STANDBY) { + // ignore, don't send response + } else { + logInfo("Registering app " + description.name) + val app = createApplication(description, sender) + registerApplication(app) + logInfo("Registered app " + description.name + " with ID " + app.id) + context.watch(sender) // This doesn't work with remote actors but helps for testing + persistenceEngine.addApplication(app) + sender ! RegisteredApplication(app.id, masterUrl) + schedule() + } } case ExecutorStateChanged(appId, execId, state, message, exitStatus) => { @@ -173,27 +242,63 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act } } + case MasterChangeAcknowledged(appId) => { + idToApp.get(appId) match { + case Some(app) => + logInfo("Application has been re-registered: " + appId) + app.state = ApplicationState.WAITING + case None => + logWarning("Master change ack from unknown app: " + appId) + } + + if (canCompleteRecovery) { completeRecovery() } + } + + case WorkerSchedulerStateResponse(workerId, executors) => { + idToWorker.get(workerId) match { + case Some(worker) => + logInfo("Worker has been re-registered: " + workerId) + worker.state = WorkerState.ALIVE + + val validExecutors = executors.filter(exec => idToApp.get(exec.appId).isDefined) + for (exec <- validExecutors) { + val app = idToApp.get(exec.appId).get + val execInfo = app.addExecutor(worker, exec.cores, Some(exec.execId)) + worker.addExecutor(execInfo) + execInfo.copyState(exec) + } + case None => + logWarning("Scheduler state from unknown worker: " + workerId) + } + + if (canCompleteRecovery) { completeRecovery() } + } + case Terminated(actor) => { // The disconnected actor could've been either a worker or an app; remove whichever of // those we have an entry for in the corresponding actor hashmap actorToWorker.get(actor).foreach(removeWorker) actorToApp.get(actor).foreach(finishApplication) + if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } } case DisassociatedEvent(_, address, _) => { // The disconnected client could've been either a worker or an app; remove whichever it was addressToWorker.get(address).foreach(removeWorker) addressToApp.get(address).foreach(finishApplication) + if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } } case AssociationErrorEvent(_, _, address, _) => { // The disconnected client could've been either a worker or an app; remove whichever it was addressToWorker.get(address).foreach(removeWorker) addressToApp.get(address).foreach(finishApplication) + if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } } case RequestMasterState => { - sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray) + sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray, + state) } case CheckForWorkerTimeOut => { @@ -205,6 +310,50 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act } } + def canCompleteRecovery = + workers.count(_.state == WorkerState.UNKNOWN) == 0 && + apps.count(_.state == ApplicationState.UNKNOWN) == 0 + + def beginRecovery(storedApps: Seq[ApplicationInfo], storedWorkers: Seq[WorkerInfo]) { + for (app <- storedApps) { + logInfo("Trying to recover app: " + app.id) + try { + registerApplication(app) + app.state = ApplicationState.UNKNOWN + app.driver ! MasterChanged(masterUrl, masterWebUiUrl) + } catch { + case e: Exception => logInfo("App " + app.id + " had exception on reconnect") + } + } + + for (worker <- storedWorkers) { + logInfo("Trying to recover worker: " + worker.id) + try { + registerWorker(worker) + worker.state = WorkerState.UNKNOWN + worker.actor ! MasterChanged(masterUrl, masterWebUiUrl) + } catch { + case e: Exception => logInfo("Worker " + worker.id + " had exception on reconnect") + } + } + } + + def completeRecovery() { + // Ensure "only-once" recovery semantics using a short synchronization period. + synchronized { + if (state != RecoveryState.RECOVERING) { return } + state = RecoveryState.COMPLETING_RECOVERY + } + + // Kill off any workers and apps that didn't respond to us. + workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker) + apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication) + + state = RecoveryState.ALIVE + schedule() + logInfo("Recovery complete - resuming operations!") + } + /** * Can an app use the given worker? True if the worker has enough memory and we haven't already * launched an executor for the app on it (right now the standalone backend doesn't like having @@ -219,6 +368,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act * every time a new app joins or resource availability changes. */ def schedule() { + if (state != RecoveryState.ALIVE) { return } // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app // in the queue, then the second app, etc. if (spreadOutApps) { @@ -266,14 +416,13 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor( + worker.actor ! LaunchExecutor(masterUrl, exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome) exec.application.driver ! ExecutorAdded( exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) } - def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int, - publicAddress: String): WorkerInfo = { + def registerWorker(worker: WorkerInfo): Unit = { // There may be one or more refs to dead workers on this same node (w/ different ID's), // remove them. workers.filter { w => @@ -281,12 +430,17 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act }.foreach { w => workers -= w } - val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress) + + val workerAddress = worker.actor.path.address + if (addressToWorker.contains(workerAddress)) { + logInfo("Attempted to re-register worker at same address: " + workerAddress) + return + } + workers += worker idToWorker(worker.id) = worker - actorToWorker(sender) = worker - addressToWorker(sender.path.address) = worker - worker + actorToWorker(worker.actor) = worker + addressToWorker(workerAddress) = worker } def removeWorker(worker: WorkerInfo) { @@ -301,25 +455,36 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act exec.id, ExecutorState.LOST, Some("worker lost"), None) exec.application.removeExecutor(exec) } + persistenceEngine.removeWorker(worker) } - def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { + def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) - val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl) + new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl) + } + + def registerApplication(app: ApplicationInfo): Unit = { + val appAddress = app.driver.path.address + if (addressToWorker.contains(appAddress)) { + logInfo("Attempted to re-register application at same address: " + appAddress) + return + } + applicationMetricsSystem.registerSource(app.appSource) apps += app idToApp(app.id) = app - actorToApp(driver) = app - addressToApp(driver.path.address) = app + actorToApp(app.driver) = app + addressToApp(appAddress) = app if (firstApp == None) { firstApp = Some(app) } + // TODO: What is firstApp?? Can we remove it? val workersAlive = workers.filter(_.state == WorkerState.ALIVE).toArray - if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= desc.memoryPerSlave)) { + if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= app.desc.memoryPerSlave)) { logWarning("Could not find any workers with enough memory for " + firstApp.get.id) } - app + waitingApps += app } def finishApplication(app: ApplicationInfo) { @@ -344,13 +509,14 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act waitingApps -= app for (exec <- app.executors.values) { exec.worker.removeExecutor(exec) - exec.worker.actor ! KillExecutor(exec.application.id, exec.id) + exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id) exec.state = ExecutorState.KILLED } app.markFinished(state) if (state != ApplicationState.FINISHED) { app.driver ! ApplicationRemoved(state.toString) } + persistenceEngine.removeApplication(app) schedule() } } @@ -404,8 +570,8 @@ private[spark] object Master { def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int, Int) = { val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), name = actorName) - val timeoutDuration = Duration.create( - System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + val timeoutDuration : FiniteDuration = Duration.create( + System.getProperty("spark.akka.askTimeout", "10").toLong, TimeUnit.SECONDS) implicit val timeout = Timeout(timeoutDuration) val respFuture = actor ? RequestWebUIPort // ask pattern val resp = Await.result(respFuture, timeoutDuration).asInstanceOf[WebUIPortResponse] diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala new file mode 100644 index 0000000000..74a9f8cd82 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +sealed trait MasterMessages extends Serializable + +/** Contains messages seen only by the Master and its associated entities. */ +private[master] object MasterMessages { + + // LeaderElectionAgent to Master + + case object ElectedLeader + + case object RevokedLeadership + + // Actor System to LeaderElectionAgent + + case object CheckLeader + + // Actor System to Master + + case object CheckForWorkerTimeOut + + case class BeginRecovery(storedApps: Seq[ApplicationInfo], storedWorkers: Seq[WorkerInfo]) + + case object CompleteRecovery + + case object RequestWebUIPort + + case class WebUIPortResponse(webUIBoundPort: Int) +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala new file mode 100644 index 0000000000..94b986caf2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +/** + * Allows Master to persist any state that is necessary in order to recover from a failure. + * The following semantics are required: + * - addApplication and addWorker are called before completing registration of a new app/worker. + * - removeApplication and removeWorker are called at any time. + * Given these two requirements, we will have all apps and workers persisted, but + * we might not have yet deleted apps or workers that finished (so their liveness must be verified + * during recovery). + */ +private[spark] trait PersistenceEngine { + def addApplication(app: ApplicationInfo) + + def removeApplication(app: ApplicationInfo) + + def addWorker(worker: WorkerInfo) + + def removeWorker(worker: WorkerInfo) + + /** + * Returns the persisted data sorted by their respective ids (which implies that they're + * sorted by time of creation). + */ + def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) + + def close() {} +} + +private[spark] class BlackHolePersistenceEngine extends PersistenceEngine { + override def addApplication(app: ApplicationInfo) {} + override def removeApplication(app: ApplicationInfo) {} + override def addWorker(worker: WorkerInfo) {} + override def removeWorker(worker: WorkerInfo) {} + override def readPersistedData() = (Nil, Nil) +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala new file mode 100644 index 0000000000..b91be821f0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +private[spark] object RecoveryState + extends Enumeration("STANDBY", "ALIVE", "RECOVERING", "COMPLETING_RECOVERY") { + + type MasterState = Value + + val STANDBY, ALIVE, RECOVERING, COMPLETING_RECOVERY = Value +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala new file mode 100644 index 0000000000..81e15c534f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import scala.collection.JavaConversions._ +import scala.concurrent.ops._ + +import org.apache.spark.Logging +import org.apache.zookeeper._ +import org.apache.zookeeper.data.Stat +import org.apache.zookeeper.Watcher.Event.KeeperState + +/** + * Provides a Scala-side interface to the standard ZooKeeper client, with the addition of retry + * logic. If the ZooKeeper session expires or otherwise dies, a new ZooKeeper session will be + * created. If ZooKeeper remains down after several retries, the given + * [[org.apache.spark.deploy.master.SparkZooKeeperWatcher SparkZooKeeperWatcher]] will be + * informed via zkDown(). + * + * Additionally, all commands sent to ZooKeeper will be retried until they either fail too many + * times or a semantic exception is thrown (e.g.., "node already exists"). + */ +private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher) extends Logging { + val ZK_URL = System.getProperty("spark.deploy.zookeeper.url", "") + + val ZK_ACL = ZooDefs.Ids.OPEN_ACL_UNSAFE + val ZK_TIMEOUT_MILLIS = 30000 + val RETRY_WAIT_MILLIS = 5000 + val ZK_CHECK_PERIOD_MILLIS = 10000 + val MAX_RECONNECT_ATTEMPTS = 3 + + private var zk: ZooKeeper = _ + + private val watcher = new ZooKeeperWatcher() + private var reconnectAttempts = 0 + private var closed = false + + /** Connect to ZooKeeper to start the session. Must be called before anything else. */ + def connect() { + connectToZooKeeper() + + new Thread() { + override def run() = sessionMonitorThread() + }.start() + } + + def sessionMonitorThread(): Unit = { + while (!closed) { + Thread.sleep(ZK_CHECK_PERIOD_MILLIS) + if (zk.getState != ZooKeeper.States.CONNECTED) { + reconnectAttempts += 1 + val attemptsLeft = MAX_RECONNECT_ATTEMPTS - reconnectAttempts + if (attemptsLeft <= 0) { + logError("Could not connect to ZooKeeper: system failure") + zkWatcher.zkDown() + close() + } else { + logWarning("ZooKeeper connection failed, retrying " + attemptsLeft + " more times...") + connectToZooKeeper() + } + } + } + } + + def close() { + if (!closed && zk != null) { zk.close() } + closed = true + } + + private def connectToZooKeeper() { + if (zk != null) zk.close() + zk = new ZooKeeper(ZK_URL, ZK_TIMEOUT_MILLIS, watcher) + } + + /** + * Attempts to maintain a live ZooKeeper exception despite (very) transient failures. + * Mainly useful for handling the natural ZooKeeper session expiration. + */ + private class ZooKeeperWatcher extends Watcher { + def process(event: WatchedEvent) { + if (closed) { return } + + event.getState match { + case KeeperState.SyncConnected => + reconnectAttempts = 0 + zkWatcher.zkSessionCreated() + case KeeperState.Expired => + connectToZooKeeper() + case KeeperState.Disconnected => + logWarning("ZooKeeper disconnected, will retry...") + } + } + } + + def create(path: String, bytes: Array[Byte], createMode: CreateMode): String = { + retry { + zk.create(path, bytes, ZK_ACL, createMode) + } + } + + def exists(path: String, watcher: Watcher = null): Stat = { + retry { + zk.exists(path, watcher) + } + } + + def getChildren(path: String, watcher: Watcher = null): List[String] = { + retry { + zk.getChildren(path, watcher).toList + } + } + + def getData(path: String): Array[Byte] = { + retry { + zk.getData(path, false, null) + } + } + + def delete(path: String, version: Int = -1): Unit = { + retry { + zk.delete(path, version) + } + } + + /** + * Creates the given directory (non-recursively) if it doesn't exist. + * All znodes are created in PERSISTENT mode with no data. + */ + def mkdir(path: String) { + if (exists(path) == null) { + try { + create(path, "".getBytes, CreateMode.PERSISTENT) + } catch { + case e: Exception => + // If the exception caused the directory not to be created, bubble it up, + // otherwise ignore it. + if (exists(path) == null) { throw e } + } + } + } + + /** + * Recursively creates all directories up to the given one. + * All znodes are created in PERSISTENT mode with no data. + */ + def mkdirRecursive(path: String) { + var fullDir = "" + for (dentry <- path.split("/").tail) { + fullDir += "/" + dentry + mkdir(fullDir) + } + } + + /** + * Retries the given function up to 3 times. The assumption is that failure is transient, + * UNLESS it is a semantic exception (i.e., trying to get data from a node that doesn't exist), + * in which case the exception will be thrown without retries. + * + * @param fn Block to execute, possibly multiple times. + */ + def retry[T](fn: => T, n: Int = MAX_RECONNECT_ATTEMPTS): T = { + try { + fn + } catch { + case e: KeeperException.NoNodeException => throw e + case e: KeeperException.NodeExistsException => throw e + case e if n > 0 => + logError("ZooKeeper exception, " + n + " more retries...", e) + Thread.sleep(RETRY_WAIT_MILLIS) + retry(fn, n-1) + } + } +} + +trait SparkZooKeeperWatcher { + /** + * Called whenever a ZK session is created -- + * this will occur when we create our first session as well as each time + * the session expires or errors out. + */ + def zkSessionCreated() + + /** + * Called if ZK appears to be completely down (i.e., not just a transient error). + * We will no longer attempt to reconnect to ZK, and the SparkZooKeeperSession is considered dead. + */ + def zkDown() +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 6219f11f2a..e05f587b58 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -22,28 +22,44 @@ import scala.collection.mutable import org.apache.spark.util.Utils private[spark] class WorkerInfo( - val id: String, - val host: String, - val port: Int, - val cores: Int, - val memory: Int, - val actor: ActorRef, - val webUiPort: Int, - val publicAddress: String) { + val id: String, + val host: String, + val port: Int, + val cores: Int, + val memory: Int, + val actor: ActorRef, + val webUiPort: Int, + val publicAddress: String) + extends Serializable { Utils.checkHost(host, "Expected hostname") assert (port > 0) - var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info - var state: WorkerState.Value = WorkerState.ALIVE - var coresUsed = 0 - var memoryUsed = 0 + @transient var executors: mutable.HashMap[String, ExecutorInfo] = _ // fullId => info + @transient var state: WorkerState.Value = _ + @transient var coresUsed: Int = _ + @transient var memoryUsed: Int = _ - var lastHeartbeat = System.currentTimeMillis() + @transient var lastHeartbeat: Long = _ + + init() def coresFree: Int = cores - coresUsed def memoryFree: Int = memory - memoryUsed + private def readObject(in: java.io.ObjectInputStream) : Unit = { + in.defaultReadObject() + init() + } + + private def init() { + executors = new mutable.HashMap + state = WorkerState.ALIVE + coresUsed = 0 + memoryUsed = 0 + lastHeartbeat = System.currentTimeMillis() + } + def hostPort: String = { assert (port > 0) host + ":" + port diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala index fb3fe88d92..0b36ef6005 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala @@ -20,5 +20,5 @@ package org.apache.spark.deploy.master private[spark] object WorkerState extends Enumeration { type WorkerState = Value - val ALIVE, DEAD, DECOMMISSIONED = Value + val ALIVE, DEAD, DECOMMISSIONED, UNKNOWN = Value } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala new file mode 100644 index 0000000000..7809013e83 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import akka.actor.ActorRef +import org.apache.zookeeper._ +import org.apache.zookeeper.Watcher.Event.EventType + +import org.apache.spark.deploy.master.MasterMessages._ +import org.apache.spark.Logging + +private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, masterUrl: String) + extends LeaderElectionAgent with SparkZooKeeperWatcher with Logging { + + val WORKING_DIR = System.getProperty("spark.deploy.zookeeper.dir", "/spark") + "/leader_election" + + private val watcher = new ZooKeeperWatcher() + private val zk = new SparkZooKeeperSession(this) + private var status = LeadershipStatus.NOT_LEADER + private var myLeaderFile: String = _ + private var leaderUrl: String = _ + + override def preStart() { + logInfo("Starting ZooKeeper LeaderElection agent") + zk.connect() + } + + override def zkSessionCreated() { + synchronized { + zk.mkdirRecursive(WORKING_DIR) + myLeaderFile = + zk.create(WORKING_DIR + "/master_", masterUrl.getBytes, CreateMode.EPHEMERAL_SEQUENTIAL) + self ! CheckLeader + } + } + + override def preRestart(reason: scala.Throwable, message: scala.Option[scala.Any]) { + logError("LeaderElectionAgent failed, waiting " + zk.ZK_TIMEOUT_MILLIS + "...", reason) + Thread.sleep(zk.ZK_TIMEOUT_MILLIS) + super.preRestart(reason, message) + } + + override def zkDown() { + logError("ZooKeeper down! LeaderElectionAgent shutting down Master.") + System.exit(1) + } + + override def postStop() { + zk.close() + } + + override def receive = { + case CheckLeader => checkLeader() + } + + private class ZooKeeperWatcher extends Watcher { + def process(event: WatchedEvent) { + if (event.getType == EventType.NodeDeleted) { + logInfo("Leader file disappeared, a master is down!") + self ! CheckLeader + } + } + } + + /** Uses ZK leader election. Navigates several ZK potholes along the way. */ + def checkLeader() { + val masters = zk.getChildren(WORKING_DIR).toList + val leader = masters.sorted.head + val leaderFile = WORKING_DIR + "/" + leader + + // Setup a watch for the current leader. + zk.exists(leaderFile, watcher) + + try { + leaderUrl = new String(zk.getData(leaderFile)) + } catch { + // A NoNodeException may be thrown if old leader died since the start of this method call. + // This is fine -- just check again, since we're guaranteed to see the new values. + case e: KeeperException.NoNodeException => + logInfo("Leader disappeared while reading it -- finding next leader") + checkLeader() + return + } + + // Synchronization used to ensure no interleaving between the creation of a new session and the + // checking of a leader, which could cause us to delete our real leader file erroneously. + synchronized { + val isLeader = myLeaderFile == leaderFile + if (!isLeader && leaderUrl == masterUrl) { + // We found a different master file pointing to this process. + // This can happen in the following two cases: + // (1) The master process was restarted on the same node. + // (2) The ZK server died between creating the node and returning the name of the node. + // For this case, we will end up creating a second file, and MUST explicitly delete the + // first one, since our ZK session is still open. + // Note that this deletion will cause a NodeDeleted event to be fired so we check again for + // leader changes. + assert(leaderFile < myLeaderFile) + logWarning("Cleaning up old ZK master election file that points to this master.") + zk.delete(leaderFile) + } else { + updateLeadershipStatus(isLeader) + } + } + } + + def updateLeadershipStatus(isLeader: Boolean) { + if (isLeader && status == LeadershipStatus.NOT_LEADER) { + status = LeadershipStatus.LEADER + masterActor ! ElectedLeader + } else if (!isLeader && status == LeadershipStatus.LEADER) { + status = LeadershipStatus.NOT_LEADER + masterActor ! RevokedLeadership + } + } + + private object LeadershipStatus extends Enumeration { + type LeadershipStatus = Value + val LEADER, NOT_LEADER = Value + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala new file mode 100644 index 0000000000..825344b3bb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import org.apache.spark.Logging +import org.apache.zookeeper._ + +import akka.serialization.Serialization + +class ZooKeeperPersistenceEngine(serialization: Serialization) + extends PersistenceEngine + with SparkZooKeeperWatcher + with Logging +{ + val WORKING_DIR = System.getProperty("spark.deploy.zookeeper.dir", "/spark") + "/master_status" + + val zk = new SparkZooKeeperSession(this) + + zk.connect() + + override def zkSessionCreated() { + zk.mkdirRecursive(WORKING_DIR) + } + + override def zkDown() { + logError("PersistenceEngine disconnected from ZooKeeper -- ZK looks down.") + } + + override def addApplication(app: ApplicationInfo) { + serializeIntoFile(WORKING_DIR + "/app_" + app.id, app) + } + + override def removeApplication(app: ApplicationInfo) { + zk.delete(WORKING_DIR + "/app_" + app.id) + } + + override def addWorker(worker: WorkerInfo) { + serializeIntoFile(WORKING_DIR + "/worker_" + worker.id, worker) + } + + override def removeWorker(worker: WorkerInfo) { + zk.delete(WORKING_DIR + "/worker_" + worker.id) + } + + override def close() { + zk.close() + } + + override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = { + val sortedFiles = zk.getChildren(WORKING_DIR).toList.sorted + val appFiles = sortedFiles.filter(_.startsWith("app_")) + val apps = appFiles.map(deserializeFromFile[ApplicationInfo]) + val workerFiles = sortedFiles.filter(_.startsWith("worker_")) + val workers = workerFiles.map(deserializeFromFile[WorkerInfo]) + (apps, workers) + } + + private def serializeIntoFile(path: String, value: AnyRef) { + val serializer = serialization.findSerializerFor(value) + val serialized = serializer.toBinary(value) + zk.create(path, serialized, CreateMode.PERSISTENT) + } + + def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): T = { + val fileData = zk.getData("/spark/master_status/" + filename) + val clazz = m.runtimeClass.asInstanceOf[Class[T]] + val serializer = serialization.serializerFor(clazz) + serializer.fromBinary(fileData).asInstanceOf[T] + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index e3dc30eefc..fff9cb60c7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -43,7 +43,8 @@ private[spark] class ExecutorRunner( val workerId: String, val host: String, val sparkHome: File, - val workDir: File) + val workDir: File, + var state: ExecutorState.Value) extends Logging { val fullId = appId + "/" + execId @@ -83,7 +84,8 @@ private[spark] class ExecutorRunner( process.destroy() process.waitFor() } - worker ! ExecutorStateChanged(appId, execId, ExecutorState.KILLED, None, None) + state = ExecutorState.KILLED + worker ! ExecutorStateChanged(appId, execId, state, None, None) Runtime.getRuntime.removeShutdownHook(shutdownHook) } } @@ -102,7 +104,7 @@ private[spark] class ExecutorRunner( // SPARK-698: do not call the run.cmd script, as process.destroy() // fails to kill a process tree on Windows Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++ - command.arguments.map(substituteVariables) + (command.arguments ++ Seq(appId)).map(substituteVariables) } /** @@ -180,9 +182,9 @@ private[spark] class ExecutorRunner( // long-lived processes only. However, in the future, we might restart the executor a few // times on the same machine. val exitCode = process.waitFor() + state = ExecutorState.FAILED val message = "Command exited with code " + exitCode - worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), - Some(exitCode)) + worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode)) } catch { case interrupted: InterruptedException => logInfo("Runner thread for executor " + fullId + " interrupted") @@ -192,8 +194,9 @@ private[spark] class ExecutorRunner( if (process != null) { process.destroy() } + state = ExecutorState.FAILED val message = e.getClass + ": " + e.getMessage - worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), None) + worker ! ExecutorStateChanged(appId, execId, state, Some(message), None) } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 3904b701b2..991b22d9f8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -23,26 +23,42 @@ import java.io.File import scala.collection.mutable.HashMap import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext.Implicits.global -import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated} +import akka.actor._ import akka.remote.{RemotingLifecycleEvent, AssociationErrorEvent, DisassociatedEvent} import org.apache.spark.Logging -import org.apache.spark.deploy.ExecutorState +import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{Utils, AkkaUtils} - - +import org.apache.spark.deploy.DeployMessages.WorkerStateResponse +import org.apache.spark.deploy.DeployMessages.RegisterWorkerFailed +import org.apache.spark.deploy.DeployMessages.KillExecutor +import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged +import scala.Some +import akka.remote.DisassociatedEvent +import org.apache.spark.deploy.DeployMessages.LaunchExecutor +import org.apache.spark.deploy.DeployMessages.RegisterWorker +import org.apache.spark.deploy.DeployMessages.WorkerSchedulerStateResponse +import org.apache.spark.deploy.DeployMessages.MasterChanged +import org.apache.spark.deploy.DeployMessages.Heartbeat +import org.apache.spark.deploy.DeployMessages.RegisteredWorker +import akka.actor.Terminated + +/** + * @param masterUrls Each url should look like spark://host:port. + */ private[spark] class Worker( host: String, port: Int, webUiPort: Int, cores: Int, memory: Int, - masterUrl: String, + masterUrls: Array[String], workDirPath: String = null) extends Actor with Logging { @@ -54,8 +70,18 @@ private[spark] class Worker( // Send a heartbeat every (heartbeat timeout) / 4 milliseconds val HEARTBEAT_MILLIS = System.getProperty("spark.worker.timeout", "60").toLong * 1000 / 4 + val REGISTRATION_TIMEOUT = 20.seconds + val REGISTRATION_RETRIES = 3 + + // Index into masterUrls that we're currently trying to register with. + var masterIndex = 0 + + val masterLock: Object = new Object() var master: ActorRef = null - var masterWebUiUrl : String = "" + var activeMasterUrl: String = "" + var activeMasterWebUiUrl : String = "" + @volatile var registered = false + @volatile var connected = false val workerId = generateWorkerId() var sparkHome: File = null var workDir: File = null @@ -95,6 +121,7 @@ private[spark] class Worker( } override def preStart() { + assert(!registered) logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( host, port, cores, Utils.megabytesToString(memory))) sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse(".")) @@ -103,46 +130,100 @@ private[spark] class Worker( webUi = new WorkerWebUI(this, workDir, Some(webUiPort)) webUi.start() - connectToMaster() + registerWithMaster() metricsSystem.registerSource(workerSource) metricsSystem.start() } - def connectToMaster() { - logInfo("Connecting to master " + masterUrl) - master = context.actorFor(Master.toAkkaUrl(masterUrl)) - master ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get, publicAddress) - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - context.watch(master) // Doesn't work with remote actors, but useful for testing + def changeMaster(url: String, uiUrl: String) { + masterLock.synchronized { + activeMasterUrl = url + activeMasterWebUiUrl = uiUrl + master = context.actorFor(Master.toAkkaUrl(activeMasterUrl)) + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + context.watch(master) // Doesn't work with remote actors, but useful for testing + connected = true + } + } + + def tryRegisterAllMasters() { + for (masterUrl <- masterUrls) { + logInfo("Connecting to master " + masterUrl + "...") + val actor = context.actorFor(Master.toAkkaUrl(masterUrl)) + actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get, + publicAddress) + } + } + + def registerWithMaster() { + tryRegisterAllMasters() + + var retries = 0 + lazy val retryTimer: Cancellable = + context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) { + retries += 1 + if (registered) { + retryTimer.cancel() + } else if (retries >= REGISTRATION_RETRIES) { + logError("All masters are unresponsive! Giving up.") + System.exit(1) + } else { + tryRegisterAllMasters() + } + } + retryTimer // start timer } import context.dispatcher override def receive = { - case RegisteredWorker(url) => - masterWebUiUrl = url - logInfo("Successfully registered with master") - context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis) { - master ! Heartbeat(workerId) + case RegisteredWorker(masterUrl, masterWebUiUrl) => + logInfo("Successfully registered with master " + masterUrl) + registered = true + changeMaster(masterUrl, masterWebUiUrl) + context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat) + + case SendHeartbeat => + masterLock.synchronized { + if (connected) { master ! Heartbeat(workerId) } } + case MasterChanged(masterUrl, masterWebUiUrl) => + logInfo("Master has changed, new master is at " + masterUrl) + context.unwatch(master) + changeMaster(masterUrl, masterWebUiUrl) + + val execs = executors.values. + map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) + sender ! WorkerSchedulerStateResponse(workerId, execs.toList) + case RegisterWorkerFailed(message) => - logError("Worker registration failed: " + message) - System.exit(1) - - case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) => - logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) - val manager = new ExecutorRunner( - appId, execId, appDesc, cores_, memory_, self, workerId, host, new File(execSparkHome_), workDir) - executors(appId + "/" + execId) = manager - manager.start() - coresUsed += cores_ - memoryUsed += memory_ - master ! ExecutorStateChanged(appId, execId, ExecutorState.RUNNING, None, None) + if (!registered) { + logError("Worker registration failed: " + message) + System.exit(1) + } + + case LaunchExecutor(masterUrl, appId, execId, appDesc, cores_, memory_, execSparkHome_) => + if (masterUrl != activeMasterUrl) { + logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor.") + } else { + logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) + val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_, + self, workerId, host, new File(execSparkHome_), workDir, ExecutorState.RUNNING) + executors(appId + "/" + execId) = manager + manager.start() + coresUsed += cores_ + memoryUsed += memory_ + masterLock.synchronized { + master ! ExecutorStateChanged(appId, execId, manager.state, None, None) + } + } case ExecutorStateChanged(appId, execId, state, message, exitStatus) => - master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) + masterLock.synchronized { + master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) + } val fullId = appId + "/" + execId if (ExecutorState.isFinished(state)) { val executor = executors(fullId) @@ -155,14 +236,18 @@ private[spark] class Worker( memoryUsed -= executor.memory } - case KillExecutor(appId, execId) => - val fullId = appId + "/" + execId - executors.get(fullId) match { - case Some(executor) => - logInfo("Asked to kill executor " + fullId) - executor.kill() - case None => - logInfo("Asked to kill unknown executor " + fullId) + case KillExecutor(masterUrl, appId, execId) => + if (masterUrl != activeMasterUrl) { + logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor " + execId) + } else { + val fullId = appId + "/" + execId + executors.get(fullId) match { + case Some(executor) => + logInfo("Asked to kill executor " + fullId) + executor.kill() + case None => + logInfo("Asked to kill unknown executor " + fullId) + } } case DisassociatedEvent(_, _, _) => @@ -170,17 +255,14 @@ private[spark] class Worker( case RequestWorkerState => { sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, - finishedExecutors.values.toList, masterUrl, cores, memory, - coresUsed, memoryUsed, masterWebUiUrl) + finishedExecutors.values.toList, activeMasterUrl, cores, memory, + coresUsed, memoryUsed, activeMasterWebUiUrl) } } def masterDisconnected() { - // TODO: It would be nice to try to reconnect to the master, but just shut down for now. - // (Note that if reconnecting we would also need to assign IDs differently.) - logError("Connection to master failed! Shutting down.") - executors.values.foreach(_.kill()) - System.exit(1) + logError("Connection to master failed! Waiting for master to reconnect...") + connected = false } def generateWorkerId(): String = { @@ -198,17 +280,18 @@ private[spark] object Worker { def main(argStrings: Array[String]) { val args = new WorkerArguments(argStrings) val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, - args.memory, args.master, args.workDir) + args.memory, args.masters, args.workDir) actorSystem.awaitTermination() } def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int, - masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = { + masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None) + : (ActorSystem, Int) = { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterUrl, workDir), name = "Worker") + masterUrls, workDir), name = "Worker") (actorSystem, boundPort) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 0ae89a864f..3ed528e6b3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -29,7 +29,7 @@ private[spark] class WorkerArguments(args: Array[String]) { var webUiPort = 8081 var cores = inferDefaultCores() var memory = inferDefaultMemory() - var master: String = null + var masters: Array[String] = null var workDir: String = null // Check for settings in environment variables @@ -86,14 +86,14 @@ private[spark] class WorkerArguments(args: Array[String]) { printUsageAndExit(0) case value :: tail => - if (master != null) { // Two positional arguments were given + if (masters != null) { // Two positional arguments were given printUsageAndExit(1) } - master = value + masters = value.stripPrefix("spark://").split(",").map("spark://" + _) parse(tail) case Nil => - if (master == null) { // No positional argument was given + if (masters == null) { // No positional argument was given printUsageAndExit(1) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 07bc479c83..a38e32b339 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -108,7 +108,7 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I val logText = <node>{Utils.offsetBytes(path, startByte, endByte)}</node> - val linkToMaster = <p><a href={worker.masterWebUiUrl}>Back to Master</a></p> + val linkToMaster = <p><a href={worker.activeMasterWebUiUrl}>Back to Master</a></p> val range = <span>Bytes {startByte.toString} - {endByte.toString} of {logLength}</span> diff --git a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index f705a5631a..73fa7d6b6a 100644 --- a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -24,23 +24,15 @@ import akka.remote._ import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._ +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{Utils, AkkaUtils} -import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisteredExecutor -import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.LaunchTask import akka.remote.DisassociatedEvent -import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutor -import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutorFailed -import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisteredExecutor -import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.LaunchTask import akka.remote.AssociationErrorEvent import akka.remote.DisassociatedEvent import akka.actor.Terminated -import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutor -import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutorFailed -private[spark] class StandaloneExecutorBackend( +private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, executorId: String, hostPort: String, @@ -75,15 +67,28 @@ private[spark] class StandaloneExecutorBackend( case LaunchTask(taskDesc) => logInfo("Got assigned task " + taskDesc.taskId) if (executor == null) { - logError("Received launchTask but executor was null") + logError("Received LaunchTask command but executor was null") System.exit(1) } else { executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask) } + case KillTask(taskId, _) => + if (executor == null) { + logError("Received KillTask command but executor was null") + System.exit(1) + } else { + executor.killTask(taskId) + } + case DisassociatedEvent(_, _, _) => logError("Driver terminated or disconnected! Shutting down.") System.exit(1) + + case StopExecutor => + logInfo("Driver commanded a shutdown") + context.stop(self) + context.system.shutdown() } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { @@ -91,7 +96,7 @@ private[spark] class StandaloneExecutorBackend( } } -private[spark] object StandaloneExecutorBackend { +private[spark] object CoarseGrainedExecutorBackend { def run(driverUrl: String, executorId: String, hostname: String, cores: Int) { // Debug code Utils.checkHost(hostname) @@ -102,16 +107,19 @@ private[spark] object StandaloneExecutorBackend { // set it val sparkHostPort = hostname + ":" + boundPort System.setProperty("spark.hostPort", sparkHostPort) + actorSystem.actorOf( - Props(classOf[StandaloneExecutorBackend], driverUrl, executorId, sparkHostPort, cores), + Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores), name = "Executor") actorSystem.awaitTermination() } def main(args: Array[String]) { if (args.length < 4) { - //the reason we allow the last frameworkId argument is to make it easy to kill rogue executors - System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores> [<appid>]") + //the reason we allow the last appid argument is to make it easy to kill rogue executors + System.err.println( + "Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> <cores> " + + "[<appid>]") System.exit(1) } run(args(0), args(1), args(2), args(3).toInt) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 3800063234..de4540493a 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -25,9 +25,10 @@ import java.util.concurrent._ import scala.collection.JavaConversions._ import scala.collection.mutable.HashMap -import org.apache.spark.scheduler._ import org.apache.spark._ -import org.apache.spark.storage.StorageLevel +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.scheduler._ +import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} import org.apache.spark.util.Utils /** @@ -36,7 +37,8 @@ import org.apache.spark.util.Utils private[spark] class Executor( executorId: String, slaveHostname: String, - properties: Seq[(String, String)]) + properties: Seq[(String, String)], + isLocal: Boolean = false) extends Logging { // Application dependencies (added through SparkContext) that we've fetched so far on this node. @@ -73,46 +75,75 @@ private[spark] class Executor( private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) Thread.currentThread.setContextClassLoader(replClassLoader) - // Make any thread terminations due to uncaught exceptions kill the entire - // executor process to avoid surprising stalls. - Thread.setDefaultUncaughtExceptionHandler( - new Thread.UncaughtExceptionHandler { - override def uncaughtException(thread: Thread, exception: Throwable) { - try { - logError("Uncaught exception in thread " + thread, exception) - - // We may have been called from a shutdown hook. If so, we must not call System.exit(). - // (If we do, we will deadlock.) - if (!Utils.inShutdown()) { - if (exception.isInstanceOf[OutOfMemoryError]) { - System.exit(ExecutorExitCode.OOM) - } else { - System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION) + if (!isLocal) { + // Setup an uncaught exception handler for non-local mode. + // Make any thread terminations due to uncaught exceptions kill the entire + // executor process to avoid surprising stalls. + Thread.setDefaultUncaughtExceptionHandler( + new Thread.UncaughtExceptionHandler { + override def uncaughtException(thread: Thread, exception: Throwable) { + try { + logError("Uncaught exception in thread " + thread, exception) + + // We may have been called from a shutdown hook. If so, we must not call System.exit(). + // (If we do, we will deadlock.) + if (!Utils.inShutdown()) { + if (exception.isInstanceOf[OutOfMemoryError]) { + System.exit(ExecutorExitCode.OOM) + } else { + System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION) + } } + } catch { + case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM) + case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE) } - } catch { - case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM) - case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE) } } - } - ) + ) + } val executorSource = new ExecutorSource(this, executorId) // Initialize Spark environment (using system properties read above) - val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false) - SparkEnv.set(env) - env.metricsSystem.registerSource(executorSource) + private val env = { + if (!isLocal) { + val _env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, + isDriver = false, isLocal = false) + SparkEnv.set(_env) + _env.metricsSystem.registerSource(executorSource) + _env + } else { + SparkEnv.get + } + } private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size") // Start worker thread pool - val threadPool = new ThreadPoolExecutor( - 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) + val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") + + // Maintains the list of running tasks. + private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] + + val sparkUser = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER) def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) { - threadPool.execute(new TaskRunner(context, taskId, serializedTask)) + val tr = new TaskRunner(context, taskId, serializedTask) + runningTasks.put(taskId, tr) + threadPool.execute(tr) + } + + def killTask(taskId: Long) { + val tr = runningTasks.get(taskId) + if (tr != null) { + tr.kill() + // We remove the task also in the finally block in TaskRunner.run. + // The reason we need to remove it here is because killTask might be called before the task + // is even launched, and never reaching that finally block. ConcurrentHashMap's remove is + // idempotent. + runningTasks.remove(taskId) + } } /** Get the Yarn approved local directories. */ @@ -124,56 +155,87 @@ private[spark] class Executor( .getOrElse(Option(System.getenv("LOCAL_DIRS")) .getOrElse("")) - if (localDirs.isEmpty()) { + if (localDirs.isEmpty) { throw new Exception("Yarn Local dirs can't be empty") } - return localDirs + localDirs } - class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) + class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) extends Runnable { - override def run() { + @volatile private var killed = false + @volatile private var task: Task[Any] = _ + + def kill() { + logInfo("Executor is trying to kill task " + taskId) + killed = true + if (task != null) { + task.kill() + } + } + + override def run(): Unit = SparkHadoopUtil.get.runAsUser(sparkUser) { () => val startTime = System.currentTimeMillis() SparkEnv.set(env) Thread.currentThread.setContextClassLoader(replClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() logInfo("Running task ID " + taskId) - context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) + execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) var attemptedTask: Option[Task[Any]] = None var taskStart: Long = 0 - def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum - val startGCTime = getTotalGCTime + def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum + val startGCTime = gcTime try { SparkEnv.set(env) Accumulators.clear() val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) updateDependencies(taskFiles, taskJars) - val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + + // If this task has been killed before we deserialized it, let's quit now. Otherwise, + // continue executing the task. + if (killed) { + logInfo("Executor killed task " + taskId) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + return + } + attemptedTask = Some(task) - logInfo("Its epoch is " + task.epoch) + logDebug("Task " + taskId +"'s epoch is " + task.epoch) env.mapOutputTracker.updateEpoch(task.epoch) + + // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() val value = task.run(taskId.toInt) val taskFinish = System.currentTimeMillis() + + // If the task has been killed, let's fail it. + if (task.killed) { + logInfo("Executor killed task " + taskId) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + return + } + for (m <- task.metrics) { - m.hostname = Utils.localHostName + m.hostname = Utils.localHostName() m.executorDeserializeTime = (taskStart - startTime).toInt m.executorRunTime = (taskFinish - taskStart).toInt - m.jvmGCTime = getTotalGCTime - startGCTime + m.jvmGCTime = gcTime - startGCTime } - //TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c - // we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could - // just change the relevants bytes in the byte buffer + // TODO I'd also like to track the time it takes to serialize the task results, but that is + // huge headache, b/c we need to serialize the task metrics first. If TaskMetrics had a + // custom serialized format, we could just change the relevants bytes in the byte buffer val accumUpdates = Accumulators.values + val directResult = new DirectTaskResult(value, accumUpdates, task.metrics.getOrElse(null)) val serializedDirectResult = ser.serialize(directResult) logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit) val serializedResult = { if (serializedDirectResult.limit >= akkaFrameSize - 1024) { logInfo("Storing result for " + taskId + " in local BlockManager") - val blockId = "taskresult_" + taskId + val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) ser.serialize(new IndirectTaskResult[Any](blockId)) @@ -182,12 +244,13 @@ private[spark] class Executor( serializedDirectResult } } - context.statusUpdate(taskId, TaskState.FINISHED, serializedResult) + + execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) logInfo("Finished task ID " + taskId) } catch { case ffe: FetchFailedException => { val reason = ffe.toTaskEndReason - context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) } case t: Throwable => { @@ -195,10 +258,10 @@ private[spark] class Executor( val metrics = attemptedTask.flatMap(t => t.metrics) for (m <- metrics) { m.executorRunTime = serviceTime - m.jvmGCTime = getTotalGCTime - startGCTime + m.jvmGCTime = gcTime - startGCTime } val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics) - context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) // TODO: Should we exit the whole executor here? On the one hand, the failed task may // have left some weird state around depending on when the exception was thrown, but on @@ -206,6 +269,8 @@ private[spark] class Executor( logError("Exception in task ID " + taskId, t) //System.exit(1) } + } finally { + runningTasks.remove(taskId) } } } @@ -215,7 +280,7 @@ private[spark] class Executor( * created by the interpreter to the search path */ private def createClassLoader(): ExecutorURLClassLoader = { - var loader = this.getClass.getClassLoader + val loader = this.getClass.getClassLoader // For each of the jars in the jarSet, add them to the class loader. // We assume each of the files has already been fetched. @@ -237,7 +302,7 @@ private[spark] class Executor( val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader]) - return constructor.newInstance(classUri, parent) + constructor.newInstance(classUri, parent) } catch { case _: ClassNotFoundException => logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") @@ -245,7 +310,7 @@ private[spark] class Executor( null } } else { - return parent + parent } } diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index da62091980..b56d8c9912 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -18,14 +18,18 @@ package org.apache.spark.executor import java.nio.ByteBuffer -import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver} -import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => MesosTaskStatus, _} -import org.apache.spark.TaskState.TaskState + import com.google.protobuf.ByteString -import org.apache.spark.{Logging} + +import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver} +import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} + +import org.apache.spark.Logging import org.apache.spark.TaskState +import org.apache.spark.TaskState.TaskState import org.apache.spark.util.Utils + private[spark] class MesosExecutorBackend extends MesosExecutor with ExecutorBackend @@ -71,7 +75,11 @@ private[spark] class MesosExecutorBackend } override def killTask(d: ExecutorDriver, t: TaskID) { - logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)") + if (executor == null) { + logError("Received KillTask but executor was null") + } else { + executor.killTask(t.getValue.toLong) + } } override def reregistered(d: ExecutorDriver, p2: SlaveInfo) {} diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index f311141148..0b4892f98f 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -102,4 +102,9 @@ class ShuffleWriteMetrics extends Serializable { * Number of bytes written for a shuffle */ var shuffleBytesWritten: Long = _ + + /** + * Time spent blocking on writes to disk or buffer cache, in nanoseconds. + */ + var shuffleWriteTime: Long = _ } diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index c24fd48c04..703bc6a9ca 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -79,7 +79,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] private val registerRequests = new SynchronizedQueue[SendingConnection] - implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool()) + implicit val futureExecContext = ExecutionContext.fromExecutor( + Utils.newDaemonCachedThreadPool("Connection manager future execution context")) private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala index 3c29700920..1b9fa1e53a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala @@ -20,17 +20,18 @@ package org.apache.spark.network.netty import io.netty.buffer._ import org.apache.spark.Logging +import org.apache.spark.storage.{TestBlockId, BlockId} private[spark] class FileHeader ( val fileLen: Int, - val blockId: String) extends Logging { + val blockId: BlockId) extends Logging { lazy val buffer = { val buf = Unpooled.buffer() buf.capacity(FileHeader.HEADER_SIZE) buf.writeInt(fileLen) - buf.writeInt(blockId.length) - blockId.foreach((x: Char) => buf.writeByte(x)) + buf.writeInt(blockId.name.length) + blockId.name.foreach((x: Char) => buf.writeByte(x)) //padding the rest of header if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) @@ -57,18 +58,15 @@ private[spark] object FileHeader { for (i <- 1 to idLength) { idBuilder += buf.readByte().asInstanceOf[Char] } - val blockId = idBuilder.toString() + val blockId = BlockId(idBuilder.toString()) new FileHeader(length, blockId) } - - def main (args:Array[String]){ - - val header = new FileHeader(25,"block_0"); - val buf = header.buffer; - val newheader = FileHeader.create(buf); - System.out.println("id="+newheader.blockId+",size="+newheader.fileLen) - + def main (args:Array[String]) { + val header = new FileHeader(25, TestBlockId("my_block")) + val buf = header.buffer + val newHeader = FileHeader.create(buf) + System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala index 9493ccffd9..481ff8c3e0 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala @@ -27,12 +27,13 @@ import org.apache.spark.Logging import org.apache.spark.network.ConnectionManagerId import scala.collection.JavaConverters._ +import org.apache.spark.storage.BlockId private[spark] class ShuffleCopier extends Logging { - def getBlock(host: String, port: Int, blockId: String, - resultCollectCallback: (String, Long, ByteBuf) => Unit) { + def getBlock(host: String, port: Int, blockId: BlockId, + resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt @@ -41,7 +42,7 @@ private[spark] class ShuffleCopier extends Logging { try { fc.init() fc.connect(host, port) - fc.sendRequest(blockId) + fc.sendRequest(blockId.name) fc.waitForClose() fc.close() } catch { @@ -53,14 +54,14 @@ private[spark] class ShuffleCopier extends Logging { } } - def getBlock(cmId: ConnectionManagerId, blockId: String, - resultCollectCallback: (String, Long, ByteBuf) => Unit) { + def getBlock(cmId: ConnectionManagerId, blockId: BlockId, + resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { getBlock(cmId.host, cmId.port, blockId, resultCollectCallback) } def getBlocks(cmId: ConnectionManagerId, - blocks: Seq[(String, Long)], - resultCollectCallback: (String, Long, ByteBuf) => Unit) { + blocks: Seq[(BlockId, Long)], + resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { for ((blockId, size) <- blocks) { getBlock(cmId, blockId, resultCollectCallback) @@ -71,7 +72,7 @@ private[spark] class ShuffleCopier extends Logging { private[spark] object ShuffleCopier extends Logging { - private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit) + private class ShuffleClientHandler(resultCollectCallBack: (BlockId, Long, ByteBuf) => Unit) extends FileClientHandler with Logging { override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { @@ -79,14 +80,14 @@ private[spark] object ShuffleCopier extends Logging { resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) } - override def handleError(blockId: String) { + override def handleError(blockId: BlockId) { if (!isComplete) { resultCollectCallBack(blockId, -1, null) } } } - def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { + def echoResultCollectCallBack(blockId: BlockId, size: Long, content: ByteBuf) { if (size != -1) { logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") } @@ -99,7 +100,7 @@ private[spark] object ShuffleCopier extends Logging { } val host = args(0) val port = args(1).toInt - val file = args(2) + val blockId = BlockId(args(2)) val threads = if (args.length > 3) args(3).toInt else 10 val copiers = Executors.newFixedThreadPool(80) @@ -107,12 +108,12 @@ private[spark] object ShuffleCopier extends Logging { Executors.callable(new Runnable() { def run() { val copier = new ShuffleCopier() - copier.getBlock(host, port, file, echoResultCollectCallBack) + copier.getBlock(host, port, blockId, echoResultCollectCallBack) } }) }).asJava copiers.invokeAll(tasks) - copiers.shutdown + copiers.shutdown() System.exit(0) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala index 8afcbe190a..546d921067 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala @@ -21,6 +21,7 @@ import java.io.File import org.apache.spark.Logging import org.apache.spark.util.Utils +import org.apache.spark.storage.{BlockId, FileSegment} private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging { @@ -53,8 +54,8 @@ private[spark] object ShuffleSender { val localDirs = args.drop(2).map(new File(_)) val pResovler = new PathResolver { - override def getAbsolutePath(blockId: String): String = { - if (!blockId.startsWith("shuffle_")) { + override def getBlockLocation(blockId: BlockId): FileSegment = { + if (!blockId.isShuffle) { throw new Exception("Block " + blockId + " is not a shuffle block") } // Figure out which local directory it hashes to, and which subdirectory in that @@ -62,8 +63,8 @@ private[spark] object ShuffleSender { val dirId = hash % localDirs.length val subDirId = (hash / localDirs.length) % subDirsPerLocalDir val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) - val file = new File(subDir, blockId) - return file.getAbsolutePath + val file = new File(subDir, blockId.name) + return new FileSegment(file, 0, file.length()) } } val sender = new ShuffleSender(port, pResovler) diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index f132e2b735..70a5a8caff 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -15,6 +15,8 @@ * limitations under the License. */ +package org.apache + /** * Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to * Spark, while [[org.apache.spark.rdd.RDD]] is the data type representing a distributed collection, diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala new file mode 100644 index 0000000000..44c5078621 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.ExecutionContext.Implicits.global + +import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} +import scala.reflect.ClassTag + +/** + * A set of asynchronous RDD actions available through an implicit conversion. + * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. + */ +class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging { + + /** + * Returns a future for counting the number of elements in the RDD. + */ + def countAsync(): FutureAction[Long] = { + val totalCount = new AtomicLong + self.context.submitJob( + self, + (iter: Iterator[T]) => { + var result = 0L + while (iter.hasNext) { + result += 1L + iter.next() + } + result + }, + Range(0, self.partitions.size), + (index: Int, data: Long) => totalCount.addAndGet(data), + totalCount.get()) + } + + /** + * Returns a future for retrieving all elements of this RDD. + */ + def collectAsync(): FutureAction[Seq[T]] = { + val results = new Array[Array[T]](self.partitions.size) + self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size), + (index, data) => results(index) = data, results.flatten.toSeq) + } + + /** + * Returns a future for retrieving the first num elements of the RDD. + */ + def takeAsync(num: Int): FutureAction[Seq[T]] = { + val f = new ComplexFutureAction[Seq[T]] + + f.run { + val results = new ArrayBuffer[T](num) + val totalParts = self.partitions.length + var partsScanned = 0 + while (results.size < num && partsScanned < totalParts) { + // The number of partitions to try in this iteration. It is ok for this number to be + // greater than totalParts because we actually cap it at totalParts in runJob. + var numPartsToTry = 1 + if (partsScanned > 0) { + // If we didn't find any rows after the first iteration, just try all partitions next. + // Otherwise, interpolate the number of partitions we need to try, but overestimate it + // by 50%. + if (results.size == 0) { + numPartsToTry = totalParts - 1 + } else { + numPartsToTry = (1.5 * num * partsScanned / results.size).toInt + } + } + numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions + + val left = num - results.size + val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + + val buf = new Array[Array[T]](p.size) + f.runJob(self, + (it: Iterator[T]) => it.take(left).toArray, + p, + (index: Int, data: Array[T]) => buf(index) = data, + Unit) + + buf.foreach(results ++= _.take(num - results.size)) + partsScanned += numPartsToTry + } + results.toSeq + } + + f + } + + /** + * Applies a function f to all elements of this RDD. + */ + def foreachAsync(f: T => Unit): FutureAction[Unit] = { + self.context.submitJob[T, Unit, Unit](self, _.foreach(f), Range(0, self.partitions.size), + (index, data) => Unit, Unit) + } + + /** + * Applies a function f to each partition of this RDD. + */ + def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = { + self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size), + (index, data) => Unit, Unit) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index fe2946bcbe..63b9fe1478 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -18,15 +18,15 @@ package org.apache.spark.rdd import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext} -import org.apache.spark.storage.BlockManager +import org.apache.spark.storage.{BlockId, BlockManager} import scala.reflect.ClassTag -private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition { +private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends Partition { val index = idx } private[spark] -class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[String]) +class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { @transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 3f4d4ad46a..99ea6e8ee8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -19,6 +19,7 @@ package org.apache.spark.rdd import scala.reflect.ClassTag import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.{NullWritable, BytesWritable} @@ -84,9 +85,9 @@ private[spark] object CheckpointRDD extends Logging { def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) { val env = SparkEnv.get val outputDir = new Path(path) - val fs = outputDir.getFileSystem(env.hadoop.newConfiguration()) + val fs = outputDir.getFileSystem(SparkHadoopUtil.get.newConfiguration()) - val finalOutputName = splitIdToFile(ctx.splitId) + val finalOutputName = splitIdToFile(ctx.partitionId) val finalOutputPath = new Path(outputDir, finalOutputName) val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId) @@ -123,7 +124,7 @@ private[spark] object CheckpointRDD extends Logging { def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = { val env = SparkEnv.get - val fs = path.getFileSystem(env.hadoop.newConfiguration()) + val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration()) val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt val fileInputStream = fs.open(path, bufferSize) val serializer = env.serializer.newInstance() @@ -146,7 +147,7 @@ private[spark] object CheckpointRDD extends Logging { val sc = new SparkContext(cluster, "CheckpointRDD Test") val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) val path = new Path(hdfsPath, "temp") - val fs = path.getFileSystem(env.hadoop.newConfiguration()) + val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration()) sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _) val cpRDD = new CheckpointRDD[Int](sc, path.toString) assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same") diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 0187256a8e..911a002884 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -18,13 +18,12 @@ package org.apache.spark.rdd import java.io.{ObjectOutputStream, IOException} -import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext} +import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} +import org.apache.spark.util.AppendOnlyMap private[spark] sealed trait CoGroupSplitDep extends Serializable @@ -105,17 +104,14 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val split = s.asInstanceOf[CoGroupPartition] val numRdds = split.deps.size // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs) - val map = new JHashMap[K, Seq[ArrayBuffer[Any]]] + val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]] - def getSeq(k: K): Seq[ArrayBuffer[Any]] = { - val seq = map.get(k) - if (seq != null) { - seq - } else { - val seq = Array.fill(numRdds)(new ArrayBuffer[Any]) - map.put(k, seq) - seq - } + val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => { + if (hadVal) oldVal else Array.fill(numRdds)(new ArrayBuffer[Any]) + } + + val getSeq = (k: K) => { + map.changeValue(k, update) } val ser = SparkEnv.get.serializerManager.get(serializerClass) @@ -129,12 +125,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: case ShuffleCoGroupSplitDep(shuffleId) => { // Read map outputs of shuffle val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach { + fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach { kv => getSeq(kv._1)(depNum) += kv._2 } } } - JavaConversions.mapAsScalaMap(map).iterator + new InterruptibleIterator(context, map.iterator) } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index d3b3fffd40..32901a508f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -27,54 +27,19 @@ import org.apache.hadoop.mapred.RecordReader import org.apache.hadoop.mapred.Reporter import org.apache.hadoop.util.ReflectionUtils -import org.apache.spark.{Logging, Partition, SerializableWritable, SparkContext, SparkEnv, - TaskContext} +import org.apache.spark._ import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.NextIterator import org.apache.hadoop.conf.{Configuration, Configurable} -/** - * An RDD that reads a file (or multiple files) from Hadoop (e.g. files in HDFS, the local file - * system, or S3). - * This accepts a general, broadcasted Hadoop Configuration because those tend to remain the same - * across multiple reads; the 'path' is the only variable that is different across new JobConfs - * created from the Configuration. - */ -class HadoopFileRDD[K, V]( - sc: SparkContext, - path: String, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int) - extends HadoopRDD[K, V](sc, broadcastedConf, inputFormatClass, keyClass, valueClass, minSplits) { - - override def getJobConf(): JobConf = { - if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) { - // getJobConf() has been called previously, so there is already a local cache of the JobConf - // needed by this RDD. - return HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] - } else { - // Create a new JobConf, set the input file/directory paths to read from, and cache the - // JobConf (i.e., in a shared hash map in the slave's JVM process that's accessible through - // HadoopRDD.putCachedMetadata()), so that we only create one copy across multiple - // getJobConf() calls for this RDD in the local process. - // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. - val newJobConf = new JobConf(broadcastedConf.value.value) - FileInputFormat.setInputPaths(newJobConf, path) - HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) - return newJobConf - } - } -} /** * A Spark split class that wraps around a Hadoop InputSplit. */ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit) extends Partition { - + val inputSplit = new SerializableWritable[InputSplit](s) override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt @@ -83,11 +48,24 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp } /** - * A base class that provides core functionality for reading data partitions stored in Hadoop. + * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, + * sources in HBase, or S3). + * + * @param sc The SparkContext to associate the RDD with. + * @param broadCastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed + * variabe references an instance of JobConf, then that JobConf will be used for the Hadoop job. + * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration. + * @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that HadoopRDD + * creates. + * @param inputFormatClass Storage format of the data to be read. + * @param keyClass Class of the key associated with the inputFormatClass. + * @param valueClass Class of the value associated with the inputFormatClass. + * @param minSplits Minimum number of Hadoop Splits (HadoopRDD partitions) to generate. */ class HadoopRDD[K, V]( sc: SparkContext, broadcastedConf: Broadcast[SerializableWritable[Configuration]], + initLocalJobConfFuncOpt: Option[JobConf => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], @@ -105,6 +83,7 @@ class HadoopRDD[K, V]( sc, sc.broadcast(new SerializableWritable(conf)) .asInstanceOf[Broadcast[SerializableWritable[Configuration]]], + None /* initLocalJobConfFuncOpt */, inputFormatClass, keyClass, valueClass, @@ -130,6 +109,7 @@ class HadoopRDD[K, V]( // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. val newJobConf = new JobConf(broadcastedConf.value.value) + initLocalJobConfFuncOpt.map(f => f(newJobConf)) HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) return newJobConf } @@ -164,38 +144,41 @@ class HadoopRDD[K, V]( array } - override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] { - val split = theSplit.asInstanceOf[HadoopPartition] - logInfo("Input split: " + split.inputSplit) - var reader: RecordReader[K, V] = null - - val jobConf = getJobConf() - val inputFormat = getInputFormat(jobConf) - reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) - - // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback{ () => closeIfNeeded() } - - val key: K = reader.createKey() - val value: V = reader.createValue() - - override def getNext() = { - try { - finished = !reader.next(key, value) - } catch { - case eof: EOFException => - finished = true + override def compute(theSplit: Partition, context: TaskContext) = { + val iter = new NextIterator[(K, V)] { + val split = theSplit.asInstanceOf[HadoopPartition] + logInfo("Input split: " + split.inputSplit) + var reader: RecordReader[K, V] = null + + val jobConf = getJobConf() + val inputFormat = getInputFormat(jobConf) + reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) + + // Register an on-task-completion callback to close the input stream. + context.addOnCompleteCallback{ () => closeIfNeeded() } + + val key: K = reader.createKey() + val value: V = reader.createValue() + + override def getNext() = { + try { + finished = !reader.next(key, value) + } catch { + case eof: EOFException => + finished = true + } + (key, value) } - (key, value) - } - override def close() { - try { - reader.close() - } catch { - case e: Exception => logWarning("Exception in RecordReader.close()", e) + override def close() { + try { + reader.close() + } catch { + case e: Exception => logWarning("Exception in RecordReader.close()", e) + } } } + new InterruptibleIterator[(K, V)](context, iter) } override def getPreferredLocations(split: Partition): Seq[String] = { @@ -216,10 +199,10 @@ private[spark] object HadoopRDD { * The three methods below are helpers for accessing the local map, a property of the SparkEnv of * the local process. */ - def getCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.get(key) + def getCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.get(key) - def containsCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.containsKey(key) + def containsCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.containsKey(key) def putCachedMetadata(key: String, value: Any) = - SparkEnv.get.hadoop.hadoopJobMetadata.put(key, value) + SparkEnv.get.hadoopJobMetadata.put(key, value) } diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala index 3cf22851dd..67636751bb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala @@ -22,14 +22,14 @@ import scala.reflect.ClassTag /** - * A variant of the MapPartitionsRDD that passes the partition index into the - * closure. This can be used to generate or collect partition specific - * information such as the number of tuples in a partition. + * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the + * TaskContext, the closure can either get access to the interruptible flag or get the index + * of the partition in the RDD. */ private[spark] -class MapPartitionsWithIndexRDD[U: ClassTag, T: ClassTag]( +class MapPartitionsWithContextRDD[U: ClassTag, T: ClassTag]( prev: RDD[T], - f: (Int, Iterator[T]) => Iterator[U], + f: (TaskContext, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean ) extends RDD[U](prev) { @@ -38,5 +38,5 @@ class MapPartitionsWithIndexRDD[U: ClassTag, T: ClassTag]( override val partitioner = if (preservesPartitioning) prev.partitioner else None override def compute(split: Partition, context: TaskContext) = - f(split.index, firstParent[T].iterator(split, context)) + f(context, firstParent[T].iterator(split, context)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 7b3a89f7e0..2662d48c84 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ -import org.apache.spark.{Dependency, Logging, Partition, SerializableWritable, SparkContext, TaskContext} +import org.apache.spark.{InterruptibleIterator, Logging, Partition, SerializableWritable, SparkContext, TaskContext} private[spark] @@ -71,49 +71,52 @@ class NewHadoopRDD[K, V]( result } - override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] { - val split = theSplit.asInstanceOf[NewHadoopPartition] - logInfo("Input split: " + split.serializableHadoopSplit) - val conf = confBroadcast.value.value - val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) - val format = inputFormatClass.newInstance - if (format.isInstanceOf[Configurable]) { - format.asInstanceOf[Configurable].setConf(conf) - } - val reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - - // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => close()) - - var havePair = false - var finished = false - - override def hasNext: Boolean = { - if (!finished && !havePair) { - finished = !reader.nextKeyValue - havePair = !finished + override def compute(theSplit: Partition, context: TaskContext) = { + val iter = new Iterator[(K, V)] { + val split = theSplit.asInstanceOf[NewHadoopPartition] + logInfo("Input split: " + split.serializableHadoopSplit) + val conf = confBroadcast.value.value + val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val format = inputFormatClass.newInstance + if (format.isInstanceOf[Configurable]) { + format.asInstanceOf[Configurable].setConf(conf) + } + val reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + + // Register an on-task-completion callback to close the input stream. + context.addOnCompleteCallback(() => close()) + + var havePair = false + var finished = false + + override def hasNext: Boolean = { + if (!finished && !havePair) { + finished = !reader.nextKeyValue + havePair = !finished + } + !finished } - !finished - } - override def next: (K, V) = { - if (!hasNext) { - throw new java.util.NoSuchElementException("End of stream") + override def next(): (K, V) = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + havePair = false + (reader.getCurrentKey, reader.getCurrentValue) } - havePair = false - return (reader.getCurrentKey, reader.getCurrentValue) - } - private def close() { - try { - reader.close() - } catch { - case e: Exception => logWarning("Exception in RecordReader.close()", e) + private def close() { + try { + reader.close() + } catch { + case e: Exception => logWarning("Exception in RecordReader.close()", e) + } } } + new InterruptibleIterator(context, iter) } override def getPreferredLocations(split: Partition): Seq[String] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index c8e623081a..0c2a051a42 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -85,18 +85,24 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) } val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) if (self.partitioner == Some(partitioner)) { - self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) + self.mapPartitionsWithContext((context, iter) => { + new InterruptibleIterator(context, aggregator.combineValuesByKey(iter)) + }, preservesPartitioning = true) } else if (mapSideCombine) { val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner) .setSerializer(serializerClass) - partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true) + partitioned.mapPartitionsWithContext((context, iter) => { + new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter)) + }, preservesPartitioning = true) } else { // Don't apply map-side combiner. // A sanity check to make sure mergeCombiners is not defined. assert(mergeCombiners == null) val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass) - values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) + values.mapPartitionsWithContext((context, iter) => { + new InterruptibleIterator(context, aggregator.combineValuesByKey(iter)) + }, preservesPartitioning = true) } } @@ -565,7 +571,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" <split #> <attempt # = spark task #> */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber) + val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.partitionId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = outputFormatClass.newInstance val committer = format.getOutputCommitter(hadoopContext) @@ -664,7 +670,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt - writer.setup(context.stageId, context.splitId, attemptNumber) + writer.setup(context.stageId, context.partitionId, attemptNumber) writer.open() var count = 0 diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 78fe0cdcdb..09d0a8189d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -96,8 +96,9 @@ private[spark] class ParallelCollectionRDD[T: ClassTag]( slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray } - override def compute(s: Partition, context: TaskContext) = - s.asInstanceOf[ParallelCollectionPartition[T]].iterator + override def compute(s: Partition, context: TaskContext) = { + new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator) + } override def getPreferredLocations(s: Partition): Seq[String] = { locationPrefs.getOrElse(s.index, Nil) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 731ef90c90..3c237ca20a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -269,6 +269,19 @@ abstract class RDD[T: ClassTag]( def distinct(): RDD[T] = distinct(partitions.size) /** + * Return a new RDD that has exactly numPartitions partitions. + * + * Can increase or decrease the level of parallelism in this RDD. Internally, this uses + * a shuffle to redistribute data. + * + * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, + * which can avoid performing a shuffle. + */ + def repartition(numPartitions: Int): RDD[T] = { + coalesce(numPartitions, true) + } + + /** * Return a new RDD that is reduced into `numPartitions` partitions. * * This results in a narrow dependency, e.g. if you go from 1000 partitions @@ -421,26 +434,39 @@ abstract class RDD[T: ClassTag]( command: Seq[String], env: Map[String, String] = Map(), printPipeContext: (String => Unit) => Unit = null, - printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = + printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = { new PipedRDD(this, command, env, if (printPipeContext ne null) sc.clean(printPipeContext) else null, if (printRDDElement ne null) sc.clean(printRDDElement) else null) + } /** * Return a new RDD by applying a function to each partition of this RDD. */ def mapPartitions[U: ClassTag](f: Iterator[T] => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = + preservesPartitioning: Boolean = false): RDD[U] = { new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning) + } /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. */ def mapPartitionsWithIndex[U: ClassTag]( - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) + f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { + val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter) + new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning) + } + + /** + * Return a new RDD by applying a function to each partition of this RDD. This is a variant of + * mapPartitions that also passes the TaskContext into the closure. + */ + def mapPartitionsWithContext[U: ClassTag]( + f: (TaskContext, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = { + new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning) + } /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index @@ -448,22 +474,23 @@ abstract class RDD[T: ClassTag]( */ @deprecated("use mapPartitionsWithIndex", "0.7.0") def mapPartitionsWithSplit[U: ClassTag]( - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) + f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { + mapPartitionsWithIndex(f, preservesPartitioning) + } /** * Maps f over this RDD, where f takes an additional parameter of type A. This * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def mapWith[A: ClassTag, U: ClassTag](constructA: Int => A, preservesPartitioning: Boolean = false) - (f:(T, A) => U): RDD[U] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { - val a = constructA(index) - iter.map(t => f(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + def mapWith[A: ClassTag, U: ClassTag] + (constructA: Int => A, preservesPartitioning: Boolean = false) + (f: (T, A) => U): RDD[U] = { + def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { + val a = constructA(context.partitionId) + iter.map(t => f(t, a)) + } + new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) } /** @@ -471,13 +498,14 @@ abstract class RDD[T: ClassTag]( * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def flatMapWith[A: ClassTag, U: ClassTag](constructA: Int => A, preservesPartitioning: Boolean = false) - (f:(T, A) => Seq[U]): RDD[U] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { - val a = constructA(index) - iter.flatMap(t => f(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + def flatMapWith[A: ClassTag, U: ClassTag] + (constructA: Int => A, preservesPartitioning: Boolean = false) + (f: (T, A) => Seq[U]): RDD[U] = { + def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { + val a = constructA(context.partitionId) + iter.flatMap(t => f(t, a)) + } + new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) } /** @@ -485,13 +513,12 @@ abstract class RDD[T: ClassTag]( * This additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def foreachWith[A: ClassTag](constructA: Int => A) - (f:(T, A) => Unit) { - def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { - val a = constructA(index) - iter.map(t => {f(t, a); t}) - } - (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {}) + def foreachWith[A: ClassTag](constructA: Int => A)(f: (T, A) => Unit) { + def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { + val a = constructA(context.partitionId) + iter.map(t => {f(t, a); t}) + } + new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {}) } /** @@ -499,13 +526,12 @@ abstract class RDD[T: ClassTag]( * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def filterWith[A: ClassTag](constructA: Int => A) - (p:(T, A) => Boolean): RDD[T] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { - val a = constructA(index) - iter.filter(t => p(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true) + def filterWith[A: ClassTag](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = { + def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { + val a = constructA(context.partitionId) + iter.filter(t => p(t, a)) + } + new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true) } /** @@ -544,16 +570,14 @@ abstract class RDD[T: ClassTag]( * Applies a function f to all elements of this RDD. */ def foreach(f: T => Unit) { - val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) + sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f)) } /** * Applies a function f to each partition of this RDD. */ def foreachPartition(f: Iterator[T] => Unit) { - val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) + sc.runJob(this, (iter: Iterator[T]) => f(iter)) } /** @@ -678,6 +702,8 @@ abstract class RDD[T: ClassTag]( */ def count(): Long = { sc.runJob(this, (iter: Iterator[T]) => { + // Use a while loop to count the number of elements rather than iter.size because + // iter.size uses a for loop, which is slightly slower in current version of Scala. var result = 0L while (iter.hasNext) { result += 1L diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index b7205865cf..1d109a2496 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -57,7 +57,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( override def compute(split: Partition, context: TaskContext): Iterator[P] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId - SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context.taskMetrics, + SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, SparkEnv.get.serializerManager.get(serializerClass)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 85c512f3de..aab30b1bb4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -111,7 +111,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( } case ShuffleCoGroupSplitDep(shuffleId) => { val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index, - context.taskMetrics, serializer) + context, serializer) iter.foreach(op) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 03fe0e00f9..ab7b3a2e24 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -29,8 +29,8 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import org.apache.spark.storage.{BlockManager, BlockManagerMaster} -import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId} +import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of @@ -42,34 +42,40 @@ import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap} * locations to run each task on, based on the current cache status, and passes these to the * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are - * not caused by shuffie file loss are handled by the TaskScheduler, which will retry each task + * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * * THREADING: This class runs all its logic in a single thread executing the run() method, to which - * events are submitted using a synchonized queue (eventQueue). The public API methods, such as + * events are submitted using a synchronized queue (eventQueue). The public API methods, such as * runJob, taskEnded and executorLost, post events asynchronously to this queue. All other methods * should be private. */ private[spark] class DAGScheduler( taskSched: TaskScheduler, - mapOutputTracker: MapOutputTracker, + mapOutputTracker: MapOutputTrackerMaster, blockManagerMaster: BlockManagerMaster, env: SparkEnv) - extends TaskSchedulerListener with Logging { + extends Logging { def this(taskSched: TaskScheduler) { - this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get) + this(taskSched, SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], + SparkEnv.get.blockManager.master, SparkEnv.get) } - taskSched.setListener(this) + taskSched.setDAGScheduler(this) // Called by TaskScheduler to report task's starting. - override def taskStarted(task: Task[_], taskInfo: TaskInfo) { + def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventQueue.put(BeginEvent(task, taskInfo)) } + // Called to report that a task has completed and results are being fetched remotely. + def taskGettingResult(task: Task[_], taskInfo: TaskInfo) { + eventQueue.put(GettingResultEvent(task, taskInfo)) + } + // Called by TaskScheduler to report task completions or failures. - override def taskEnded( + def taskEnded( task: Task[_], reason: TaskEndReason, result: Any, @@ -80,17 +86,18 @@ class DAGScheduler( } // Called by TaskScheduler when an executor fails. - override def executorLost(execId: String) { + def executorLost(execId: String) { eventQueue.put(ExecutorLost(execId)) } // Called by TaskScheduler when a host is added - override def executorGained(execId: String, host: String) { + def executorGained(execId: String, host: String) { eventQueue.put(ExecutorGained(execId, host)) } - // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures. - override def taskSetFailed(taskSet: TaskSet, reason: String) { + // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or + // cancellation of the job itself. + def taskSetFailed(taskSet: TaskSet, reason: String) { eventQueue.put(TaskSetFailed(taskSet, reason)) } @@ -105,13 +112,15 @@ class DAGScheduler( private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent] - val nextJobId = new AtomicInteger(0) + private[scheduler] val nextJobId = new AtomicInteger(0) + + def numTotalJobs: Int = nextJobId.get() - val nextStageId = new AtomicInteger(0) + private val nextStageId = new AtomicInteger(0) - val stageIdToStage = new TimeStampedHashMap[Int, Stage] + private val stageIdToStage = new TimeStampedHashMap[Int, Stage] - val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] + private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] @@ -128,6 +137,7 @@ class DAGScheduler( // stray messages to detect. val failedEpoch = new HashMap[String, Long] + // stage id to the active job val idToActiveJob = new HashMap[Int, ActiveJob] val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done @@ -139,7 +149,7 @@ class DAGScheduler( val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] - val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup) + val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup) // Start a thread to run the DAGScheduler event loop def start() { @@ -157,7 +167,7 @@ class DAGScheduler( private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { if (!cacheLocs.contains(rdd.id)) { - val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray + val blockIds = rdd.partitions.indices.map(index=> RDDBlockId(rdd.id, index)).toArray[BlockId] val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster) cacheLocs(rdd.id) = blockIds.map { id => locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId)) @@ -179,7 +189,7 @@ class DAGScheduler( shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => - val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId) + val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId) shuffleToMapStage(shuffleDep.shuffleId) = stage stage } @@ -192,6 +202,7 @@ class DAGScheduler( */ private def newStage( rdd: RDD[_], + numTasks: Int, shuffleDep: Option[ShuffleDependency[_,_]], jobId: Int, callSite: Option[String] = None) @@ -204,9 +215,10 @@ class DAGScheduler( mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size) } val id = nextStageId.getAndIncrement() - val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) + val stage = + new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) stageIdToStage(id) = stage - stageToInfos(stage) = StageInfo(stage) + stageToInfos(stage) = new StageInfo(stage) stage } @@ -262,32 +274,41 @@ class DAGScheduler( } /** - * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a - * JobWaiter whose getResult() method will return the result of the job when it is complete. - * - * The job is assumed to have at least one partition; zero partition jobs should be handled - * without a JobSubmitted event. + * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object + * can be used to block until the the job finishes executing or can be used to cancel the job. */ - private[scheduler] def prepareJob[T, U: ClassTag]( - finalRdd: RDD[T], + def submitJob[T, U]( + rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: String, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null) - : (JobSubmitted, JobWaiter[U]) = + properties: Properties = null): JobWaiter[U] = { + // Check to make sure we are not launching a task on a partition that does not exist. + val maxPartitions = rdd.partitions.length + partitions.find(p => p >= maxPartitions).foreach { p => + throw new IllegalArgumentException( + "Attempting to access a non-existent partition: " + p + ". " + + "Total number of partitions: " + maxPartitions) + } + + val jobId = nextJobId.getAndIncrement() + if (partitions.size == 0) { + return new JobWaiter[U](this, jobId, 0, resultHandler) + } + assert(partitions.size > 0) - val waiter = new JobWaiter(partitions.size, resultHandler) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, - properties) - (toSubmit, waiter) + val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) + eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, + waiter, properties)) + waiter } def runJob[T, U: ClassTag]( - finalRdd: RDD[T], + rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: String, @@ -295,21 +316,7 @@ class DAGScheduler( resultHandler: (Int, U) => Unit, properties: Properties = null) { - if (partitions.size == 0) { - return - } - - // Check to make sure we are not launching a task on a partition that does not exist. - val maxPartitions = finalRdd.partitions.length - partitions.find(p => p >= maxPartitions).foreach { p => - throw new IllegalArgumentException( - "Attempting to access a non-existent partition: " + p + ". " + - "Total number of partitions: " + maxPartitions) - } - - val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob( - finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties) - eventQueue.put(toSubmit) + val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) waiter.awaitResult() match { case JobSucceeded => {} case JobFailed(exception: Exception, _) => @@ -330,19 +337,40 @@ class DAGScheduler( val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray - eventQueue.put(JobSubmitted(rdd, func2, partitions, allowLocal = false, callSite, listener, properties)) + val jobId = nextJobId.getAndIncrement() + eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, + listener, properties)) listener.awaitResult() // Will throw an exception if the job fails } /** + * Cancel a job that is running or waiting in the queue. + */ + def cancelJob(jobId: Int) { + logInfo("Asked to cancel job " + jobId) + eventQueue.put(JobCancelled(jobId)) + } + + def cancelJobGroup(groupId: String) { + logInfo("Asked to cancel job group " + groupId) + eventQueue.put(JobGroupCancelled(groupId)) + } + + /** + * Cancel all jobs that are running or waiting in the queue. + */ + def cancelAllJobs() { + eventQueue.put(AllJobsCancelled) + } + + /** * Process one event retrieved from the event queue. * Returns true if we should stop the event loop. */ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { event match { - case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) => - val jobId = nextJobId.getAndIncrement() - val finalStage = newStage(finalRDD, None, jobId, Some(callSite)) + case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => + val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite)) val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length + @@ -361,18 +389,43 @@ class DAGScheduler( submitStage(finalStage) } + case JobCancelled(jobId) => + // Cancel a job: find all the running stages that are linked to this job, and cancel them. + running.filter(_.jobId == jobId).foreach { stage => + taskSched.cancelTasks(stage.id) + } + + case JobGroupCancelled(groupId) => + // Cancel all jobs belonging to this job group. + // First finds all active jobs with this group id, and then kill stages for them. + val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + .map(_.jobId) + if (!jobIds.isEmpty) { + running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage => + taskSched.cancelTasks(stage.id) + } + } + + case AllJobsCancelled => + // Cancel all running jobs. + running.foreach { stage => + taskSched.cancelTasks(stage.id) + } + case ExecutorGained(execId, host) => handleExecutorGained(execId, host) case ExecutorLost(execId) => handleExecutorLost(execId) - case begin: BeginEvent => - listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo)) + case BeginEvent(task, taskInfo) => + listenerBus.post(SparkListenerTaskStart(task, taskInfo)) + + case GettingResultEvent(task, taskInfo) => + listenerBus.post(SparkListenerTaskGettingResult(task, taskInfo)) - case completion: CompletionEvent => - listenerBus.post(SparkListenerTaskEnd( - completion.task, completion.reason, completion.taskInfo, completion.taskMetrics)) + case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) => + listenerBus.post(SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics)) handleTaskCompletion(completion) case TaskSetFailed(taskSet, reason) => @@ -542,7 +595,7 @@ class DAGScheduler( // must be run listener before possible NotSerializableException // should be "StageSubmitted" first and then "JobEnded" - listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties)) + listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties)) if (tasks.size > 0) { // Preemptively serialize a task to make sure it can be serialized. We are catching this @@ -563,9 +616,7 @@ class DAGScheduler( logDebug("New pending tasks: " + myPending) taskSched.submitTasks( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) - if (!stage.submissionTime.isDefined) { - stage.submissionTime = Some(System.currentTimeMillis()) - } + stageToInfos(stage).submissionTime = Some(System.currentTimeMillis()) } else { logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) @@ -579,15 +630,20 @@ class DAGScheduler( */ private def handleTaskCompletion(event: CompletionEvent) { val task = event.task + + if (!stageIdToStage.contains(task.stageId)) { + // Skip all the actions if the stage has been cancelled. + return + } val stage = stageIdToStage(task.stageId) def markStageAsFinished(stage: Stage) = { - val serviceTime = stage.submissionTime match { + val serviceTime = stageToInfos(stage).submissionTime match { case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0) - case _ => "Unkown" + case _ => "Unknown" } logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime)) - stage.completionTime = Some(System.currentTimeMillis) + stageToInfos(stage).completionTime = Some(System.currentTimeMillis()) listenerBus.post(StageCompleted(stageToInfos(stage))) running -= stage } @@ -627,7 +683,7 @@ class DAGScheduler( if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) } else { - stage.addOutputLoc(smt.partition, status) + stage.addOutputLoc(smt.partitionId, status) } if (running.contains(stage) && pendingTasks(stage).isEmpty) { markStageAsFinished(stage) @@ -753,14 +809,14 @@ class DAGScheduler( /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set - * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. + * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. */ private def abortStage(failedStage: Stage, reason: String) { val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq - failedStage.completionTime = Some(System.currentTimeMillis()) + stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis()) for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) - val error = new SparkException("Job failed: " + reason) + val error = new SparkException("Job aborted: " + reason) job.listener.jobFailed(error) listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage)))) idToActiveJob -= resultStage.jobId diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 10ff1b4376..708d221d60 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -31,9 +31,10 @@ import org.apache.spark.executor.TaskMetrics * submitted) but there is a single "logic" thread that reads these events and takes decisions. * This greatly simplifies synchronization. */ -private[spark] sealed trait DAGSchedulerEvent +private[scheduler] sealed trait DAGSchedulerEvent -private[spark] case class JobSubmitted( +private[scheduler] case class JobSubmitted( + jobId: Int, finalRDD: RDD[_], func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], @@ -43,9 +44,19 @@ private[spark] case class JobSubmitted( properties: Properties = null) extends DAGSchedulerEvent -private[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent +private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent -private[spark] case class CompletionEvent( +private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent + +private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent + +private[scheduler] +case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent + +private[scheduler] +case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent + +private[scheduler] case class CompletionEvent( task: Task[_], reason: TaskEndReason, result: Any, @@ -54,10 +65,12 @@ private[spark] case class CompletionEvent( taskMetrics: TaskMetrics) extends DAGSchedulerEvent -private[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent +private[scheduler] +case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent -private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent +private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent -private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent +private[scheduler] +case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent -private[spark] case object StopDAGScheduler extends DAGSchedulerEvent +private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 151514896f..7b5c0e29ad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -40,7 +40,7 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: Spar }) metricRegistry.register(MetricRegistry.name("job", "allJobs"), new Gauge[Int] { - override def getValue: Int = dagScheduler.nextJobId.get() + override def getValue: Int = dagScheduler.numTotalJobs }) metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] { diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index 370ccd183c..1791ee660d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.deploy.SparkHadoopUtil import scala.collection.immutable.Set import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.hadoop.security.UserGroupInformation @@ -87,9 +88,8 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl // This method does not expect failures, since validate has already passed ... private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = { - val env = SparkEnv.get val conf = new JobConf(configuration) - env.hadoop.addCredentials(conf) + SparkHadoopUtil.get.addCredentials(conf) FileInputFormat.setInputPaths(conf, path) val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] = @@ -108,9 +108,8 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl // This method does not expect failures, since validate has already passed ... private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = { - val env = SparkEnv.get val jobConf = new JobConf(configuration) - env.hadoop.addCredentials(jobConf) + SparkHadoopUtil.get.addCredentials(jobConf) FileInputFormat.setInputPaths(jobConf, path) val instance: org.apache.hadoop.mapred.InputFormat[_, _] = diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 3628b1b078..60927831a1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -1,292 +1,384 @@ -/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler
-
-import java.io.PrintWriter
-import java.io.File
-import java.io.FileNotFoundException
-import java.text.SimpleDateFormat
-import java.util.{Date, Properties}
-import java.util.concurrent.LinkedBlockingQueue
-
-import scala.collection.mutable.{Map, HashMap, ListBuffer}
-import scala.io.Source
-
-import org.apache.spark._
-import org.apache.spark.rdd.RDD
-import org.apache.spark.executor.TaskMetrics
-
-// Used to record runtime information for each job, including RDD graph
-// tasks' start/stop shuffle information and information from outside
-
-class JobLogger(val logDirName: String) extends SparkListener with Logging {
- private val logDir =
- if (System.getenv("SPARK_LOG_DIR") != null)
- System.getenv("SPARK_LOG_DIR")
- else
- "/tmp/spark"
- private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
- private val stageIDToJobID = new HashMap[Int, Int]
- private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
- private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
- private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
-
- createLogDir()
- def this() = this(String.valueOf(System.currentTimeMillis()))
-
- def getLogDir = logDir
- def getJobIDtoPrintWriter = jobIDToPrintWriter
- def getStageIDToJobID = stageIDToJobID
- def getJobIDToStages = jobIDToStages
- def getEventQueue = eventQueue
-
- // Create a folder for log files, the folder's name is the creation time of the jobLogger
- protected def createLogDir() {
- val dir = new File(logDir + "/" + logDirName + "/")
- if (dir.exists()) {
- return
- }
- if (dir.mkdirs() == false) {
- logError("create log directory error:" + logDir + "/" + logDirName + "/")
- }
- }
-
- // Create a log file for one job, the file name is the jobID
- protected def createLogWriter(jobID: Int) {
- try{
- val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
- jobIDToPrintWriter += (jobID -> fileWriter)
- } catch {
- case e: FileNotFoundException => e.printStackTrace()
- }
- }
-
- // Close log file, and clean the stage relationship in stageIDToJobID
- protected def closeLogWriter(jobID: Int) =
- jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
- fileWriter.close()
- jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
- stageIDToJobID -= stage.id
- })
- jobIDToPrintWriter -= jobID
- jobIDToStages -= jobID
- }
-
- // Write log information to log file, withTime parameter controls whether to recored
- // time stamp for the information
- protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
- var writeInfo = info
- if (withTime) {
- val date = new Date(System.currentTimeMillis())
- writeInfo = DATE_FORMAT.format(date) + ": " +info
- }
- jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
- }
-
- protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
- stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
-
- protected def buildJobDep(jobID: Int, stage: Stage) {
- if (stage.jobId == jobID) {
- jobIDToStages.get(jobID) match {
- case Some(stageList) => stageList += stage
- case None => val stageList = new ListBuffer[Stage]
- stageList += stage
- jobIDToStages += (jobID -> stageList)
- }
- stageIDToJobID += (stage.id -> jobID)
- stage.parents.foreach(buildJobDep(jobID, _))
- }
- }
-
- protected def recordStageDep(jobID: Int) {
- def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
- var rddList = new ListBuffer[RDD[_]]
- rddList += rdd
- rdd.dependencies.foreach{ dep => dep match {
- case shufDep: ShuffleDependency[_,_] =>
- case _ => rddList ++= getRddsInStage(dep.rdd)
- }
- }
- rddList
- }
- jobIDToStages.get(jobID).foreach {_.foreach { stage =>
- var depRddDesc: String = ""
- getRddsInStage(stage.rdd).foreach { rdd =>
- depRddDesc += rdd.id + ","
- }
- var depStageDesc: String = ""
- stage.parents.foreach { stage =>
- depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
- }
- jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
- depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
- " STAGE_DEP=" + depStageDesc, false)
- }
- }
- }
-
- // Generate indents and convert to String
- protected def indentString(indent: Int) = {
- val sb = new StringBuilder()
- for (i <- 1 to indent) {
- sb.append(" ")
- }
- sb.toString()
- }
-
- protected def getRddName(rdd: RDD[_]) = {
- var rddName = rdd.getClass.getName
- if (rdd.name != null) {
- rddName = rdd.name
- }
- rddName
- }
-
- protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
- val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
- jobLogInfo(jobID, indentString(indent) + rddInfo, false)
- rdd.dependencies.foreach{ dep => dep match {
- case shufDep: ShuffleDependency[_,_] =>
- val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
- jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
- case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
- }
- }
- }
-
- protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
- var stageInfo: String = ""
- if (stage.isShuffleMap) {
- stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
- stage.shuffleDep.get.shuffleId
- }else{
- stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
- }
- if (stage.jobId == jobID) {
- jobLogInfo(jobID, indentString(indent) + stageInfo, false)
- recordRddInStageGraph(jobID, stage.rdd, indent)
- stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
- } else
- jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
- }
-
- // Record task metrics into job log files
- protected def recordTaskMetrics(stageID: Int, status: String,
- taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
- val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
- " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
- " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
- val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
- val readMetrics =
- taskMetrics.shuffleReadMetrics match {
- case Some(metrics) =>
- " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
- " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
- " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
- " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
- " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
- " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
- " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
- case None => ""
- }
- val writeMetrics =
- taskMetrics.shuffleWriteMetrics match {
- case Some(metrics) =>
- " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
- case None => ""
- }
- stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
- }
-
- override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
- stageLogInfo(
- stageSubmitted.stage.id,
- "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
- stageSubmitted.stage.id, stageSubmitted.taskSize))
- }
-
- override def onStageCompleted(stageCompleted: StageCompleted) {
- stageLogInfo(
- stageCompleted.stageInfo.stage.id,
- "STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id))
-
- }
-
- override def onTaskStart(taskStart: SparkListenerTaskStart) { }
-
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
- val task = taskEnd.task
- val taskInfo = taskEnd.taskInfo
- var taskStatus = ""
- task match {
- case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
- case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
- }
- taskEnd.reason match {
- case Success => taskStatus += " STATUS=SUCCESS"
- recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
- case Resubmitted =>
- taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
- " STAGE_ID=" + task.stageId
- stageLogInfo(task.stageId, taskStatus)
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
- taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
- task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
- mapId + " REDUCE_ID=" + reduceId
- stageLogInfo(task.stageId, taskStatus)
- case OtherFailure(message) =>
- taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
- " STAGE_ID=" + task.stageId + " INFO=" + message
- stageLogInfo(task.stageId, taskStatus)
- case _ =>
- }
- }
-
- override def onJobEnd(jobEnd: SparkListenerJobEnd) {
- val job = jobEnd.job
- var info = "JOB_ID=" + job.jobId
- jobEnd.jobResult match {
- case JobSucceeded => info += " STATUS=SUCCESS"
- case JobFailed(exception, _) =>
- info += " STATUS=FAILED REASON="
- exception.getMessage.split("\\s+").foreach(info += _ + "_")
- case _ =>
- }
- jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
- closeLogWriter(job.jobId)
- }
-
- protected def recordJobProperties(jobID: Int, properties: Properties) {
- if(properties != null) {
- val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
- jobLogInfo(jobID, description, false)
- }
- }
-
- override def onJobStart(jobStart: SparkListenerJobStart) {
- val job = jobStart.job
- val properties = jobStart.properties
- createLogWriter(job.jobId)
- recordJobProperties(job.jobId, properties)
- buildJobDep(job.jobId, job.finalStage)
- recordStageDep(job.jobId)
- recordStageDepGraph(job.jobId, job.finalStage)
- jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
- }
-}
+/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io.{IOException, File, FileNotFoundException, PrintWriter} +import java.text.SimpleDateFormat +import java.util.{Date, Properties} +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.mutable.{HashMap, HashSet, ListBuffer} + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.StorageLevel + +/** + * A logger class to record runtime information for jobs in Spark. This class outputs one log file + * for each Spark job, containing RDD graph, tasks start/stop, shuffle information. + * JobLogger is a subclass of SparkListener, use addSparkListener to add JobLogger to a SparkContext + * after the SparkContext is created. + * Note that each JobLogger only works for one SparkContext + * @param logDirName The base directory for the log files. + */ + +class JobLogger(val user: String, val logDirName: String) + extends SparkListener with Logging { + + def this() = this(System.getProperty("user.name", "<unknown>"), + String.valueOf(System.currentTimeMillis())) + + private val logDir = + if (System.getenv("SPARK_LOG_DIR") != null) + System.getenv("SPARK_LOG_DIR") + else + "/tmp/spark-%s".format(user) + + private val jobIDToPrintWriter = new HashMap[Int, PrintWriter] + private val stageIDToJobID = new HashMap[Int, Int] + private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]] + private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents] + + createLogDir() + + // The following 5 functions are used only in testing. + private[scheduler] def getLogDir = logDir + private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter + private[scheduler] def getStageIDToJobID = stageIDToJobID + private[scheduler] def getJobIDToStages = jobIDToStages + private[scheduler] def getEventQueue = eventQueue + + /** Create a folder for log files, the folder's name is the creation time of jobLogger */ + protected def createLogDir() { + val dir = new File(logDir + "/" + logDirName + "/") + if (dir.exists()) { + return + } + if (dir.mkdirs() == false) { + // JobLogger should throw a exception rather than continue to construct this object. + throw new IOException("create log directory error:" + logDir + "/" + logDirName + "/") + } + } + + /** + * Create a log file for one job + * @param jobID ID of the job + * @exception FileNotFoundException Fail to create log file + */ + protected def createLogWriter(jobID: Int) { + try { + val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID) + jobIDToPrintWriter += (jobID -> fileWriter) + } catch { + case e: FileNotFoundException => e.printStackTrace() + } + } + + /** + * Close log file, and clean the stage relationship in stageIDToJobID + * @param jobID ID of the job + */ + protected def closeLogWriter(jobID: Int) { + jobIDToPrintWriter.get(jobID).foreach { fileWriter => + fileWriter.close() + jobIDToStages.get(jobID).foreach(_.foreach{ stage => + stageIDToJobID -= stage.id + }) + jobIDToPrintWriter -= jobID + jobIDToStages -= jobID + } + } + + /** + * Write info into log file + * @param jobID ID of the job + * @param info Info to be recorded + * @param withTime Controls whether to record time stamp before the info, default is true + */ + protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) { + var writeInfo = info + if (withTime) { + val date = new Date(System.currentTimeMillis()) + writeInfo = DATE_FORMAT.format(date) + ": " +info + } + jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo)) + } + + /** + * Write info into log file + * @param stageID ID of the stage + * @param info Info to be recorded + * @param withTime Controls whether to record time stamp before the info, default is true + */ + protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) { + stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime)) + } + + /** + * Build stage dependency for a job + * @param jobID ID of the job + * @param stage Root stage of the job + */ + protected def buildJobDep(jobID: Int, stage: Stage) { + if (stage.jobId == jobID) { + jobIDToStages.get(jobID) match { + case Some(stageList) => stageList += stage + case None => val stageList = new ListBuffer[Stage] + stageList += stage + jobIDToStages += (jobID -> stageList) + } + stageIDToJobID += (stage.id -> jobID) + stage.parents.foreach(buildJobDep(jobID, _)) + } + } + + /** + * Record stage dependency and RDD dependency for a stage + * @param jobID Job ID of the stage + */ + protected def recordStageDep(jobID: Int) { + def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = { + var rddList = new ListBuffer[RDD[_]] + rddList += rdd + rdd.dependencies.foreach { + case shufDep: ShuffleDependency[_, _] => + case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd) + } + rddList + } + jobIDToStages.get(jobID).foreach {_.foreach { stage => + var depRddDesc: String = "" + getRddsInStage(stage.rdd).foreach { rdd => + depRddDesc += rdd.id + "," + } + var depStageDesc: String = "" + stage.parents.foreach { stage => + depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")" + } + jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" + + depRddDesc.substring(0, depRddDesc.length - 1) + ")" + + " STAGE_DEP=" + depStageDesc, false) + } + } + } + + /** + * Generate indents and convert to String + * @param indent Number of indents + * @return string of indents + */ + protected def indentString(indent: Int): String = { + val sb = new StringBuilder() + for (i <- 1 to indent) { + sb.append(" ") + } + sb.toString() + } + + /** + * Get RDD's name + * @param rdd Input RDD + * @return String of RDD's name + */ + protected def getRddName(rdd: RDD[_]): String = { + var rddName = rdd.getClass.getSimpleName + if (rdd.name != null) { + rddName = rdd.name + } + rddName + } + + /** + * Record RDD dependency graph in a stage + * @param jobID Job ID of the stage + * @param rdd Root RDD of the stage + * @param indent Indent number before info + */ + protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) { + val rddInfo = + if (rdd.getStorageLevel != StorageLevel.NONE) { + "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " CACHED" + " " + + rdd.origin + " " + rdd.generator + } else { + "RDD_ID=" + rdd.id + " " + getRddName(rdd) + " NONE" + " " + + rdd.origin + " " + rdd.generator + } + jobLogInfo(jobID, indentString(indent) + rddInfo, false) + rdd.dependencies.foreach { + case shufDep: ShuffleDependency[_, _] => + val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId + jobLogInfo(jobID, indentString(indent + 1) + depInfo, false) + case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1) + } + } + + /** + * Record stage dependency graph of a job + * @param jobID Job ID of the stage + * @param stage Root stage of the job + * @param indent Indent number before info, default is 0 + */ + protected def recordStageDepGraph(jobID: Int, stage: Stage, idSet: HashSet[Int], indent: Int = 0) { + val stageInfo = if (stage.isShuffleMap) { + "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId + } else { + "STAGE_ID=" + stage.id + " RESULT_STAGE" + } + if (stage.jobId == jobID) { + jobLogInfo(jobID, indentString(indent) + stageInfo, false) + if (!idSet.contains(stage.id)) { + idSet += stage.id + recordRddInStageGraph(jobID, stage.rdd, indent) + stage.parents.foreach(recordStageDepGraph(jobID, _, idSet, indent + 2)) + } + } else { + jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false) + } + } + + /** + * Record task metrics into job log files, including execution info and shuffle metrics + * @param stageID Stage ID of the task + * @param status Status info of the task + * @param taskInfo Task description info + * @param taskMetrics Task running metrics + */ + protected def recordTaskMetrics(stageID: Int, status: String, + taskInfo: TaskInfo, taskMetrics: TaskMetrics) { + val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID + + " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime + + " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname + val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime + val readMetrics = taskMetrics.shuffleReadMetrics match { + case Some(metrics) => + " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime + + " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched + + " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched + + " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched + + " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime + + " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime + + " REMOTE_BYTES_READ=" + metrics.remoteBytesRead + case None => "" + } + val writeMetrics = taskMetrics.shuffleWriteMetrics match { + case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten + case None => "" + } + stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics) + } + + /** + * When stage is submitted, record stage submit info + * @param stageSubmitted Stage submitted event + */ + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { + stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format( + stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks)) + } + + /** + * When stage is completed, record stage completion status + * @param stageCompleted Stage completed event + */ + override def onStageCompleted(stageCompleted: StageCompleted) { + stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format( + stageCompleted.stage.stageId)) + } + + override def onTaskStart(taskStart: SparkListenerTaskStart) { } + + /** + * When task ends, record task completion status and metrics + * @param taskEnd Task end event + */ + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val task = taskEnd.task + val taskInfo = taskEnd.taskInfo + var taskStatus = "" + task match { + case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK" + case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK" + } + taskEnd.reason match { + case Success => taskStatus += " STATUS=SUCCESS" + recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics) + case Resubmitted => + taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId + + " STAGE_ID=" + task.stageId + stageLogInfo(task.stageId, taskStatus) + case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => + taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" + + task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + + mapId + " REDUCE_ID=" + reduceId + stageLogInfo(task.stageId, taskStatus) + case OtherFailure(message) => + taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId + + " STAGE_ID=" + task.stageId + " INFO=" + message + stageLogInfo(task.stageId, taskStatus) + case _ => + } + } + + /** + * When job ends, recording job completion status and close log file + * @param jobEnd Job end event + */ + override def onJobEnd(jobEnd: SparkListenerJobEnd) { + val job = jobEnd.job + var info = "JOB_ID=" + job.jobId + jobEnd.jobResult match { + case JobSucceeded => info += " STATUS=SUCCESS" + case JobFailed(exception, _) => + info += " STATUS=FAILED REASON=" + exception.getMessage.split("\\s+").foreach(info += _ + "_") + case _ => + } + jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase) + closeLogWriter(job.jobId) + } + + /** + * Record job properties into job log file + * @param jobID ID of the job + * @param properties Properties of the job + */ + protected def recordJobProperties(jobID: Int, properties: Properties) { + if(properties != null) { + val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "") + jobLogInfo(jobID, description, false) + } + } + + /** + * When job starts, record job property and stage graph + * @param jobStart Job start event + */ + override def onJobStart(jobStart: SparkListenerJobStart) { + val job = jobStart.job + val properties = jobStart.properties + createLogWriter(job.jobId) + recordJobProperties(job.jobId, properties) + buildJobDep(job.jobId, job.finalStage) + recordStageDep(job.jobId) + recordStageDepGraph(job.jobId, job.finalStage, new HashSet[Int]) + jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED") + } +} + diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 200d881799..58f238d8cf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -17,48 +17,58 @@ package org.apache.spark.scheduler -import scala.collection.mutable.ArrayBuffer - /** * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their * results to the given handler function. */ -private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit) +private[spark] class JobWaiter[T]( + dagScheduler: DAGScheduler, + jobId: Int, + totalTasks: Int, + resultHandler: (Int, T) => Unit) extends JobListener { private var finishedTasks = 0 - private var jobFinished = false // Is the job as a whole finished (succeeded or failed)? - private var jobResult: JobResult = null // If the job is finished, this will be its result + // Is the job as a whole finished (succeeded or failed)? + private var _jobFinished = totalTasks == 0 - override def taskSucceeded(index: Int, result: Any) { - synchronized { - if (jobFinished) { - throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") - } - resultHandler(index, result.asInstanceOf[T]) - finishedTasks += 1 - if (finishedTasks == totalTasks) { - jobFinished = true - jobResult = JobSucceeded - this.notifyAll() - } - } + def jobFinished = _jobFinished + + // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero + // partition RDDs), we set the jobResult directly to JobSucceeded. + private var jobResult: JobResult = if (jobFinished) JobSucceeded else null + + /** + * Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled + * asynchronously. After the low level scheduler cancels all the tasks belonging to this job, it + * will fail this job with a SparkException. + */ + def cancel() { + dagScheduler.cancelJob(jobId) } - override def jobFailed(exception: Exception) { - synchronized { - if (jobFinished) { - throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter") - } - jobFinished = true - jobResult = JobFailed(exception, None) + override def taskSucceeded(index: Int, result: Any): Unit = synchronized { + if (_jobFinished) { + throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") + } + resultHandler(index, result.asInstanceOf[T]) + finishedTasks += 1 + if (finishedTasks == totalTasks) { + _jobFinished = true + jobResult = JobSucceeded this.notifyAll() } } + override def jobFailed(exception: Exception): Unit = synchronized { + _jobFinished = true + jobResult = JobFailed(exception, None) + this.notifyAll() + } + def awaitResult(): JobResult = synchronized { - while (!jobFinished) { + while (!_jobFinished) { this.wait() } return jobResult diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 9eb8d48501..596f9adde9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -43,7 +43,10 @@ private[spark] class Pool( var runningTasks = 0 var priority = 0 - var stageId = 0 + + // A pool's stage id is used to break the tie in scheduling. + var stageId = -1 + var name = poolName var parent: Pool = null diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 07e8317e3a..310ec62ca8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -23,7 +23,7 @@ import java.util.zip.{GZIPInputStream, GZIPOutputStream} import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDDCheckpointData -import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap} +import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap} private[spark] object ResultTask { @@ -32,23 +32,23 @@ private[spark] object ResultTask { // expensive on the master node if it needs to launch thousands of tasks. val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues) + val metadataCleaner = new MetadataCleaner(MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues) def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { synchronized { val old = serializedInfoCache.get(stageId).orNull if (old != null) { - return old + old } else { val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance + val ser = SparkEnv.get.closureSerializer.newInstance() val objOut = ser.serializeStream(new GZIPOutputStream(out)) objOut.writeObject(rdd) objOut.writeObject(func) objOut.close() val bytes = out.toByteArray serializedInfoCache.put(stageId, bytes) - return bytes + bytes } } } @@ -56,11 +56,11 @@ private[spark] object ResultTask { def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = { val loader = Thread.currentThread.getContextClassLoader val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance + val ser = SparkEnv.get.closureSerializer.newInstance() val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] - return (rdd, func) + (rdd, func) } def clearCache() { @@ -71,29 +71,37 @@ private[spark] object ResultTask { } +/** + * A task that sends back the output to the driver application. + * + * See [[org.apache.spark.scheduler.Task]] for more information. + * + * @param stageId id of the stage this task belongs to + * @param rdd input to func + * @param func a function to apply on a partition of the RDD + * @param _partitionId index of the number in the RDD + * @param locs preferred task execution locations for locality scheduling + * @param outputId index of the task in this job (a job can launch tasks on only a subset of the + * input RDD's partitions). + */ private[spark] class ResultTask[T, U]( stageId: Int, var rdd: RDD[T], var func: (TaskContext, Iterator[T]) => U, - var partition: Int, + _partitionId: Int, @transient locs: Seq[TaskLocation], var outputId: Int) - extends Task[U](stageId) with Externalizable { + extends Task[U](stageId, _partitionId) with Externalizable { def this() = this(0, null, null, 0, null, 0) - var split = if (rdd == null) { - null - } else { - rdd.partitions(partition) - } + var split = if (rdd == null) null else rdd.partitions(partitionId) @transient private val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } - override def run(attemptId: Long): U = { - val context = new TaskContext(stageId, partition, attemptId, runningLocally = false) + override def runTask(context: TaskContext): U = { metrics = Some(context.taskMetrics) try { func(context, rdd.iterator(split, context)) @@ -104,17 +112,17 @@ private[spark] class ResultTask[T, U]( override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ResultTask(" + stageId + ", " + partition + ")" + override def toString = "ResultTask(" + stageId + ", " + partitionId + ")" override def writeExternal(out: ObjectOutput) { RDDCheckpointData.synchronized { - split = rdd.partitions(partition) + split = rdd.partitions(partitionId) out.writeInt(stageId) val bytes = ResultTask.serializeInfo( stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) out.writeInt(bytes.length) out.write(bytes) - out.writeInt(partition) + out.writeInt(partitionId) out.writeInt(outputId) out.writeLong(epoch) out.writeObject(split) @@ -129,7 +137,7 @@ private[spark] class ResultTask[T, U]( val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) rdd = rdd_.asInstanceOf[RDD[T]] func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] - partition = in.readInt() + partitionId = in.readInt() outputId = in.readInt() epoch = in.readLong() split = in.readObject().asInstanceOf[Partition] diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index 4e25086ec9..356fe56bf3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -30,7 +30,10 @@ import scala.xml.XML * addTaskSetManager: build the leaf nodes(TaskSetManagers) */ private[spark] trait SchedulableBuilder { + def rootPool: Pool + def buildPools() + def addTaskSetManager(manager: Schedulable, properties: Properties) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index d23df0dd2b..1dc71a0428 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.HashMap import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.storage._ -import org.apache.spark.util.{TimeStampedHashMap, MetadataCleaner} +import org.apache.spark.util.{MetadataCleanerType, TimeStampedHashMap, MetadataCleaner} import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDDCheckpointData @@ -37,7 +37,7 @@ private[spark] object ShuffleMapTask { // expensive on the master node if it needs to launch thousands of tasks. val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - val metadataCleaner = new MetadataCleaner("ShuffleMapTask", serializedInfoCache.clearOldValues) + val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_MAP_TASK, serializedInfoCache.clearOldValues) def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { @@ -53,7 +53,7 @@ private[spark] object ShuffleMapTask { objOut.close() val bytes = out.toByteArray serializedInfoCache.put(stageId, bytes) - return bytes + bytes } } } @@ -66,7 +66,7 @@ private[spark] object ShuffleMapTask { val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] - return (rdd, dep) + (rdd, dep) } } @@ -75,7 +75,7 @@ private[spark] object ShuffleMapTask { val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) val objIn = new ObjectInputStream(in) val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap - return (HashMap(set.toSeq: _*)) + HashMap(set.toSeq: _*) } def clearCache() { @@ -85,13 +85,25 @@ private[spark] object ShuffleMapTask { } } +/** + * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner + * specified in the ShuffleDependency). + * + * See [[org.apache.spark.scheduler.Task]] for more information. + * + * @param stageId id of the stage this task belongs to + * @param rdd the final RDD in this stage + * @param dep the ShuffleDependency + * @param _partitionId index of the number in the RDD + * @param locs preferred task execution locations for locality scheduling + */ private[spark] class ShuffleMapTask( stageId: Int, var rdd: RDD[_], var dep: ShuffleDependency[_,_], - var partition: Int, + _partitionId: Int, @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId) + extends Task[MapStatus](stageId, _partitionId) with Externalizable with Logging { @@ -101,16 +113,16 @@ private[spark] class ShuffleMapTask( if (locs == null) Nil else locs.toSet.toSeq } - var split = if (rdd == null) null else rdd.partitions(partition) + var split = if (rdd == null) null else rdd.partitions(partitionId) override def writeExternal(out: ObjectOutput) { RDDCheckpointData.synchronized { - split = rdd.partitions(partition) + split = rdd.partitions(partitionId) out.writeInt(stageId) val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) out.writeInt(bytes.length) out.write(bytes) - out.writeInt(partition) + out.writeInt(partitionId) out.writeLong(epoch) out.writeObject(split) } @@ -124,68 +136,70 @@ private[spark] class ShuffleMapTask( val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) rdd = rdd_ dep = dep_ - partition = in.readInt() + partitionId = in.readInt() epoch = in.readLong() split = in.readObject().asInstanceOf[Partition] } - override def run(attemptId: Long): MapStatus = { + override def runTask(context: TaskContext): MapStatus = { val numOutputSplits = dep.partitioner.numPartitions - - val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false) - metrics = Some(taskContext.taskMetrics) + metrics = Some(context.taskMetrics) val blockManager = SparkEnv.get.blockManager - var shuffle: ShuffleBlocks = null - var buckets: ShuffleWriterGroup = null + val shuffleBlockManager = blockManager.shuffleBlockManager + var shuffle: ShuffleWriterGroup = null + var success = false try { // Obtain all the block writers for shuffle blocks. val ser = SparkEnv.get.serializerManager.get(dep.serializerClass) - shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser) - buckets = shuffle.acquireWriters(partition) + shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser) // Write the map output to its associated buckets. - for (elem <- rdd.iterator(split, taskContext)) { + for (elem <- rdd.iterator(split, context)) { val pair = elem.asInstanceOf[Product2[Any, Any]] val bucketId = dep.partitioner.getPartition(pair._1) - buckets.writers(bucketId).write(pair) + shuffle.writers(bucketId).write(pair) } // Commit the writes. Get the size of each bucket block (total block size). var totalBytes = 0L - val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter => + var totalTime = 0L + val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter => writer.commit() - writer.close() - val size = writer.size() + val size = writer.fileSegment().length totalBytes += size + totalTime += writer.timeWriting() MapOutputTracker.compressSize(size) } // Update shuffle metrics. val shuffleMetrics = new ShuffleWriteMetrics shuffleMetrics.shuffleBytesWritten = totalBytes + shuffleMetrics.shuffleWriteTime = totalTime metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) - return new MapStatus(blockManager.blockManagerId, compressedSizes) + success = true + new MapStatus(blockManager.blockManagerId, compressedSizes) } catch { case e: Exception => // If there is an exception from running the task, revert the partial writes // and throw the exception upstream to Spark. - if (buckets != null) { - buckets.writers.foreach(_.revertPartialWrites()) + if (shuffle != null) { + shuffle.writers.foreach(_.revertPartialWrites()) } throw e } finally { // Release the writers back to the shuffle block manager. - if (shuffle != null && buckets != null) { - shuffle.releaseWriters(buckets) + if (shuffle != null && shuffle.writers != null) { + shuffle.writers.foreach(_.close()) + shuffle.releaseWriters(success) } // Execute the callbacks on task completion. - taskContext.executeOnCompleteCallbacks() + context.executeOnCompleteCallbacks() } } override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition) + override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partitionId) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 62b521ad45..a35081f7b1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -24,13 +24,16 @@ import org.apache.spark.executor.TaskMetrics sealed trait SparkListenerEvents -case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int, properties: Properties) +case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties) extends SparkListenerEvents -case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents +case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents +case class SparkListenerTaskGettingResult( + task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents + case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo, taskMetrics: TaskMetrics) extends SparkListenerEvents @@ -54,7 +57,13 @@ trait SparkListener { /** * Called when a task starts */ - def onTaskStart(taskEnd: SparkListenerTaskStart) { } + def onTaskStart(taskStart: SparkListenerTaskStart) { } + + /** + * Called when a task begins remotely fetching its result (will not be called for tasks that do + * not need to fetch the result remotely). + */ + def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { } /** * Called when a task ends @@ -80,7 +89,7 @@ class StatsReportListener extends SparkListener with Logging { override def onStageCompleted(stageCompleted: StageCompleted) { import org.apache.spark.scheduler.StatsReportListener._ implicit val sc = stageCompleted - this.logInfo("Finished stage: " + stageCompleted.stageInfo) + this.logInfo("Finished stage: " + stageCompleted.stage) showMillisDistribution("task runtime:", (info, _) => Some(info.duration)) //shuffle write @@ -93,7 +102,7 @@ class StatsReportListener extends SparkListener with Logging { //runtime breakdown - val runtimePcts = stageCompleted.stageInfo.taskInfos.map{ + val runtimePcts = stageCompleted.stage.taskInfos.map{ case (info, metrics) => RuntimePercentage(info.duration, metrics) } showDistribution("executor (non-fetch) time pct: ", Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%") @@ -111,7 +120,7 @@ object StatsReportListener extends Logging { val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%" def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = { - Distribution(stage.stageInfo.taskInfos.flatMap{ + Distribution(stage.stage.taskInfos.flatMap { case ((info,metric)) => getMetric(info, metric)}) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 4d3e4a17ba..d5824e7954 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -49,6 +49,8 @@ private[spark] class SparkListenerBus() extends Logging { sparkListeners.foreach(_.onJobEnd(jobEnd))
case taskStart: SparkListenerTaskStart =>
sparkListeners.foreach(_.onTaskStart(taskStart))
+ case taskGettingResult: SparkListenerTaskGettingResult =>
+ sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
case taskEnd: SparkListenerTaskEnd =>
sparkListeners.foreach(_.onTaskEnd(taskEnd))
case _ =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index aa293dc6b3..7cb3fe46e5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManagerId private[spark] class Stage( val id: Int, val rdd: RDD[_], + val numTasks: Int, val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage val parents: List[Stage], val jobId: Int, @@ -49,11 +50,6 @@ private[spark] class Stage( val numPartitions = rdd.partitions.size val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) var numAvailableOutputs = 0 - - /** When first task was submitted to scheduler. */ - var submissionTime: Option[Long] = None - var completionTime: Option[Long] = None - private var nextAttemptId = 0 def isAvailable: Boolean = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index b6f11969e5..93599dfdc8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -21,9 +21,16 @@ import scala.collection._ import org.apache.spark.executor.TaskMetrics -case class StageInfo( - val stage: Stage, +class StageInfo( + stage: Stage, val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]() ) { - override def toString = stage.rdd.toString + val stageId = stage.id + /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ + var submissionTime: Option[Long] = None + var completionTime: Option[Long] = None + val rddName = stage.rdd.name + val name = stage.name + val numPartitions = stage.numPartitions + val numTasks = stage.numTasks } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 598d91752a..69b42e86ea 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -17,25 +17,74 @@ package org.apache.spark.scheduler -import org.apache.spark.serializer.SerializerInstance import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import org.apache.spark.util.ByteBufferInputStream + import scala.collection.mutable.HashMap + +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream + +import org.apache.spark.TaskContext import org.apache.spark.executor.TaskMetrics +import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.util.ByteBufferInputStream + /** - * A task to execute on a worker node. + * A unit of execution. We have two kinds of Task's in Spark: + * - [[org.apache.spark.scheduler.ShuffleMapTask]] + * - [[org.apache.spark.scheduler.ResultTask]] + * + * A Spark job consists of one or more stages. The very last stage in a job consists of multiple + * ResultTask's, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task + * and sends the task output back to the driver application. A ShuffleMapTask executes the task + * and divides the task output to multiple buckets (based on the task's partitioner). + * + * @param stageId id of the stage this task belongs to + * @param partitionId index of the number in the RDD */ -private[spark] abstract class Task[T](val stageId: Int) extends Serializable { - def run(attemptId: Long): T +private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { + + final def run(attemptId: Long): T = { + context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) + if (_killed) { + kill() + } + runTask(context) + } + + def runTask(context: TaskContext): T + def preferredLocations: Seq[TaskLocation] = Nil - var epoch: Long = -1 // Map output tracker epoch. Will be set by TaskScheduler. + // Map output tracker epoch. Will be set by TaskScheduler. + var epoch: Long = -1 var metrics: Option[TaskMetrics] = None + // Task context, to be initialized in run(). + @transient protected var context: TaskContext = _ + + // A flag to indicate whether the task is killed. This is used in case context is not yet + // initialized when kill() is invoked. + @volatile @transient private var _killed = false + + /** + * Whether the task has been killed. + */ + def killed: Boolean = _killed + + /** + * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark + * code and user code to properly handle the flag. This function should be idempotent so it can + * be called multiple times. + */ + def kill() { + _killed = true + if (context != null) { + context.interrupted = true + } + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 7c2a422aff..4bae26f3a6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -31,9 +31,25 @@ class TaskInfo( val host: String, val taskLocality: TaskLocality.TaskLocality) { + /** + * The time when the task started remotely getting the result. Will not be set if the + * task result was sent immediately when the task finished (as opposed to sending an + * IndirectTaskResult and later fetching the result from the block manager). + */ + var gettingResultTime: Long = 0 + + /** + * The time when the task has completed successfully (including the time to remotely fetch + * results, if necessary). + */ var finishTime: Long = 0 + var failed = false + def markGettingResult(time: Long = System.currentTimeMillis) { + gettingResultTime = time + } + def markSuccessful(time: Long = System.currentTimeMillis) { finishTime = time } @@ -43,6 +59,8 @@ class TaskInfo( failed = true } + def gettingResult: Boolean = gettingResultTime != 0 + def finished: Boolean = finishTime != 0 def successful: Boolean = finished && !failed @@ -52,6 +70,8 @@ class TaskInfo( def status: String = { if (running) "RUNNING" + else if (gettingResult) + "GET RESULT" else if (failed) "FAILED" else if (successful) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index db3954a9d3..7e468d0d67 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -24,13 +24,14 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.{SparkEnv} import java.nio.ByteBuffer import org.apache.spark.util.Utils +import org.apache.spark.storage.BlockId // Task result. Also contains updates to accumulator variables. private[spark] sealed trait TaskResult[T] /** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */ private[spark] -case class IndirectTaskResult[T](val blockId: String) extends TaskResult[T] with Serializable +case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable /** A TaskResult that contains the task's return value and accumulator updates. */ private[spark] diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 7c2a9f03d7..10e0478108 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -24,8 +24,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode * Each TaskScheduler schedulers task for a single SparkContext. * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage, * and are responsible for sending the tasks to the cluster, running them, retrying if there - * are failures, and mitigating stragglers. They return events to the DAGScheduler through - * the TaskSchedulerListener interface. + * are failures, and mitigating stragglers. They return events to the DAGScheduler. */ private[spark] trait TaskScheduler { @@ -45,8 +44,11 @@ private[spark] trait TaskScheduler { // Submit a sequence of tasks to run. def submitTasks(taskSet: TaskSet): Unit - // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called. - def setListener(listener: TaskSchedulerListener): Unit + // Cancel a stage. + def cancelTasks(stageId: Int) + + // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. + def setDAGScheduler(dagScheduler: DAGScheduler): Unit // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. def defaultParallelism(): Int diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala deleted file mode 100644 index 593fa9fb93..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler - -import scala.collection.mutable.Map - -import org.apache.spark.TaskEndReason -import org.apache.spark.executor.TaskMetrics - -/** - * Interface for getting events back from the TaskScheduler. - */ -private[spark] trait TaskSchedulerListener { - // A task has started. - def taskStarted(task: Task[_], taskInfo: TaskInfo) - - // A task has finished or failed. - def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit - - // A node was added to the cluster. - def executorGained(execId: String, host: String): Unit - - // A node was lost from the cluster. - def executorLost(execId: String): Unit - - // The TaskScheduler wants to abort an entire task set. - def taskSetFailed(taskSet: TaskSet, reason: String): Unit -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index c3ad325156..03bf760837 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -31,5 +31,9 @@ private[spark] class TaskSet( val properties: Properties) { val id: String = stageId + "." + attempt + def kill() { + tasks.foreach(_.kill()) + } + override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index 1a844b7e7e..85033958ef 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -17,7 +17,6 @@ package org.apache.spark.scheduler.cluster -import java.lang.{Boolean => JBoolean} import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong import java.util.{TimerTask, Timer} @@ -79,14 +78,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) private val executorIdToHost = new HashMap[String, String] - // JAR server, if any JARs were added by the user to the SparkContext - var jarServer: HttpServer = null - - // URIs of JARs to pass to executor - var jarUris: String = "" - // Listener object to pass upcalls into - var listener: TaskSchedulerListener = null + var dagScheduler: DAGScheduler = null var backend: SchedulerBackend = null @@ -101,8 +94,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // This is a var so that we can reset it for testing purposes. private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) - override def setListener(listener: TaskSchedulerListener) { - this.listener = listener + override def setDAGScheduler(dagScheduler: DAGScheduler) { + this.dagScheduler = dagScheduler } def initialize(context: SchedulerBackend) { @@ -171,8 +164,31 @@ private[spark] class ClusterScheduler(val sc: SparkContext) backend.reviveOffers() } - def taskSetFinished(manager: TaskSetManager) { - this.synchronized { + override def cancelTasks(stageId: Int): Unit = synchronized { + logInfo("Cancelling stage " + stageId) + activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the stage. + val taskIds = taskSetTaskIds(tsm.taskSet.id) + if (taskIds.size > 0) { + taskIds.foreach { tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId) + } + } + tsm.error("Stage %d was cancelled".format(stageId)) + } + } + + def taskSetFinished(manager: TaskSetManager): Unit = synchronized { + // Check to see if the given task set has been removed. This is possible in the case of + // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has + // more than one running tasks). + if (activeTaskSets.contains(manager.taskSet.id)) { activeTaskSets -= manager.taskSet.id manager.parent.removeSchedulable(manager) logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) @@ -281,7 +297,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } // Update the DAGScheduler without holding a lock on this, since that can deadlock if (failedExecutor != None) { - listener.executorLost(failedExecutor.get) + dagScheduler.executorLost(failedExecutor.get) backend.reviveOffers() } if (taskFailed) { @@ -290,6 +306,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } } + def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) { + taskSetManager.handleTaskGettingResult(tid) + } + def handleSuccessfulTask( taskSetManager: ClusterTaskSetManager, tid: Long, @@ -334,9 +354,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext) if (backend != null) { backend.stop() } - if (jarServer != null) { - jarServer.stop() - } if (taskResultGetter != null) { taskResultGetter.stop() } @@ -384,9 +401,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext) logError("Lost an executor " + executorId + " (already removed): " + reason) } } - // Call listener.executorLost without holding the lock on this to prevent deadlock + // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock if (failedExecutor != None) { - listener.executorLost(failedExecutor.get) + dagScheduler.executorLost(failedExecutor.get) backend.reviveOffers() } } @@ -405,7 +422,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } def executorGained(execId: String, host: String) { - listener.executorGained(execId, host) + dagScheduler.executorGained(execId, host) } def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala index 194ab55102..ee47aaffca 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -17,18 +17,16 @@ package org.apache.spark.scheduler.cluster -import java.nio.ByteBuffer -import java.util.{Arrays, NoSuchElementException} +import java.util.Arrays import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet import scala.math.max import scala.math.min -import scala.Some import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv, - SparkException, Success, TaskEndReason, TaskResultLost, TaskState} + Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler._ import org.apache.spark.util.{SystemClock, Clock} @@ -417,11 +415,17 @@ private[spark] class ClusterTaskSetManager( } private def taskStarted(task: Task[_], info: TaskInfo) { - sched.listener.taskStarted(task, info) + sched.dagScheduler.taskStarted(task, info) + } + + def handleTaskGettingResult(tid: Long) = { + val info = taskInfos(tid) + info.markGettingResult() + sched.dagScheduler.taskGettingResult(tasks(info.index), info) } /** - * Marks the task as successful and notifies the listener that a task has ended. + * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = { val info = taskInfos(tid) @@ -431,7 +435,7 @@ private[spark] class ClusterTaskSetManager( if (!successful(index)) { logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format( tid, info.duration, info.host, tasksSuccessful, numTasks)) - sched.listener.taskEnded( + sched.dagScheduler.taskEnded( tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) // Mark successful and stop if all the tasks have succeeded. @@ -447,7 +451,8 @@ private[spark] class ClusterTaskSetManager( } /** - * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the listener. + * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the + * DAG Scheduler. */ def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) { val info = taskInfos(tid) @@ -458,54 +463,57 @@ private[spark] class ClusterTaskSetManager( val index = info.index info.markFailed() if (!successful(index)) { - logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) + logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) copiesRunning(index) -= 1 // Check if the problem is a map output fetch failure. In that case, this // task will never succeed on any node, so tell the scheduler about it. reason.foreach { - _ match { - case fetchFailed: FetchFailed => - logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) - sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) - successful(index) = true - tasksSuccessful += 1 - sched.taskSetFinished(this) - removeAllRunningTasks() - return - - case ef: ExceptionFailure => - sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) - val key = ef.description - val now = clock.getTime() - val (printFull, dupCount) = { - if (recentExceptions.contains(key)) { - val (dupCount, printTime) = recentExceptions(key) - if (now - printTime > EXCEPTION_PRINT_INTERVAL) { - recentExceptions(key) = (0, now) - (true, 0) - } else { - recentExceptions(key) = (dupCount + 1, printTime) - (false, dupCount + 1) - } - } else { + case fetchFailed: FetchFailed => + logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress) + sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null) + successful(index) = true + tasksSuccessful += 1 + sched.taskSetFinished(this) + removeAllRunningTasks() + return + + case TaskKilled => + logWarning("Task %d was killed.".format(tid)) + sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null) + return + + case ef: ExceptionFailure => + sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) + val key = ef.description + val now = clock.getTime() + val (printFull, dupCount) = { + if (recentExceptions.contains(key)) { + val (dupCount, printTime) = recentExceptions(key) + if (now - printTime > EXCEPTION_PRINT_INTERVAL) { recentExceptions(key) = (0, now) (true, 0) + } else { + recentExceptions(key) = (dupCount + 1, printTime) + (false, dupCount + 1) } - } - if (printFull) { - val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s\n%s".format( - ef.className, ef.description, locs.mkString("\n"))) } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + recentExceptions(key) = (0, now) + (true, 0) } + } + if (printFull) { + val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logWarning("Loss was due to %s\n%s\n%s".format( + ef.className, ef.description, locs.mkString("\n"))) + } else { + logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + } - case TaskResultLost => - logInfo("Lost result for TID %s on host %s".format(tid, info.host)) - sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null) + case TaskResultLost => + logWarning("Lost result for TID %s on host %s".format(tid, info.host)) + sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null) - case _ => {} - } + case _ => {} } // On non-fetch failures, re-enqueue the task as pending for a max number of retries addPendingTask(index) @@ -532,7 +540,7 @@ private[spark] class ClusterTaskSetManager( failed = true causeOfFailure = message // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.listener.taskSetFailed(taskSet, message) + sched.dagScheduler.taskSetFailed(taskSet, message) removeAllRunningTasks() sched.taskSetFinished(this) } @@ -605,7 +613,7 @@ private[spark] class ClusterTaskSetManager( addPendingTask(index) // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our // stage finishes when a total of tasks.size tasks finish. - sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null) + sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null) } } } @@ -630,11 +638,11 @@ private[spark] class ClusterTaskSetManager( var foundTasks = false val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) - if (tasksSuccessful >= minFinishedForSpeculation) { + if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { val time = clock.getTime() val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray Arrays.sort(durations) - val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1)) + val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1)) val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) // TODO: Threshold should also look at standard deviation of task durations and have a lower // bound based on that. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index c0b836bf1a..53316dae2a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -24,26 +24,28 @@ import org.apache.spark.scheduler.TaskDescription import org.apache.spark.util.{Utils, SerializableBuffer} -private[spark] sealed trait StandaloneClusterMessage extends Serializable +private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable -private[spark] object StandaloneClusterMessages { +private[spark] object CoarseGrainedClusterMessages { // Driver to executors - case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage + case class LaunchTask(task: TaskDescription) extends CoarseGrainedClusterMessage + + case class KillTask(taskId: Long, executor: String) extends CoarseGrainedClusterMessage case class RegisteredExecutor(sparkProperties: Seq[(String, String)]) - extends StandaloneClusterMessage + extends CoarseGrainedClusterMessage - case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage + case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage // Executors to driver case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) - extends StandaloneClusterMessage { + extends CoarseGrainedClusterMessage { Utils.checkHostPort(hostPort, "Expected host port") } case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, - data: SerializableBuffer) extends StandaloneClusterMessage + data: SerializableBuffer) extends CoarseGrainedClusterMessage object StatusUpdate { /** Alternate factory method that takes a ByteBuffer directly for the data field */ @@ -54,10 +56,14 @@ private[spark] object StandaloneClusterMessages { } // Internal messages in driver - case object ReviveOffers extends StandaloneClusterMessage + case object ReviveOffers extends CoarseGrainedClusterMessage + + case object StopDriver extends CoarseGrainedClusterMessage + + case object StopExecutor extends CoarseGrainedClusterMessage - case object StopDriver extends StandaloneClusterMessage + case object StopExecutors extends CoarseGrainedClusterMessage - case class RemoveExecutor(executorId: String, reason: String) extends StandaloneClusterMessage + case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index b6f0ec961a..3ccc38d72b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -29,16 +29,19 @@ import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycle import org.apache.spark.{SparkException, Logging, TaskState} import org.apache.spark.scheduler.TaskDescription -import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._ +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.Utils /** - * A standalone scheduler backend, which waits for standalone executors to connect to it through - * Akka. These may be executed in a variety of ways, such as Mesos tasks for the coarse-grained - * Mesos mode or standalone processes for Spark's standalone deploy mode (spark.deploy.*). + * A scheduler backend that waits for coarse grained executors to connect to it through Akka. + * This backend holds onto each executor for the duration of the Spark job rather than relinquishing + * executors whenever a task is done and asking the scheduler to launch a new executor for + * each new task. Executors may be launched in a variety of ways, such as Mesos tasks for the + * coarse-grained Mesos mode or standalone processes for Spark's standalone deploy mode + * (spark.deploy.*). */ private[spark] -class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) extends SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed @@ -84,17 +87,33 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor case StatusUpdate(executorId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { - freeCores(executorId) += 1 - makeOffers(executorId) + if (executorActor.contains(executorId)) { + freeCores(executorId) += 1 + makeOffers(executorId) + } else { + // Ignoring the update since we don't know about the executor. + val msg = "Ignored task status update (%d state %s) from unknown executor %s with ID %s" + logWarning(msg.format(taskId, state, sender, executorId)) + } } case ReviveOffers => makeOffers() + case KillTask(taskId, executorId) => + executorActor(executorId) ! KillTask(taskId, executorId) + case StopDriver => sender ! true context.stop(self) + case StopExecutors => + logInfo("Asking each executor to shut down") + for (executor <- executorActor.values) { + executor ! StopExecutor + } + sender ! true + case RemoveExecutor(executorId, reason) => removeExecutor(executorId, reason) sender ! true @@ -159,16 +178,31 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } } driverActor = actorSystem.actorOf( - Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME) + Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME) } - private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + private val timeout = { + Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + } + + def stopExecutors() { + try { + if (driverActor != null) { + logInfo("Shutting down all executors") + val future = driverActor.ask(StopExecutors)(timeout) + Await.ready(future, timeout) + } + } catch { + case e: Exception => + throw new SparkException("Error asking standalone scheduler to shut down executors", e) + } + } override def stop() { try { if (driverActor != null) { val future = driverActor.ask(StopDriver)(timeout) - Await.result(future, timeout) + Await.ready(future, timeout) } } catch { case e: Exception => @@ -180,6 +214,10 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor driverActor ! ReviveOffers } + override def killTask(taskId: Long, executorId: String) { + driverActor ! KillTask(taskId, executorId) + } + override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism")) .map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2)) @@ -187,7 +225,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor def removeExecutor(executorId: String, reason: String) { try { val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) - Await.result(future, timeout) + Await.ready(future, timeout) } catch { case e: Exception => throw new SparkException("Error notifying standalone scheduler's driver actor", e) @@ -195,6 +233,6 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } } -private[spark] object StandaloneSchedulerBackend { - val ACTOR_NAME = "StandaloneScheduler" +private[spark] object CoarseGrainedSchedulerBackend { + val ACTOR_NAME = "CoarseGrainedScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala index d57eb3276f..5367218faa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import org.apache.spark.{SparkContext} +import org.apache.spark.SparkContext /** * A backend interface for cluster scheduling systems that allows plugging in different ones under @@ -30,8 +30,8 @@ private[spark] trait SchedulerBackend { def reviveOffers(): Unit def defaultParallelism(): Int + def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException + // Memory used by each executor (in megabytes) protected val executorMemory: Int = SparkContext.executorMemoryRequested - - // TODO: Probably want to add a killTask too } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala new file mode 100644 index 0000000000..d78bdbaa7a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.spark.{Logging, SparkContext} + +private[spark] class SimrSchedulerBackend( + scheduler: ClusterScheduler, + sc: SparkContext, + driverFilePath: String) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + with Logging { + + val tmpPath = new Path(driverFilePath + "_tmp") + val filePath = new Path(driverFilePath) + + val maxCores = System.getProperty("spark.simr.executor.cores", "1").toInt + + override def start() { + super.start() + + val driverUrl = "akka://spark@%s:%s/user/%s".format( + System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), + CoarseGrainedSchedulerBackend.ACTOR_NAME) + + val conf = new Configuration() + val fs = FileSystem.get(conf) + + logInfo("Writing to HDFS file: " + driverFilePath) + logInfo("Writing Akka address: " + driverUrl) + + // Create temporary file to prevent race condition where executors get empty driverUrl file + val temp = fs.create(tmpPath, true) + temp.writeUTF(driverUrl) + temp.writeInt(maxCores) + temp.close() + + // "Atomic" rename + fs.rename(tmpPath, filePath) + } + + override def stop() { + val conf = new Configuration() + val fs = FileSystem.get(conf) + fs.delete(new Path(driverFilePath), false) + super.stopExecutors() + super.stop() + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index fa83ae19d6..7127a72d6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -26,9 +26,9 @@ import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( scheduler: ClusterScheduler, sc: SparkContext, - master: String, + masters: Array[String], appName: String) - extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) with ClientListener with Logging { @@ -44,15 +44,15 @@ private[spark] class SparkDeploySchedulerBackend( // The endpoint for executors to talk to us val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), - StandaloneSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ACTOR_NAME) val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command( - "org.apache.spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) + "org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse(null) val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome, "http://" + sc.ui.appUIAddress) - client = new Client(sc.env.actorSystem, master, appDesc, this) + client = new Client(sc.env.actorSystem, masters, appDesc, this) client.start() } @@ -71,8 +71,14 @@ private[spark] class SparkDeploySchedulerBackend( override def disconnected() { if (!stopping) { - logError("Disconnected from Spark cluster!") - scheduler.error("Disconnected from Spark cluster") + logWarning("Disconnected from Spark cluster! Waiting for reconnection...") + } + } + + override def dead() { + if (!stopping) { + logError("Spark cluster looks dead, giving up.") + scheduler.error("Spark cluster looks down") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala index b2a8f06472..e68c527713 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala @@ -24,33 +24,16 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult} import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.util.Utils /** * Runs a thread pool that deserializes and remotely fetches (if necessary) task results. */ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler) extends Logging { - private val MIN_THREADS = System.getProperty("spark.resultGetter.minThreads", "4").toInt - private val MAX_THREADS = System.getProperty("spark.resultGetter.maxThreads", "4").toInt - private val getTaskResultExecutor = new ThreadPoolExecutor( - MIN_THREADS, - MAX_THREADS, - 0L, - TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable], - new ResultResolverThreadFactory) - - class ResultResolverThreadFactory extends ThreadFactory { - private var counter = 0 - private var PREFIX = "Result resolver thread" - - override def newThread(r: Runnable): Thread = { - val thread = new Thread(r, "%s-%s".format(PREFIX, counter)) - counter += 1 - thread.setDaemon(true) - return thread - } - } + private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt + private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( + THREADS, "Result resolver thread") protected val serializer = new ThreadLocal[SerializerInstance] { override def initialValue(): SerializerInstance = { @@ -67,6 +50,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche case directResult: DirectTaskResult[_] => directResult case IndirectTaskResult(blockId) => logDebug("Fetching indirect task result for TID %s".format(tid)) + scheduler.handleTaskGettingResult(taskSetManager, tid) val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId) if (!serializedTaskResult.isDefined) { /* We won't be able to get the task result if the machine that ran the task failed diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index bf4040fafc..8de9b72b2f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -30,13 +30,14 @@ import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.spark.{SparkException, Logging, SparkContext, TaskState} -import org.apache.spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend} +import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend} /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds * onto each Mesos node for the duration of the Spark job instead of relinquishing cores whenever * a task is done. It launches Spark tasks within the coarse-grained Mesos tasks using the - * StandaloneBackend mechanism. This class is useful for lower and more predictable latency. + * CoarseGrainedSchedulerBackend mechanism. This class is useful for lower and more predictable + * latency. * * Unfortunately this has a bit of duplication from MesosSchedulerBackend, but it seems hard to * remove this. @@ -46,7 +47,7 @@ private[spark] class CoarseMesosSchedulerBackend( sc: SparkContext, master: String, appName: String) - extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) with MScheduler with Logging { @@ -122,20 +123,20 @@ private[spark] class CoarseMesosSchedulerBackend( val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), - StandaloneSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ACTOR_NAME) val uri = System.getProperty("spark.executor.uri") if (uri == null) { val runScript = new File(sparkHome, "spark-class").getCanonicalPath command.setValue( - "\"%s\" org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format( + "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format( runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.split('/').last.split('.').head command.setValue( - "cd %s*; ./spark-class org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format( - basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) + "cd %s*; ./spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d" + .format(basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } return command.build() diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala index 4d1bb1c639..2699f0b33e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -17,23 +17,19 @@ package org.apache.spark.scheduler.local -import java.io.File -import java.lang.management.ManagementFactory -import java.util.concurrent.atomic.AtomicInteger import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} + +import akka.actor._ import org.apache.spark._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.executor.ExecutorURLClassLoader +import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode -import akka.actor._ -import org.apache.spark.util.Utils + /** * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally @@ -41,52 +37,57 @@ import org.apache.spark.util.Utils * testing fault recovery. */ -private[spark] +private[local] case class LocalReviveOffers() -private[spark] +private[local] case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) +private[local] +case class KillTask(taskId: Long) + private[spark] -class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging { +class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int) + extends Actor with Logging { + + val executor = new Executor("localhost", "localhost", Seq.empty, isLocal = true) def receive = { case LocalReviveOffers => launchTask(localScheduler.resourceOffer(freeCores)) + case LocalStatusUpdate(taskId, state, serializeData) => - freeCores += 1 - localScheduler.statusUpdate(taskId, state, serializeData) - launchTask(localScheduler.resourceOffer(freeCores)) + if (TaskState.isFinished(state)) { + freeCores += 1 + launchTask(localScheduler.resourceOffer(freeCores)) + } + + case KillTask(taskId) => + executor.killTask(taskId) } - def launchTask(tasks : Seq[TaskDescription]) { + private def launchTask(tasks: Seq[TaskDescription]) { for (task <- tasks) { freeCores -= 1 - localScheduler.threadPool.submit(new Runnable { - def run() { - localScheduler.runTask(task.taskId, task.serializedTask) - } - }) + executor.launchTask(localScheduler, task.taskId, task.serializedTask) } } } private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) extends TaskScheduler + with ExecutorBackend with Logging { - var attemptId = new AtomicInteger(0) - var threadPool = Utils.newDaemonFixedThreadPool(threads) val env = SparkEnv.get - var listener: TaskSchedulerListener = null + val attemptId = new AtomicInteger + var dagScheduler: DAGScheduler = null // Application dependencies (added through SparkContext) that we've fetched so far on this node. // Each map holds the master's timestamp for the version of that file or JAR we got. val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() val currentJars: HashMap[String, Long] = new HashMap[String, Long]() - val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader) - var schedulableBuilder: SchedulableBuilder = null var rootPool: Pool = null val schedulingMode: SchedulingMode = SchedulingMode.withName( @@ -113,8 +114,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test") } - override def setListener(listener: TaskSchedulerListener) { - this.listener = listener + override def setDAGScheduler(dagScheduler: DAGScheduler) { + this.dagScheduler = dagScheduler } override def submitTasks(taskSet: TaskSet) { @@ -127,6 +128,26 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } } + override def cancelTasks(stageId: Int): Unit = synchronized { + logInfo("Cancelling stage " + stageId) + logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId)) + activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the stage. + val taskIds = taskSetTaskIds(tsm.taskSet.id) + if (taskIds.size > 0) { + taskIds.foreach { tid => + localActor ! KillTask(tid) + } + } + tsm.error("Stage %d was cancelled".format(stageId)) + } + } + def resourceOffer(freeCores: Int): Seq[TaskDescription] = { synchronized { var freeCpuCores = freeCores @@ -166,107 +187,32 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } } - def runTask(taskId: Long, bytes: ByteBuffer) { - logInfo("Running " + taskId) - val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) - // Set the Spark execution environment for the worker thread - SparkEnv.set(env) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objectSer = SparkEnv.get.serializer.newInstance() - var attemptedTask: Option[Task[_]] = None - val start = System.currentTimeMillis() - var taskStart: Long = 0 - def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum - val startGCTime = getTotalGCTime - - try { - Accumulators.clear() - Thread.currentThread().setContextClassLoader(classLoader) - - // Serialize and deserialize the task so that accumulators are changed to thread-local ones; - // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) - updateDependencies(taskFiles, taskJars) // Download any files added with addFile - val deserializedTask = ser.deserialize[Task[_]]( - taskBytes, Thread.currentThread.getContextClassLoader) - attemptedTask = Some(deserializedTask) - val deserTime = System.currentTimeMillis() - start - taskStart = System.currentTimeMillis() - - // Run it - val result: Any = deserializedTask.run(taskId) - - // Serialize and deserialize the result to emulate what the Mesos - // executor does. This is useful to catch serialization errors early - // on in development (so when users move their local Spark programs - // to the cluster, they don't get surprised by serialization errors). - val serResult = objectSer.serialize(result) - deserializedTask.metrics.get.resultSize = serResult.limit() - val resultToReturn = objectSer.deserialize[Any](serResult) - val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( - ser.serialize(Accumulators.values)) - val serviceTime = System.currentTimeMillis() - taskStart - logInfo("Finished " + taskId) - deserializedTask.metrics.get.executorRunTime = serviceTime.toInt - deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime - deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt - val taskResult = new DirectTaskResult( - result, accumUpdates, deserializedTask.metrics.getOrElse(null)) - val serializedResult = ser.serialize(taskResult) - localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult) - } catch { - case t: Throwable => { - val serviceTime = System.currentTimeMillis() - taskStart - val metrics = attemptedTask.flatMap(t => t.metrics) - for (m <- metrics) { - m.executorRunTime = serviceTime.toInt - m.jvmGCTime = getTotalGCTime - startGCTime - } - val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics) - localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure)) - } - } - } - - /** - * Download any missing dependencies if we receive a new set of files and JARs from the - * SparkContext. Also adds any new JARs we fetched to the class loader. - */ - private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - synchronized { - // Fetch missing dependencies - for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentFiles(name) = timestamp - } - - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentJars(name) = timestamp - // Add it to our class loader - val localName = name.split("/").last - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!classLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - classLoader.addURL(url) + override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { + if (TaskState.isFinished(state)) { + synchronized { + taskIdToTaskSetId.get(taskId) match { + case Some(taskSetId) => + val taskSetManager = activeTaskSets(taskSetId) + taskSetTaskIds(taskSetId) -= taskId + + state match { + case TaskState.FINISHED => + taskSetManager.taskEnded(taskId, state, serializedData) + case TaskState.FAILED => + taskSetManager.taskFailed(taskId, state, serializedData) + case TaskState.KILLED => + taskSetManager.error("Task %d was killed".format(taskId)) + case _ => {} + } + case None => + logInfo("Ignoring update from TID " + taskId + " because its task set is gone") } } + localActor ! LocalStatusUpdate(taskId, state, serializedData) } } - def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) { - synchronized { - val taskSetId = taskIdToTaskSetId(taskId) - val taskSetManager = activeTaskSets(taskSetId) - taskSetTaskIds(taskSetId) -= taskId - taskSetManager.statusUpdate(taskId, state, serializedData) - } - } - - override def stop() { - threadPool.shutdownNow() + override def stop() { } override def defaultParallelism() = threads diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala index c2e2399ccb..53bf78267e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala @@ -132,19 +132,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas return None } - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - SparkEnv.set(env) - state match { - case TaskState.FINISHED => - taskEnded(tid, state, serializedData) - case TaskState.FAILED => - taskFailed(tid, state, serializedData) - case _ => {} - } - } - def taskStarted(task: Task[_], info: TaskInfo) { - sched.listener.taskStarted(task, info) + sched.dagScheduler.taskStarted(task, info) } def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { @@ -159,7 +148,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas } } result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics) + sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info, + result.metrics) numFinished += 1 decreaseRunningTasks(1) finished(index) = true @@ -176,7 +166,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas decreaseRunningTasks(1) val reason: ExceptionFailure = ser.deserialize[ExceptionFailure]( serializedData, getClass.getClassLoader) - sched.listener.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null)) + sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null)) if (!finished(index)) { copiesRunning(index) -= 1 numFailures(index) += 1 @@ -185,9 +175,9 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas reason.className, reason.description, locs.mkString("\n"))) if (numFailures(index) > MAX_TASK_FAILURES) { val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format( - taskSet.id, index, 4, reason.description) + taskSet.id, index, MAX_TASK_FAILURES, reason.description) decreaseRunningTasks(runningTasks) - sched.listener.taskSetFailed(taskSet, errorMessage) + sched.dagScheduler.taskSetFailed(taskSet, errorMessage) // need to delete failed Taskset from schedule queue sched.taskSetFinished(this) } @@ -195,5 +185,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas } override def error(message: String) { + sched.dagScheduler.taskSetFailed(taskSet, message) + sched.taskSetFinished(this) } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index e936b1cfed..55b25f145a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -26,9 +26,8 @@ import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar} import org.apache.spark.{SerializableWritable, Logging} -import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock, StorageLevel} - import org.apache.spark.broadcast.HttpBroadcast +import org.apache.spark.storage.{GetBlock,GotBlock, PutBlock, StorageLevel, TestBlockId} /** * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]]. @@ -43,13 +42,14 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging val kryo = instantiator.newKryo() val classLoader = Thread.currentThread.getContextClassLoader + val blockId = TestBlockId("1") // Register some commonly used classes val toRegister: Seq[AnyRef] = Seq( ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY, - PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY), - GotBlock("1", ByteBuffer.allocate(1)), - GetBlock("1"), + PutBlock(blockId, ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY), + GotBlock(blockId, ByteBuffer.allocate(1)), + GetBlock(blockId), 1 to 10, 1 until 10, 1L to 10L, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockException.scala b/core/src/main/scala/org/apache/spark/storage/BlockException.scala index 290dbce4f5..0d0a2dadc7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockException.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockException.scala @@ -18,5 +18,5 @@ package org.apache.spark.storage private[spark] -case class BlockException(blockId: String, message: String) extends Exception(message) +case class BlockException(blockId: BlockId, message: String) extends Exception(message) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 3aeda3879d..e51c5b30a3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -47,7 +47,7 @@ import org.apache.spark.util.Utils */ private[storage] -trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])] +trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging with BlockFetchTracker { def initialize() } @@ -57,20 +57,20 @@ private[storage] object BlockFetcherIterator { // A request to fetch one or more blocks, complete with their sizes - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { + class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { val size = blocks.map(_._2).sum } // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize // the block (since we want all deserializaton to happen in the calling thread); can also // represent a fetch failure if size == -1. - class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { + class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { def failed: Boolean = size == -1 } class BasicBlockFetcherIterator( private val blockManager: BlockManager, - val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer) extends BlockFetcherIterator { @@ -92,12 +92,12 @@ object BlockFetcherIterator { // This represents the number of local blocks, also counting zero-sized blocks private var numLocal = 0 // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks - protected val localBlocksToFetch = new ArrayBuffer[String]() + protected val localBlocksToFetch = new ArrayBuffer[BlockId]() // This represents the number of remote blocks, also counting zero-sized blocks private var numRemote = 0 // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks - protected val remoteBlocksToFetch = new HashSet[String]() + protected val remoteBlocksToFetch = new HashSet[BlockId]() // A queue to hold our results. protected val results = new LinkedBlockingQueue[FetchResult] @@ -167,7 +167,7 @@ object BlockFetcherIterator { logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) val iterator = blockInfos.iterator var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(String, Long)] + var curBlocks = new ArrayBuffer[(BlockId, Long)] while (iterator.hasNext) { val (blockId, size) = iterator.next() // Skip empty blocks @@ -183,7 +183,7 @@ object BlockFetcherIterator { // Add this FetchRequest remoteRequests += new FetchRequest(address, curBlocks) curRequestSize = 0 - curBlocks = new ArrayBuffer[(String, Long)] + curBlocks = new ArrayBuffer[(BlockId, Long)] } } // Add in the final request @@ -241,7 +241,7 @@ object BlockFetcherIterator { override def hasNext: Boolean = resultsGotten < _numBlocksToFetch - override def next(): (String, Option[Iterator[Any]]) = { + override def next(): (BlockId, Option[Iterator[Any]]) = { resultsGotten += 1 val startFetchWait = System.currentTimeMillis() val result = results.take() @@ -267,7 +267,7 @@ object BlockFetcherIterator { class NettyBlockFetcherIterator( blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer) extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) { @@ -303,7 +303,7 @@ object BlockFetcherIterator { override protected def sendRequest(req: FetchRequest) { - def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) { + def putResult(blockId: BlockId, blockSize: Long, blockData: ByteBuf) { val fetchResult = new FetchResult(blockId, blockSize, () => dataDeserialize(blockId, blockData.nioBuffer, serializer)) results.put(fetchResult) @@ -337,7 +337,7 @@ object BlockFetcherIterator { logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") } - override def next(): (String, Option[Iterator[Any]]) = { + override def next(): (BlockId, Option[Iterator[Any]]) = { resultsGotten += 1 val result = results.take() // If all the results has been retrieved, copiers will exit automatically diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala new file mode 100644 index 0000000000..7156d855d8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +/** + * Identifies a particular Block of data, usually associated with a single file. + * A Block can be uniquely identified by its filename, but each type of Block has a different + * set of keys which produce its unique name. + * + * If your BlockId should be serializable, be sure to add it to the BlockId.fromString() method. + */ +private[spark] sealed abstract class BlockId { + /** A globally unique identifier for this Block. Can be used for ser/de. */ + def name: String + + // convenience methods + def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None + def isRDD = isInstanceOf[RDDBlockId] + def isShuffle = isInstanceOf[ShuffleBlockId] + def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId] + + override def toString = name + override def hashCode = name.hashCode + override def equals(other: Any): Boolean = other match { + case o: BlockId => getClass == o.getClass && name.equals(o.name) + case _ => false + } +} + +private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { + def name = "rdd_" + rddId + "_" + splitIndex +} + +private[spark] +case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { + def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId +} + +private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId { + def name = "broadcast_" + broadcastId +} + +private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId { + def name = broadcastId.name + "_" + hType +} + +private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId { + def name = "taskresult_" + taskId +} + +private[spark] case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId { + def name = "input-" + streamId + "-" + uniqueId +} + +// Intended only for testing purposes +private[spark] case class TestBlockId(id: String) extends BlockId { + def name = "test_" + id +} + +private[spark] object BlockId { + val RDD = "rdd_([0-9]+)_([0-9]+)".r + val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r + val BROADCAST = "broadcast_([0-9]+)".r + val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r + val TASKRESULT = "taskresult_([0-9]+)".r + val STREAM = "input-([0-9]+)-([0-9]+)".r + val TEST = "test_(.*)".r + + /** Converts a BlockId "name" String back into a BlockId. */ + def apply(id: String) = id match { + case RDD(rddId, splitIndex) => + RDDBlockId(rddId.toInt, splitIndex.toInt) + case SHUFFLE(shuffleId, mapId, reduceId) => + ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) + case BROADCAST(broadcastId) => + BroadcastBlockId(broadcastId.toLong) + case BROADCAST_HELPER(broadcastId, hType) => + BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType) + case TASKRESULT(taskId) => + TaskResultBlockId(taskId.toLong) + case STREAM(streamId, uniqueId) => + StreamBlockId(streamId.toInt, uniqueId.toLong) + case TEST(value) => + TestBlockId(value) + case _ => + throw new IllegalStateException("Unrecognized BlockId: " + id) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala new file mode 100644 index 0000000000..c8f397609a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.util.concurrent.ConcurrentHashMap + +private[storage] class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { + // To save space, 'pending' and 'failed' are encoded as special sizes: + @volatile var size: Long = BlockInfo.BLOCK_PENDING + private def pending: Boolean = size == BlockInfo.BLOCK_PENDING + private def failed: Boolean = size == BlockInfo.BLOCK_FAILED + private def initThread: Thread = BlockInfo.blockInfoInitThreads.get(this) + + setInitThread() + + private def setInitThread() { + // Set current thread as init thread - waitForReady will not block this thread + // (in case there is non trivial initialization which ends up calling waitForReady as part of + // initialization itself) + BlockInfo.blockInfoInitThreads.put(this, Thread.currentThread()) + } + + /** + * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing). + * Return true if the block is available, false otherwise. + */ + def waitForReady(): Boolean = { + if (pending && initThread != Thread.currentThread()) { + synchronized { + while (pending) this.wait() + } + } + !failed + } + + /** Mark this BlockInfo as ready (i.e. block is finished writing) */ + def markReady(sizeInBytes: Long) { + require (sizeInBytes >= 0, "sizeInBytes was negative: " + sizeInBytes) + assert (pending) + size = sizeInBytes + BlockInfo.blockInfoInitThreads.remove(this) + synchronized { + this.notifyAll() + } + } + + /** Mark this BlockInfo as ready but failed */ + def markFailure() { + assert (pending) + size = BlockInfo.BLOCK_FAILED + BlockInfo.blockInfoInitThreads.remove(this) + synchronized { + this.notifyAll() + } + } +} + +private object BlockInfo { + // initThread is logically a BlockInfo field, but we store it here because + // it's only needed while this block is in the 'pending' state and we want + // to minimize BlockInfo's memory footprint. + private val blockInfoInitThreads = new ConcurrentHashMap[BlockInfo, Thread] + + private val BLOCK_PENDING: Long = -1L + private val BLOCK_FAILED: Long = -2L +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 7852849ce5..252329c4e1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -17,17 +17,18 @@ package org.apache.spark.storage -import java.io.{InputStream, OutputStream} +import java.io.{File, InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} -import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet} +import scala.collection.mutable.{HashMap, ArrayBuffer} +import scala.util.Random import akka.actor.{ActorSystem, Cancellable, Props} import scala.concurrent.{Await, Future} import scala.concurrent.duration.Duration import scala.concurrent.duration._ -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream +import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec @@ -37,7 +38,6 @@ import org.apache.spark.util._ import sun.nio.ch.DirectBuffer - private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, @@ -46,74 +46,20 @@ private[spark] class BlockManager( maxMemory: Long) extends Logging { - private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) { - @volatile var pending: Boolean = true - @volatile var size: Long = -1L - @volatile var initThread: Thread = null - @volatile var failed = false - - setInitThread() - - private def setInitThread() { - // Set current thread as init thread - waitForReady will not block this thread - // (in case there is non trivial initialization which ends up calling waitForReady as part of - // initialization itself) - this.initThread = Thread.currentThread() - } - - /** - * Wait for this BlockInfo to be marked as ready (i.e. block is finished writing). - * Return true if the block is available, false otherwise. - */ - def waitForReady(): Boolean = { - if (initThread != Thread.currentThread() && pending) { - synchronized { - while (pending) this.wait() - } - } - !failed - } - - /** Mark this BlockInfo as ready (i.e. block is finished writing) */ - def markReady(sizeInBytes: Long) { - assert (pending) - size = sizeInBytes - initThread = null - failed = false - initThread = null - pending = false - synchronized { - this.notifyAll() - } - } - - /** Mark this BlockInfo as ready but failed */ - def markFailure() { - assert (pending) - size = 0 - initThread = null - failed = true - initThread = null - pending = false - synchronized { - this.notifyAll() - } - } - } - val shuffleBlockManager = new ShuffleBlockManager(this) + val diskBlockManager = new DiskBlockManager(shuffleBlockManager, + System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) - private val blockInfo = new TimeStampedHashMap[String, BlockInfo] + private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) - private[storage] val diskStore: DiskStore = - new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) + private[storage] val diskStore = new DiskStore(this, diskBlockManager) // If we use Netty for shuffle, start a new Netty-based shuffle sender service. private val nettyPort: Int = { val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt - if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0 + if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0 } val connectionManager = new ConnectionManager(0) @@ -154,7 +100,8 @@ private[spark] class BlockManager( var heartBeatTask: Cancellable = null - val metadataCleaner = new MetadataCleaner("BlockManager", this.dropOldBlocks) + private val metadataCleaner = new MetadataCleaner(MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks) + private val broadcastCleaner = new MetadataCleaner(MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks) initialize() // The compression codec to use. Note that the "lazy" val is necessary because we want to delay @@ -248,7 +195,7 @@ private[spark] class BlockManager( /** * Get storage level of local block. If no info exists for the block, then returns null. */ - def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull + def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull /** * Tell the master about the current storage status of a block. This will send a block update @@ -258,7 +205,7 @@ private[spark] class BlockManager( * droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid). * This ensures that update in master will compensate for the increase in memory on slave. */ - def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) { + def reportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L) { val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize) if (needReregister) { logInfo("Got told to reregister updating block " + blockId) @@ -269,11 +216,11 @@ private[spark] class BlockManager( } /** - * Actually send a UpdateBlockInfo message. Returns the mater's response, + * Actually send a UpdateBlockInfo message. Returns the master's response, * which will be true if the block was successfully recorded and false if * the slave needs to re-register. */ - private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = { + private def tryToReportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = { val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized { info.level match { case null => @@ -298,7 +245,7 @@ private[spark] class BlockManager( /** * Get locations of an array of blocks. */ - def getLocationBlockIds(blockIds: Array[String]): Array[Seq[BlockManagerId]] = { + def getLocationBlockIds(blockIds: Array[BlockId]): Array[Seq[BlockManagerId]] = { val startTimeMs = System.currentTimeMillis val locations = master.getLocations(blockIds).toArray logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) @@ -310,7 +257,7 @@ private[spark] class BlockManager( * shuffle blocks. It is safe to do so without a lock on block info since disk store * never deletes (recent) items. */ - def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { + def getLocalFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { diskStore.getValues(blockId, serializer).orElse( sys.error("Block " + blockId + " not found on disk, though it should be")) } @@ -318,94 +265,19 @@ private[spark] class BlockManager( /** * Get block from local block manager. */ - def getLocal(blockId: String): Option[Iterator[Any]] = { + def getLocal(blockId: BlockId): Option[Iterator[Any]] = { logDebug("Getting local block " + blockId) - val info = blockInfo.get(blockId).orNull - if (info != null) { - info.synchronized { - - // In the another thread is writing the block, wait for it to become ready. - if (!info.waitForReady()) { - // If we get here, the block write failed. - logWarning("Block " + blockId + " was marked as failure.") - return None - } - - val level = info.level - logDebug("Level for block " + blockId + " is " + level) - - // Look for the block in memory - if (level.useMemory) { - logDebug("Getting block " + blockId + " from memory") - memoryStore.getValues(blockId) match { - case Some(iterator) => - return Some(iterator) - case None => - logDebug("Block " + blockId + " not found in memory") - } - } - - // Look for block on disk, potentially loading it back into memory if required - if (level.useDisk) { - logDebug("Getting block " + blockId + " from disk") - if (level.useMemory && level.deserialized) { - diskStore.getValues(blockId) match { - case Some(iterator) => - // Put the block back in memory before returning it - // TODO: Consider creating a putValues that also takes in a iterator ? - val elements = new ArrayBuffer[Any] - elements ++= iterator - memoryStore.putValues(blockId, elements, level, true).data match { - case Left(iterator2) => - return Some(iterator2) - case _ => - throw new Exception("Memory store did not return back an iterator") - } - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } else if (level.useMemory && !level.deserialized) { - // Read it as a byte buffer into memory first, then return it - diskStore.getBytes(blockId) match { - case Some(bytes) => - // Put a copy of the block back in memory before returning it. Note that we can't - // put the ByteBuffer returned by the disk store as that's a memory-mapped file. - // The use of rewind assumes this. - assert (0 == bytes.position()) - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) - memoryStore.putBytes(blockId, copyForMemory, level) - bytes.rewind() - return Some(dataDeserialize(blockId, bytes)) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } else { - diskStore.getValues(blockId) match { - case Some(iterator) => - return Some(iterator) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } - } - } - } else { - logDebug("Block " + blockId + " not registered locally") - } - return None + doGetLocal(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]] } /** * Get block from the local block manager as serialized bytes. */ - def getLocalBytes(blockId: String): Option[ByteBuffer] = { - // TODO: This whole thing is very similar to getLocal; we need to refactor it somehow + def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = { logDebug("Getting local block " + blockId + " as bytes") - // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work - if (ShuffleBlockManager.isShuffle(blockId)) { + if (blockId.isShuffle) { return diskStore.getBytes(blockId) match { case Some(bytes) => Some(bytes) @@ -413,12 +285,15 @@ private[spark] class BlockManager( throw new Exception("Block " + blockId + " not found on disk, though it should be") } } + doGetLocal(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]] + } + private def doGetLocal(blockId: BlockId, asValues: Boolean): Option[Any] = { val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { - // In the another thread is writing the block, wait for it to become ready. + // If another thread is writing the block, wait for it to become ready. if (!info.waitForReady()) { // If we get here, the block write failed. logWarning("Block " + blockId + " was marked as failure.") @@ -431,62 +306,104 @@ private[spark] class BlockManager( // Look for the block in memory if (level.useMemory) { logDebug("Getting block " + blockId + " from memory") - memoryStore.getBytes(blockId) match { - case Some(bytes) => - return Some(bytes) + val result = if (asValues) { + memoryStore.getValues(blockId) + } else { + memoryStore.getBytes(blockId) + } + result match { + case Some(values) => + return Some(values) case None => logDebug("Block " + blockId + " not found in memory") } } - // Look for block on disk + // Look for block on disk, potentially storing it back into memory if required: if (level.useDisk) { - // Read it as a byte buffer into memory first, then return it - diskStore.getBytes(blockId) match { - case Some(bytes) => - assert (0 == bytes.position()) - if (level.useMemory) { - if (level.deserialized) { - memoryStore.putBytes(blockId, bytes, level) - } else { - // The memory store will hang onto the ByteBuffer, so give it a copy instead of - // the memory-mapped file buffer we got from the disk store - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) - memoryStore.putBytes(blockId, copyForMemory, level) - } - } - bytes.rewind() - return Some(bytes) + logDebug("Getting block " + blockId + " from disk") + val bytes: ByteBuffer = diskStore.getBytes(blockId) match { + case Some(bytes) => bytes case None => throw new Exception("Block " + blockId + " not found on disk, though it should be") } + assert (0 == bytes.position()) + + if (!level.useMemory) { + // If the block shouldn't be stored in memory, we can just return it: + if (asValues) { + return Some(dataDeserialize(blockId, bytes)) + } else { + return Some(bytes) + } + } else { + // Otherwise, we also have to store something in the memory store: + if (!level.deserialized || !asValues) { + // We'll store the bytes in memory if the block's storage level includes + // "memory serialized", or if it should be cached as objects in memory + // but we only requested its serialized bytes: + val copyForMemory = ByteBuffer.allocate(bytes.limit) + copyForMemory.put(bytes) + memoryStore.putBytes(blockId, copyForMemory, level) + bytes.rewind() + } + if (!asValues) { + return Some(bytes) + } else { + val values = dataDeserialize(blockId, bytes) + if (level.deserialized) { + // Cache the values before returning them: + // TODO: Consider creating a putValues that also takes in a iterator? + val valuesBuffer = new ArrayBuffer[Any] + valuesBuffer ++= values + memoryStore.putValues(blockId, valuesBuffer, level, true).data match { + case Left(values2) => + return Some(values2) + case _ => + throw new Exception("Memory store did not return back an iterator") + } + } else { + return Some(values) + } + } + } } } } else { logDebug("Block " + blockId + " not registered locally") } - return None + None } /** * Get block from remote block managers. */ - def getRemote(blockId: String): Option[Iterator[Any]] = { - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } + def getRemote(blockId: BlockId): Option[Iterator[Any]] = { logDebug("Getting remote block " + blockId) - // Get locations of block - val locations = master.getLocations(blockId) + doGetRemote(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]] + } - // Get block from remote locations + /** + * Get block from remote block managers as serialized bytes. + */ + def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = { + logDebug("Getting remote block " + blockId + " as bytes") + doGetRemote(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]] + } + + private def doGetRemote(blockId: BlockId, asValues: Boolean): Option[Any] = { + require(blockId != null, "BlockId is null") + val locations = Random.shuffle(master.getLocations(blockId)) for (loc <- locations) { logDebug("Getting remote block " + blockId + " from " + loc) val data = BlockManagerWorker.syncGetBlock( GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) if (data != null) { - return Some(dataDeserialize(blockId, data)) + if (asValues) { + return Some(dataDeserialize(blockId, data)) + } else { + return Some(data) + } } logDebug("The value of block " + blockId + " is null") } @@ -495,34 +412,9 @@ private[spark] class BlockManager( } /** - * Get block from remote block managers as serialized bytes. - */ - def getRemoteBytes(blockId: String): Option[ByteBuffer] = { - // TODO: As with getLocalBytes, this is very similar to getRemote and perhaps should be - // refactored. - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - logDebug("Getting remote block " + blockId + " as bytes") - - val locations = master.getLocations(blockId) - for (loc <- locations) { - logDebug("Getting remote block " + blockId + " from " + loc) - val data = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) - if (data != null) { - return Some(data) - } - logDebug("The value of block " + blockId + " is null") - } - logDebug("Block " + blockId + " not found") - return None - } - - /** * Get a block from the block manager (either local or remote). */ - def get(blockId: String): Option[Iterator[Any]] = { + def get(blockId: BlockId): Option[Iterator[Any]] = { val local = getLocal(blockId) if (local.isDefined) { logInfo("Found block %s locally".format(blockId)) @@ -543,7 +435,7 @@ private[spark] class BlockManager( * so that we can control the maxMegabytesInFlight for the fetch. */ def getMultiple( - blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer) + blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer) : BlockFetcherIterator = { val iter = @@ -557,7 +449,7 @@ private[spark] class BlockManager( iter } - def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) + def put(blockId: BlockId, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) : Long = { val elements = new ArrayBuffer[Any] elements ++= values @@ -566,35 +458,38 @@ private[spark] class BlockManager( /** * A short circuited method to get a block writer that can write data directly to disk. + * The Block will be appended to the File specified by filename. * This is currently used for writing shuffle files out. Callers should handle error * cases. */ - def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) + def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int) : BlockObjectWriter = { - val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize) - writer.registerCloseEventHandler(() => { - val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false) - blockInfo.put(blockId, myInfo) - myInfo.markReady(writer.size()) - }) - writer + val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) + new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream) } /** * Put a new block of values to the block manager. Returns its (estimated) size in bytes. */ - def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, - tellMaster: Boolean = true) : Long = { + def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, + tellMaster: Boolean = true) : Long = { + require(values != null, "Values is null") + doPut(blockId, Left(values), level, tellMaster) + } - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - if (values == null) { - throw new IllegalArgumentException("Values is null") - } - if (level == null || !level.isValid) { - throw new IllegalArgumentException("Storage level is null or invalid") - } + /** + * Put a new block of serialized bytes to the block manager. + */ + def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel, + tellMaster: Boolean = true) { + require(bytes != null, "Bytes is null") + doPut(blockId, Right(bytes), level, tellMaster) + } + + private def doPut(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer], + level: StorageLevel, tellMaster: Boolean = true): Long = { + require(blockId != null, "BlockId is null") + require(level != null && level.isValid, "StorageLevel is null or invalid") // Remember the block's storage level so that we can correctly drop it to disk if it needs // to be dropped right after it got put into memory. Note, however, that other threads will @@ -610,7 +505,8 @@ private[spark] class BlockManager( return oldBlockOpt.get.size } - // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? + // TODO: So the block info exists - but previous attempt to load it (?) failed. + // What do we do now ? Retry on it ? oldBlockOpt.get } else { tinfo @@ -619,10 +515,10 @@ private[spark] class BlockManager( val startTimeMs = System.currentTimeMillis - // If we need to replicate the data, we'll want access to the values, but because our - // put will read the whole iterator, there will be no values left. For the case where - // the put serializes data, we'll remember the bytes, above; but for the case where it - // doesn't, such as deserialized storage, let's rely on the put returning an Iterator. + // If we're storing values and we need to replicate the data, we'll want access to the values, + // but because our put will read the whole iterator, there will be no values left. For the + // case where the put serializes data, we'll remember the bytes, above; but for the case where + // it doesn't, such as deserialized storage, let's rely on the put returning an Iterator. var valuesAfterPut: Iterator[Any] = null // Ditto for the bytes after the put @@ -631,30 +527,51 @@ private[spark] class BlockManager( // Size of the block in bytes (to return to caller) var size = 0L + // If we're storing bytes, then initiate the replication before storing them locally. + // This is faster as data is already serialized and ready to send. + val replicationFuture = if (data.isRight && level.replication > 1) { + val bufferView = data.right.get.duplicate() // Doesn't copy the bytes, just creates a wrapper + Future { + replicate(blockId, bufferView, level) + } + } else { + null + } + myInfo.synchronized { logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") var marked = false try { - if (level.useMemory) { - // Save it just to memory first, even if it also has useDisk set to true; we will later - // drop it to disk if the memory store can't hold it. - val res = memoryStore.putValues(blockId, values, level, true) - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case Left(newIterator) => valuesAfterPut = newIterator + data match { + case Left(values) => { + if (level.useMemory) { + // Save it just to memory first, even if it also has useDisk set to true; we will + // drop it to disk later if the memory store can't hold it. + val res = memoryStore.putValues(blockId, values, level, true) + size = res.size + res.data match { + case Right(newBytes) => bytesAfterPut = newBytes + case Left(newIterator) => valuesAfterPut = newIterator + } + } else { + // Save directly to disk. + // Don't get back the bytes unless we replicate them. + val askForBytes = level.replication > 1 + val res = diskStore.putValues(blockId, values, level, askForBytes) + size = res.size + res.data match { + case Right(newBytes) => bytesAfterPut = newBytes + case _ => + } + } } - } else { - // Save directly to disk. - // Don't get back the bytes unless we replicate them. - val askForBytes = level.replication > 1 - val res = diskStore.putValues(blockId, values, level, askForBytes) - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case _ => + case Right(bytes) => { + bytes.rewind() + // Store it only in memory at first, even if useDisk is also set to true + (if (level.useMemory) memoryStore else diskStore).putBytes(blockId, bytes, level) + size = bytes.limit } } @@ -679,132 +596,46 @@ private[spark] class BlockManager( } logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) - // Replicate block if required + // Either we're storing bytes and we asynchronously started replication, or we're storing + // values and need to serialize and replicate them now: if (level.replication > 1) { - val remoteStartTime = System.currentTimeMillis - // Serialize the block if not already done - if (bytesAfterPut == null) { - if (valuesAfterPut == null) { - throw new SparkException( - "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") - } - bytesAfterPut = dataSerialize(blockId, valuesAfterPut) - } - replicate(blockId, bytesAfterPut, level) - logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime)) - } - BlockManager.dispose(bytesAfterPut) - - return size - } - - - /** - * Put a new block of serialized bytes to the block manager. - */ - def putBytes( - blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) { - - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - if (bytes == null) { - throw new IllegalArgumentException("Bytes is null") - } - if (level == null || !level.isValid) { - throw new IllegalArgumentException("Storage level is null or invalid") - } - - // Remember the block's storage level so that we can correctly drop it to disk if it needs - // to be dropped right after it got put into memory. Note, however, that other threads will - // not be able to get() this block until we call markReady on its BlockInfo. - val myInfo = { - val tinfo = new BlockInfo(level, tellMaster) - // Do atomically ! - val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) - - if (oldBlockOpt.isDefined) { - if (oldBlockOpt.get.waitForReady()) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") - return - } - - // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? - oldBlockOpt.get - } else { - tinfo - } - } - - val startTimeMs = System.currentTimeMillis - - // Initiate the replication before storing it locally. This is faster as - // data is already serialized and ready for sending - val replicationFuture = if (level.replication > 1) { - val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper - Future { - replicate(blockId, bufferView, level) - } - } else { - null - } - - myInfo.synchronized { - logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) - + " to get into synchronized block") - - var marked = false - try { - if (level.useMemory) { - // Store it only in memory at first, even if useDisk is also set to true - bytes.rewind() - memoryStore.putBytes(blockId, bytes, level) - } else { - bytes.rewind() - diskStore.putBytes(blockId, bytes, level) - } - - // assert (0 == bytes.position(), "" + bytes) - - // Now that the block is in either the memory or disk store, let other threads read it, - // and tell the master about it. - marked = true - myInfo.markReady(bytes.limit) - if (tellMaster) { - reportBlockStatus(blockId, myInfo) - } - } finally { - // If we failed at putting the block to memory/disk, notify other possible readers - // that it has failed, and then remove it from the block info map. - if (! marked) { - // Note that the remove must happen before markFailure otherwise another thread - // could've inserted a new BlockInfo before we remove it. - blockInfo.remove(blockId) - myInfo.markFailure() - logWarning("Putting block " + blockId + " failed") + data match { + case Right(bytes) => Await.ready(replicationFuture, Duration.Inf) + case Left(values) => { + val remoteStartTime = System.currentTimeMillis + // Serialize the block if not already done + if (bytesAfterPut == null) { + if (valuesAfterPut == null) { + throw new SparkException( + "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") + } + bytesAfterPut = dataSerialize(blockId, valuesAfterPut) + } + replicate(blockId, bytesAfterPut, level) + logDebug("Put block " + blockId + " remotely took " + + Utils.getUsedTimeMs(remoteStartTime)) } } } - // If replication had started, then wait for it to finish - if (level.replication > 1) { - Await.ready(replicationFuture, Duration.Inf) - } + BlockManager.dispose(bytesAfterPut) if (level.replication > 1) { - logDebug("PutBytes for block " + blockId + " with replication took " + + logDebug("Put for block " + blockId + " with replication took " + Utils.getUsedTimeMs(startTimeMs)) } else { - logDebug("PutBytes for block " + blockId + " without replication took " + + logDebug("Put for block " + blockId + " without replication took " + Utils.getUsedTimeMs(startTimeMs)) } + + size } /** * Replicate block to another node. */ var cachedPeers: Seq[BlockManagerId] = null - private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { + private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel) { val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) if (cachedPeers == null) { cachedPeers = master.getPeers(blockManagerId, level.replication - 1) @@ -827,14 +658,14 @@ private[spark] class BlockManager( /** * Read a block consisting of a single object. */ - def getSingle(blockId: String): Option[Any] = { + def getSingle(blockId: BlockId): Option[Any] = { get(blockId).map(_.next()) } /** * Write a block consisting of a single object. */ - def putSingle(blockId: String, value: Any, level: StorageLevel, tellMaster: Boolean = true) { + def putSingle(blockId: BlockId, value: Any, level: StorageLevel, tellMaster: Boolean = true) { put(blockId, Iterator(value), level, tellMaster) } @@ -842,7 +673,7 @@ private[spark] class BlockManager( * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory * store reaches its limit and needs to free up space. */ - def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) { + def dropFromMemory(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer]) { logInfo("Dropping block " + blockId + " from memory") val info = blockInfo.get(blockId).orNull if (info != null) { @@ -891,16 +722,15 @@ private[spark] class BlockManager( // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps // from RDD.id to blocks. logInfo("Removing RDD " + rddId) - val rddPrefix = "rdd_" + rddId + "_" - val blocksToRemove = blockInfo.filter(_._1.startsWith(rddPrefix)).map(_._1) - blocksToRemove.foreach(blockId => removeBlock(blockId, false)) + val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) + blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false)) blocksToRemove.size } /** * Remove a block from both memory and disk. */ - def removeBlock(blockId: String, tellMaster: Boolean = true) { + def removeBlock(blockId: BlockId, tellMaster: Boolean = true) { logInfo("Removing block " + blockId) val info = blockInfo.get(blockId).orNull if (info != null) info.synchronized { @@ -921,13 +751,22 @@ private[spark] class BlockManager( } } - def dropOldBlocks(cleanupTime: Long) { - logInfo("Dropping blocks older than " + cleanupTime) + private def dropOldNonBroadcastBlocks(cleanupTime: Long) { + logInfo("Dropping non broadcast blocks older than " + cleanupTime) + dropOldBlocks(cleanupTime, !_.isBroadcast) + } + + private def dropOldBroadcastBlocks(cleanupTime: Long) { + logInfo("Dropping broadcast blocks older than " + cleanupTime) + dropOldBlocks(cleanupTime, _.isBroadcast) + } + + private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) { val iterator = blockInfo.internalMap.entrySet().iterator() while (iterator.hasNext) { val entry = iterator.next() val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) - if (time < cleanupTime) { + if (time < cleanupTime && shouldDrop(id)) { info.synchronized { val level = info.level if (level.useMemory) { @@ -944,39 +783,45 @@ private[spark] class BlockManager( } } - def shouldCompress(blockId: String): Boolean = { - if (ShuffleBlockManager.isShuffle(blockId)) { - compressShuffle - } else if (blockId.startsWith("broadcast_")) { - compressBroadcast - } else if (blockId.startsWith("rdd_")) { - compressRdds - } else { - false // Won't happen in a real cluster, but it can in tests - } + def shouldCompress(blockId: BlockId): Boolean = blockId match { + case ShuffleBlockId(_, _, _) => compressShuffle + case BroadcastBlockId(_) => compressBroadcast + case RDDBlockId(_, _) => compressRdds + case _ => false } /** * Wrap an output stream for compression if block compression is enabled for its block type */ - def wrapForCompression(blockId: String, s: OutputStream): OutputStream = { + def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s } /** * Wrap an input stream for compression if block compression is enabled for its block type */ - def wrapForCompression(blockId: String, s: InputStream): InputStream = { + def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s } + /** Serializes into a stream. */ + def dataSerializeStream( + blockId: BlockId, + outputStream: OutputStream, + values: Iterator[Any], + serializer: Serializer = defaultSerializer) { + val byteStream = new FastBufferedOutputStream(outputStream) + val ser = serializer.newInstance() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() + } + + /** Serializes into a byte buffer. */ def dataSerialize( - blockId: String, + blockId: BlockId, values: Iterator[Any], serializer: Serializer = defaultSerializer): ByteBuffer = { val byteStream = new FastByteArrayOutputStream(4096) - val ser = serializer.newInstance() - ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() + dataSerializeStream(blockId, byteStream, values, serializer) byteStream.trim() ByteBuffer.wrap(byteStream.array) } @@ -986,7 +831,7 @@ private[spark] class BlockManager( * the iterator is reached. */ def dataDeserialize( - blockId: String, + blockId: BlockId, bytes: ByteBuffer, serializer: Serializer = defaultSerializer): Iterator[Any] = { bytes.rewind() @@ -1004,6 +849,7 @@ private[spark] class BlockManager( memoryStore.clear() diskStore.clear() metadataCleaner.cancel() + broadcastCleaner.cancel() logInfo("BlockManager stopped") } } @@ -1041,10 +887,10 @@ private[spark] object BlockManager extends Logging { } def blockIdsToBlockManagers( - blockIds: Array[String], + blockIds: Array[BlockId], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null) - : Map[String, Seq[BlockManagerId]] = + : Map[BlockId, Seq[BlockManagerId]] = { // env == null and blockManagerMaster != null is used in tests assert (env != null || blockManagerMaster != null) @@ -1054,7 +900,7 @@ private[spark] object BlockManager extends Logging { blockManagerMaster.getLocations(blockIds) } - val blockManagers = new HashMap[String, Seq[BlockManagerId]] + val blockManagers = new HashMap[BlockId, Seq[BlockManagerId]] for (i <- 0 until blockIds.length) { blockManagers(blockIds(i)) = blockLocations(i) } @@ -1062,19 +908,19 @@ private[spark] object BlockManager extends Logging { } def blockIdsToExecutorIds( - blockIds: Array[String], + blockIds: Array[BlockId], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null) - : Map[String, Seq[String]] = + : Map[BlockId, Seq[String]] = { blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId)) } def blockIdsToHosts( - blockIds: Array[String], + blockIds: Array[BlockId], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null) - : Map[String, Seq[String]] = + : Map[BlockId, Seq[String]] = { blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host)) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 0c977f05d1..48d7101b0a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -69,7 +69,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi def updateBlockInfo( blockManagerId: BlockManagerId, - blockId: String, + blockId: BlockId, storageLevel: StorageLevel, memSize: Long, diskSize: Long): Boolean = { @@ -80,12 +80,12 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi } /** Get locations of the blockId from the driver */ - def getLocations(blockId: String): Seq[BlockManagerId] = { + def getLocations(blockId: BlockId): Seq[BlockManagerId] = { askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId)) } /** Get locations of multiple blockIds from the driver */ - def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } @@ -103,7 +103,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi * Remove a block from the slaves that have it. This can only be used to remove * blocks that the driver knows about. */ - def removeBlock(blockId: String) { + def removeBlock(blockId: BlockId) { askDriverWithReply(RemoveBlock(blockId)) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 3776951782..154a3980e9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -48,7 +48,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId] // Mapping from block id to the set of block managers that have the block. - private val blockLocations = new JHashMap[String, mutable.HashSet[BlockManagerId]] + private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] val akkaTimeout = Duration.create( System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") @@ -130,10 +130,9 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { // First remove the metadata for the given RDD, and then asynchronously remove the blocks // from the slaves. - val prefix = "rdd_" + rddId + "_" // Find all blocks for the given RDD, remove the block from both blockLocations and // the blockManagerInfo that is tracking the blocks. - val blocks = blockLocations.keySet().filter(_.startsWith(prefix)) + val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) blocks.foreach { blockId => val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) @@ -199,7 +198,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { // Remove a block from the slaves that have it. This can only be used to remove // blocks that the master knows about. - private def removeBlockFromWorkers(blockId: String) { + private def removeBlockFromWorkers(blockId: BlockId) { val locations = blockLocations.get(blockId) if (locations != null) { locations.foreach { blockManagerId: BlockManagerId => @@ -229,9 +228,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { - if (id.executorId == "<driver>" && !isLocal) { - // Got a register message from the master node; don't register it - } else if (!blockManagerInfo.contains(id)) { + if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { case Some(manager) => // A block manager of the same executor already exists. @@ -248,7 +245,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { private def updateBlockInfo( blockManagerId: BlockManagerId, - blockId: String, + blockId: BlockId, storageLevel: StorageLevel, memSize: Long, diskSize: Long) { @@ -293,11 +290,11 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { sender ! true } - private def getLocations(blockId: String): Seq[BlockManagerId] = { + private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty } - private def getLocationsMultipleBlockIds(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { blockIds.map(blockId => getLocations(blockId)) } @@ -331,7 +328,7 @@ object BlockManagerMasterActor { private var _remainingMem: Long = maxMem // Mapping from block id to its status. - private val _blocks = new JHashMap[String, BlockStatus] + private val _blocks = new JHashMap[BlockId, BlockStatus] logInfo("Registering block manager %s with %s RAM".format( blockManagerId.hostPort, Utils.bytesToString(maxMem))) @@ -340,7 +337,7 @@ object BlockManagerMasterActor { _lastSeenMs = System.currentTimeMillis() } - def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, + def updateBlockInfo(blockId: BlockId, storageLevel: StorageLevel, memSize: Long, diskSize: Long) { updateLastSeenMs() @@ -384,7 +381,7 @@ object BlockManagerMasterActor { } } - def removeBlock(blockId: String) { + def removeBlock(blockId: BlockId) { if (_blocks.containsKey(blockId)) { _remainingMem += _blocks.get(blockId).memSize _blocks.remove(blockId) @@ -395,7 +392,7 @@ object BlockManagerMasterActor { def lastSeenMs: Long = _lastSeenMs - def blocks: JHashMap[String, BlockStatus] = _blocks + def blocks: JHashMap[BlockId, BlockStatus] = _blocks override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 24333a179c..45f51da288 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -30,7 +30,7 @@ private[storage] object BlockManagerMessages { // Remove a block from the slaves that have it. This can only be used to remove // blocks that the master knows about. - case class RemoveBlock(blockId: String) extends ToBlockManagerSlave + case class RemoveBlock(blockId: BlockId) extends ToBlockManagerSlave // Remove all blocks belonging to a specific RDD. case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave @@ -51,7 +51,7 @@ private[storage] object BlockManagerMessages { class UpdateBlockInfo( var blockManagerId: BlockManagerId, - var blockId: String, + var blockId: BlockId, var storageLevel: StorageLevel, var memSize: Long, var diskSize: Long) @@ -62,7 +62,7 @@ private[storage] object BlockManagerMessages { override def writeExternal(out: ObjectOutput) { blockManagerId.writeExternal(out) - out.writeUTF(blockId) + out.writeUTF(blockId.name) storageLevel.writeExternal(out) out.writeLong(memSize) out.writeLong(diskSize) @@ -70,7 +70,7 @@ private[storage] object BlockManagerMessages { override def readExternal(in: ObjectInput) { blockManagerId = BlockManagerId(in) - blockId = in.readUTF() + blockId = BlockId(in.readUTF()) storageLevel = StorageLevel(in) memSize = in.readLong() diskSize = in.readLong() @@ -79,7 +79,7 @@ private[storage] object BlockManagerMessages { object UpdateBlockInfo { def apply(blockManagerId: BlockManagerId, - blockId: String, + blockId: BlockId, storageLevel: StorageLevel, memSize: Long, diskSize: Long): UpdateBlockInfo = { @@ -87,14 +87,14 @@ private[storage] object BlockManagerMessages { } // For pattern-matching - def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { + def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, BlockId, StorageLevel, Long, Long)] = { Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) } } - case class GetLocations(blockId: String) extends ToBlockManagerMaster + case class GetLocations(blockId: BlockId) extends ToBlockManagerMaster - case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster + case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 951503019f..3a65e55733 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -26,6 +26,7 @@ import org.apache.spark.storage.BlockManagerMessages._ * An actor to take commands from the master to execute options. For example, * this is used to remove blocks from the slave's BlockManager. */ +private[storage] class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor { override def receive = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index 678c38203c..0c66addf9d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -77,7 +77,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends } } - private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) { + private def putBlock(id: BlockId, bytes: ByteBuffer, level: StorageLevel) { val startTimeMs = System.currentTimeMillis() logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes) blockManager.putBytes(id, bytes, level) @@ -85,7 +85,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends + " with data size: " + bytes.limit) } - private def getBlock(id: String): ByteBuffer = { + private def getBlock(id: BlockId): ByteBuffer = { val startTimeMs = System.currentTimeMillis() logDebug("GetBlock " + id + " started from " + startTimeMs) val buffer = blockManager.getLocalBytes(id) match { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala index d8fa6a91d1..80dcb5a207 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.network._ -private[spark] case class GetBlock(id: String) -private[spark] case class GotBlock(id: String, data: ByteBuffer) -private[spark] case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel) +private[spark] case class GetBlock(id: BlockId) +private[spark] case class GotBlock(id: BlockId, data: ByteBuffer) +private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel) private[spark] class BlockMessage() { // Un-initialized: typ = 0 @@ -34,7 +34,7 @@ private[spark] class BlockMessage() { // GotBlock: typ = 2 // PutBlock: typ = 3 private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED - private var id: String = null + private var id: BlockId = null private var data: ByteBuffer = null private var level: StorageLevel = null @@ -74,7 +74,7 @@ private[spark] class BlockMessage() { for (i <- 1 to idLength) { idBuilder += buffer.getChar() } - id = idBuilder.toString() + id = BlockId(idBuilder.toString) if (typ == BlockMessage.TYPE_PUT_BLOCK) { @@ -109,28 +109,17 @@ private[spark] class BlockMessage() { set(buffer) } - def getType: Int = { - return typ - } - - def getId: String = { - return id - } - - def getData: ByteBuffer = { - return data - } - - def getLevel: StorageLevel = { - return level - } + def getType: Int = typ + def getId: BlockId = id + def getData: ByteBuffer = data + def getLevel: StorageLevel = level def toBufferMessage: BufferMessage = { val startTime = System.currentTimeMillis val buffers = new ArrayBuffer[ByteBuffer]() - var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2) - buffer.putInt(typ).putInt(id.length()) - id.foreach((x: Char) => buffer.putChar(x)) + var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2) + buffer.putInt(typ).putInt(id.name.length) + id.name.foreach((x: Char) => buffer.putChar(x)) buffer.flip() buffers += buffer @@ -212,7 +201,8 @@ private[spark] object BlockMessage { def main(args: Array[String]) { val B = new BlockMessage() - B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2)) + val blockId = TestBlockId("ABC") + B.set(new PutBlock(blockId, ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2)) val bMsg = B.toBufferMessage val C = new BlockMessage() C.set(bMsg) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala index 0aaf846b5b..6ce9127c74 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala @@ -111,14 +111,15 @@ private[spark] object BlockMessageArray { } def main(args: Array[String]) { - val blockMessages = + val blockMessages = (0 until 10).map { i => if (i % 2 == 0) { val buffer = ByteBuffer.allocate(100) buffer.clear - BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY_SER)) + BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer, + StorageLevel.MEMORY_ONLY_SER)) } else { - BlockMessage.fromGetBlock(GetBlock(i.toString)) + BlockMessage.fromGetBlock(GetBlock(TestBlockId(i.toString))) } } val blockMessageArray = new BlockMessageArray(blockMessages) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 39f103297f..469e68fed7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -17,6 +17,13 @@ package org.apache.spark.storage +import java.io.{FileOutputStream, File, OutputStream} +import java.nio.channels.FileChannel + +import it.unimi.dsi.fastutil.io.FastBufferedOutputStream + +import org.apache.spark.Logging +import org.apache.spark.serializer.{SerializationStream, Serializer} /** * An interface for writing JVM objects to some underlying storage. This interface allows @@ -25,22 +32,14 @@ package org.apache.spark.storage * * This interface does not support concurrent writes. */ -abstract class BlockObjectWriter(val blockId: String) { - - var closeEventHandler: () => Unit = _ +abstract class BlockObjectWriter(val blockId: BlockId) { def open(): BlockObjectWriter - def close() { - closeEventHandler() - } + def close() def isOpen: Boolean - def registerCloseEventHandler(handler: () => Unit) { - closeEventHandler = handler - } - /** * Flush the partial writes and commit them as a single atomic block. Return the * number of bytes written for this commit. @@ -59,7 +58,126 @@ abstract class BlockObjectWriter(val blockId: String) { def write(value: Any) /** - * Size of the valid writes, in bytes. + * Returns the file segment of committed data that this Writer has written. + */ + def fileSegment(): FileSegment + + /** + * Cumulative time spent performing blocking writes, in ns. */ - def size(): Long + def timeWriting(): Long +} + +/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */ +class DiskBlockObjectWriter( + blockId: BlockId, + file: File, + serializer: Serializer, + bufferSize: Int, + compressStream: OutputStream => OutputStream) + extends BlockObjectWriter(blockId) + with Logging +{ + + /** Intercepts write calls and tracks total time spent writing. Not thread safe. */ + private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream { + def timeWriting = _timeWriting + private var _timeWriting = 0L + + private def callWithTiming(f: => Unit) = { + val start = System.nanoTime() + f + _timeWriting += (System.nanoTime() - start) + } + + def write(i: Int): Unit = callWithTiming(out.write(i)) + override def write(b: Array[Byte]) = callWithTiming(out.write(b)) + override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len)) + } + + private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean + + /** The file channel, used for repositioning / truncating the file. */ + private var channel: FileChannel = null + private var bs: OutputStream = null + private var fos: FileOutputStream = null + private var ts: TimeTrackingOutputStream = null + private var objOut: SerializationStream = null + private val initialPosition = file.length() + private var lastValidPosition = initialPosition + private var initialized = false + private var _timeWriting = 0L + + override def open(): BlockObjectWriter = { + fos = new FileOutputStream(file, true) + ts = new TimeTrackingOutputStream(fos) + channel = fos.getChannel() + lastValidPosition = initialPosition + bs = compressStream(new FastBufferedOutputStream(ts, bufferSize)) + objOut = serializer.newInstance().serializeStream(bs) + initialized = true + this + } + + override def close() { + if (initialized) { + if (syncWrites) { + // Force outstanding writes to disk and track how long it takes + objOut.flush() + val start = System.nanoTime() + fos.getFD.sync() + _timeWriting += System.nanoTime() - start + } + objOut.close() + + _timeWriting += ts.timeWriting + + channel = null + bs = null + fos = null + ts = null + objOut = null + } + } + + override def isOpen: Boolean = objOut != null + + override def commit(): Long = { + if (initialized) { + // NOTE: Flush the serializer first and then the compressed/buffered output stream + objOut.flush() + bs.flush() + val prevPos = lastValidPosition + lastValidPosition = channel.position() + lastValidPosition - prevPos + } else { + // lastValidPosition is zero if stream is uninitialized + lastValidPosition + } + } + + override def revertPartialWrites() { + if (initialized) { + // Discard current writes. We do this by flushing the outstanding writes and + // truncate the file to the last valid position. + objOut.flush() + bs.flush() + channel.truncate(lastValidPosition) + } + } + + override def write(value: Any) { + if (!initialized) { + open() + } + objOut.writeObject(value) + } + + override def fileSegment(): FileSegment = { + val bytesWritten = lastValidPosition - initialPosition + new FileSegment(file, initialPosition, bytesWritten) + } + + // Only valid if called after close() + override def timeWriting() = _timeWriting } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala index fa834371f4..ea42656240 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala @@ -27,7 +27,7 @@ import org.apache.spark.Logging */ private[spark] abstract class BlockStore(val blockManager: BlockManager) extends Logging { - def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) + def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) /** * Put in a block and, possibly, also return its content as either bytes or another Iterator. @@ -36,26 +36,26 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging { * @return a PutResult that contains the size of the data, as well as the values put if * returnValues is true (if not, the result's data field can be null) */ - def putValues(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, + def putValues(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, returnValues: Boolean) : PutResult /** * Return the size of a block in bytes. */ - def getSize(blockId: String): Long + def getSize(blockId: BlockId): Long - def getBytes(blockId: String): Option[ByteBuffer] + def getBytes(blockId: BlockId): Option[ByteBuffer] - def getValues(blockId: String): Option[Iterator[Any]] + def getValues(blockId: BlockId): Option[Iterator[Any]] /** * Remove a block, if it exists. * @param blockId the block to remove. * @return True if the block was found and removed, False otherwise. */ - def remove(blockId: String): Boolean + def remove(blockId: BlockId): Boolean - def contains(blockId: String): Boolean + def contains(blockId: BlockId): Boolean def clear() { } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala new file mode 100644 index 0000000000..fcd2e97982 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.File +import java.text.SimpleDateFormat +import java.util.{Date, Random} + +import org.apache.spark.Logging +import org.apache.spark.executor.ExecutorExitCode +import org.apache.spark.network.netty.{PathResolver, ShuffleSender} +import org.apache.spark.util.Utils + +/** + * Creates and maintains the logical mapping between logical blocks and physical on-disk + * locations. By default, one block is mapped to one file with a name given by its BlockId. + * However, it is also possible to have a block map to only a segment of a file, by calling + * mapBlockToFileSegment(). + * + * @param rootDirs The directories to use for storing block files. Data will be hashed among these. + */ +private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootDirs: String) + extends PathResolver with Logging { + + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 + private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt + + // Create one local directory for each path mentioned in spark.local.dir; then, inside this + // directory, create multiple subdirectories that we will hash files into, in order to avoid + // having really large inodes at the top level. + private val localDirs: Array[File] = createLocalDirs() + private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) + private var shuffleSender : ShuffleSender = null + + addShutdownHook() + + /** + * Returns the phyiscal file segment in which the given BlockId is located. + * If the BlockId has been mapped to a specific FileSegment, that will be returned. + * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly. + */ + def getBlockLocation(blockId: BlockId): FileSegment = { + if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) { + shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId]) + } else { + val file = getFile(blockId.name) + new FileSegment(file, 0, file.length()) + } + } + + def getFile(filename: String): File = { + // Figure out which local directory it hashes to, and which subdirectory in that + val hash = Utils.nonNegativeHash(filename) + val dirId = hash % localDirs.length + val subDirId = (hash / localDirs.length) % subDirsPerLocalDir + + // Create the subdirectory if it doesn't already exist + var subDir = subDirs(dirId)(subDirId) + if (subDir == null) { + subDir = subDirs(dirId).synchronized { + val old = subDirs(dirId)(subDirId) + if (old != null) { + old + } else { + val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) + newDir.mkdir() + subDirs(dirId)(subDirId) = newDir + newDir + } + } + } + + new File(subDir, filename) + } + + def getFile(blockId: BlockId): File = getFile(blockId.name) + + private def createLocalDirs(): Array[File] = { + logDebug("Creating local directories at root dirs '" + rootDirs + "'") + val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + rootDirs.split(",").map { rootDir => + var foundLocalDir = false + var localDir: File = null + var localDirId: String = null + var tries = 0 + val rand = new Random() + while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) { + tries += 1 + try { + localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) + localDir = new File(rootDir, "spark-local-" + localDirId) + if (!localDir.exists) { + foundLocalDir = localDir.mkdirs() + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e) + } + } + if (!foundLocalDir) { + logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + + " attempts to create local dir in " + rootDir) + System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) + } + logInfo("Created local directory at " + localDir) + localDir + } + } + + private def addShutdownHook() { + localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir)) + Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { + override def run() { + logDebug("Shutdown hook called") + localDirs.foreach { localDir => + try { + if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + } catch { + case t: Throwable => + logError("Exception while deleting local spark dir: " + localDir, t) + } + } + + if (shuffleSender != null) { + shuffleSender.stop() + } + } + }) + } + + private[storage] def startShuffleBlockSender(port: Int): Int = { + shuffleSender = new ShuffleSender(port, this) + logInfo("Created ShuffleSender binding to port : " + shuffleSender.port) + shuffleSender.port + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 63447baf8c..5a1e7b4444 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -17,153 +17,46 @@ package org.apache.spark.storage -import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile} +import java.io.{FileOutputStream, RandomAccessFile} import java.nio.ByteBuffer -import java.nio.channels.FileChannel import java.nio.channels.FileChannel.MapMode -import java.util.{Random, Date} -import java.text.SimpleDateFormat import scala.collection.mutable.ArrayBuffer -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream - -import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.serializer.{Serializer, SerializationStream} import org.apache.spark.Logging -import org.apache.spark.network.netty.ShuffleSender -import org.apache.spark.network.netty.PathResolver +import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils /** * Stores BlockManager blocks on disk. */ -private class DiskStore(blockManager: BlockManager, rootDirs: String) +private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager) extends BlockStore(blockManager) with Logging { - class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int) - extends BlockObjectWriter(blockId) { - - private val f: File = createFile(blockId /*, allowAppendExisting */) - - // The file channel, used for repositioning / truncating the file. - private var channel: FileChannel = null - private var bs: OutputStream = null - private var objOut: SerializationStream = null - private var lastValidPosition = 0L - private var initialized = false - - override def open(): DiskBlockObjectWriter = { - val fos = new FileOutputStream(f, true) - channel = fos.getChannel() - bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize)) - objOut = serializer.newInstance().serializeStream(bs) - initialized = true - this - } - - override def close() { - if (initialized) { - objOut.close() - channel = null - bs = null - objOut = null - } - // Invoke the close callback handler. - super.close() - } - - override def isOpen: Boolean = objOut != null - - // Flush the partial writes, and set valid length to be the length of the entire file. - // Return the number of bytes written for this commit. - override def commit(): Long = { - if (initialized) { - // NOTE: Flush the serializer first and then the compressed/buffered output stream - objOut.flush() - bs.flush() - val prevPos = lastValidPosition - lastValidPosition = channel.position() - lastValidPosition - prevPos - } else { - // lastValidPosition is zero if stream is uninitialized - lastValidPosition - } - } - - override def revertPartialWrites() { - if (initialized) { - // Discard current writes. We do this by flushing the outstanding writes and - // truncate the file to the last valid position. - objOut.flush() - bs.flush() - channel.truncate(lastValidPosition) - } - } - - override def write(value: Any) { - if (!initialized) { - open() - } - objOut.writeObject(value) - } - - override def size(): Long = lastValidPosition - } - - private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt - - private var shuffleSender : ShuffleSender = null - // Create one local directory for each path mentioned in spark.local.dir; then, inside this - // directory, create multiple subdirectories that we will hash files into, in order to avoid - // having really large inodes at the top level. - private val localDirs: Array[File] = createLocalDirs() - private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) - - addShutdownHook() - - def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) - : BlockObjectWriter = { - new DiskBlockObjectWriter(blockId, serializer, bufferSize) + override def getSize(blockId: BlockId): Long = { + diskManager.getBlockLocation(blockId).length } - override def getSize(blockId: String): Long = { - getFile(blockId).length() - } - - override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) { // So that we do not modify the input offsets ! // duplicate does not copy buffer, so inexpensive val bytes = _bytes.duplicate() logDebug("Attempting to put block " + blockId) val startTime = System.currentTimeMillis - val file = createFile(blockId) - val channel = new RandomAccessFile(file, "rw").getChannel() + val file = diskManager.getFile(blockId) + val channel = new FileOutputStream(file).getChannel() while (bytes.remaining > 0) { channel.write(bytes) } channel.close() val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file on disk in %d ms".format( - blockId, Utils.bytesToString(bytes.limit), (finishTime - startTime))) - } - - private def getFileBytes(file: File): ByteBuffer = { - val length = file.length() - val channel = new RandomAccessFile(file, "r").getChannel() - val buffer = try { - channel.map(MapMode.READ_ONLY, 0, length) - } finally { - channel.close() - } - - buffer + file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime))) } override def putValues( - blockId: String, + blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, returnValues: Boolean) @@ -171,159 +64,62 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) logDebug("Attempting to write values for block " + blockId) val startTime = System.currentTimeMillis - val file = createFile(blockId) - val fileOut = blockManager.wrapForCompression(blockId, - new FastBufferedOutputStream(new FileOutputStream(file))) - val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut) - objOut.writeAll(values.iterator) - objOut.close() - val length = file.length() + val file = diskManager.getFile(blockId) + val outputStream = new FileOutputStream(file) + blockManager.dataSerializeStream(blockId, outputStream, values.iterator) + val length = file.length val timeTaken = System.currentTimeMillis - startTime logDebug("Block %s stored as %s file on disk in %d ms".format( - blockId, Utils.bytesToString(length), timeTaken)) + file.getName, Utils.bytesToString(length), timeTaken)) if (returnValues) { // Return a byte buffer for the contents of the file - val buffer = getFileBytes(file) + val buffer = getBytes(blockId).get PutResult(length, Right(buffer)) } else { PutResult(length, null) } } - override def getBytes(blockId: String): Option[ByteBuffer] = { - val file = getFile(blockId) - val bytes = getFileBytes(file) - Some(bytes) + override def getBytes(blockId: BlockId): Option[ByteBuffer] = { + val segment = diskManager.getBlockLocation(blockId) + val channel = new RandomAccessFile(segment.file, "r").getChannel() + val buffer = try { + channel.map(MapMode.READ_ONLY, segment.offset, segment.length) + } finally { + channel.close() + } + Some(buffer) } - override def getValues(blockId: String): Option[Iterator[Any]] = { - getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes)) + override def getValues(blockId: BlockId): Option[Iterator[Any]] = { + getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) } /** * A version of getValues that allows a custom serializer. This is used as part of the * shuffle short-circuit code. */ - def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { + def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer)) } - override def remove(blockId: String): Boolean = { - val file = getFile(blockId) - if (file.exists()) { + override def remove(blockId: BlockId): Boolean = { + val fileSegment = diskManager.getBlockLocation(blockId) + val file = fileSegment.file + if (file.exists() && file.length() == fileSegment.length) { file.delete() } else { - false - } - } - - override def contains(blockId: String): Boolean = { - getFile(blockId).exists() - } - - private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { - val file = getFile(blockId) - if (!allowAppendExisting && file.exists()) { - // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task - // was rescheduled on the same machine as the old task. - logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") - file.delete() - } - file - } - - private def getFile(blockId: String): File = { - logDebug("Getting file for block " + blockId) - - // Figure out which local directory it hashes to, and which subdirectory in that - val hash = Utils.nonNegativeHash(blockId) - val dirId = hash % localDirs.length - val subDirId = (hash / localDirs.length) % subDirsPerLocalDir - - // Create the subdirectory if it doesn't already exist - var subDir = subDirs(dirId)(subDirId) - if (subDir == null) { - subDir = subDirs(dirId).synchronized { - val old = subDirs(dirId)(subDirId) - if (old != null) { - old - } else { - val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) - newDir.mkdir() - subDirs(dirId)(subDirId) = newDir - newDir - } - } - } - - new File(subDir, blockId) - } - - private def createLocalDirs(): Array[File] = { - logDebug("Creating local directories at root dirs '" + rootDirs + "'") - val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - rootDirs.split(",").map { rootDir => - var foundLocalDir = false - var localDir: File = null - var localDirId: String = null - var tries = 0 - val rand = new Random() - while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) { - tries += 1 - try { - localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - localDir = new File(rootDir, "spark-local-" + localDirId) - if (!localDir.exists) { - foundLocalDir = localDir.mkdirs() - } - } catch { - case e: Exception => - logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e) - } + if (fileSegment.length < file.length()) { + logWarning("Could not delete block associated with only a part of a file: " + blockId) } - if (!foundLocalDir) { - logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + - " attempts to create local dir in " + rootDir) - System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) - } - logInfo("Created local directory at " + localDir) - localDir + false } } - private def addShutdownHook() { - localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir)) - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { - override def run() { - logDebug("Shutdown hook called") - localDirs.foreach { localDir => - try { - if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) - } catch { - case t: Throwable => - logError("Exception while deleting local spark dir: " + localDir, t) - } - } - if (shuffleSender != null) { - shuffleSender.stop - } - } - }) - } - - private[storage] def startShuffleBlockSender(port: Int): Int = { - val pResolver = new PathResolver { - override def getAbsolutePath(blockId: String): String = { - if (!blockId.startsWith("shuffle_")) { - return null - } - DiskStore.this.getFile(blockId).getAbsolutePath() - } - } - shuffleSender = new ShuffleSender(port, pResolver) - logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port) - shuffleSender.port + override def contains(blockId: BlockId): Boolean = { + val file = diskManager.getBlockLocation(blockId).file + file.exists() } } diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala new file mode 100644 index 0000000000..555486830a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.File + +/** + * References a particular segment of a file (potentially the entire file), + * based off an offset and a length. + */ +private[spark] class FileSegment(val file: File, val offset: Long, val length : Long) { + override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) +} diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 77a39c71ed..05f676c6e2 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -32,7 +32,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) case class Entry(value: Any, size: Long, deserialized: Boolean) - private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true) + private val entries = new LinkedHashMap[BlockId, Entry](32, 0.75f, true) @volatile private var currentMemory = 0L // Object used to ensure that only one thread is putting blocks and if necessary, dropping // blocks from the memory store. @@ -42,13 +42,13 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) def freeMemory: Long = maxMemory - currentMemory - override def getSize(blockId: String): Long = { + override def getSize(blockId: BlockId): Long = { entries.synchronized { entries.get(blockId).size } } - override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) { // Work on a duplicate - since the original input might be used elsewhere. val bytes = _bytes.duplicate() bytes.rewind() @@ -64,7 +64,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def putValues( - blockId: String, + blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, returnValues: Boolean) @@ -81,7 +81,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def getBytes(blockId: String): Option[ByteBuffer] = { + override def getBytes(blockId: BlockId): Option[ByteBuffer] = { val entry = entries.synchronized { entries.get(blockId) } @@ -94,7 +94,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def getValues(blockId: String): Option[Iterator[Any]] = { + override def getValues(blockId: BlockId): Option[Iterator[Any]] = { val entry = entries.synchronized { entries.get(blockId) } @@ -108,7 +108,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def remove(blockId: String): Boolean = { + override def remove(blockId: BlockId): Boolean = { entries.synchronized { val entry = entries.remove(blockId) if (entry != null) { @@ -131,14 +131,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } /** - * Return the RDD ID that a given block ID is from, or null if it is not an RDD block. + * Return the RDD ID that a given block ID is from, or None if it is not an RDD block. */ - private def getRddId(blockId: String): String = { - if (blockId.startsWith("rdd_")) { - blockId.split('_')(1) - } else { - null - } + private def getRddId(blockId: BlockId): Option[Int] = { + blockId.asRDDId.map(_.rddId) } /** @@ -151,7 +147,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * blocks to free memory for one block, another thread may use up the freed space for * another block. */ - private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = { + private def tryToPut(blockId: BlockId, value: Any, size: Long, deserialized: Boolean): Boolean = { // TODO: Its possible to optimize the locking by locking entries only when selecting blocks // to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been // released, it must be ensured that those to-be-dropped blocks are not double counted for @@ -195,7 +191,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Assumes that a lock is held by the caller to ensure only one thread is dropping blocks. * Otherwise, the freed space may fill up before the caller puts in their new value. */ - private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = { + private def ensureFreeSpace(blockIdToAdd: BlockId, space: Long): Boolean = { logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( space, currentMemory, maxMemory)) @@ -207,7 +203,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) if (maxMemory - currentMemory < space) { val rddToAdd = getRddId(blockIdToAdd) - val selectedBlocks = new ArrayBuffer[String]() + val selectedBlocks = new ArrayBuffer[BlockId]() var selectedMemory = 0L // This is synchronized to ensure that the set of entries is not changed @@ -218,7 +214,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { val pair = iterator.next() val blockId = pair.getKey - if (rddToAdd != null && rddToAdd == getRddId(blockId)) { + if (rddToAdd != None && rddToAdd == getRddId(blockId)) { logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + "block from the same RDD") return false @@ -252,7 +248,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) return true } - override def contains(blockId: String): Boolean = { + override def contains(blockId: BlockId): Boolean = { entries.synchronized { entries.containsKey(blockId) } } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 9da11efb57..2f1b049ce4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -17,51 +17,199 @@ package org.apache.spark.storage -import org.apache.spark.serializer.Serializer +import java.io.File +import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicInteger +import scala.collection.JavaConversions._ -private[spark] -class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter]) +import org.apache.spark.serializer.Serializer +import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap} +import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} +import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup +/** A group of writers for a ShuffleMapTask, one writer per reducer. */ +private[spark] trait ShuffleWriterGroup { + val writers: Array[BlockObjectWriter] -private[spark] -trait ShuffleBlocks { - def acquireWriters(mapId: Int): ShuffleWriterGroup - def releaseWriters(group: ShuffleWriterGroup) + /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */ + def releaseWriters(success: Boolean) } - +/** + * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file + * per reducer (this set of files is called a ShuffleFileGroup). + * + * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle + * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer + * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle + * files, it releases them for another task. + * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple: + * - shuffleId: The unique id given to the entire shuffle stage. + * - bucketId: The id of the output partition (i.e., reducer id) + * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a + * time owns a particular fileId, and this id is returned to a pool when the task finishes. + * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length) + * that specifies where in a given file the actual block data is located. + * + * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping + * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for + * each block stored in each file. In order to find the location of a shuffle block, we search the + * files within a ShuffleFileGroups associated with the block's reducer. + */ private[spark] class ShuffleBlockManager(blockManager: BlockManager) { + // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. + // TODO: Remove this once the shuffle file consolidation feature is stable. + val consolidateShuffleFiles = + System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean + + private val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 + + /** + * Contains all the state related to a particular shuffle. This includes a pool of unused + * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle. + */ + private class ShuffleState() { + val nextFileId = new AtomicInteger(0) + val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() + val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]() + } - def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = { - new ShuffleBlocks { - // Get a group of writers for a map task. - override def acquireWriters(mapId: Int): ShuffleWriterGroup = { - val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 - val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) - blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) + type ShuffleId = Int + private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState] + + private + val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup) + + def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = { + new ShuffleWriterGroup { + shuffleStates.putIfAbsent(shuffleId, new ShuffleState()) + private val shuffleState = shuffleStates(shuffleId) + private var fileGroup: ShuffleFileGroup = null + + val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) { + fileGroup = getUnusedFileGroup() + Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) + blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize) + } + } else { + Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => + val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) + val blockFile = blockManager.diskBlockManager.getFile(blockId) + blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize) + } + } + + override def releaseWriters(success: Boolean) { + if (consolidateShuffleFiles) { + if (success) { + val offsets = writers.map(_.fileSegment().offset) + fileGroup.recordMapOutput(mapId, offsets) + } + recycleFileGroup(fileGroup) } - new ShuffleWriterGroup(mapId, writers) } - override def releaseWriters(group: ShuffleWriterGroup) = { - // Nothing really to release here. + private def getUnusedFileGroup(): ShuffleFileGroup = { + val fileGroup = shuffleState.unusedFileGroups.poll() + if (fileGroup != null) fileGroup else newFileGroup() + } + + private def newFileGroup(): ShuffleFileGroup = { + val fileId = shuffleState.nextFileId.getAndIncrement() + val files = Array.tabulate[File](numBuckets) { bucketId => + val filename = physicalFileName(shuffleId, bucketId, fileId) + blockManager.diskBlockManager.getFile(filename) + } + val fileGroup = new ShuffleFileGroup(fileId, shuffleId, files) + shuffleState.allFileGroups.add(fileGroup) + fileGroup + } + + private def recycleFileGroup(group: ShuffleFileGroup) { + shuffleState.unusedFileGroups.add(group) } } } -} + /** + * Returns the physical file segment in which the given BlockId is located. + * This function should only be called if shuffle file consolidation is enabled, as it is + * an error condition if we don't find the expected block. + */ + def getBlockLocation(id: ShuffleBlockId): FileSegment = { + // Search all file groups associated with this shuffle. + val shuffleState = shuffleStates(id.shuffleId) + for (fileGroup <- shuffleState.allFileGroups) { + val segment = fileGroup.getFileSegmentFor(id.mapId, id.reduceId) + if (segment.isDefined) { return segment.get } + } + throw new IllegalStateException("Failed to find shuffle block: " + id) + } + + private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { + "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) + } + + private def cleanup(cleanupTime: Long) { + shuffleStates.clearOldValues(cleanupTime) + } +} private[spark] object ShuffleBlockManager { + /** + * A group of shuffle files, one per reducer. + * A particular mapper will be assigned a single ShuffleFileGroup to write its output to. + */ + private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) { + /** + * Stores the absolute index of each mapId in the files of this group. For instance, + * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0. + */ + private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]() - // Returns the block id for a given shuffle block. - def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = { - "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId - } + /** + * Stores consecutive offsets of blocks into each reducer file, ordered by position in the file. + * This ordering allows us to compute block lengths by examining the following block offset. + * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every + * reducer. + */ + private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { + new PrimitiveVector[Long]() + } + + def numBlocks = mapIdToIndex.size + + def apply(bucketId: Int) = files(bucketId) + + def recordMapOutput(mapId: Int, offsets: Array[Long]) { + mapIdToIndex(mapId) = numBlocks + for (i <- 0 until offsets.length) { + blockOffsetsByReducer(i) += offsets(i) + } + } - // Returns true if the block is a shuffle block. - def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_") + /** Returns the FileSegment associated with the given map task, or None if no entry exists. */ + def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = { + val file = files(reducerId) + val blockOffsets = blockOffsetsByReducer(reducerId) + val index = mapIdToIndex.getOrElse(mapId, -1) + if (index >= 0) { + val offset = blockOffsets(index) + val length = + if (index + 1 < numBlocks) { + blockOffsets(index + 1) - offset + } else { + file.length() - offset + } + assert(length >= 0) + Some(new FileSegment(file, offset, length)) + } else { + None + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala new file mode 100644 index 0000000000..1e4db4f66b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala @@ -0,0 +1,86 @@ +package org.apache.spark.storage + +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{CountDownLatch, Executors} + +import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.SparkContext +import org.apache.spark.util.Utils + +/** + * Utility for micro-benchmarking shuffle write performance. + * + * Writes simulated shuffle output from several threads and records the observed throughput. + */ +object StoragePerfTester { + def main(args: Array[String]) = { + /** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */ + val dataSizeMb = Utils.memoryStringToMb(sys.env.getOrElse("OUTPUT_DATA", "1g")) + + /** Number of map tasks. All tasks execute concurrently. */ + val numMaps = sys.env.get("NUM_MAPS").map(_.toInt).getOrElse(8) + + /** Number of reduce splits for each map task. */ + val numOutputSplits = sys.env.get("NUM_REDUCERS").map(_.toInt).getOrElse(500) + + val recordLength = 1000 // ~1KB records + val totalRecords = dataSizeMb * 1000 + val recordsPerMap = totalRecords / numMaps + + val writeData = "1" * recordLength + val executor = Executors.newFixedThreadPool(numMaps) + + System.setProperty("spark.shuffle.compress", "false") + System.setProperty("spark.shuffle.sync", "true") + + // This is only used to instantiate a BlockManager. All thread scheduling is done manually. + val sc = new SparkContext("local[4]", "Write Tester") + val blockManager = sc.env.blockManager + + def writeOutputBytes(mapId: Int, total: AtomicLong) = { + val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits, + new KryoSerializer()) + val writers = shuffle.writers + for (i <- 1 to recordsPerMap) { + writers(i % numOutputSplits).write(writeData) + } + writers.map {w => + w.commit() + total.addAndGet(w.fileSegment().length) + w.close() + } + + shuffle.releaseWriters(true) + } + + val start = System.currentTimeMillis() + val latch = new CountDownLatch(numMaps) + val totalBytes = new AtomicLong() + for (task <- 1 to numMaps) { + executor.submit(new Runnable() { + override def run() = { + try { + writeOutputBytes(task, totalBytes) + latch.countDown() + } catch { + case e: Exception => + println("Exception in child thread: " + e + " " + e.getMessage) + System.exit(1) + } + } + }) + } + latch.await() + val end = System.currentTimeMillis() + val time = (end - start) / 1000.0 + val bytesPerSecond = totalBytes.get() / time + val bytesPerFile = (totalBytes.get() / (numOutputSplits * numMaps.toDouble)).toLong + + System.err.println("files_total\t\t%s".format(numMaps * numOutputSplits)) + System.err.println("bytes_per_file\t\t%s".format(Utils.bytesToString(bytesPerFile))) + System.err.println("agg_throughput\t\t%s/s".format(Utils.bytesToString(bytesPerSecond.toLong))) + + executor.shutdown() + sc.stop() + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 2bb7715696..1720007e4e 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -23,20 +23,24 @@ import org.apache.spark.util.Utils private[spark] case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, - blocks: Map[String, BlockStatus]) { + blocks: Map[BlockId, BlockStatus]) { - def memUsed(blockPrefix: String = "") = { - blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize). - reduceOption(_+_).getOrElse(0l) - } + def memUsed() = blocks.values.map(_.memSize).reduceOption(_+_).getOrElse(0L) - def diskUsed(blockPrefix: String = "") = { - blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize). - reduceOption(_+_).getOrElse(0l) - } + def memUsedByRDD(rddId: Int) = + rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_+_).getOrElse(0L) + + def diskUsed() = blocks.values.map(_.diskSize).reduceOption(_+_).getOrElse(0L) + + def diskUsedByRDD(rddId: Int) = + rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_+_).getOrElse(0L) def memRemaining : Long = maxMem - memUsed() + def rddBlocks = blocks.flatMap { + case (rdd: RDDBlockId, status) => Some(rdd, status) + case _ => None + } } case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, @@ -60,7 +64,7 @@ object StorageUtils { /* Returns RDD-level information, compiled from a list of StorageStatus objects */ def rddInfoFromStorageStatus(storageStatusList: Seq[StorageStatus], sc: SparkContext) : Array[RDDInfo] = { - rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc) + rddInfoFromBlockStatusList(storageStatusList.flatMap(_.rddBlocks).toMap[RDDBlockId, BlockStatus], sc) } /* Returns a map of blocks to their locations, compiled from a list of StorageStatus objects */ @@ -71,26 +75,21 @@ object StorageUtils { } /* Given a list of BlockStatus objets, returns information for each RDD */ - def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus], + def rddInfoFromBlockStatusList(infos: Map[RDDBlockId, BlockStatus], sc: SparkContext) : Array[RDDInfo] = { // Group by rddId, ignore the partition name - val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) => - k.substring(0,k.lastIndexOf('_')) - }.mapValues(_.values.toArray) + val groupedRddBlocks = infos.groupBy { case(k, v) => k.rddId }.mapValues(_.values.toArray) // For each RDD, generate an RDDInfo object - val rddInfos = groupedRddBlocks.map { case (rddKey, rddBlocks) => + val rddInfos = groupedRddBlocks.map { case (rddId, rddBlocks) => // Add up memory and disk sizes val memSize = rddBlocks.map(_.memSize).reduce(_ + _) val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _) - // Find the id of the RDD, e.g. rdd_1 => 1 - val rddId = rddKey.split("_").last.toInt - // Get the friendly name and storage level for the RDD, if available sc.persistentRdds.get(rddId).map { r => - val rddName = Option(r.name).getOrElse(rddKey) + val rddName = Option(r.name).getOrElse(rddId.toString) val rddStorageLevel = r.getStorageLevel RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize) } @@ -101,16 +100,14 @@ object StorageUtils { rddInfos } - /* Removes all BlockStatus object that are not part of a block prefix */ - def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus], - prefix: String) : Array[StorageStatus] = { + /* Filters storage status by a given RDD id. */ + def filterStorageStatusByRDD(storageStatusList: Array[StorageStatus], rddId: Int) + : Array[StorageStatus] = { storageStatusList.map { status => - val newBlocks = status.blocks.filterKeys(_.startsWith(prefix)) + val newBlocks = status.rddBlocks.filterKeys(_.rddId == rddId).toMap[BlockId, BlockStatus] //val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _) StorageStatus(status.blockManagerId, status.maxMem, newBlocks) } - } - } diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index f2ae8dd97d..860e680576 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -36,11 +36,11 @@ private[spark] object ThreadingTest { val numBlocksPerProducer = 20000 private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread { - val queue = new ArrayBlockingQueue[(String, Seq[Int])](100) + val queue = new ArrayBlockingQueue[(BlockId, Seq[Int])](100) override def run() { for (i <- 1 to numBlocksPerProducer) { - val blockId = "b-" + id + "-" + i + val blockId = TestBlockId("b-" + id + "-" + i) val blockSize = Random.nextInt(1000) val block = (1 to blockSize).map(_ => Random.nextInt()) val level = randomLevel() @@ -64,7 +64,7 @@ private[spark] object ThreadingTest { private[spark] class ConsumerThread( manager: BlockManager, - queue: ArrayBlockingQueue[(String, Seq[Int])] + queue: ArrayBlockingQueue[(BlockId, Seq[Int])] ) extends Thread { var numBlockConsumed = 0 diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index 453394dfda..fcd1b518d0 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -35,7 +35,7 @@ private[spark] object UIWorkloadGenerator { def main(args: Array[String]) { if (args.length < 2) { - println("usage: ./spark-class spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]") + println("usage: ./spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]") System.exit(1) } val master = args(0) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala index b39c0e9769..ca5a28625b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala @@ -38,7 +38,7 @@ private[spark] class IndexPage(parent: JobProgressUI) { val now = System.currentTimeMillis() var activeTime = 0L - for (tasks <- listener.stageToTasksActive.values; t <- tasks) { + for (tasks <- listener.stageIdToTasksActive.values; t <- tasks) { activeTime += t.timeRunning(now) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index eb3b4e8522..6b854740d6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -36,52 +36,52 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt val DEFAULT_POOL_NAME = "default" - val stageToPool = new HashMap[Stage, String]() - val stageToDescription = new HashMap[Stage, String]() - val poolToActiveStages = new HashMap[String, HashSet[Stage]]() + val stageIdToPool = new HashMap[Int, String]() + val stageIdToDescription = new HashMap[Int, String]() + val poolToActiveStages = new HashMap[String, HashSet[StageInfo]]() - val activeStages = HashSet[Stage]() - val completedStages = ListBuffer[Stage]() - val failedStages = ListBuffer[Stage]() + val activeStages = HashSet[StageInfo]() + val completedStages = ListBuffer[StageInfo]() + val failedStages = ListBuffer[StageInfo]() // Total metrics reflect metrics only for completed tasks var totalTime = 0L var totalShuffleRead = 0L var totalShuffleWrite = 0L - val stageToTime = HashMap[Int, Long]() - val stageToShuffleRead = HashMap[Int, Long]() - val stageToShuffleWrite = HashMap[Int, Long]() - val stageToTasksActive = HashMap[Int, HashSet[TaskInfo]]() - val stageToTasksComplete = HashMap[Int, Int]() - val stageToTasksFailed = HashMap[Int, Int]() - val stageToTaskInfos = + val stageIdToTime = HashMap[Int, Long]() + val stageIdToShuffleRead = HashMap[Int, Long]() + val stageIdToShuffleWrite = HashMap[Int, Long]() + val stageIdToTasksActive = HashMap[Int, HashSet[TaskInfo]]() + val stageIdToTasksComplete = HashMap[Int, Int]() + val stageIdToTasksFailed = HashMap[Int, Int]() + val stageIdToTaskInfos = HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]() override def onJobStart(jobStart: SparkListenerJobStart) {} override def onStageCompleted(stageCompleted: StageCompleted) = synchronized { - val stage = stageCompleted.stageInfo.stage - poolToActiveStages(stageToPool(stage)) -= stage + val stage = stageCompleted.stage + poolToActiveStages(stageIdToPool(stage.stageId)) -= stage activeStages -= stage completedStages += stage trimIfNecessary(completedStages) } /** If stages is too large, remove and garbage collect old stages */ - def trimIfNecessary(stages: ListBuffer[Stage]) = synchronized { + def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized { if (stages.size > RETAINED_STAGES) { val toRemove = RETAINED_STAGES / 10 stages.takeRight(toRemove).foreach( s => { - stageToTaskInfos.remove(s.id) - stageToTime.remove(s.id) - stageToShuffleRead.remove(s.id) - stageToShuffleWrite.remove(s.id) - stageToTasksActive.remove(s.id) - stageToTasksComplete.remove(s.id) - stageToTasksFailed.remove(s.id) - stageToPool.remove(s) - if (stageToDescription.contains(s)) {stageToDescription.remove(s)} + stageIdToTaskInfos.remove(s.stageId) + stageIdToTime.remove(s.stageId) + stageIdToShuffleRead.remove(s.stageId) + stageIdToShuffleWrite.remove(s.stageId) + stageIdToTasksActive.remove(s.stageId) + stageIdToTasksComplete.remove(s.stageId) + stageIdToTasksFailed.remove(s.stageId) + stageIdToPool.remove(s.stageId) + if (stageIdToDescription.contains(s.stageId)) {stageIdToDescription.remove(s.stageId)} }) stages.trimEnd(toRemove) } @@ -95,63 +95,69 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList val poolName = Option(stageSubmitted.properties).map { p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME) }.getOrElse(DEFAULT_POOL_NAME) - stageToPool(stage) = poolName + stageIdToPool(stage.stageId) = poolName val description = Option(stageSubmitted.properties).flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) } - description.map(d => stageToDescription(stage) = d) + description.map(d => stageIdToDescription(stage.stageId) = d) - val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[Stage]()) + val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[StageInfo]()) stages += stage } override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { val sid = taskStart.task.stageId - val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) + val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) tasksActive += taskStart.taskInfo - val taskList = stageToTaskInfos.getOrElse( + val taskList = stageIdToTaskInfos.getOrElse( sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]()) taskList += ((taskStart.taskInfo, None, None)) - stageToTaskInfos(sid) = taskList + stageIdToTaskInfos(sid) = taskList } - + + override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) + = synchronized { + // Do nothing: because we don't do a deep copy of the TaskInfo, the TaskInfo in + // stageToTaskInfos already has the updated status. + } + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { val sid = taskEnd.task.stageId - val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) + val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) tasksActive -= taskEnd.taskInfo val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) = taskEnd.reason match { case e: ExceptionFailure => - stageToTasksFailed(sid) = stageToTasksFailed.getOrElse(sid, 0) + 1 + stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1 (Some(e), e.metrics) case _ => - stageToTasksComplete(sid) = stageToTasksComplete.getOrElse(sid, 0) + 1 + stageIdToTasksComplete(sid) = stageIdToTasksComplete.getOrElse(sid, 0) + 1 (None, Option(taskEnd.taskMetrics)) } - stageToTime.getOrElseUpdate(sid, 0L) + stageIdToTime.getOrElseUpdate(sid, 0L) val time = metrics.map(m => m.executorRunTime).getOrElse(0) - stageToTime(sid) += time + stageIdToTime(sid) += time totalTime += time - stageToShuffleRead.getOrElseUpdate(sid, 0L) + stageIdToShuffleRead.getOrElseUpdate(sid, 0L) val shuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s => s.remoteBytesRead).getOrElse(0L) - stageToShuffleRead(sid) += shuffleRead + stageIdToShuffleRead(sid) += shuffleRead totalShuffleRead += shuffleRead - stageToShuffleWrite.getOrElseUpdate(sid, 0L) + stageIdToShuffleWrite.getOrElseUpdate(sid, 0L) val shuffleWrite = metrics.flatMap(m => m.shuffleWriteMetrics).map(s => s.shuffleBytesWritten).getOrElse(0L) - stageToShuffleWrite(sid) += shuffleWrite + stageIdToShuffleWrite(sid) += shuffleWrite totalShuffleWrite += shuffleWrite - val taskList = stageToTaskInfos.getOrElse( + val taskList = stageIdToTaskInfos.getOrElse( sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]()) taskList -= ((taskEnd.taskInfo, None, None)) taskList += ((taskEnd.taskInfo, metrics, failureInfo)) - stageToTaskInfos(sid) = taskList + stageIdToTaskInfos(sid) = taskList } override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { @@ -159,10 +165,15 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList case end: SparkListenerJobEnd => end.jobResult match { case JobFailed(ex, Some(stage)) => - activeStages -= stage - poolToActiveStages(stageToPool(stage)) -= stage - failedStages += stage - trimIfNecessary(failedStages) + /* If two jobs share a stage we could get this failure message twice. So we first + * check whether we've already retired this stage. */ + val stageInfo = activeStages.filter(s => s.stageId == stage.id).headOption + stageInfo.foreach {s => + activeStages -= s + poolToActiveStages(stageIdToPool(stage.id)) -= s + failedStages += s + trimIfNecessary(failedStages) + } case _ => } case _ => diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index 06810d8dbc..cfeeccda41 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -21,13 +21,13 @@ import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet import scala.xml.Node -import org.apache.spark.scheduler.{Schedulable, Stage} +import org.apache.spark.scheduler.{Schedulable, StageInfo} import org.apache.spark.ui.UIUtils /** Table showing list of pools */ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressListener) { - var poolToActiveStages: HashMap[String, HashSet[Stage]] = listener.poolToActiveStages + var poolToActiveStages: HashMap[String, HashSet[StageInfo]] = listener.poolToActiveStages def toNodeSeq(): Seq[Node] = { listener.synchronized { @@ -35,7 +35,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis } } - private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[Stage]]) => Seq[Node], + private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[StageInfo]]) => Seq[Node], rows: Seq[Schedulable] ): Seq[Node] = { <table class="table table-bordered table-striped table-condensed sortable table-fixed"> @@ -53,7 +53,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis </table> } - private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[Stage]]) + private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[StageInfo]]) : Seq[Node] = { val activeStages = poolToActiveStages.get(p.name) match { case Some(stages) => stages.size diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 163a3746ea..35b5d5fd59 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -40,7 +40,7 @@ private[spark] class StagePage(parent: JobProgressUI) { val stageId = request.getParameter("id").toInt val now = System.currentTimeMillis() - if (!listener.stageToTaskInfos.contains(stageId)) { + if (!listener.stageIdToTaskInfos.contains(stageId)) { val content = <div> <h4>Summary Metrics</h4> No tasks have started yet @@ -49,23 +49,23 @@ private[spark] class StagePage(parent: JobProgressUI) { return headerSparkPage(content, parent.sc, "Details for Stage %s".format(stageId), Stages) } - val tasks = listener.stageToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime) + val tasks = listener.stageIdToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime) val numCompleted = tasks.count(_._1.finished) - val shuffleReadBytes = listener.stageToShuffleRead.getOrElse(stageId, 0L) + val shuffleReadBytes = listener.stageIdToShuffleRead.getOrElse(stageId, 0L) val hasShuffleRead = shuffleReadBytes > 0 - val shuffleWriteBytes = listener.stageToShuffleWrite.getOrElse(stageId, 0L) + val shuffleWriteBytes = listener.stageIdToShuffleWrite.getOrElse(stageId, 0L) val hasShuffleWrite = shuffleWriteBytes > 0 var activeTime = 0L - listener.stageToTasksActive(stageId).foreach(activeTime += _.timeRunning(now)) + listener.stageIdToTasksActive(stageId).foreach(activeTime += _.timeRunning(now)) val summary = <div> <ul class="unstyled"> <li> <strong>CPU time: </strong> - {parent.formatDuration(listener.stageToTime.getOrElse(stageId, 0L) + activeTime)} + {parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)} </li> {if (hasShuffleRead) <li> @@ -83,10 +83,10 @@ private[spark] class StagePage(parent: JobProgressUI) { </div> val taskHeaders: Seq[String] = - Seq("Task ID", "Status", "Locality Level", "Executor", "Launch Time", "Duration") ++ - Seq("GC Time") ++ + Seq("Task Index", "Task ID", "Status", "Locality Level", "Executor", "Launch Time") ++ + Seq("Duration", "GC Time") ++ {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ - {if (hasShuffleWrite) Seq("Shuffle Write") else Nil} ++ + {if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++ Seq("Errors") val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite), tasks) @@ -153,6 +153,7 @@ private[spark] class StagePage(parent: JobProgressUI) { val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L) <tr> + <td>{info.index}</td> <td>{info.taskId}</td> <td>{info.status}</td> <td>{info.taskLocality}</td> @@ -169,6 +170,8 @@ private[spark] class StagePage(parent: JobProgressUI) { Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")}</td> }} {if (shuffleWrite) { + <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => + parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}</td> <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")}</td> }} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 07db8622da..d7d0441c38 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -22,13 +22,13 @@ import java.util.Date import scala.xml.Node import scala.collection.mutable.HashSet -import org.apache.spark.scheduler.{SchedulingMode, Stage, TaskInfo} +import org.apache.spark.scheduler.{SchedulingMode, StageInfo, TaskInfo} import org.apache.spark.ui.UIUtils import org.apache.spark.util.Utils /** Page showing list of all ongoing and recently finished stages */ -private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressUI) { +private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgressUI) { val listener = parent.listener val dateFmt = parent.dateFmt @@ -73,40 +73,40 @@ private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressU } - private def stageRow(s: Stage): Seq[Node] = { + private def stageRow(s: StageInfo): Seq[Node] = { val submissionTime = s.submissionTime match { case Some(t) => dateFmt.format(new Date(t)) case None => "Unknown" } - val shuffleRead = listener.stageToShuffleRead.getOrElse(s.id, 0L) match { + val shuffleRead = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L) match { case 0 => "" case b => Utils.bytesToString(b) } - val shuffleWrite = listener.stageToShuffleWrite.getOrElse(s.id, 0L) match { + val shuffleWrite = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L) match { case 0 => "" case b => Utils.bytesToString(b) } - val startedTasks = listener.stageToTasksActive.getOrElse(s.id, HashSet[TaskInfo]()).size - val completedTasks = listener.stageToTasksComplete.getOrElse(s.id, 0) - val failedTasks = listener.stageToTasksFailed.getOrElse(s.id, 0) match { + val startedTasks = listener.stageIdToTasksActive.getOrElse(s.stageId, HashSet[TaskInfo]()).size + val completedTasks = listener.stageIdToTasksComplete.getOrElse(s.stageId, 0) + val failedTasks = listener.stageIdToTasksFailed.getOrElse(s.stageId, 0) match { case f if f > 0 => "(%s failed)".format(f) case _ => "" } - val totalTasks = s.numPartitions + val totalTasks = s.numTasks - val poolName = listener.stageToPool.get(s) + val poolName = listener.stageIdToPool.get(s.stageId) val nameLink = - <a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.id)}>{s.name}</a> - val description = listener.stageToDescription.get(s) + <a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.stageId)}>{s.name}</a> + val description = listener.stageIdToDescription.get(s.stageId) .map(d => <div><em>{d}</em></div><div>{nameLink}</div>).getOrElse(nameLink) val finishTime = s.completionTime.getOrElse(System.currentTimeMillis()) val duration = s.submissionTime.map(t => finishTime - t) <tr> - <td>{s.id}</td> + <td>{s.stageId}</td> {if (isFairScheduler) { <td><a href={"%s/stages/pool?poolname=%s".format(UIUtils.prependBaseUri(),poolName.get)}> {poolName.get}</a></td>} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 43c1257677..b83cd54f3c 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.storage.{StorageStatus, StorageUtils} +import org.apache.spark.storage.{BlockId, StorageStatus, StorageUtils} import org.apache.spark.storage.BlockManagerMasterActor.BlockStatus import org.apache.spark.ui.UIUtils._ import org.apache.spark.ui.Page._ @@ -33,21 +33,20 @@ private[spark] class RDDPage(parent: BlockManagerUI) { val sc = parent.sc def render(request: HttpServletRequest): Seq[Node] = { - val id = request.getParameter("id") - val prefix = "rdd_" + id.toString + val id = request.getParameter("id").toInt val storageStatusList = sc.getExecutorStorageStatus - val filteredStorageStatusList = StorageUtils. - filterStorageStatusByPrefix(storageStatusList, prefix) + val filteredStorageStatusList = StorageUtils.filterStorageStatusByRDD(storageStatusList, id) val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head val workerHeaders = Seq("Host", "Memory Usage", "Disk Usage") - val workers = filteredStorageStatusList.map((prefix, _)) + val workers = filteredStorageStatusList.map((id, _)) val workerTable = listingTable(workerHeaders, workerRow, workers) val blockHeaders = Seq("Block Name", "Storage Level", "Size in Memory", "Size on Disk", "Executors") - val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1) + val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray. + sortWith(_._1.name < _._1.name) val blockLocations = StorageUtils.blockLocationsFromStorageStatus(filteredStorageStatusList) val blocks = blockStatuses.map { case(id, status) => (id, status, blockLocations.get(id).getOrElse(Seq("UNKNOWN"))) @@ -99,7 +98,7 @@ private[spark] class RDDPage(parent: BlockManagerUI) { headerSparkPage(content, parent.sc, "RDD Storage Info for " + rddInfo.name, Storage) } - def blockRow(row: (String, BlockStatus, Seq[String])): Seq[Node] = { + def blockRow(row: (BlockId, BlockStatus, Seq[String])): Seq[Node] = { val (id, block, locations) = row <tr> <td>{id}</td> @@ -118,15 +117,15 @@ private[spark] class RDDPage(parent: BlockManagerUI) { </tr> } - def workerRow(worker: (String, StorageStatus)): Seq[Node] = { - val (prefix, status) = worker + def workerRow(worker: (Int, StorageStatus)): Seq[Node] = { + val (rddId, status) = worker <tr> <td>{status.blockManagerId.host + ":" + status.blockManagerId.port}</td> <td> - {Utils.bytesToString(status.memUsed(prefix))} + {Utils.bytesToString(status.memUsedByRDD(rddId))} ({Utils.bytesToString(status.memRemaining)} Remaining) </td> - <td>{Utils.bytesToString(status.diskUsed(prefix))}</td> + <td>{Utils.bytesToString(status.diskUsedByRDD(rddId))}</td> </tr> } } diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala new file mode 100644 index 0000000000..f60deafc6f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * A simple open hash table optimized for the append-only use case, where keys + * are never removed, but the value for each key may be changed. + * + * This implementation uses quadratic probing with a power-of-2 hash table + * size, which is guaranteed to explore all spaces for each key (see + * http://en.wikipedia.org/wiki/Quadratic_probing). + * + * TODO: Cache the hash values of each key? java.util.HashMap does that. + */ +private[spark] +class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable { + require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + require(initialCapacity >= 1, "Invalid initial capacity") + + private var capacity = nextPowerOf2(initialCapacity) + private var mask = capacity - 1 + private var curSize = 0 + + // Holds keys and values in the same array for memory locality; specifically, the order of + // elements is key0, value0, key1, value1, key2, value2, etc. + private var data = new Array[AnyRef](2 * capacity) + + // Treat the null key differently so we can use nulls in "data" to represent empty items. + private var haveNullValue = false + private var nullValue: V = null.asInstanceOf[V] + + private val LOAD_FACTOR = 0.7 + + /** Get the value for a given key */ + def apply(key: K): V = { + val k = key.asInstanceOf[AnyRef] + if (k.eq(null)) { + return nullValue + } + var pos = rehash(k.hashCode) & mask + var i = 1 + while (true) { + val curKey = data(2 * pos) + if (k.eq(curKey) || k == curKey) { + return data(2 * pos + 1).asInstanceOf[V] + } else if (curKey.eq(null)) { + return null.asInstanceOf[V] + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } + } + return null.asInstanceOf[V] + } + + /** Set the value for a key */ + def update(key: K, value: V): Unit = { + val k = key.asInstanceOf[AnyRef] + if (k.eq(null)) { + if (!haveNullValue) { + incrementSize() + } + nullValue = value + haveNullValue = true + return + } + val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef]) + if (isNewEntry) { + incrementSize() + } + } + + /** + * Set the value for key to updateFunc(hadValue, oldValue), where oldValue will be the old value + * for key, if any, or null otherwise. Returns the newly updated value. + */ + def changeValue(key: K, updateFunc: (Boolean, V) => V): V = { + val k = key.asInstanceOf[AnyRef] + if (k.eq(null)) { + if (!haveNullValue) { + incrementSize() + } + nullValue = updateFunc(haveNullValue, nullValue) + haveNullValue = true + return nullValue + } + var pos = rehash(k.hashCode) & mask + var i = 1 + while (true) { + val curKey = data(2 * pos) + if (k.eq(curKey) || k == curKey) { + val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V]) + data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] + return newValue + } else if (curKey.eq(null)) { + val newValue = updateFunc(false, null.asInstanceOf[V]) + data(2 * pos) = k + data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] + incrementSize() + return newValue + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } + } + null.asInstanceOf[V] // Never reached but needed to keep compiler happy + } + + /** Iterator method from Iterable */ + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { + var pos = -1 + + /** Get the next value we should return from next(), or null if we're finished iterating */ + def nextValue(): (K, V) = { + if (pos == -1) { // Treat position -1 as looking at the null value + if (haveNullValue) { + return (null.asInstanceOf[K], nullValue) + } + pos += 1 + } + while (pos < capacity) { + if (!data(2 * pos).eq(null)) { + return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V]) + } + pos += 1 + } + null + } + + override def hasNext: Boolean = nextValue() != null + + override def next(): (K, V) = { + val value = nextValue() + if (value == null) { + throw new NoSuchElementException("End of iterator") + } + pos += 1 + value + } + } + + override def size: Int = curSize + + /** Increase table size by 1, rehashing if necessary */ + private def incrementSize() { + curSize += 1 + if (curSize > LOAD_FACTOR * capacity) { + growTable() + } + } + + /** + * Re-hash a value to deal better with hash functions that don't differ + * in the lower bits, similar to java.util.HashMap + */ + private def rehash(h: Int): Int = { + val r = h ^ (h >>> 20) ^ (h >>> 12) + r ^ (r >>> 7) ^ (r >>> 4) + } + + /** + * Put an entry into a table represented by data, returning true if + * this increases the size of the table or false otherwise. Assumes + * that "data" has at least one empty slot. + */ + private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = { + val mask = (data.length / 2) - 1 + var pos = rehash(key.hashCode) & mask + var i = 1 + while (true) { + val curKey = data(2 * pos) + if (curKey.eq(null)) { + data(2 * pos) = key + data(2 * pos + 1) = value.asInstanceOf[AnyRef] + return true + } else if (curKey.eq(key) || curKey == key) { + data(2 * pos + 1) = value.asInstanceOf[AnyRef] + return false + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } + } + return false // Never reached but needed to keep compiler happy + } + + /** Double the table's size and re-hash everything */ + private def growTable() { + val newCapacity = capacity * 2 + if (newCapacity >= (1 << 30)) { + // We can't make the table this big because we want an array of 2x + // that size for our data, but array sizes are at most Int.MaxValue + throw new Exception("Can't make capacity bigger than 2^29 elements") + } + val newData = new Array[AnyRef](2 * newCapacity) + var pos = 0 + while (pos < capacity) { + if (!data(2 * pos).eq(null)) { + putInto(newData, data(2 * pos), data(2 * pos + 1)) + } + pos += 1 + } + data = newData + capacity = newCapacity + mask = newCapacity - 1 + } + + private def nextPowerOf2(n: Int): Int = { + val highBit = Integer.highestOneBit(n) + if (highBit == n) n else highBit << 1 + } +} diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index a430a75451..67a7f87a5c 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -17,7 +17,6 @@ package org.apache.spark.util -import java.util.concurrent.{TimeUnit, ScheduledFuture, Executors} import java.util.{TimerTask, Timer} import org.apache.spark.Logging @@ -25,11 +24,14 @@ import org.apache.spark.Logging /** * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) */ -class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging { +class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, cleanupFunc: (Long) => Unit) extends Logging { + val name = cleanerType.toString + private val delaySeconds = MetadataCleaner.getDelaySeconds private val periodSeconds = math.max(10, delaySeconds / 10) private val timer = new Timer(name + " cleanup timer", true) + private val task = new TimerTask { override def run() { try { @@ -53,9 +55,38 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging } } +object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext", "HttpBroadcast", "DagScheduler", "ResultTask", + "ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") { + + val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, + SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value + + type MetadataCleanerType = Value + + def systemProperty(which: MetadataCleanerType.MetadataCleanerType) = "spark.cleaner.ttl." + which.toString +} object MetadataCleaner { + + // using only sys props for now : so that workers can also get to it while preserving earlier behavior. def getDelaySeconds = System.getProperty("spark.cleaner.ttl", "-1").toInt - def setDelaySeconds(delay: Int) { System.setProperty("spark.cleaner.ttl", delay.toString) } + + def getDelaySeconds(cleanerType: MetadataCleanerType.MetadataCleanerType): Int = { + System.getProperty(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds.toString).toInt + } + + def setDelaySeconds(cleanerType: MetadataCleanerType.MetadataCleanerType, delay: Int) { + System.setProperty(MetadataCleanerType.systemProperty(cleanerType), delay.toString) + } + + def setDelaySeconds(delay: Int, resetAll: Boolean = true) { + // override for all ? + System.setProperty("spark.cleaner.ttl", delay.toString) + if (resetAll) { + for (cleanerType <- MetadataCleanerType.values) { + System.clearProperty(MetadataCleanerType.systemProperty(cleanerType)) + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 94ce50e964..7557ddab19 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,16 +18,12 @@ package org.apache.spark.util import java.io._ -import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket} +import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address} import java.util.{Locale, Random, UUID} +import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor} -import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor} -import java.util.regex.Pattern -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} - -import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ import scala.collection.Map import scala.io.Source @@ -43,7 +39,7 @@ import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} import org.apache.spark.deploy.SparkHadoopUtil import java.nio.ByteBuffer -import org.apache.spark.{SparkEnv, SparkException, Logging} +import org.apache.spark.{SparkException, Logging} /** @@ -155,7 +151,7 @@ private[spark] object Utils extends Logging { return buf } - private val shutdownDeletePaths = new collection.mutable.HashSet[String]() + private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() // Register the path to be deleted via shutdown hook def registerShutdownDeleteDir(file: File) { @@ -287,9 +283,8 @@ private[spark] object Utils extends Logging { } case _ => // Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others - val env = SparkEnv.get val uri = new URI(url) - val conf = env.hadoop.newConfiguration() + val conf = SparkHadoopUtil.get.newConfiguration() val fs = FileSystem.get(uri, conf) val in = fs.open(new Path(uri)) val out = new FileOutputStream(tempFile) @@ -454,14 +449,17 @@ private[spark] object Utils extends Logging { hostPortParseResults.get(hostPort) } - private[spark] val daemonThreadFactory: ThreadFactory = - new ThreadFactoryBuilder().setDaemon(true).build() + private val daemonThreadFactoryBuilder: ThreadFactoryBuilder = + new ThreadFactoryBuilder().setDaemon(true) /** - * Wrapper over newCachedThreadPool. + * Wrapper over newCachedThreadPool. Thread names are formatted as prefix-ID, where ID is a + * unique, sequentially assigned integer. */ - def newDaemonCachedThreadPool(): ThreadPoolExecutor = - Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] + def newDaemonCachedThreadPool(prefix: String): ThreadPoolExecutor = { + val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build() + Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] + } /** * Return the string to tell how long has passed in seconds. The passing parameter should be in @@ -472,10 +470,13 @@ private[spark] object Utils extends Logging { } /** - * Wrapper over newFixedThreadPool. + * Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a + * unique, sequentially assigned integer. */ - def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = - Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] + def newDaemonFixedThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = { + val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build() + Executors.newFixedThreadPool(nThreads, threadFactory).asInstanceOf[ThreadPoolExecutor] + } private def listFilesSafely(file: File): Seq[File] = { val files = file.listFiles() @@ -820,4 +821,10 @@ private[spark] object Utils extends Logging { // Nothing else to guard against ? hashAbs } + + /** Returns a copy of the system properties that is thread-safe to iterator over. */ + def getSystemProperties(): Map[String, String] = { + return System.getProperties().clone() + .asInstanceOf[java.util.Properties].toMap[String, String] + } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala new file mode 100644 index 0000000000..a1a452315d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/BitSet.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + + +/** + * A simple, fixed-size bit set implementation. This implementation is fast because it avoids + * safety/bound checking. + */ +class BitSet(numBits: Int) { + + private[this] val words = new Array[Long](bit2words(numBits)) + private[this] val numWords = words.length + + /** + * Sets the bit at the specified index to true. + * @param index the bit index + */ + def set(index: Int) { + val bitmask = 1L << (index & 0x3f) // mod 64 and shift + words(index >> 6) |= bitmask // div by 64 and mask + } + + /** + * Return the value of the bit with the specified index. The value is true if the bit with + * the index is currently set in this BitSet; otherwise, the result is false. + * + * @param index the bit index + * @return the value of the bit with the specified index + */ + def get(index: Int): Boolean = { + val bitmask = 1L << (index & 0x3f) // mod 64 and shift + (words(index >> 6) & bitmask) != 0 // div by 64 and mask + } + + /** Return the number of bits set to true in this BitSet. */ + def cardinality(): Int = { + var sum = 0 + var i = 0 + while (i < numWords) { + sum += java.lang.Long.bitCount(words(i)) + i += 1 + } + sum + } + + /** + * Returns the index of the first bit that is set to true that occurs on or after the + * specified starting index. If no such bit exists then -1 is returned. + * + * To iterate over the true bits in a BitSet, use the following loop: + * + * for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i+1)) { + * // operate on index i here + * } + * + * @param fromIndex the index to start checking from (inclusive) + * @return the index of the next set bit, or -1 if there is no such bit + */ + def nextSetBit(fromIndex: Int): Int = { + var wordIndex = fromIndex >> 6 + if (wordIndex >= numWords) { + return -1 + } + + // Try to find the next set bit in the current word + val subIndex = fromIndex & 0x3f + var word = words(wordIndex) >> subIndex + if (word != 0) { + return (wordIndex << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word) + } + + // Find the next set bit in the rest of the words + wordIndex += 1 + while (wordIndex < numWords) { + word = words(wordIndex) + if (word != 0) { + return (wordIndex << 6) + java.lang.Long.numberOfTrailingZeros(word) + } + wordIndex += 1 + } + + -1 + } + + /** Return the number of longs it would take to hold numBits. */ + private def bit2words(numBits: Int) = ((numBits - 1) >> 6) + 1 +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala new file mode 100644 index 0000000000..45849b3380 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.reflect.ClassTag + + +/** + * A fast hash map implementation for nullable keys. This hash map supports insertions and updates, + * but not deletions. This map is about 5X faster than java.util.HashMap, while using much less + * space overhead. + * + * Under the hood, it uses our OpenHashSet implementation. + */ +private[spark] +class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( + initialCapacity: Int) + extends Iterable[(K, V)] + with Serializable { + + def this() = this(64) + + protected var _keySet = new OpenHashSet[K](initialCapacity) + + // Init in constructor (instead of in declaration) to work around a Scala compiler specialization + // bug that would generate two arrays (one for Object and one for specialized T). + private var _values: Array[V] = _ + _values = new Array[V](_keySet.capacity) + + @transient private var _oldValues: Array[V] = null + + // Treat the null key differently so we can use nulls in "data" to represent empty items. + private var haveNullValue = false + private var nullValue: V = null.asInstanceOf[V] + + override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size + + /** Get the value for a given key */ + def apply(k: K): V = { + if (k == null) { + nullValue + } else { + val pos = _keySet.getPos(k) + if (pos < 0) { + null.asInstanceOf[V] + } else { + _values(pos) + } + } + } + + /** Set the value for a key */ + def update(k: K, v: V) { + if (k == null) { + haveNullValue = true + nullValue = v + } else { + val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK + _values(pos) = v + _keySet.rehashIfNeeded(k, grow, move) + _oldValues = null + } + } + + /** + * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise, + * set its value to mergeValue(oldValue). + * + * @return the newly updated value. + */ + def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = { + if (k == null) { + if (haveNullValue) { + nullValue = mergeValue(nullValue) + } else { + haveNullValue = true + nullValue = defaultValue + } + nullValue + } else { + val pos = _keySet.addWithoutResize(k) + if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { + val newValue = defaultValue + _values(pos & OpenHashSet.POSITION_MASK) = newValue + _keySet.rehashIfNeeded(k, grow, move) + newValue + } else { + _values(pos) = mergeValue(_values(pos)) + _values(pos) + } + } + } + + override def iterator = new Iterator[(K, V)] { + var pos = -1 + var nextPair: (K, V) = computeNextPair() + + /** Get the next value we should return from next(), or null if we're finished iterating */ + def computeNextPair(): (K, V) = { + if (pos == -1) { // Treat position -1 as looking at the null value + if (haveNullValue) { + pos += 1 + return (null.asInstanceOf[K], nullValue) + } + pos += 1 + } + pos = _keySet.nextPos(pos) + if (pos >= 0) { + val ret = (_keySet.getValue(pos), _values(pos)) + pos += 1 + ret + } else { + null + } + } + + def hasNext = nextPair != null + + def next() = { + val pair = nextPair + nextPair = computeNextPair() + pair + } + } + + // The following member variables are declared as protected instead of private for the + // specialization to work (specialized class extends the non-specialized one and needs access + // to the "private" variables). + // They also should have been val's. We use var's because there is a Scala compiler bug that + // would throw illegal access error at runtime if they are declared as val's. + protected var grow = (newCapacity: Int) => { + _oldValues = _values + _values = new Array[V](newCapacity) + } + + protected var move = (oldPos: Int, newPos: Int) => { + _values(newPos) = _oldValues(oldPos) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala new file mode 100644 index 0000000000..49d95afdb9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.reflect._ + +/** + * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never + * removed. + * + * The underlying implementation uses Scala compiler's specialization to generate optimized + * storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet + * while incurring much less memory overhead. This can serve as building blocks for higher level + * data structures such as an optimized HashMap. + * + * This OpenHashSet is designed to serve as building blocks for higher level data structures + * such as an optimized hash map. Compared with standard hash set implementations, this class + * provides its various callbacks interfaces (e.g. allocateFunc, moveFunc) and interfaces to + * retrieve the position of a key in the underlying array. + * + * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed + * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). + */ +private[spark] +class OpenHashSet[@specialized(Long, Int) T: ClassTag]( + initialCapacity: Int, + loadFactor: Double) + extends Serializable { + + require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + require(initialCapacity >= 1, "Invalid initial capacity") + require(loadFactor < 1.0, "Load factor must be less than 1.0") + require(loadFactor > 0.0, "Load factor must be greater than 0.0") + + import OpenHashSet._ + + def this(initialCapacity: Int) = this(initialCapacity, 0.7) + + def this() = this(64) + + // The following member variables are declared as protected instead of private for the + // specialization to work (specialized class extends the non-specialized one and needs access + // to the "private" variables). + + protected val hasher: Hasher[T] = { + // It would've been more natural to write the following using pattern matching. But Scala 2.9.x + // compiler has a bug when specialization is used together with this pattern matching, and + // throws: + // scala.tools.nsc.symtab.Types$TypeError: type mismatch; + // found : scala.reflect.AnyValManifest[Long] + // required: scala.reflect.ClassTag[Int] + // at scala.tools.nsc.typechecker.Contexts$Context.error(Contexts.scala:298) + // at scala.tools.nsc.typechecker.Infer$Inferencer.error(Infer.scala:207) + // ... + val mt = classTag[T] + if (mt == ClassTag.Long) { + (new LongHasher).asInstanceOf[Hasher[T]] + } else if (mt == ClassTag.Int) { + (new IntHasher).asInstanceOf[Hasher[T]] + } else { + new Hasher[T] + } + } + + protected var _capacity = nextPowerOf2(initialCapacity) + protected var _mask = _capacity - 1 + protected var _size = 0 + + protected var _bitset = new BitSet(_capacity) + + // Init of the array in constructor (instead of in declaration) to work around a Scala compiler + // specialization bug that would generate two arrays (one for Object and one for specialized T). + protected var _data: Array[T] = _ + _data = new Array[T](_capacity) + + /** Number of elements in the set. */ + def size: Int = _size + + /** The capacity of the set (i.e. size of the underlying array). */ + def capacity: Int = _capacity + + /** Return true if this set contains the specified element. */ + def contains(k: T): Boolean = getPos(k) != INVALID_POS + + /** + * Add an element to the set. If the set is over capacity after the insertion, grow the set + * and rehash all elements. + */ + def add(k: T) { + addWithoutResize(k) + rehashIfNeeded(k, grow, move) + } + + /** + * Add an element to the set. This one differs from add in that it doesn't trigger rehashing. + * The caller is responsible for calling rehashIfNeeded. + * + * Use (retval & POSITION_MASK) to get the actual position, and + * (retval & EXISTENCE_MASK) != 0 for prior existence. + * + * @return The position where the key is placed, plus the highest order bit is set if the key + * exists previously. + */ + def addWithoutResize(k: T): Int = putInto(_bitset, _data, k) + + /** + * Rehash the set if it is overloaded. + * @param k A parameter unused in the function, but to force the Scala compiler to specialize + * this method. + * @param allocateFunc Callback invoked when we are allocating a new, larger array. + * @param moveFunc Callback invoked when we move the key from one position (in the old data array) + * to a new position (in the new data array). + */ + def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { + if (_size > loadFactor * _capacity) { + rehash(k, allocateFunc, moveFunc) + } + } + + /** + * Return the position of the element in the underlying array, or INVALID_POS if it is not found. + */ + def getPos(k: T): Int = { + var pos = hashcode(hasher.hash(k)) & _mask + var i = 1 + while (true) { + if (!_bitset.get(pos)) { + return INVALID_POS + } else if (k == _data(pos)) { + return pos + } else { + val delta = i + pos = (pos + delta) & _mask + i += 1 + } + } + // Never reached here + INVALID_POS + } + + /** Return the value at the specified position. */ + def getValue(pos: Int): T = _data(pos) + + /** + * Return the next position with an element stored, starting from the given position inclusively. + */ + def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos) + + /** + * Put an entry into the set. Return the position where the key is placed. In addition, the + * highest bit in the returned position is set if the key exists prior to this put. + * + * This function assumes the data array has at least one empty slot. + */ + private def putInto(bitset: BitSet, data: Array[T], k: T): Int = { + val mask = data.length - 1 + var pos = hashcode(hasher.hash(k)) & mask + var i = 1 + while (true) { + if (!bitset.get(pos)) { + // This is a new key. + data(pos) = k + bitset.set(pos) + _size += 1 + return pos | NONEXISTENCE_MASK + } else if (data(pos) == k) { + // Found an existing key. + return pos + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } + } + // Never reached here + assert(INVALID_POS != INVALID_POS) + INVALID_POS + } + + /** + * Double the table's size and re-hash everything. We are not really using k, but it is declared + * so Scala compiler can specialize this method (which leads to calling the specialized version + * of putInto). + * + * @param k A parameter unused in the function, but to force the Scala compiler to specialize + * this method. + * @param allocateFunc Callback invoked when we are allocating a new, larger array. + * @param moveFunc Callback invoked when we move the key from one position (in the old data array) + * to a new position (in the new data array). + */ + private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { + val newCapacity = _capacity * 2 + require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + + allocateFunc(newCapacity) + val newData = new Array[T](newCapacity) + val newBitset = new BitSet(newCapacity) + var pos = 0 + _size = 0 + while (pos < _capacity) { + if (_bitset.get(pos)) { + val newPos = putInto(newBitset, newData, _data(pos)) + moveFunc(pos, newPos & POSITION_MASK) + } + pos += 1 + } + _bitset = newBitset + _data = newData + _capacity = newCapacity + _mask = newCapacity - 1 + } + + /** + * Re-hash a value to deal better with hash functions that don't differ + * in the lower bits, similar to java.util.HashMap + */ + private def hashcode(h: Int): Int = { + val r = h ^ (h >>> 20) ^ (h >>> 12) + r ^ (r >>> 7) ^ (r >>> 4) + } + + private def nextPowerOf2(n: Int): Int = { + val highBit = Integer.highestOneBit(n) + if (highBit == n) n else highBit << 1 + } +} + + +private[spark] +object OpenHashSet { + + val INVALID_POS = -1 + val NONEXISTENCE_MASK = 0x80000000 + val POSITION_MASK = 0xEFFFFFF + + /** + * A set of specialized hash function implementation to avoid boxing hash code computation + * in the specialized implementation of OpenHashSet. + */ + sealed class Hasher[@specialized(Long, Int) T] { + def hash(o: T): Int = o.hashCode() + } + + class LongHasher extends Hasher[Long] { + override def hash(o: Long): Int = (o ^ (o >>> 32)).toInt + } + + class IntHasher extends Hasher[Int] { + override def hash(o: Int): Int = o + } + + private def grow1(newSize: Int) {} + private def move1(oldPos: Int, newPos: Int) { } + + private val grow = grow1 _ + private val move = move1 _ +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala new file mode 100644 index 0000000000..2e1ef06cbc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.reflect._ + +/** + * A fast hash map implementation for primitive, non-null keys. This hash map supports + * insertions and updates, but not deletions. This map is about an order of magnitude + * faster than java.util.HashMap, while using much less space overhead. + * + * Under the hood, it uses our OpenHashSet implementation. + */ +private[spark] +class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, + @specialized(Long, Int, Double) V: ClassTag]( + initialCapacity: Int) + extends Iterable[(K, V)] + with Serializable { + + def this() = this(64) + + require(classTag[K] == classTag[Long] || classTag[K] == classTag[Int]) + + // Init in constructor (instead of in declaration) to work around a Scala compiler specialization + // bug that would generate two arrays (one for Object and one for specialized T). + protected var _keySet: OpenHashSet[K] = _ + private var _values: Array[V] = _ + _keySet = new OpenHashSet[K](initialCapacity) + _values = new Array[V](_keySet.capacity) + + private var _oldValues: Array[V] = null + + override def size = _keySet.size + + /** Get the value for a given key */ + def apply(k: K): V = { + val pos = _keySet.getPos(k) + _values(pos) + } + + /** Get the value for a given key, or returns elseValue if it doesn't exist. */ + def getOrElse(k: K, elseValue: V): V = { + val pos = _keySet.getPos(k) + if (pos >= 0) _values(pos) else elseValue + } + + /** Set the value for a key */ + def update(k: K, v: V) { + val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK + _values(pos) = v + _keySet.rehashIfNeeded(k, grow, move) + _oldValues = null + } + + /** + * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise, + * set its value to mergeValue(oldValue). + * + * @return the newly updated value. + */ + def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = { + val pos = _keySet.addWithoutResize(k) + if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { + val newValue = defaultValue + _values(pos & OpenHashSet.POSITION_MASK) = newValue + _keySet.rehashIfNeeded(k, grow, move) + newValue + } else { + _values(pos) = mergeValue(_values(pos)) + _values(pos) + } + } + + override def iterator = new Iterator[(K, V)] { + var pos = 0 + var nextPair: (K, V) = computeNextPair() + + /** Get the next value we should return from next(), or null if we're finished iterating */ + def computeNextPair(): (K, V) = { + pos = _keySet.nextPos(pos) + if (pos >= 0) { + val ret = (_keySet.getValue(pos), _values(pos)) + pos += 1 + ret + } else { + null + } + } + + def hasNext = nextPair != null + + def next() = { + val pair = nextPair + nextPair = computeNextPair() + pair + } + } + + // The following member variables are declared as protected instead of private for the + // specialization to work (specialized class extends the unspecialized one and needs access + // to the "private" variables). + // They also should have been val's. We use var's because there is a Scala compiler bug that + // would throw illegal access error at runtime if they are declared as val's. + protected var grow = (newCapacity: Int) => { + _oldValues = _values + _values = new Array[V](newCapacity) + } + + protected var move = (oldPos: Int, newPos: Int) => { + _values(newPos) = _oldValues(oldPos) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala new file mode 100644 index 0000000000..465c221d5f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.reflect.ClassTag + +/** Provides a simple, non-threadsafe, array-backed vector that can store primitives. */ +private[spark] +class PrimitiveVector[@specialized(Long, Int, Double) V: ClassTag](initialSize: Int = 64) { + private var numElements = 0 + private var array: Array[V] = _ + + // NB: This must be separate from the declaration, otherwise the specialized parent class + // will get its own array with the same initial size. TODO: Figure out why... + array = new Array[V](initialSize) + + def apply(index: Int): V = { + require(index < numElements) + array(index) + } + + def +=(value: V) { + if (numElements == array.length) { resize(array.length * 2) } + array(numElements) = value + numElements += 1 + } + + def length = numElements + + def getUnderlyingArray = array + + /** Resizes the array, dropping elements if the total length decreases. */ + def resize(newLength: Int) { + val newArray = new Array[V](newLength) + array.copyToArray(newArray) + array = newArray + } +} |