diff options
129 files changed, 3674 insertions, 4009 deletions
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index c7819d4932..c16afd6b36 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -32,12 +32,26 @@ fi # Build up classpath CLASSPATH="$SPARK_CLASSPATH:$FWDIR/conf" -if [ -f "$FWDIR/RELEASE" ]; then - ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar` + +# First check if we have a dependencies jar. If so, include binary classes with the deps jar +if [ -f "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar ]; then + CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes" + + DEPS_ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar` + CLASSPATH="$CLASSPATH:$DEPS_ASSEMBLY_JAR" else - ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar` + # Else use spark-assembly jar from either RELEASE or assembly directory + if [ -f "$FWDIR/RELEASE" ]; then + ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar` + else + ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar` + fi + CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR" fi -CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR" # Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1 if [[ $SPARK_TESTING == 1 ]]; then diff --git a/bin/slaves.sh b/bin/slaves.sh index 752565b759..c367c2fd8e 100755 --- a/bin/slaves.sh +++ b/bin/slaves.sh @@ -28,7 +28,7 @@ # SPARK_SSH_OPTS Options passed to ssh when running remote commands. ## -usage="Usage: slaves.sh [--config confdir] command..." +usage="Usage: slaves.sh [--config <conf-dir>] command..." # if no args specified, show usage if [ $# -le 0 ]; then @@ -46,6 +46,23 @@ bin=`cd "$bin"; pwd` # spark-env.sh. Save it here. HOSTLIST=$SPARK_SLAVES +# Check if --config is passed as an argument. It is an optional parameter. +# Exit if the argument is not a directory. +if [ "$1" == "--config" ] +then + shift + conf_dir=$1 + if [ ! -d "$conf_dir" ] + then + echo "ERROR : $conf_dir is not a directory" + echo $usage + exit 1 + else + export SPARK_CONF_DIR=$conf_dir + fi + shift +fi + if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then . "${SPARK_CONF_DIR}/spark-env.sh" fi diff --git a/bin/spark-daemon.sh b/bin/spark-daemon.sh index 5bfe967fbf..a0c0d44b58 100755 --- a/bin/spark-daemon.sh +++ b/bin/spark-daemon.sh @@ -29,7 +29,7 @@ # SPARK_NICENESS The scheduling priority for daemons. Defaults to 0. ## -usage="Usage: spark-daemon.sh [--config <conf-dir>] [--hosts hostlistfile] (start|stop) <spark-command> <spark-instance-number> <args...>" +usage="Usage: spark-daemon.sh [--config <conf-dir>] (start|stop) <spark-command> <spark-instance-number> <args...>" # if no args specified, show usage if [ $# -le 1 ]; then @@ -43,6 +43,25 @@ bin=`cd "$bin"; pwd` . "$bin/spark-config.sh" # get arguments + +# Check if --config is passed as an argument. It is an optional parameter. +# Exit if the argument is not a directory. + +if [ "$1" == "--config" ] +then + shift + conf_dir=$1 + if [ ! -d "$conf_dir" ] + then + echo "ERROR : $conf_dir is not a directory" + echo $usage + exit 1 + else + export SPARK_CONF_DIR=$conf_dir + fi + shift +fi + startStop=$1 shift command=$1 diff --git a/bin/spark-daemons.sh b/bin/spark-daemons.sh index 354eb905a1..64286cb2da 100755 --- a/bin/spark-daemons.sh +++ b/bin/spark-daemons.sh @@ -19,7 +19,7 @@ # Run a Spark command on all slave hosts. -usage="Usage: spark-daemons.sh [--config confdir] [--hosts hostlistfile] [start|stop] command instance-number args..." +usage="Usage: spark-daemons.sh [--config <conf-dir>] [start|stop] command instance-number args..." # if no args specified, show usage if [ $# -le 1 ]; then diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java index c4aa2669e0..8a09210245 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java @@ -21,6 +21,7 @@ import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundByteHandlerAdapter; +import org.apache.spark.storage.BlockId; abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { @@ -33,7 +34,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter { } public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header); - public abstract void handleError(String blockId); + public abstract void handleError(BlockId blockId); @Override public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) { diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java index d3d57a0255..172c6e4b1c 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java @@ -24,6 +24,8 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundMessageHandlerAdapter; import io.netty.channel.DefaultFileRegion; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.FileSegment; class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> { @@ -34,41 +36,36 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> { } @Override - public void messageReceived(ChannelHandlerContext ctx, String blockId) { - String path = pResolver.getAbsolutePath(blockId); - // if getFilePath returns null, close the channel - if (path == null) { + public void messageReceived(ChannelHandlerContext ctx, String blockIdString) { + BlockId blockId = BlockId.apply(blockIdString); + FileSegment fileSegment = pResolver.getBlockLocation(blockId); + // if getBlockLocation returns null, close the channel + if (fileSegment == null) { //ctx.close(); return; } - File file = new File(path); + File file = fileSegment.file(); if (file.exists()) { if (!file.isFile()) { - //logger.info("Not a file : " + file.getAbsolutePath()); ctx.write(new FileHeader(0, blockId).buffer()); ctx.flush(); return; } - long length = file.length(); + long length = fileSegment.length(); if (length > Integer.MAX_VALUE || length <= 0) { - //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length); ctx.write(new FileHeader(0, blockId).buffer()); ctx.flush(); return; } int len = new Long(length).intValue(); - //logger.info("Sending block "+blockId+" filelen = "+len); - //logger.info("header = "+ (new FileHeader(len, blockId)).buffer()); ctx.write((new FileHeader(len, blockId)).buffer()); try { ctx.sendFile(new DefaultFileRegion(new FileInputStream(file) - .getChannel(), 0, file.length())); + .getChannel(), fileSegment.offset(), fileSegment.length())); } catch (Exception e) { - //logger.warning("Exception when sending file : " + file.getAbsolutePath()); e.printStackTrace(); } } else { - //logger.warning("File not found: " + file.getAbsolutePath()); ctx.write(new FileHeader(0, blockId).buffer()); } ctx.flush(); diff --git a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java b/core/src/main/java/org/apache/spark/network/netty/PathResolver.java index 94c034cad0..9f7ced44cf 100755 --- a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java +++ b/core/src/main/java/org/apache/spark/network/netty/PathResolver.java @@ -17,13 +17,10 @@ package org.apache.spark.network.netty;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.FileSegment;
public interface PathResolver {
- /**
- * Get the absolute path of the file
- *
- * @param fileId
- * @return the absolute path of file
- */
- public String getAbsolutePath(String fileId);
+ /** Get the file segment in which the given block resides. */
+ public FileSegment getBlockLocation(BlockId blockId);
}
diff --git a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala index f87460039b..0c47afae54 100644 --- a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala @@ -17,20 +17,29 @@ package org.apache.hadoop.mapred +private[apache] trait SparkHadoopMapRedUtil { def newJobContext(conf: JobConf, jobId: JobID): JobContext = { - val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", "org.apache.hadoop.mapred.JobContext"); - val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[org.apache.hadoop.mapreduce.JobID]) + val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", + "org.apache.hadoop.mapred.JobContext") + val ctor = klass.getDeclaredConstructor(classOf[JobConf], + classOf[org.apache.hadoop.mapreduce.JobID]) ctor.newInstance(conf, jobId).asInstanceOf[JobContext] } def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = { - val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", "org.apache.hadoop.mapred.TaskAttemptContext") + val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", + "org.apache.hadoop.mapred.TaskAttemptContext") val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID]) ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] } - def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = { + def newTaskAttemptID( + jtIdentifier: String, + jobId: Int, + isMap: Boolean, + taskId: Int, + attemptId: Int) = { new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId) } diff --git a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala index 93180307fa..32429f01ac 100644 --- a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala @@ -17,9 +17,10 @@ package org.apache.hadoop.mapreduce -import org.apache.hadoop.conf.Configuration import java.lang.{Integer => JInteger, Boolean => JBoolean} +import org.apache.hadoop.conf.Configuration +private[apache] trait SparkHadoopMapReduceUtil { def newJobContext(conf: Configuration, jobId: JobID): JobContext = { val klass = firstAvailableClass( @@ -37,23 +38,31 @@ trait SparkHadoopMapReduceUtil { ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] } - def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = { - val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID"); + def newTaskAttemptID( + jtIdentifier: String, + jobId: Int, + isMap: Boolean, + taskId: Int, + attemptId: Int) = { + val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID") try { - // first, attempt to use the old-style constructor that takes a boolean isMap (not available in YARN) + // First, attempt to use the old-style constructor that takes a boolean isMap + // (not available in YARN) val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean], - classOf[Int], classOf[Int]) - ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), new - JInteger(attemptId)).asInstanceOf[TaskAttemptID] + classOf[Int], classOf[Int]) + ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), + new JInteger(attemptId)).asInstanceOf[TaskAttemptID] } catch { case exc: NoSuchMethodException => { - // failed, look for the new ctor that takes a TaskType (not available in 1.x) - val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType").asInstanceOf[Class[Enum[_]]] - val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(taskTypeClass, if(isMap) "MAP" else "REDUCE") + // If that failed, look for the new constructor that takes a TaskType (not available in 1.x) + val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType") + .asInstanceOf[Class[Enum[_]]] + val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke( + taskTypeClass, if(isMap) "MAP" else "REDUCE") val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass, classOf[Int], classOf[Int]) - ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), new - JInteger(attemptId)).asInstanceOf[TaskAttemptID] + ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), + new JInteger(attemptId)).asInstanceOf[TaskAttemptID] } } } diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala index 908ff56a6b..d9ed572da6 100644 --- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala @@ -22,13 +22,17 @@ import scala.collection.mutable.HashMap import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics} import org.apache.spark.serializer.Serializer -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util.CompletionIterator private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { - override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) + override def fetch[T]( + shuffleId: Int, + reduceId: Int, + context: TaskContext, + serializer: Serializer) : Iterator[T] = { @@ -45,12 +49,12 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) } - val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map { + val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map { case (address, splits) => - (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2))) + (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) } - def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[T] = { + def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = { val blockId = blockPair._1 val blockOption = blockPair._2 blockOption match { @@ -58,9 +62,8 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin block.asInstanceOf[Iterator[T]] } case None => { - val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r blockId match { - case regex(shufId, mapId, _) => + case ShuffleBlockId(shufId, mapId, _) => val address = statuses(mapId.toInt)._1 throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null) case _ => @@ -74,7 +77,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer) val itr = blockFetcherItr.flatMap(unpackBlock) - CompletionIterator[T, Iterator[T]](itr, { + val completionIter = CompletionIterator[T, Iterator[T]](itr, { val shuffleMetrics = new ShuffleReadMetrics shuffleMetrics.shuffleFinishTime = System.currentTimeMillis shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime @@ -83,7 +86,9 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks - metrics.shuffleReadMetrics = Some(shuffleMetrics) + context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics) }) + + new InterruptibleIterator[T](context, completionIter) } } diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index 4cf7eb96da..519ecde50a 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -18,7 +18,7 @@ package org.apache.spark import scala.collection.mutable.{ArrayBuffer, HashSet} -import org.apache.spark.storage.{BlockManager, StorageLevel} +import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, RDDBlockId} import org.apache.spark.rdd.RDD @@ -28,17 +28,17 @@ import org.apache.spark.rdd.RDD private[spark] class CacheManager(blockManager: BlockManager) extends Logging { /** Keys of RDD splits that are being computed/loaded. */ - private val loading = new HashSet[String]() + private val loading = new HashSet[RDDBlockId]() /** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */ def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel) : Iterator[T] = { - val key = "rdd_%d_%d".format(rdd.id, split.index) + val key = RDDBlockId(rdd.id, split.index) logDebug("Looking for partition " + key) blockManager.get(key) match { case Some(values) => // Partition is already materialized, so just return its values - return values.asInstanceOf[Iterator[T]] + return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) case None => // Mark the split as loading (unless someone else marks it first) @@ -56,7 +56,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { // downside of the current code is that threads wait serially if this does happen. blockManager.get(key) match { case Some(values) => - return values.asInstanceOf[Iterator[T]] + return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]]) case None => logInfo("Whoever was loading %s failed; we'll try it ourselves".format(key)) loading.add(key) @@ -73,7 +73,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { if (context.runningLocally) { return computedValues } val elements = new ArrayBuffer[Any] elements ++= computedValues - blockManager.put(key, elements, storageLevel, true) + blockManager.put(key, elements, storageLevel, tellMaster = true) return elements.iterator.asInstanceOf[Iterator[T]] } finally { loading.synchronized { diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala new file mode 100644 index 0000000000..1ad9240cfa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.concurrent._ +import scala.concurrent.duration.Duration +import scala.util.Try + +import org.apache.spark.scheduler.{JobSucceeded, JobWaiter} +import org.apache.spark.scheduler.JobFailed +import org.apache.spark.rdd.RDD + + +/** + * A future for the result of an action. This is an extension of the Scala Future interface to + * support cancellation. + */ +trait FutureAction[T] extends Future[T] { + // Note that we redefine methods of the Future trait here explicitly so we can specify a different + // documentation (with reference to the word "action"). + + /** + * Cancels the execution of this action. + */ + def cancel() + + /** + * Blocks until this action completes. + * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf + * for unbounded waiting, or a finite positive duration + * @return this FutureAction + */ + override def ready(atMost: Duration)(implicit permit: CanAwait): FutureAction.this.type + + /** + * Awaits and returns the result (of type T) of this action. + * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf + * for unbounded waiting, or a finite positive duration + * @throws Exception exception during action execution + * @return the result value if the action is completed within the specific maximum wait time + */ + @throws(classOf[Exception]) + override def result(atMost: Duration)(implicit permit: CanAwait): T + + /** + * When this action is completed, either through an exception, or a value, applies the provided + * function. + */ + def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) + + /** + * Returns whether the action has already been completed with a value or an exception. + */ + override def isCompleted: Boolean + + /** + * The value of this Future. + * + * If the future is not completed the returned value will be None. If the future is completed + * the value will be Some(Success(t)) if it contains a valid result, or Some(Failure(error)) if + * it contains an exception. + */ + override def value: Option[Try[T]] + + /** + * Blocks and returns the result of this job. + */ + @throws(classOf[Exception]) + def get(): T = Await.result(this, Duration.Inf) +} + + +/** + * The future holding the result of an action that triggers a single job. Examples include + * count, collect, reduce. + */ +class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) + extends FutureAction[T] { + + override def cancel() { + jobWaiter.cancel() + } + + override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = { + if (!atMost.isFinite()) { + awaitResult() + } else { + val finishTime = System.currentTimeMillis() + atMost.toMillis + while (!isCompleted) { + val time = System.currentTimeMillis() + if (time >= finishTime) { + throw new TimeoutException + } else { + jobWaiter.wait(finishTime - time) + } + } + } + this + } + + @throws(classOf[Exception]) + override def result(atMost: Duration)(implicit permit: CanAwait): T = { + ready(atMost)(permit) + awaitResult() match { + case scala.util.Success(res) => res + case scala.util.Failure(e) => throw e + } + } + + override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) { + executor.execute(new Runnable { + override def run() { + func(awaitResult()) + } + }) + } + + override def isCompleted: Boolean = jobWaiter.jobFinished + + override def value: Option[Try[T]] = { + if (jobWaiter.jobFinished) { + Some(awaitResult()) + } else { + None + } + } + + private def awaitResult(): Try[T] = { + jobWaiter.awaitResult() match { + case JobSucceeded => scala.util.Success(resultFunc) + case JobFailed(e: Exception, _) => scala.util.Failure(e) + } + } +} + + +/** + * A FutureAction for actions that could trigger multiple Spark jobs. Examples include take, + * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the + * action thread if it is being blocked by a job. + */ +class ComplexFutureAction[T] extends FutureAction[T] { + + // Pointer to the thread that is executing the action. It is set when the action is run. + @volatile private var thread: Thread = _ + + // A flag indicating whether the future has been cancelled. This is used in case the future + // is cancelled before the action was even run (and thus we have no thread to interrupt). + @volatile private var _cancelled: Boolean = false + + // A promise used to signal the future. + private val p = promise[T]() + + override def cancel(): Unit = this.synchronized { + _cancelled = true + if (thread != null) { + thread.interrupt() + } + } + + /** + * Executes some action enclosed in the closure. To properly enable cancellation, the closure + * should use runJob implementation in this promise. See takeAsync for example. + */ + def run(func: => T)(implicit executor: ExecutionContext): this.type = { + scala.concurrent.future { + thread = Thread.currentThread + try { + p.success(func) + } catch { + case e: Exception => p.failure(e) + } finally { + thread = null + } + } + this + } + + /** + * Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext + * to enable cancellation. + */ + def runJob[T, U, R]( + rdd: RDD[T], + processPartition: Iterator[T] => U, + partitions: Seq[Int], + resultHandler: (Int, U) => Unit, + resultFunc: => R) { + // If the action hasn't been cancelled yet, submit the job. The check and the submitJob + // command need to be in an atomic block. + val job = this.synchronized { + if (!cancelled) { + rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc) + } else { + throw new SparkException("Action has been cancelled") + } + } + + // Wait for the job to complete. If the action is cancelled (with an interrupt), + // cancel the job and stop the execution. This is not in a synchronized block because + // Await.ready eventually waits on the monitor in FutureJob.jobWaiter. + try { + Await.ready(job, Duration.Inf) + } catch { + case e: InterruptedException => + job.cancel() + throw new SparkException("Action has been cancelled") + } + } + + /** + * Returns whether the promise has been cancelled. + */ + def cancelled: Boolean = _cancelled + + @throws(classOf[InterruptedException]) + @throws(classOf[scala.concurrent.TimeoutException]) + override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = { + p.future.ready(atMost)(permit) + this + } + + @throws(classOf[Exception]) + override def result(atMost: Duration)(implicit permit: CanAwait): T = { + p.future.result(atMost)(permit) + } + + override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit = { + p.future.onComplete(func)(executor) + } + + override def isCompleted: Boolean = p.isCompleted + + override def value: Option[Try[T]] = p.future.value +} diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala new file mode 100644 index 0000000000..56e0b8d2c0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * An iterator that wraps around an existing iterator to provide task killing functionality. + * It works by checking the interrupted flag in TaskContext. + */ +class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T]) + extends Iterator[T] { + + def hasNext: Boolean = !context.interrupted && delegate.hasNext + + def next(): T = delegate.next() +} diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala index 307c383a89..a85aa50a9b 100644 --- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala @@ -27,7 +27,10 @@ private[spark] abstract class ShuffleFetcher { * Fetch the shuffle outputs for a given ShuffleDependency. * @return An iterator over the elements of the fetched shuffle outputs. */ - def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, + def fetch[T]( + shuffleId: Int, + reduceId: Int, + context: TaskContext, serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T] /** Stop the fetcher */ diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b884ae5879..564466cfd5 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -51,25 +51,20 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor import org.apache.mesos.MesosNativeLibrary -import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.LocalSparkCluster import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, - ClusterScheduler} -import org.apache.spark.scheduler.local.LocalScheduler +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, + SparkDeploySchedulerBackend, ClusterScheduler} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} -import org.apache.spark.storage.{StorageUtils, BlockManagerSource} +import org.apache.spark.scheduler.local.LocalScheduler +import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils} import org.apache.spark.ui.SparkUI -import org.apache.spark.util._ -import org.apache.spark.scheduler.StageInfo -import org.apache.spark.storage.RDDInfo -import org.apache.spark.storage.StorageStatus -import scala.Some -import org.apache.spark.scheduler.StageInfo -import org.apache.spark.storage.RDDInfo -import org.apache.spark.storage.StorageStatus +import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, + TimeStampedHashMap, Utils} + + /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -125,7 +120,7 @@ class SparkContext( private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]] private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup) - // Initalize the Spark UI + // Initialize the Spark UI private[spark] val ui = new SparkUI(this) ui.bind() @@ -161,8 +156,8 @@ class SparkContext( val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters val SPARK_REGEX = """spark://(.*)""".r - //Regular expression for connection to Mesos cluster - val MESOS_REGEX = """(mesos://.*)""".r + // Regular expression for connection to Mesos cluster + val MESOS_REGEX = """mesos://(.*)""".r master match { case "local" => @@ -213,25 +208,24 @@ class SparkContext( throw new SparkException("YARN mode not available ?", th) } } - val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem) + val backend = new CoarseGrainedSchedulerBackend(scheduler, this.env.actorSystem) scheduler.initialize(backend) scheduler - case _ => - if (MESOS_REGEX.findFirstIn(master).isEmpty) { - logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master)) - } + case MESOS_REGEX(mesosUrl) => MesosNativeLibrary.load() val scheduler = new ClusterScheduler(this) val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean - val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos:// val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName) + new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName) } else { - new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName) + new MesosSchedulerBackend(scheduler, this, mesosUrl, appName) } scheduler.initialize(backend) scheduler + + case _ => + throw new SparkException("Could not parse Master URL: '" + master + "'") } } taskScheduler.start() @@ -288,15 +282,46 @@ class SparkContext( Option(localProperties.get).map(_.getProperty(key)).getOrElse(null) /** Set a human readable description of the current job. */ + @deprecated("use setJobGroup", "0.8.1") def setJobDescription(value: String) { - setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value) + setJobGroup("", value) + } + + /** + * Assigns a group id to all the jobs started by this thread until the group id is set to a + * different value or cleared. + * + * Often, a unit of execution in an application consists of multiple Spark actions or jobs. + * Application programmers can use this method to group all those jobs together and give a + * group description. Once set, the Spark web UI will associate such jobs with this group. + * + * The application can also use [[org.apache.spark.SparkContext.cancelJobGroup]] to cancel all + * running jobs in this group. For example, + * {{{ + * // In the main thread: + * sc.setJobGroup("some_job_to_cancel", "some job description") + * sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count() + * + * // In a separate thread: + * sc.cancelJobGroup("some_job_to_cancel") + * }}} + */ + def setJobGroup(groupId: String, description: String) { + setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description) + setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId) + } + + /** Clear the job group id and its description. */ + def clearJobGroup() { + setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null) + setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null) } // Post init taskScheduler.postStartHook() - val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this) - val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this) + private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this) + private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this) def initDriverMetrics() { SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource) @@ -335,7 +360,7 @@ class SparkContext( } /** - * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf giving its InputFormat and any + * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and any * other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable, * etc). */ @@ -361,24 +386,15 @@ class SparkContext( ): RDD[(K, V)] = { // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it. val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration)) - hadoopFile(path, confBroadcast, inputFormatClass, keyClass, valueClass, minSplits) - } - - /** - * Get an RDD for a Hadoop file with an arbitray InputFormat. Accept a Hadoop Configuration - * that has already been broadcast, assuming that it's safe to use it to construct a - * HadoopFileRDD (i.e., except for file 'path', all other configuration properties can be resued). - */ - def hadoopFile[K, V]( - path: String, - confBroadcast: Broadcast[SerializableWritable[Configuration]], - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int - ): RDD[(K, V)] = { - new HadoopFileRDD( - this, path, confBroadcast, inputFormatClass, keyClass, valueClass, minSplits) + val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path) + new HadoopRDD( + this, + confBroadcast, + Some(setInputPathsFunc), + inputFormatClass, + keyClass, + valueClass, + minSplits) } /** @@ -764,10 +780,11 @@ class SparkContext( allowLocal: Boolean, resultHandler: (Int, U) => Unit) { val callSite = Utils.formatSparkCallSite + val cleanedFunc = clean(func) logInfo("Starting job: " + callSite) val start = System.nanoTime - val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, - localProperties.get) + val result = dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, + resultHandler, localProperties.get) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() result @@ -857,6 +874,42 @@ class SparkContext( } /** + * Submit a job for execution and return a FutureJob holding the result. + */ + def submitJob[T, U, R]( + rdd: RDD[T], + processPartition: Iterator[T] => U, + partitions: Seq[Int], + resultHandler: (Int, U) => Unit, + resultFunc: => R): SimpleFutureAction[R] = + { + val cleanF = clean(processPartition) + val callSite = Utils.formatSparkCallSite + val waiter = dagScheduler.submitJob( + rdd, + (context: TaskContext, iter: Iterator[T]) => cleanF(iter), + partitions, + callSite, + allowLocal = false, + resultHandler, + localProperties.get) + new SimpleFutureAction(waiter, resultFunc) + } + + /** + * Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]] + * for more information. + */ + def cancelJobGroup(groupId: String) { + dagScheduler.cancelJobGroup(groupId) + } + + /** Cancel all jobs that have been scheduled or are running. */ + def cancelAllJobs() { + dagScheduler.cancelAllJobs() + } + + /** * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) */ @@ -912,7 +965,10 @@ class SparkContext( * various Spark features. */ object SparkContext { - val SPARK_JOB_DESCRIPTION = "spark.job.description" + + private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description" + + private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 @@ -939,6 +995,8 @@ object SparkContext { implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) = new PairRDDFunctions(rdd) + implicit def rddToAsyncRDDActions[T: ClassManifest](rdd: RDD[T]) = new AsyncRDDActions(rdd) + implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest]( rdd: RDD[(K, V)]) = new SequenceFileRDDFunctions(rdd) diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 2bab9d6e3d..103a1c2051 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -17,14 +17,14 @@ package org.apache.hadoop.mapred -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.fs.Path - +import java.io.IOException import java.text.SimpleDateFormat import java.text.NumberFormat -import java.io.IOException import java.util.Date +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path + import org.apache.spark.Logging import org.apache.spark.SerializableWritable @@ -36,7 +36,11 @@ import org.apache.spark.SerializableWritable * Saves the RDD using a JobConf, which should contain an output key class, an output value class, * a filename to write to, etc, exactly like in a Hadoop MapReduce job. */ -class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable { +private[apache] +class SparkHadoopWriter(@transient jobConf: JobConf) + extends Logging + with SparkHadoopMapRedUtil + with Serializable { private val now = new Date() private val conf = new SerializableWritable(jobConf) @@ -83,13 +87,11 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH } getOutputCommitter().setupTask(getTaskContext()) - writer = getOutputFormat().getRecordWriter( - fs, conf.value, outputName, Reporter.NULL) + writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL) } def write(key: AnyRef, value: AnyRef) { - if (writer!=null) { - //println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")") + if (writer != null) { writer.write(key, value) } else { throw new IOException("Writer is null, open() has not been called") @@ -179,6 +181,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH } } +private[apache] object SparkHadoopWriter { def createJobID(time: Date, id: Int): JobID = { val formatter = new SimpleDateFormat("yyyyMMddHHmm") diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index c2c358c7ad..cae983ed4c 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -17,21 +17,30 @@ package org.apache.spark -import executor.TaskMetrics import scala.collection.mutable.ArrayBuffer +import org.apache.spark.executor.TaskMetrics + class TaskContext( val stageId: Int, - val splitId: Int, + val partitionId: Int, val attemptId: Long, val runningLocally: Boolean = false, - val taskMetrics: TaskMetrics = TaskMetrics.empty() + @volatile var interrupted: Boolean = false, + private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty() ) extends Serializable { - @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit] + @deprecated("use partitionId", "0.8.1") + def splitId = partitionId + + // List of callback functions to execute when the task completes. + @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit] - // Add a callback function to be executed on task completion. An example use - // is for HadoopRDD to register a callback to close the input stream. + /** + * Add a callback function to be executed on task completion. An example use + * is for HadoopRDD to register a callback to close the input stream. + * @param f Callback function. + */ def addOnCompleteCallback(f: () => Unit) { onCompleteCallbacks += f } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 8466c2a004..c1e5e04b31 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -52,4 +52,6 @@ private[spark] case class ExceptionFailure( */ private[spark] case object TaskResultLost extends TaskEndReason +private[spark] case object TaskKilled extends TaskEndReason + private[spark] case class OtherFailure(message: String) extends TaskEndReason diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 1f8ad688a6..12b4d94a56 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -308,7 +308,7 @@ private class BytesToString extends org.apache.spark.api.java.function.Function[ * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it * collects a list of pickled strings that we pass to Python through a socket. */ -class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) +private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) extends AccumulatorParam[JList[Array[Byte]]] { Utils.checkHost(serverHost, "Expected hostname") diff --git a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala deleted file mode 100644 index f82dea9f3a..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala +++ /dev/null @@ -1,1058 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.io._ -import java.net._ -import java.util.{BitSet, Comparator, Timer, TimerTask, UUID} -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable.{ListBuffer, Map, Set} -import scala.math - -import org.apache.spark._ -import org.apache.spark.storage.{BlockManager, StorageLevel} -import org.apache.spark.util.Utils - -private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) - extends Broadcast[T](id) - with Logging - with Serializable { - - def value = value_ - - def blockId: String = BlockManager.toBroadcastId(id) - - MultiTracker.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } - - @transient var arrayOfBlocks: Array[BroadcastBlock] = null - @transient var hasBlocksBitVector: BitSet = null - @transient var numCopiesSent: Array[Int] = null - @transient var totalBytes = -1 - @transient var totalBlocks = -1 - @transient var hasBlocks = new AtomicInteger(0) - - // Used ONLY by driver to track how many unique blocks have been sent out - @transient var sentBlocks = new AtomicInteger(0) - - @transient var listenPortLock = new Object - @transient var guidePortLock = new Object - @transient var totalBlocksLock = new Object - - @transient var listOfSources = ListBuffer[SourceInfo]() - - @transient var serveMR: ServeMultipleRequests = null - - // Used only in driver - @transient var guideMR: GuideMultipleRequests = null - - // Used only in Workers - @transient var ttGuide: TalkToGuide = null - - @transient var hostAddress = Utils.localIpAddress - @transient var listenPort = -1 - @transient var guidePort = -1 - - @transient var stopBroadcast = false - - // Must call this after all the variables have been created/initialized - if (!isLocal) { - sendBroadcast() - } - - def sendBroadcast() { - logInfo("Local host address: " + hostAddress) - - // Create a variableInfo object and store it in valueInfos - var variableInfo = MultiTracker.blockifyObject(value_) - - // Prepare the value being broadcasted - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks.set(variableInfo.totalBlocks) - - // Guide has all the blocks - hasBlocksBitVector = new BitSet(totalBlocks) - hasBlocksBitVector.set(0, totalBlocks) - - // Guide still hasn't sent any block - numCopiesSent = new Array[Int](totalBlocks) - - guideMR = new GuideMultipleRequests - guideMR.setDaemon(true) - guideMR.start() - logInfo("GuideMultipleRequests started...") - - // Must always come AFTER guideMR is created - while (guidePort == -1) { - guidePortLock.synchronized { guidePortLock.wait() } - } - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - // Must always come AFTER serveMR is created - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Must always come AFTER listenPort is created - val driverSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) - hasBlocksBitVector.synchronized { - driverSource.hasBlocksBitVector = hasBlocksBitVector - } - - // In the beginning, this is the only known source to Guide - listOfSources += driverSource - - // Register with the Tracker - MultiTracker.registerBroadcast(id, - SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) - } - - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - MultiTracker.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { - case Some(x) => - value_ = x.asInstanceOf[T] - - case None => - logInfo("Started reading broadcast variable " + id) - // Initializing everything because driver will only send null/0 values - // Only the 1st worker in a node can be here. Others will get from cache - initializeWorkerVariables() - - logInfo("Local host address: " + hostAddress) - - // Start local ServeMultipleRequests thread first - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast(id) - if (receptionSucceeded) { - value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } else { - logError("Reading broadcast variable " + id + " failed") - } - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - - // Initialize variables in the worker node. Driver sends everything as 0/null - private def initializeWorkerVariables() { - arrayOfBlocks = null - hasBlocksBitVector = null - numCopiesSent = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = new AtomicInteger(0) - - listenPortLock = new Object - totalBlocksLock = new Object - - serveMR = null - ttGuide = null - - hostAddress = Utils.localIpAddress - listenPort = -1 - - listOfSources = ListBuffer[SourceInfo]() - - stopBroadcast = false - } - - private def getLocalSourceInfo: SourceInfo = { - // Wait till hostName and listenPort are OK - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Wait till totalBlocks and totalBytes are OK - while (totalBlocks == -1) { - totalBlocksLock.synchronized { totalBlocksLock.wait() } - } - - var localSourceInfo = SourceInfo( - hostAddress, listenPort, totalBlocks, totalBytes) - - localSourceInfo.hasBlocks = hasBlocks.get - - hasBlocksBitVector.synchronized { - localSourceInfo.hasBlocksBitVector = hasBlocksBitVector - } - - return localSourceInfo - } - - // Add new SourceInfo to the listOfSources. Update if it exists already. - // Optimizing just by OR-ing the BitVectors was BAD for performance - private def addToListOfSources(newSourceInfo: SourceInfo) { - listOfSources.synchronized { - if (listOfSources.contains(newSourceInfo)) { - listOfSources = listOfSources - newSourceInfo - } - listOfSources += newSourceInfo - } - } - - private def addToListOfSources(newSourceInfos: ListBuffer[SourceInfo]) { - newSourceInfos.foreach { newSourceInfo => - addToListOfSources(newSourceInfo) - } - } - - class TalkToGuide(gInfo: SourceInfo) - extends Thread with Logging { - override def run() { - - // Keep exchaning information until all blocks have been received - while (hasBlocks.get < totalBlocks) { - talkOnce - Thread.sleep(MultiTracker.ranGen.nextInt( - MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + - MultiTracker.MinKnockInterval) - } - - // Talk one more time to let the Guide know of reception completion - talkOnce - } - - // Connect to Guide and send this worker's information - private def talkOnce { - var clientSocketToGuide: Socket = null - var oosGuide: ObjectOutputStream = null - var oisGuide: ObjectInputStream = null - - clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort) - oosGuide = new ObjectOutputStream(clientSocketToGuide.getOutputStream) - oosGuide.flush() - oisGuide = new ObjectInputStream(clientSocketToGuide.getInputStream) - - // Send local information - oosGuide.writeObject(getLocalSourceInfo) - oosGuide.flush() - - // Receive source information from Guide - var suitableSources = - oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]] - logDebug("Received suitableSources from Driver " + suitableSources) - - addToListOfSources(suitableSources) - - oisGuide.close() - oosGuide.close() - clientSocketToGuide.close() - } - } - - def receiveBroadcast(variableID: Long): Boolean = { - val gInfo = MultiTracker.getGuideInfo(variableID) - - if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { - return false - } - - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Setup initial states of variables - totalBlocks = gInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) - hasBlocksBitVector = new BitSet(totalBlocks) - numCopiesSent = new Array[Int](totalBlocks) - totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } - totalBytes = gInfo.totalBytes - - // Start ttGuide to periodically talk to the Guide - var ttGuide = new TalkToGuide(gInfo) - ttGuide.setDaemon(true) - ttGuide.start() - logInfo("TalkToGuide started...") - - // Start pController to run TalkToPeer threads - var pcController = new PeerChatterController - pcController.setDaemon(true) - pcController.start() - logInfo("PeerChatterController started...") - - // FIXME: Must fix this. This might never break if broadcast fails. - // We should be able to break and send false. Also need to kill threads - while (hasBlocks.get < totalBlocks) { - Thread.sleep(MultiTracker.MaxKnockInterval) - } - - return true - } - - class PeerChatterController - extends Thread with Logging { - private var peersNowTalking = ListBuffer[SourceInfo]() - // TODO: There is a possible bug with blocksInRequestBitVector when a - // certain bit is NOT unset upon failure resulting in an infinite loop. - private var blocksInRequestBitVector = new BitSet(totalBlocks) - - override def run() { - var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots) - - while (hasBlocks.get < totalBlocks) { - var numThreadsToCreate = 0 - listOfSources.synchronized { - numThreadsToCreate = math.min(listOfSources.size, MultiTracker.MaxChatSlots) - - threadPool.getActiveCount - } - - while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) { - var peerToTalkTo = pickPeerToTalkToRandom - - if (peerToTalkTo != null) - logDebug("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector) - else - logDebug("No peer chosen...") - - if (peerToTalkTo != null) { - threadPool.execute(new TalkToPeer(peerToTalkTo)) - - // Add to peersNowTalking. Remove in the thread. We have to do this - // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once - peersNowTalking.synchronized { peersNowTalking += peerToTalkTo } - } - - numThreadsToCreate = numThreadsToCreate - 1 - } - - // Sleep for a while before starting some more threads - Thread.sleep(MultiTracker.MinKnockInterval) - } - // Shutdown the thread pool - threadPool.shutdown() - } - - // Right now picking the one that has the most blocks this peer wants - // Also picking peer randomly if no one has anything interesting - private def pickPeerToTalkToRandom: SourceInfo = { - var curPeer: SourceInfo = null - var curMax = 0 - - logDebug("Picking peers to talk to...") - - // Find peers that are not connected right now - var peersNotInUse = ListBuffer[SourceInfo]() - listOfSources.synchronized { - peersNowTalking.synchronized { - peersNotInUse = listOfSources -- peersNowTalking - } - } - - // Select the peer that has the most blocks that this receiver does not - peersNotInUse.foreach { eachSource => - var tempHasBlocksBitVector: BitSet = null - hasBlocksBitVector.synchronized { - tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - tempHasBlocksBitVector.flip(0, tempHasBlocksBitVector.size) - tempHasBlocksBitVector.and(eachSource.hasBlocksBitVector) - - if (tempHasBlocksBitVector.cardinality > curMax) { - curPeer = eachSource - curMax = tempHasBlocksBitVector.cardinality - } - } - - // Always picking randomly - if (curPeer == null && peersNotInUse.size > 0) { - // Pick uniformly the i'th required peer - var i = MultiTracker.ranGen.nextInt(peersNotInUse.size) - - var peerIter = peersNotInUse.iterator - curPeer = peerIter.next - - while (i > 0) { - curPeer = peerIter.next - i = i - 1 - } - } - - return curPeer - } - - // Picking peer with the weight of rare blocks it has - private def pickPeerToTalkToRarestFirst: SourceInfo = { - // Find peers that are not connected right now - var peersNotInUse = ListBuffer[SourceInfo]() - listOfSources.synchronized { - peersNowTalking.synchronized { - peersNotInUse = listOfSources -- peersNowTalking - } - } - - // Count the number of copies of each block in the neighborhood - var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0) - - listOfSources.synchronized { - listOfSources.foreach { eachSource => - for (i <- 0 until totalBlocks) { - numCopiesPerBlock(i) += - ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 ) - } - } - } - - // A block is considered rare if there are at most 2 copies of that block - // This CONSTANT could be a function of the neighborhood size - var rareBlocksIndices = ListBuffer[Int]() - for (i <- 0 until totalBlocks) { - if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) { - rareBlocksIndices += i - } - } - - // Find peers with rare blocks - var peersWithRareBlocks = ListBuffer[(SourceInfo, Int)]() - var totalRareBlocks = 0 - - peersNotInUse.foreach { eachPeer => - var hasRareBlocks = 0 - rareBlocksIndices.foreach { rareBlock => - if (eachPeer.hasBlocksBitVector.get(rareBlock)) { - hasRareBlocks += 1 - } - } - - if (hasRareBlocks > 0) { - peersWithRareBlocks += ((eachPeer, hasRareBlocks)) - } - totalRareBlocks += hasRareBlocks - } - - // Select a peer from peersWithRareBlocks based on weight calculated from - // unique rare blocks - var selectedPeerToTalkTo: SourceInfo = null - - if (peersWithRareBlocks.size > 0) { - // Sort the peers based on how many rare blocks they have - peersWithRareBlocks.sortBy(_._2) - - var randomNumber = MultiTracker.ranGen.nextDouble - var tempSum = 0.0 - - var i = 0 - do { - tempSum += (1.0 * peersWithRareBlocks(i)._2 / totalRareBlocks) - if (tempSum >= randomNumber) { - selectedPeerToTalkTo = peersWithRareBlocks(i)._1 - } - i += 1 - } while (i < peersWithRareBlocks.size && selectedPeerToTalkTo == null) - } - - if (selectedPeerToTalkTo == null) { - selectedPeerToTalkTo = pickPeerToTalkToRandom - } - - return selectedPeerToTalkTo - } - - class TalkToPeer(peerToTalkTo: SourceInfo) - extends Thread with Logging { - private var peerSocketToSource: Socket = null - private var oosSource: ObjectOutputStream = null - private var oisSource: ObjectInputStream = null - - override def run() { - // TODO: There is a possible bug here regarding blocksInRequestBitVector - var blockToAskFor = -1 - - // Setup the timeout mechanism - var timeOutTask = new TimerTask { - override def run() { - cleanUpConnections() - } - } - - var timeOutTimer = new Timer - timeOutTimer.schedule(timeOutTask, MultiTracker.MaxKnockInterval) - - logInfo("TalkToPeer started... => " + peerToTalkTo) - - try { - // Connect to the source - peerSocketToSource = - new Socket(peerToTalkTo.hostAddress, peerToTalkTo.listenPort) - oosSource = - new ObjectOutputStream(peerSocketToSource.getOutputStream) - oosSource.flush() - oisSource = - new ObjectInputStream(peerSocketToSource.getInputStream) - - // Receive latest SourceInfo from peerToTalkTo - var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo] - // Update listOfSources - addToListOfSources(newPeerToTalkTo) - - // Turn the timer OFF, if the sender responds before timeout - timeOutTimer.cancel() - - // Send the latest SourceInfo - oosSource.writeObject(getLocalSourceInfo) - oosSource.flush() - - var keepReceiving = true - - while (hasBlocks.get < totalBlocks && keepReceiving) { - blockToAskFor = - pickBlockRandom(newPeerToTalkTo.hasBlocksBitVector) - - // No block to request - if (blockToAskFor < 0) { - // Nothing to receive from newPeerToTalkTo - keepReceiving = false - } else { - // Let other threads know that blockToAskFor is being requested - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor) - } - - // Start with sending the blockID - oosSource.writeObject(blockToAskFor) - oosSource.flush() - - // CHANGED: Driver might send some other block than the one - // requested to ensure fast spreading of all blocks. - val recvStartTime = System.currentTimeMillis - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - val receptionTime = (System.currentTimeMillis - recvStartTime) - - logDebug("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.") - - if (!hasBlocksBitVector.get(bcBlock.blockID)) { - arrayOfBlocks(bcBlock.blockID) = bcBlock - - // Update the hasBlocksBitVector first - hasBlocksBitVector.synchronized { - hasBlocksBitVector.set(bcBlock.blockID) - hasBlocks.getAndIncrement - } - - // Some block(may NOT be blockToAskFor) has arrived. - // In any case, blockToAskFor is not in request any more - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor, false) - } - - // Reset blockToAskFor to -1. Else it will be considered missing - blockToAskFor = -1 - } - - // Send the latest SourceInfo - oosSource.writeObject(getLocalSourceInfo) - oosSource.flush() - } - } - } catch { - // EOFException is expected to happen because sender can break - // connection due to timeout - case eofe: java.io.EOFException => { } - case e: Exception => { - logError("TalktoPeer had a " + e) - // FIXME: Remove 'newPeerToTalkTo' from listOfSources - // We probably should have the following in some form, but not - // really here. This exception can happen if the sender just breaks connection - // listOfSources.synchronized { - // logInfo("Exception in TalkToPeer. Removing source: " + peerToTalkTo) - // listOfSources = listOfSources - peerToTalkTo - // } - } - } finally { - // blockToAskFor != -1 => there was an exception - if (blockToAskFor != -1) { - blocksInRequestBitVector.synchronized { - blocksInRequestBitVector.set(blockToAskFor, false) - } - } - - cleanUpConnections() - } - } - - // Right now it picks a block uniformly that this peer does not have - private def pickBlockRandom(txHasBlocksBitVector: BitSet): Int = { - var needBlocksBitVector: BitSet = null - - // Blocks already present - hasBlocksBitVector.synchronized { - needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - - // Include blocks already in transmission ONLY IF - // MultiTracker.EndGameFraction has NOT been achieved - if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { - blocksInRequestBitVector.synchronized { - needBlocksBitVector.or(blocksInRequestBitVector) - } - } - - // Find blocks that are neither here nor in transit - needBlocksBitVector.flip(0, needBlocksBitVector.size) - - // Blocks that should/can be requested - needBlocksBitVector.and(txHasBlocksBitVector) - - if (needBlocksBitVector.cardinality == 0) { - return -1 - } else { - // Pick uniformly the i'th required block - var i = MultiTracker.ranGen.nextInt(needBlocksBitVector.cardinality) - var pickedBlockIndex = needBlocksBitVector.nextSetBit(0) - - while (i > 0) { - pickedBlockIndex = - needBlocksBitVector.nextSetBit(pickedBlockIndex + 1) - i -= 1 - } - - return pickedBlockIndex - } - } - - // Pick the block that seems to be the rarest across sources - private def pickBlockRarestFirst(txHasBlocksBitVector: BitSet): Int = { - var needBlocksBitVector: BitSet = null - - // Blocks already present - hasBlocksBitVector.synchronized { - needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet] - } - - // Include blocks already in transmission ONLY IF - // MultiTracker.EndGameFraction has NOT been achieved - if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) { - blocksInRequestBitVector.synchronized { - needBlocksBitVector.or(blocksInRequestBitVector) - } - } - - // Find blocks that are neither here nor in transit - needBlocksBitVector.flip(0, needBlocksBitVector.size) - - // Blocks that should/can be requested - needBlocksBitVector.and(txHasBlocksBitVector) - - if (needBlocksBitVector.cardinality == 0) { - return -1 - } else { - // Count the number of copies for each block across all sources - var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0) - - listOfSources.synchronized { - listOfSources.foreach { eachSource => - for (i <- 0 until totalBlocks) { - numCopiesPerBlock(i) += - ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 ) - } - } - } - - // Find the minimum - var minVal = Integer.MAX_VALUE - for (i <- 0 until totalBlocks) { - if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) < minVal) { - minVal = numCopiesPerBlock(i) - } - } - - // Find the blocks with the least copies that this peer does not have - var minBlocksIndices = ListBuffer[Int]() - for (i <- 0 until totalBlocks) { - if (needBlocksBitVector.get(i) && numCopiesPerBlock(i) == minVal) { - minBlocksIndices += i - } - } - - // Now select a random index from minBlocksIndices - if (minBlocksIndices.size == 0) { - return -1 - } else { - // Pick uniformly the i'th index - var i = MultiTracker.ranGen.nextInt(minBlocksIndices.size) - return minBlocksIndices(i) - } - } - } - - private def cleanUpConnections() { - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (peerSocketToSource != null) { - peerSocketToSource.close() - } - - // Delete from peersNowTalking - peersNowTalking.synchronized { peersNowTalking -= peerToTalkTo } - } - } - } - - class GuideMultipleRequests - extends Thread with Logging { - // Keep track of sources that have completed reception - private var setOfCompletedSources = Set[SourceInfo]() - - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) - guidePort = serverSocket.getLocalPort - logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - - guidePortLock.synchronized { guidePortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - // Stop broadcast if at least one worker has connected and - // everyone connected so far are done. Comparing with - // listOfSources.size - 1, because it includes the Guide itself - listOfSources.synchronized { - setOfCompletedSources.synchronized { - if (listOfSources.size > 1 && - setOfCompletedSources.size == listOfSources.size - 1) { - stopBroadcast = true - logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.") - } - } - } - } - } - if (clientSocket != null) { - logDebug("Guide: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new GuideSingleRequest(clientSocket)) - } catch { - // In failure, close the socket here; else, thread will close it - case ioe: IOException => { - clientSocket.close() - } - } - } - } - - // Shutdown the thread pool - threadPool.shutdown() - - logInfo("Sending stopBroadcast notifications...") - sendStopBroadcastNotifications - - MultiTracker.unregisterBroadcast(id) - } finally { - if (serverSocket != null) { - logInfo("GuideMultipleRequests now stopping...") - serverSocket.close() - } - } - } - - private def sendStopBroadcastNotifications() { - listOfSources.synchronized { - listOfSources.foreach { sourceInfo => - - var guideSocketToSource: Socket = null - var gosSource: ObjectOutputStream = null - var gisSource: ObjectInputStream = null - - try { - // Connect to the source - guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) - gosSource.flush() - gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) - - // Throw away whatever comes in - gisSource.readObject.asInstanceOf[SourceInfo] - - // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast - gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast)) - gosSource.flush() - } catch { - case e: Exception => { - logError("sendStopBroadcastNotifications had a " + e) - } - } finally { - if (gisSource != null) { - gisSource.close() - } - if (gosSource != null) { - gosSource.close() - } - if (guideSocketToSource != null) { - guideSocketToSource.close() - } - } - } - } - } - - class GuideSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var sourceInfo: SourceInfo = null - private var selectedSources: ListBuffer[SourceInfo] = null - - override def run() { - try { - logInfo("new GuideSingleRequest is running") - // Connecting worker is sending in its information - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Select a suitable source and send it back to the worker - selectedSources = selectSuitableSources(sourceInfo) - logDebug("Sending selectedSources:" + selectedSources) - oos.writeObject(selectedSources) - oos.flush() - - // Add this source to the listOfSources - addToListOfSources(sourceInfo) - } catch { - case e: Exception => { - // Assuming exception caused by receiver failure: remove - if (listOfSources != null) { - listOfSources.synchronized { listOfSources -= sourceInfo } - } - } - } finally { - logInfo("GuideSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - // Randomly select some sources to send back - private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = { - var selectedSources = ListBuffer[SourceInfo]() - - // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true' - // then add skipSourceInfo to setOfCompletedSources. Return blank. - if (skipSourceInfo.hasBlocks == totalBlocks) { - setOfCompletedSources.synchronized { setOfCompletedSources += skipSourceInfo } - return selectedSources - } - - listOfSources.synchronized { - if (listOfSources.size <= MultiTracker.MaxPeersInGuideResponse) { - selectedSources = listOfSources.clone - } else { - var picksLeft = MultiTracker.MaxPeersInGuideResponse - var alreadyPicked = new BitSet(listOfSources.size) - - while (picksLeft > 0) { - var i = -1 - - do { - i = MultiTracker.ranGen.nextInt(listOfSources.size) - } while (alreadyPicked.get(i)) - - var peerIter = listOfSources.iterator - var curPeer = peerIter.next - - // Set the BitSet before i is decremented - alreadyPicked.set(i) - - while (i > 0) { - curPeer = peerIter.next - i = i - 1 - } - - selectedSources += curPeer - - picksLeft = picksLeft - 1 - } - } - } - - // Remove the receiving source (if present) - selectedSources = selectedSources - skipSourceInfo - - return selectedSources - } - } - } - - class ServeMultipleRequests - extends Thread with Logging { - // Server at most MultiTracker.MaxChatSlots peers - var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots) - - override def run() { - var serverSocket = new ServerSocket(0) - listenPort = serverSocket.getLocalPort - - logInfo("ServeMultipleRequests started with " + serverSocket) - - listenPortLock.synchronized { listenPortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { } - } - if (clientSocket != null) { - logDebug("Serve: Accepted new client connection:" + clientSocket) - try { - threadPool.execute(new ServeSingleRequest(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ServeMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ServeSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - logInfo("new ServeSingleRequest is running") - - override def run() { - try { - // Send latest local SourceInfo to the receiver - // In the case of receiver timeout and connection close, this will - // throw a java.net.SocketException: Broken pipe - oos.writeObject(getLocalSourceInfo) - oos.flush() - - // Receive latest SourceInfo from the receiver - var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) { - stopBroadcast = true - } else { - addToListOfSources(rxSourceInfo) - } - - val startTime = System.currentTimeMillis - var curTime = startTime - var keepSending = true - var numBlocksToSend = MultiTracker.MaxChatBlocks - - while (!stopBroadcast && keepSending && numBlocksToSend > 0) { - // Receive which block to send - var blockToSend = ois.readObject.asInstanceOf[Int] - - // If it is driver AND at least one copy of each block has not been - // sent out already, MODIFY blockToSend - if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) { - blockToSend = sentBlocks.getAndIncrement - } - - // Send the block - sendBlock(blockToSend) - rxSourceInfo.hasBlocksBitVector.set(blockToSend) - - numBlocksToSend -= 1 - - // Receive latest SourceInfo from the receiver - rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo] - logDebug("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector) - addToListOfSources(rxSourceInfo) - - curTime = System.currentTimeMillis - // Revoke sending only if there is anyone waiting in the queue - if (curTime - startTime >= MultiTracker.MaxChatTime && - threadPool.getQueue.size > 0) { - keepSending = false - } - } - } catch { - case e: Exception => logError("ServeSingleRequest had a " + e) - } finally { - logInfo("ServeSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - private def sendBlock(blockToSend: Int) { - try { - oos.writeObject(arrayOfBlocks(blockToSend)) - oos.flush() - } catch { - case e: Exception => logError("sendBlock had a " + e) - } - logDebug("Sent block: " + blockToSend + " to " + clientSocket) - } - } - } -} - -private[spark] class BitTorrentBroadcastFactory -extends BroadcastFactory { - def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new BitTorrentBroadcast[T](value_, isLocal, id) - - def stop() { MultiTracker.stop() } -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index a4ceb0d6af..609464e38d 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -25,16 +25,15 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream import org.apache.spark.{HttpServer, Logging, SparkEnv} import org.apache.spark.io.CompressionCodec -import org.apache.spark.storage.{BlockManager, StorageLevel} -import org.apache.spark.util.{MetadataCleanerType, Utils, MetadataCleaner, TimeStampedHashSet} - +import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} +import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { def value = value_ - def blockId: String = BlockManager.toBroadcastId(id) + def blockId = BroadcastBlockId(id) HttpBroadcast.synchronized { SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) @@ -121,7 +120,7 @@ private object HttpBroadcast extends Logging { } def write(id: Long, value: Any) { - val file = new File(broadcastDir, "broadcast-" + id) + val file = new File(broadcastDir, BroadcastBlockId(id).name) val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(new FileOutputStream(file)) @@ -137,7 +136,7 @@ private object HttpBroadcast extends Logging { } def read[T](id: Long): T = { - val url = serverUri + "/broadcast-" + id + val url = serverUri + "/" + BroadcastBlockId(id).name val in = { if (compress) { compressionCodec.compressedInputStream(new URL(url).openStream()) diff --git a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala b/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala deleted file mode 100644 index 21ec94659e..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala +++ /dev/null @@ -1,410 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.io._ -import java.net._ -import java.util.Random - -import scala.collection.mutable.Map - -import org.apache.spark._ -import org.apache.spark.util.Utils - -private object MultiTracker -extends Logging { - - // Tracker Messages - val REGISTER_BROADCAST_TRACKER = 0 - val UNREGISTER_BROADCAST_TRACKER = 1 - val FIND_BROADCAST_TRACKER = 2 - - // Map to keep track of guides of ongoing broadcasts - var valueToGuideMap = Map[Long, SourceInfo]() - - // Random number generator - var ranGen = new Random - - private var initialized = false - private var _isDriver = false - - private var stopBroadcast = false - - private var trackMV: TrackMultipleValues = null - - def initialize(__isDriver: Boolean) { - synchronized { - if (!initialized) { - _isDriver = __isDriver - - if (isDriver) { - trackMV = new TrackMultipleValues - trackMV.setDaemon(true) - trackMV.start() - - // Set DriverHostAddress to the driver's IP address for the slaves to read - System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress) - } - - initialized = true - } - } - } - - def stop() { - stopBroadcast = true - } - - // Load common parameters - private var DriverHostAddress_ = System.getProperty( - "spark.MultiTracker.DriverHostAddress", "") - private var DriverTrackerPort_ = System.getProperty( - "spark.broadcast.driverTrackerPort", "11111").toInt - private var BlockSize_ = System.getProperty( - "spark.broadcast.blockSize", "4096").toInt * 1024 - private var MaxRetryCount_ = System.getProperty( - "spark.broadcast.maxRetryCount", "2").toInt - - private var TrackerSocketTimeout_ = System.getProperty( - "spark.broadcast.trackerSocketTimeout", "50000").toInt - private var ServerSocketTimeout_ = System.getProperty( - "spark.broadcast.serverSocketTimeout", "10000").toInt - - private var MinKnockInterval_ = System.getProperty( - "spark.broadcast.minKnockInterval", "500").toInt - private var MaxKnockInterval_ = System.getProperty( - "spark.broadcast.maxKnockInterval", "999").toInt - - // Load TreeBroadcast config params - private var MaxDegree_ = System.getProperty( - "spark.broadcast.maxDegree", "2").toInt - - // Load BitTorrentBroadcast config params - private var MaxPeersInGuideResponse_ = System.getProperty( - "spark.broadcast.maxPeersInGuideResponse", "4").toInt - - private var MaxChatSlots_ = System.getProperty( - "spark.broadcast.maxChatSlots", "4").toInt - private var MaxChatTime_ = System.getProperty( - "spark.broadcast.maxChatTime", "500").toInt - private var MaxChatBlocks_ = System.getProperty( - "spark.broadcast.maxChatBlocks", "1024").toInt - - private var EndGameFraction_ = System.getProperty( - "spark.broadcast.endGameFraction", "0.95").toDouble - - def isDriver = _isDriver - - // Common config params - def DriverHostAddress = DriverHostAddress_ - def DriverTrackerPort = DriverTrackerPort_ - def BlockSize = BlockSize_ - def MaxRetryCount = MaxRetryCount_ - - def TrackerSocketTimeout = TrackerSocketTimeout_ - def ServerSocketTimeout = ServerSocketTimeout_ - - def MinKnockInterval = MinKnockInterval_ - def MaxKnockInterval = MaxKnockInterval_ - - // TreeBroadcast configs - def MaxDegree = MaxDegree_ - - // BitTorrentBroadcast configs - def MaxPeersInGuideResponse = MaxPeersInGuideResponse_ - - def MaxChatSlots = MaxChatSlots_ - def MaxChatTime = MaxChatTime_ - def MaxChatBlocks = MaxChatBlocks_ - - def EndGameFraction = EndGameFraction_ - - class TrackMultipleValues - extends Thread with Logging { - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(DriverTrackerPort) - logInfo("TrackMultipleValues started at " + serverSocket) - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(TrackerSocketTimeout) - clientSocket = serverSocket.accept() - } catch { - case e: Exception => { - if (stopBroadcast) { - logInfo("Stopping TrackMultipleValues...") - } - } - } - - if (clientSocket != null) { - try { - threadPool.execute(new Thread { - override def run() { - val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - val ois = new ObjectInputStream(clientSocket.getInputStream) - - try { - // First, read message type - val messageType = ois.readObject.asInstanceOf[Int] - - if (messageType == REGISTER_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - // Receive hostAddress and listenPort - val gInfo = ois.readObject.asInstanceOf[SourceInfo] - - // Add to the map - valueToGuideMap.synchronized { - valueToGuideMap += (id -> gInfo) - } - - logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == UNREGISTER_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - - // Remove from the map - valueToGuideMap.synchronized { - valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault) - } - - logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap) - - // Send dummy ACK - oos.writeObject(-1) - oos.flush() - } else if (messageType == FIND_BROADCAST_TRACKER) { - // Receive Long - val id = ois.readObject.asInstanceOf[Long] - - var gInfo = - if (valueToGuideMap.contains(id)) valueToGuideMap(id) - else SourceInfo("", SourceInfo.TxNotStartedRetry) - - logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort) - - // Send reply back - oos.writeObject(gInfo) - oos.flush() - } else { - throw new SparkException("Undefined messageType at TrackMultipleValues") - } - } catch { - case e: Exception => { - logError("TrackMultipleValues had a " + e) - } - } finally { - ois.close() - oos.close() - clientSocket.close() - } - } - }) - } catch { - // In failure, close socket here; else, client thread will close - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - serverSocket.close() - } - // Shutdown the thread pool - threadPool.shutdown() - } - } - - def getGuideInfo(variableLong: Long): SourceInfo = { - var clientSocketToTracker: Socket = null - var oosTracker: ObjectOutputStream = null - var oisTracker: ObjectInputStream = null - - var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry) - - var retriesLeft = MultiTracker.MaxRetryCount - do { - try { - // Connect to the tracker to find out GuideInfo - clientSocketToTracker = - new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort) - oosTracker = - new ObjectOutputStream(clientSocketToTracker.getOutputStream) - oosTracker.flush() - oisTracker = - new ObjectInputStream(clientSocketToTracker.getInputStream) - - // Send messageType/intention - oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER) - oosTracker.flush() - - // Send Long and receive GuideInfo - oosTracker.writeObject(variableLong) - oosTracker.flush() - gInfo = oisTracker.readObject.asInstanceOf[SourceInfo] - } catch { - case e: Exception => logError("getGuideInfo had a " + e) - } finally { - if (oisTracker != null) { - oisTracker.close() - } - if (oosTracker != null) { - oosTracker.close() - } - if (clientSocketToTracker != null) { - clientSocketToTracker.close() - } - } - - Thread.sleep(MultiTracker.ranGen.nextInt( - MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) + - MultiTracker.MinKnockInterval) - - retriesLeft -= 1 - } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry) - - logDebug("Got this guidePort from Tracker: " + gInfo.listenPort) - return gInfo - } - - def registerBroadcast(id: Long, gInfo: SourceInfo) { - val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(REGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send Long of this broadcast - oosST.writeObject(id) - oosST.flush() - - // Send this tracker's information - oosST.writeObject(gInfo) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - - def unregisterBroadcast(id: Long) { - val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort) - val oosST = new ObjectOutputStream(socket.getOutputStream) - oosST.flush() - val oisST = new ObjectInputStream(socket.getInputStream) - - // Send messageType/intention - oosST.writeObject(UNREGISTER_BROADCAST_TRACKER) - oosST.flush() - - // Send Long of this broadcast - oosST.writeObject(id) - oosST.flush() - - // Receive ACK and throw it away - oisST.readObject.asInstanceOf[Int] - - // Shut stuff down - oisST.close() - oosST.close() - socket.close() - } - - // Helper method to convert an object to Array[BroadcastBlock] - def blockifyObject[IN](obj: IN): VariableInfo = { - val baos = new ByteArrayOutputStream - val oos = new ObjectOutputStream(baos) - oos.writeObject(obj) - oos.close() - baos.close() - val byteArray = baos.toByteArray - val bais = new ByteArrayInputStream(byteArray) - - var blockNum = (byteArray.length / BlockSize) - if (byteArray.length % BlockSize != 0) - blockNum += 1 - - var retVal = new Array[BroadcastBlock](blockNum) - var blockID = 0 - - for (i <- 0 until (byteArray.length, BlockSize)) { - val thisBlockSize = math.min(BlockSize, byteArray.length - i) - var tempByteArray = new Array[Byte](thisBlockSize) - val hasRead = bais.read(tempByteArray, 0, thisBlockSize) - - retVal(blockID) = new BroadcastBlock(blockID, tempByteArray) - blockID += 1 - } - bais.close() - - var variableInfo = VariableInfo(retVal, blockNum, byteArray.length) - variableInfo.hasBlocks = blockNum - - return variableInfo - } - - // Helper method to convert Array[BroadcastBlock] to object - def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock], - totalBytes: Int, - totalBlocks: Int): OUT = { - - var retByteArray = new Array[Byte](totalBytes) - for (i <- 0 until totalBlocks) { - System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, - i * BlockSize, arrayOfBlocks(i).byteArray.length) - } - byteArrayToObject(retByteArray) - } - - private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = { - val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){ - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader) - } - val retVal = in.readObject.asInstanceOf[OUT] - in.close() - return retVal - } -} - -private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte]) -extends Serializable - -private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock], - totalBlocks: Int, - totalBytes: Int) -extends Serializable { - @transient var hasBlocks = 0 -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala b/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala deleted file mode 100644 index baa1fd6da4..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.util.BitSet - -import org.apache.spark._ - -/** - * Used to keep and pass around information of peers involved in a broadcast - */ -private[spark] case class SourceInfo (hostAddress: String, - listenPort: Int, - totalBlocks: Int = SourceInfo.UnusedParam, - totalBytes: Int = SourceInfo.UnusedParam) -extends Comparable[SourceInfo] with Logging { - - var currentLeechers = 0 - var receptionFailed = false - - var hasBlocks = 0 - var hasBlocksBitVector: BitSet = new BitSet (totalBlocks) - - // Ascending sort based on leecher count - def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers) -} - -/** - * Helper Object of SourceInfo for its constants - */ -private[spark] object SourceInfo { - // Broadcast has not started yet! Should never happen. - val TxNotStartedRetry = -1 - // Broadcast has already finished. Try default mechanism. - val TxOverGoToDefault = -3 - // Other constants - val StopBroadcast = -2 - val UnusedParam = 0 -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala new file mode 100644 index 0000000000..073a0a5029 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.io._ + +import scala.math +import scala.util.Random + +import org.apache.spark._ +import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} +import org.apache.spark.util.Utils + + +private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +extends Broadcast[T](id) with Logging with Serializable { + + def value = value_ + + def broadcastId = BroadcastBlockId(id) + + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + } + + @transient var arrayOfBlocks: Array[TorrentBlock] = null + @transient var totalBlocks = -1 + @transient var totalBytes = -1 + @transient var hasBlocks = 0 + + if (!isLocal) { + sendBroadcast() + } + + def sendBroadcast() { + var tInfo = TorrentBroadcast.blockifyObject(value_) + + totalBlocks = tInfo.totalBlocks + totalBytes = tInfo.totalBytes + hasBlocks = tInfo.totalBlocks + + // Store meta-info + val metaId = BroadcastHelperBlockId(broadcastId, "meta") + val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle( + metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true) + } + + // Store individual pieces + for (i <- 0 until totalBlocks) { + val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle( + pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true) + } + } + } + + // Called by JVM when deserializing an object + private def readObject(in: ObjectInputStream) { + in.defaultReadObject() + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(broadcastId) match { + case Some(x) => + value_ = x.asInstanceOf[T] + + case None => + val start = System.nanoTime + logInfo("Started reading broadcast variable " + id) + + // Initialize @transient variables that will receive garbage values from the master. + resetWorkerVariables() + + if (receiveBroadcast(id)) { + value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) + + // Store the merged copy in cache so that the next worker doesn't need to rebuild it. + // This creates a tradeoff between memory usage and latency. + // Storing copy doubles the memory footprint; not storing doubles deserialization cost. + SparkEnv.get.blockManager.putSingle( + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + + // Remove arrayOfBlocks from memory once value_ is on local cache + resetWorkerVariables() + } else { + logError("Reading broadcast variable " + id + " failed") + } + + val time = (System.nanoTime - start) / 1e9 + logInfo("Reading broadcast variable " + id + " took " + time + " s") + } + } + } + + private def resetWorkerVariables() { + arrayOfBlocks = null + totalBytes = -1 + totalBlocks = -1 + hasBlocks = 0 + } + + def receiveBroadcast(variableID: Long): Boolean = { + // Receive meta-info + val metaId = BroadcastHelperBlockId(broadcastId, "meta") + var attemptId = 10 + while (attemptId > 0 && totalBlocks == -1) { + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(metaId) match { + case Some(x) => + val tInfo = x.asInstanceOf[TorrentInfo] + totalBlocks = tInfo.totalBlocks + totalBytes = tInfo.totalBytes + arrayOfBlocks = new Array[TorrentBlock](totalBlocks) + hasBlocks = 0 + + case None => + Thread.sleep(500) + } + } + attemptId -= 1 + } + if (totalBlocks == -1) { + return false + } + + // Receive actual blocks + val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) + for (pid <- recvOrder) { + val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(pieceId) match { + case Some(x) => + arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] + hasBlocks += 1 + SparkEnv.get.blockManager.putSingle( + pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true) + + case None => + throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) + } + } + } + + (hasBlocks == totalBlocks) + } + +} + +private object TorrentBroadcast +extends Logging { + + private var initialized = false + + def initialize(_isDriver: Boolean) { + synchronized { + if (!initialized) { + initialized = true + } + } + } + + def stop() { + initialized = false + } + + val BLOCK_SIZE = System.getProperty("spark.broadcast.blockSize", "4096").toInt * 1024 + + def blockifyObject[T](obj: T): TorrentInfo = { + val byteArray = Utils.serialize[T](obj) + val bais = new ByteArrayInputStream(byteArray) + + var blockNum = (byteArray.length / BLOCK_SIZE) + if (byteArray.length % BLOCK_SIZE != 0) + blockNum += 1 + + var retVal = new Array[TorrentBlock](blockNum) + var blockID = 0 + + for (i <- 0 until (byteArray.length, BLOCK_SIZE)) { + val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i) + var tempByteArray = new Array[Byte](thisBlockSize) + val hasRead = bais.read(tempByteArray, 0, thisBlockSize) + + retVal(blockID) = new TorrentBlock(blockID, tempByteArray) + blockID += 1 + } + bais.close() + + var tInfo = TorrentInfo(retVal, blockNum, byteArray.length) + tInfo.hasBlocks = blockNum + + return tInfo + } + + def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock], + totalBytes: Int, + totalBlocks: Int): T = { + var retByteArray = new Array[Byte](totalBytes) + for (i <- 0 until totalBlocks) { + System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, + i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length) + } + Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader) + } + +} + +private[spark] case class TorrentBlock( + blockID: Int, + byteArray: Array[Byte]) + extends Serializable + +private[spark] case class TorrentInfo( + @transient arrayOfBlocks : Array[TorrentBlock], + totalBlocks: Int, + totalBytes: Int) + extends Serializable { + + @transient var hasBlocks = 0 +} + +private[spark] class TorrentBroadcastFactory + extends BroadcastFactory { + + def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new TorrentBroadcast[T](value_, isLocal, id) + + def stop() { TorrentBroadcast.stop() } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala deleted file mode 100644 index b664f28e42..0000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala +++ /dev/null @@ -1,603 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.io._ -import java.net._ -import java.util.{Comparator, Random, UUID} - -import scala.collection.mutable.{ListBuffer, Map, Set} -import scala.math - -import org.apache.spark._ -import org.apache.spark.storage.{BlockManager, StorageLevel} -import org.apache.spark.util.Utils - -private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) -extends Broadcast[T](id) with Logging with Serializable { - - def value = value_ - - def blockId = BlockManager.toBroadcastId(id) - - MultiTracker.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } - - @transient var arrayOfBlocks: Array[BroadcastBlock] = null - @transient var totalBytes = -1 - @transient var totalBlocks = -1 - @transient var hasBlocks = 0 - - @transient var listenPortLock = new Object - @transient var guidePortLock = new Object - @transient var totalBlocksLock = new Object - @transient var hasBlocksLock = new Object - - @transient var listOfSources = ListBuffer[SourceInfo]() - - @transient var serveMR: ServeMultipleRequests = null - @transient var guideMR: GuideMultipleRequests = null - - @transient var hostAddress = Utils.localIpAddress - @transient var listenPort = -1 - @transient var guidePort = -1 - - @transient var stopBroadcast = false - - // Must call this after all the variables have been created/initialized - if (!isLocal) { - sendBroadcast() - } - - def sendBroadcast() { - logInfo("Local host address: " + hostAddress) - - // Create a variableInfo object and store it in valueInfos - var variableInfo = MultiTracker.blockifyObject(value_) - - // Prepare the value being broadcasted - arrayOfBlocks = variableInfo.arrayOfBlocks - totalBytes = variableInfo.totalBytes - totalBlocks = variableInfo.totalBlocks - hasBlocks = variableInfo.totalBlocks - - guideMR = new GuideMultipleRequests - guideMR.setDaemon(true) - guideMR.start() - logInfo("GuideMultipleRequests started...") - - // Must always come AFTER guideMR is created - while (guidePort == -1) { - guidePortLock.synchronized { guidePortLock.wait() } - } - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - // Must always come AFTER serveMR is created - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - // Must always come AFTER listenPort is created - val masterSource = - SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes) - listOfSources += masterSource - - // Register with the Tracker - MultiTracker.registerBroadcast(id, - SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes)) - } - - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() - MultiTracker.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { - case Some(x) => - value_ = x.asInstanceOf[T] - - case None => - logInfo("Started reading broadcast variable " + id) - // Initializing everything because Driver will only send null/0 values - // Only the 1st worker in a node can be here. Others will get from cache - initializeWorkerVariables() - - logInfo("Local host address: " + hostAddress) - - serveMR = new ServeMultipleRequests - serveMR.setDaemon(true) - serveMR.start() - logInfo("ServeMultipleRequests started...") - - val start = System.nanoTime - - val receptionSucceeded = receiveBroadcast(id) - if (receptionSucceeded) { - value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, false) - } else { - logError("Reading broadcast variable " + id + " failed") - } - - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - - private def initializeWorkerVariables() { - arrayOfBlocks = null - totalBytes = -1 - totalBlocks = -1 - hasBlocks = 0 - - listenPortLock = new Object - totalBlocksLock = new Object - hasBlocksLock = new Object - - serveMR = null - - hostAddress = Utils.localIpAddress - listenPort = -1 - - stopBroadcast = false - } - - def receiveBroadcast(variableID: Long): Boolean = { - val gInfo = MultiTracker.getGuideInfo(variableID) - - if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) { - return false - } - - // Wait until hostAddress and listenPort are created by the - // ServeMultipleRequests thread - while (listenPort == -1) { - listenPortLock.synchronized { listenPortLock.wait() } - } - - var clientSocketToDriver: Socket = null - var oosDriver: ObjectOutputStream = null - var oisDriver: ObjectInputStream = null - - // Connect and receive broadcast from the specified source, retrying the - // specified number of times in case of failures - var retriesLeft = MultiTracker.MaxRetryCount - do { - // Connect to Driver and send this worker's Information - clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort) - oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream) - oosDriver.flush() - oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream) - - logDebug("Connected to Driver's guiding object") - - // Send local source information - oosDriver.writeObject(SourceInfo(hostAddress, listenPort)) - oosDriver.flush() - - // Receive source information from Driver - var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo] - totalBlocks = sourceInfo.totalBlocks - arrayOfBlocks = new Array[BroadcastBlock](totalBlocks) - totalBlocksLock.synchronized { totalBlocksLock.notifyAll() } - totalBytes = sourceInfo.totalBytes - - logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort) - - val start = System.nanoTime - val receptionSucceeded = receiveSingleTransmission(sourceInfo) - val time = (System.nanoTime - start) / 1e9 - - // Updating some statistics in sourceInfo. Driver will be using them later - if (!receptionSucceeded) { - sourceInfo.receptionFailed = true - } - - // Send back statistics to the Driver - oosDriver.writeObject(sourceInfo) - - if (oisDriver != null) { - oisDriver.close() - } - if (oosDriver != null) { - oosDriver.close() - } - if (clientSocketToDriver != null) { - clientSocketToDriver.close() - } - - retriesLeft -= 1 - } while (retriesLeft > 0 && hasBlocks < totalBlocks) - - return (hasBlocks == totalBlocks) - } - - /** - * Tries to receive broadcast from the source and returns Boolean status. - * This might be called multiple times to retry a defined number of times. - */ - private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = { - var clientSocketToSource: Socket = null - var oosSource: ObjectOutputStream = null - var oisSource: ObjectInputStream = null - - var receptionSucceeded = false - try { - // Connect to the source to get the object itself - clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream) - oosSource.flush() - oisSource = new ObjectInputStream(clientSocketToSource.getInputStream) - - logDebug("Inside receiveSingleTransmission") - logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks) - - // Send the range - oosSource.writeObject((hasBlocks, totalBlocks)) - oosSource.flush() - - for (i <- hasBlocks until totalBlocks) { - val recvStartTime = System.currentTimeMillis - val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock] - val receptionTime = (System.currentTimeMillis - recvStartTime) - - logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.") - - arrayOfBlocks(hasBlocks) = bcBlock - hasBlocks += 1 - - // Set to true if at least one block is received - receptionSucceeded = true - hasBlocksLock.synchronized { hasBlocksLock.notifyAll() } - } - } catch { - case e: Exception => logError("receiveSingleTransmission had a " + e) - } finally { - if (oisSource != null) { - oisSource.close() - } - if (oosSource != null) { - oosSource.close() - } - if (clientSocketToSource != null) { - clientSocketToSource.close() - } - } - - return receptionSucceeded - } - - class GuideMultipleRequests - extends Thread with Logging { - // Keep track of sources that have completed reception - private var setOfCompletedSources = Set[SourceInfo]() - - override def run() { - var threadPool = Utils.newDaemonCachedThreadPool() - var serverSocket: ServerSocket = null - - serverSocket = new ServerSocket(0) - guidePort = serverSocket.getLocalPort - logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort) - - guidePortLock.synchronized { guidePortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { - // Stop broadcast if at least one worker has connected and - // everyone connected so far are done. Comparing with - // listOfSources.size - 1, because it includes the Guide itself - listOfSources.synchronized { - setOfCompletedSources.synchronized { - if (listOfSources.size > 1 && - setOfCompletedSources.size == listOfSources.size - 1) { - stopBroadcast = true - logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.") - } - } - } - } - } - if (clientSocket != null) { - logDebug("Guide: Accepted new client connection: " + clientSocket) - try { - threadPool.execute(new GuideSingleRequest(clientSocket)) - } catch { - // In failure, close() the socket here; else, the thread will close() it - case ioe: IOException => clientSocket.close() - } - } - } - - logInfo("Sending stopBroadcast notifications...") - sendStopBroadcastNotifications - - MultiTracker.unregisterBroadcast(id) - } finally { - if (serverSocket != null) { - logInfo("GuideMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - private def sendStopBroadcastNotifications() { - listOfSources.synchronized { - var listIter = listOfSources.iterator - while (listIter.hasNext) { - var sourceInfo = listIter.next - - var guideSocketToSource: Socket = null - var gosSource: ObjectOutputStream = null - var gisSource: ObjectInputStream = null - - try { - // Connect to the source - guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort) - gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream) - gosSource.flush() - gisSource = new ObjectInputStream(guideSocketToSource.getInputStream) - - // Send stopBroadcast signal - gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast)) - gosSource.flush() - } catch { - case e: Exception => { - logError("sendStopBroadcastNotifications had a " + e) - } - } finally { - if (gisSource != null) { - gisSource.close() - } - if (gosSource != null) { - gosSource.close() - } - if (guideSocketToSource != null) { - guideSocketToSource.close() - } - } - } - } - } - - class GuideSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var selectedSourceInfo: SourceInfo = null - private var thisWorkerInfo:SourceInfo = null - - override def run() { - try { - logInfo("new GuideSingleRequest is running") - // Connecting worker is sending in its hostAddress and listenPort it will - // be listening to. Other fields are invalid (SourceInfo.UnusedParam) - var sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - listOfSources.synchronized { - // Select a suitable source and send it back to the worker - selectedSourceInfo = selectSuitableSource(sourceInfo) - logDebug("Sending selectedSourceInfo: " + selectedSourceInfo) - oos.writeObject(selectedSourceInfo) - oos.flush() - - // Add this new (if it can finish) source to the list of sources - thisWorkerInfo = SourceInfo(sourceInfo.hostAddress, - sourceInfo.listenPort, totalBlocks, totalBytes) - logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo) - listOfSources += thisWorkerInfo - } - - // Wait till the whole transfer is done. Then receive and update source - // statistics in listOfSources - sourceInfo = ois.readObject.asInstanceOf[SourceInfo] - - listOfSources.synchronized { - // This should work since SourceInfo is a case class - assert(listOfSources.contains(selectedSourceInfo)) - - // Remove first - // (Currently removing a source based on just one failure notification!) - listOfSources = listOfSources - selectedSourceInfo - - // Update sourceInfo and put it back in, IF reception succeeded - if (!sourceInfo.receptionFailed) { - // Add thisWorkerInfo to sources that have completed reception - setOfCompletedSources.synchronized { - setOfCompletedSources += thisWorkerInfo - } - - // Update leecher count and put it back in - selectedSourceInfo.currentLeechers -= 1 - listOfSources += selectedSourceInfo - } - } - } catch { - case e: Exception => { - // Remove failed worker from listOfSources and update leecherCount of - // corresponding source worker - listOfSources.synchronized { - if (selectedSourceInfo != null) { - // Remove first - listOfSources = listOfSources - selectedSourceInfo - // Update leecher count and put it back in - selectedSourceInfo.currentLeechers -= 1 - listOfSources += selectedSourceInfo - } - - // Remove thisWorkerInfo - if (listOfSources != null) { - listOfSources = listOfSources - thisWorkerInfo - } - } - } - } finally { - logInfo("GuideSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - // Assuming the caller to have a synchronized block on listOfSources - // Select one with the most leechers. This will level-wise fill the tree - private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = { - var maxLeechers = -1 - var selectedSource: SourceInfo = null - - listOfSources.foreach { source => - if ((source.hostAddress != skipSourceInfo.hostAddress || - source.listenPort != skipSourceInfo.listenPort) && - source.currentLeechers < MultiTracker.MaxDegree && - source.currentLeechers > maxLeechers) { - selectedSource = source - maxLeechers = source.currentLeechers - } - } - - // Update leecher count - selectedSource.currentLeechers += 1 - return selectedSource - } - } - } - - class ServeMultipleRequests - extends Thread with Logging { - - var threadPool = Utils.newDaemonCachedThreadPool() - - override def run() { - var serverSocket = new ServerSocket(0) - listenPort = serverSocket.getLocalPort - - logInfo("ServeMultipleRequests started with " + serverSocket) - - listenPortLock.synchronized { listenPortLock.notifyAll() } - - try { - while (!stopBroadcast) { - var clientSocket: Socket = null - try { - serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout) - clientSocket = serverSocket.accept - } catch { - case e: Exception => { } - } - - if (clientSocket != null) { - logDebug("Serve: Accepted new client connection: " + clientSocket) - try { - threadPool.execute(new ServeSingleRequest(clientSocket)) - } catch { - // In failure, close socket here; else, the thread will close it - case ioe: IOException => clientSocket.close() - } - } - } - } finally { - if (serverSocket != null) { - logInfo("ServeMultipleRequests now stopping...") - serverSocket.close() - } - } - // Shutdown the thread pool - threadPool.shutdown() - } - - class ServeSingleRequest(val clientSocket: Socket) - extends Thread with Logging { - private val oos = new ObjectOutputStream(clientSocket.getOutputStream) - oos.flush() - private val ois = new ObjectInputStream(clientSocket.getInputStream) - - private var sendFrom = 0 - private var sendUntil = totalBlocks - - override def run() { - try { - logInfo("new ServeSingleRequest is running") - - // Receive range to send - var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)] - sendFrom = rangeToSend._1 - sendUntil = rangeToSend._2 - - // If not a valid range, stop broadcast - if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) { - stopBroadcast = true - } else { - sendObject - } - } catch { - case e: Exception => logError("ServeSingleRequest had a " + e) - } finally { - logInfo("ServeSingleRequest is closing streams and sockets") - ois.close() - oos.close() - clientSocket.close() - } - } - - private def sendObject() { - // Wait till receiving the SourceInfo from Driver - while (totalBlocks == -1) { - totalBlocksLock.synchronized { totalBlocksLock.wait() } - } - - for (i <- sendFrom until sendUntil) { - while (i == hasBlocks) { - hasBlocksLock.synchronized { hasBlocksLock.wait() } - } - try { - oos.writeObject(arrayOfBlocks(i)) - oos.flush() - } catch { - case e: Exception => logError("sendObject had a " + e) - } - logDebug("Sent block: " + i + " to " + clientSocket) - } - } - } - } -} - -private[spark] class TreeBroadcastFactory -extends BroadcastFactory { - def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new TreeBroadcast[T](value_, isLocal, id) - - def stop() { MultiTracker.stop() } -} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 993ba6bd3d..83cd3df5fa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -17,26 +17,31 @@ package org.apache.spark.deploy -import com.google.common.collect.MapMaker - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.JobConf +import com.google.common.collect.MapMaker + /** - * Contains util methods to interact with Hadoop from spark. + * Contains util methods to interact with Hadoop from Spark. */ +private[spark] class SparkHadoopUtil { // A general, soft-reference map for metadata needed during HadoopRDD split computation // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]() - // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop - // subsystems + /** + * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop + * subsystems. + */ def newConfiguration(): Configuration = new Configuration() - // Add any user credentials to the job conf which are necessary for running on a secure Hadoop - // cluster + /** + * Add any user credentials to the job conf which are necessary for running on a secure Hadoop + * cluster. + */ def addCredentials(conf: JobConf) {} def isYarnMode(): Boolean = { false } diff --git a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 7839023868..52b1c492b2 100644 --- a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -24,11 +24,11 @@ import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClie import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.TaskState.TaskState -import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._ +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{Utils, AkkaUtils} -private[spark] class StandaloneExecutorBackend( +private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, executorId: String, hostPort: String, @@ -63,12 +63,20 @@ private[spark] class StandaloneExecutorBackend( case LaunchTask(taskDesc) => logInfo("Got assigned task " + taskDesc.taskId) if (executor == null) { - logError("Received launchTask but executor was null") + logError("Received LaunchTask command but executor was null") System.exit(1) } else { executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask) } + case KillTask(taskId, _) => + if (executor == null) { + logError("Received KillTask command but executor was null") + System.exit(1) + } else { + executor.killTask(taskId) + } + case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => logError("Driver terminated or disconnected! Shutting down.") System.exit(1) @@ -79,7 +87,7 @@ private[spark] class StandaloneExecutorBackend( } } -private[spark] object StandaloneExecutorBackend { +private[spark] object CoarseGrainedExecutorBackend { def run(driverUrl: String, executorId: String, hostname: String, cores: Int) { // Debug code Utils.checkHost(hostname) @@ -91,7 +99,7 @@ private[spark] object StandaloneExecutorBackend { val sparkHostPort = hostname + ":" + boundPort System.setProperty("spark.hostPort", sparkHostPort) val actor = actorSystem.actorOf( - Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)), + Props(new CoarseGrainedExecutorBackend(driverUrl, executorId, sparkHostPort, cores)), name = "Executor") actorSystem.awaitTermination() } @@ -99,7 +107,9 @@ private[spark] object StandaloneExecutorBackend { def main(args: Array[String]) { if (args.length < 4) { //the reason we allow the last frameworkId argument is to make it easy to kill rogue executors - System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores> [<appid>]") + System.err.println( + "Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> <cores> " + + "[<appid>]") System.exit(1) } run(args(0), args(1), args(2), args(3).toInt) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index acdb8d0343..b773346df3 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -27,7 +27,7 @@ import scala.collection.mutable.HashMap import org.apache.spark.scheduler._ import org.apache.spark._ -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} import org.apache.spark.util.Utils /** @@ -36,7 +36,8 @@ import org.apache.spark.util.Utils private[spark] class Executor( executorId: String, slaveHostname: String, - properties: Seq[(String, String)]) + properties: Seq[(String, String)], + isLocal: Boolean = false) extends Logging { // Application dependencies (added through SparkContext) that we've fetched so far on this node. @@ -73,46 +74,77 @@ private[spark] class Executor( private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) Thread.currentThread.setContextClassLoader(replClassLoader) - // Make any thread terminations due to uncaught exceptions kill the entire - // executor process to avoid surprising stalls. - Thread.setDefaultUncaughtExceptionHandler( - new Thread.UncaughtExceptionHandler { - override def uncaughtException(thread: Thread, exception: Throwable) { - try { - logError("Uncaught exception in thread " + thread, exception) - - // We may have been called from a shutdown hook. If so, we must not call System.exit(). - // (If we do, we will deadlock.) - if (!Utils.inShutdown()) { - if (exception.isInstanceOf[OutOfMemoryError]) { - System.exit(ExecutorExitCode.OOM) - } else { - System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION) + if (!isLocal) { + // Setup an uncaught exception handler for non-local mode. + // Make any thread terminations due to uncaught exceptions kill the entire + // executor process to avoid surprising stalls. + Thread.setDefaultUncaughtExceptionHandler( + new Thread.UncaughtExceptionHandler { + override def uncaughtException(thread: Thread, exception: Throwable) { + try { + logError("Uncaught exception in thread " + thread, exception) + + // We may have been called from a shutdown hook. If so, we must not call System.exit(). + // (If we do, we will deadlock.) + if (!Utils.inShutdown()) { + if (exception.isInstanceOf[OutOfMemoryError]) { + System.exit(ExecutorExitCode.OOM) + } else { + System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION) + } } + } catch { + case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM) + case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE) } - } catch { - case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM) - case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE) } } - } - ) + ) + } val executorSource = new ExecutorSource(this, executorId) // Initialize Spark environment (using system properties read above) - val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false) - SparkEnv.set(env) - env.metricsSystem.registerSource(executorSource) + private val env = { + if (!isLocal) { + val _env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, + isDriver = false, isLocal = false) + SparkEnv.set(_env) + _env.metricsSystem.registerSource(executorSource) + _env + } else { + SparkEnv.get + } + } - private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size") + // Akka's message frame size. If task result is bigger than this, we use the block manager + // to send the result back. + private val akkaFrameSize = { + env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size") + } // Start worker thread pool - val threadPool = new ThreadPoolExecutor( - 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) + val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") + + // Maintains the list of running tasks. + private val runningTasks = new ConcurrentHashMap[Long, TaskRunner] def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) { - threadPool.execute(new TaskRunner(context, taskId, serializedTask)) + val tr = new TaskRunner(context, taskId, serializedTask) + runningTasks.put(taskId, tr) + threadPool.execute(tr) + } + + def killTask(taskId: Long) { + val tr = runningTasks.get(taskId) + if (tr != null) { + tr.kill() + // We remove the task also in the finally block in TaskRunner.run. + // The reason we need to remove it here is because killTask might be called before the task + // is even launched, and never reaching that finally block. ConcurrentHashMap's remove is + // idempotent. + runningTasks.remove(taskId) + } } /** Get the Yarn approved local directories. */ @@ -124,56 +156,87 @@ private[spark] class Executor( .getOrElse(Option(System.getenv("LOCAL_DIRS")) .getOrElse("")) - if (localDirs.isEmpty()) { + if (localDirs.isEmpty) { throw new Exception("Yarn Local dirs can't be empty") } - return localDirs + localDirs } - class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) + class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) extends Runnable { + @volatile private var killed = false + @volatile private var task: Task[Any] = _ + + def kill() { + logInfo("Executor is trying to kill task " + taskId) + killed = true + if (task != null) { + task.kill() + } + } + override def run() { val startTime = System.currentTimeMillis() SparkEnv.set(env) Thread.currentThread.setContextClassLoader(replClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() logInfo("Running task ID " + taskId) - context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) + execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) var attemptedTask: Option[Task[Any]] = None var taskStart: Long = 0 - def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum - val startGCTime = getTotalGCTime + def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum + val startGCTime = gcTime try { SparkEnv.set(env) Accumulators.clear() val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) updateDependencies(taskFiles, taskJars) - val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + + // If this task has been killed before we deserialized it, let's quit now. Otherwise, + // continue executing the task. + if (killed) { + logInfo("Executor killed task " + taskId) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + return + } + attemptedTask = Some(task) - logInfo("Its epoch is " + task.epoch) + logDebug("Task " + taskId +"'s epoch is " + task.epoch) env.mapOutputTracker.updateEpoch(task.epoch) + + // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() val value = task.run(taskId.toInt) val taskFinish = System.currentTimeMillis() + + // If the task has been killed, let's fail it. + if (task.killed) { + logInfo("Executor killed task " + taskId) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + return + } + for (m <- task.metrics) { - m.hostname = Utils.localHostName + m.hostname = Utils.localHostName() m.executorDeserializeTime = (taskStart - startTime).toInt m.executorRunTime = (taskFinish - taskStart).toInt - m.jvmGCTime = getTotalGCTime - startGCTime + m.jvmGCTime = gcTime - startGCTime } - //TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c - // we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could - // just change the relevants bytes in the byte buffer + // TODO I'd also like to track the time it takes to serialize the task results, but that is + // huge headache, b/c we need to serialize the task metrics first. If TaskMetrics had a + // custom serialized format, we could just change the relevants bytes in the byte buffer val accumUpdates = Accumulators.values + val directResult = new DirectTaskResult(value, accumUpdates, task.metrics.getOrElse(null)) val serializedDirectResult = ser.serialize(directResult) logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit) val serializedResult = { if (serializedDirectResult.limit >= akkaFrameSize - 1024) { logInfo("Storing result for " + taskId + " in local BlockManager") - val blockId = "taskresult_" + taskId + val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) ser.serialize(new IndirectTaskResult[Any](blockId)) @@ -182,12 +245,13 @@ private[spark] class Executor( serializedDirectResult } } - context.statusUpdate(taskId, TaskState.FINISHED, serializedResult) + + execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) logInfo("Finished task ID " + taskId) } catch { case ffe: FetchFailedException => { val reason = ffe.toTaskEndReason - context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) } case t: Throwable => { @@ -195,10 +259,10 @@ private[spark] class Executor( val metrics = attemptedTask.flatMap(t => t.metrics) for (m <- metrics) { m.executorRunTime = serviceTime - m.jvmGCTime = getTotalGCTime - startGCTime + m.jvmGCTime = gcTime - startGCTime } val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics) - context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) // TODO: Should we exit the whole executor here? On the one hand, the failed task may // have left some weird state around depending on when the exception was thrown, but on @@ -206,6 +270,8 @@ private[spark] class Executor( logError("Exception in task ID " + taskId, t) //System.exit(1) } + } finally { + runningTasks.remove(taskId) } } } @@ -215,7 +281,7 @@ private[spark] class Executor( * created by the interpreter to the search path */ private def createClassLoader(): ExecutorURLClassLoader = { - var loader = this.getClass.getClassLoader + val loader = this.getClass.getClassLoader // For each of the jars in the jarSet, add them to the class loader. // We assume each of the files has already been fetched. @@ -237,7 +303,7 @@ private[spark] class Executor( val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader]) - return constructor.newInstance(classUri, parent) + constructor.newInstance(classUri, parent) } catch { case _: ClassNotFoundException => logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") @@ -245,7 +311,7 @@ private[spark] class Executor( null } } else { - return parent + parent } } diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index da62091980..b56d8c9912 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -18,14 +18,18 @@ package org.apache.spark.executor import java.nio.ByteBuffer -import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver} -import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => MesosTaskStatus, _} -import org.apache.spark.TaskState.TaskState + import com.google.protobuf.ByteString -import org.apache.spark.{Logging} + +import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver} +import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} + +import org.apache.spark.Logging import org.apache.spark.TaskState +import org.apache.spark.TaskState.TaskState import org.apache.spark.util.Utils + private[spark] class MesosExecutorBackend extends MesosExecutor with ExecutorBackend @@ -71,7 +75,11 @@ private[spark] class MesosExecutorBackend } override def killTask(d: ExecutorDriver, t: TaskID) { - logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)") + if (executor == null) { + logError("Received KillTask but executor was null") + } else { + executor.killTask(t.getValue.toLong) + } } override def reregistered(d: ExecutorDriver, p2: SlaveInfo) {} diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index f311141148..0b4892f98f 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -102,4 +102,9 @@ class ShuffleWriteMetrics extends Serializable { * Number of bytes written for a shuffle */ var shuffleBytesWritten: Long = _ + + /** + * Time spent blocking on writes to disk or buffer cache, in nanoseconds. + */ + var shuffleWriteTime: Long = _ } diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index e15a839c4e..9c2fee4023 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -79,7 +79,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging { private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] private val registerRequests = new SynchronizedQueue[SendingConnection] - implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool()) + implicit val futureExecContext = ExecutionContext.fromExecutor( + Utils.newDaemonCachedThreadPool("Connection manager future execution context")) private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala index 3c29700920..1b9fa1e53a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala @@ -20,17 +20,18 @@ package org.apache.spark.network.netty import io.netty.buffer._ import org.apache.spark.Logging +import org.apache.spark.storage.{TestBlockId, BlockId} private[spark] class FileHeader ( val fileLen: Int, - val blockId: String) extends Logging { + val blockId: BlockId) extends Logging { lazy val buffer = { val buf = Unpooled.buffer() buf.capacity(FileHeader.HEADER_SIZE) buf.writeInt(fileLen) - buf.writeInt(blockId.length) - blockId.foreach((x: Char) => buf.writeByte(x)) + buf.writeInt(blockId.name.length) + blockId.name.foreach((x: Char) => buf.writeByte(x)) //padding the rest of header if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) @@ -57,18 +58,15 @@ private[spark] object FileHeader { for (i <- 1 to idLength) { idBuilder += buf.readByte().asInstanceOf[Char] } - val blockId = idBuilder.toString() + val blockId = BlockId(idBuilder.toString()) new FileHeader(length, blockId) } - - def main (args:Array[String]){ - - val header = new FileHeader(25,"block_0"); - val buf = header.buffer; - val newheader = FileHeader.create(buf); - System.out.println("id="+newheader.blockId+",size="+newheader.fileLen) - + def main (args:Array[String]) { + val header = new FileHeader(25, TestBlockId("my_block")) + val buf = header.buffer + val newHeader = FileHeader.create(buf) + System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala index 9493ccffd9..481ff8c3e0 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala @@ -27,12 +27,13 @@ import org.apache.spark.Logging import org.apache.spark.network.ConnectionManagerId import scala.collection.JavaConverters._ +import org.apache.spark.storage.BlockId private[spark] class ShuffleCopier extends Logging { - def getBlock(host: String, port: Int, blockId: String, - resultCollectCallback: (String, Long, ByteBuf) => Unit) { + def getBlock(host: String, port: Int, blockId: BlockId, + resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt @@ -41,7 +42,7 @@ private[spark] class ShuffleCopier extends Logging { try { fc.init() fc.connect(host, port) - fc.sendRequest(blockId) + fc.sendRequest(blockId.name) fc.waitForClose() fc.close() } catch { @@ -53,14 +54,14 @@ private[spark] class ShuffleCopier extends Logging { } } - def getBlock(cmId: ConnectionManagerId, blockId: String, - resultCollectCallback: (String, Long, ByteBuf) => Unit) { + def getBlock(cmId: ConnectionManagerId, blockId: BlockId, + resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { getBlock(cmId.host, cmId.port, blockId, resultCollectCallback) } def getBlocks(cmId: ConnectionManagerId, - blocks: Seq[(String, Long)], - resultCollectCallback: (String, Long, ByteBuf) => Unit) { + blocks: Seq[(BlockId, Long)], + resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { for ((blockId, size) <- blocks) { getBlock(cmId, blockId, resultCollectCallback) @@ -71,7 +72,7 @@ private[spark] class ShuffleCopier extends Logging { private[spark] object ShuffleCopier extends Logging { - private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit) + private class ShuffleClientHandler(resultCollectCallBack: (BlockId, Long, ByteBuf) => Unit) extends FileClientHandler with Logging { override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { @@ -79,14 +80,14 @@ private[spark] object ShuffleCopier extends Logging { resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) } - override def handleError(blockId: String) { + override def handleError(blockId: BlockId) { if (!isComplete) { resultCollectCallBack(blockId, -1, null) } } } - def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) { + def echoResultCollectCallBack(blockId: BlockId, size: Long, content: ByteBuf) { if (size != -1) { logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"") } @@ -99,7 +100,7 @@ private[spark] object ShuffleCopier extends Logging { } val host = args(0) val port = args(1).toInt - val file = args(2) + val blockId = BlockId(args(2)) val threads = if (args.length > 3) args(3).toInt else 10 val copiers = Executors.newFixedThreadPool(80) @@ -107,12 +108,12 @@ private[spark] object ShuffleCopier extends Logging { Executors.callable(new Runnable() { def run() { val copier = new ShuffleCopier() - copier.getBlock(host, port, file, echoResultCollectCallBack) + copier.getBlock(host, port, blockId, echoResultCollectCallBack) } }) }).asJava copiers.invokeAll(tasks) - copiers.shutdown + copiers.shutdown() System.exit(0) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala index 0c5ded3145..546d921067 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala @@ -21,7 +21,7 @@ import java.io.File import org.apache.spark.Logging import org.apache.spark.util.Utils -import org.apache.spark.storage.ShuffleBlockManager +import org.apache.spark.storage.{BlockId, FileSegment} private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging { @@ -54,8 +54,8 @@ private[spark] object ShuffleSender { val localDirs = args.drop(2).map(new File(_)) val pResovler = new PathResolver { - override def getAbsolutePath(blockId: String): String = { - if (!ShuffleBlockManager.isShuffle(blockId)) { + override def getBlockLocation(blockId: BlockId): FileSegment = { + if (!blockId.isShuffle) { throw new Exception("Block " + blockId + " is not a shuffle block") } // Figure out which local directory it hashes to, and which subdirectory in that @@ -63,8 +63,8 @@ private[spark] object ShuffleSender { val dirId = hash % localDirs.length val subDirId = (hash / localDirs.length) % subDirsPerLocalDir val subDir = new File(localDirs(dirId), "%02x".format(subDirId)) - val file = new File(subDir, blockId) - return file.getAbsolutePath + val file = new File(subDir, blockId.name) + return new FileSegment(file, 0, file.length()) } } val sender = new ShuffleSender(port, pResovler) diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index f132e2b735..70a5a8caff 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -15,6 +15,8 @@ * limitations under the License. */ +package org.apache + /** * Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to * Spark, while [[org.apache.spark.rdd.RDD]] is the data type representing a distributed collection, diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala new file mode 100644 index 0000000000..faaf837be0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.ExecutionContext.Implicits.global + +import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} + +/** + * A set of asynchronous RDD actions available through an implicit conversion. + * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. + */ +class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with Logging { + + /** + * Returns a future for counting the number of elements in the RDD. + */ + def countAsync(): FutureAction[Long] = { + val totalCount = new AtomicLong + self.context.submitJob( + self, + (iter: Iterator[T]) => { + var result = 0L + while (iter.hasNext) { + result += 1L + iter.next() + } + result + }, + Range(0, self.partitions.size), + (index: Int, data: Long) => totalCount.addAndGet(data), + totalCount.get()) + } + + /** + * Returns a future for retrieving all elements of this RDD. + */ + def collectAsync(): FutureAction[Seq[T]] = { + val results = new Array[Array[T]](self.partitions.size) + self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size), + (index, data) => results(index) = data, results.flatten.toSeq) + } + + /** + * Returns a future for retrieving the first num elements of the RDD. + */ + def takeAsync(num: Int): FutureAction[Seq[T]] = { + val f = new ComplexFutureAction[Seq[T]] + + f.run { + val results = new ArrayBuffer[T](num) + val totalParts = self.partitions.length + var partsScanned = 0 + while (results.size < num && partsScanned < totalParts) { + // The number of partitions to try in this iteration. It is ok for this number to be + // greater than totalParts because we actually cap it at totalParts in runJob. + var numPartsToTry = 1 + if (partsScanned > 0) { + // If we didn't find any rows after the first iteration, just try all partitions next. + // Otherwise, interpolate the number of partitions we need to try, but overestimate it + // by 50%. + if (results.size == 0) { + numPartsToTry = totalParts - 1 + } else { + numPartsToTry = (1.5 * num * partsScanned / results.size).toInt + } + } + numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions + + val left = num - results.size + val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + + val buf = new Array[Array[T]](p.size) + f.runJob(self, + (it: Iterator[T]) => it.take(left).toArray, + p, + (index: Int, data: Array[T]) => buf(index) = data, + Unit) + + buf.foreach(results ++= _.take(num - results.size)) + partsScanned += numPartsToTry + } + results.toSeq + } + + f + } + + /** + * Applies a function f to all elements of this RDD. + */ + def foreachAsync(f: T => Unit): FutureAction[Unit] = { + self.context.submitJob[T, Unit, Unit](self, _.foreach(f), Range(0, self.partitions.size), + (index, data) => Unit, Unit) + } + + /** + * Applies a function f to each partition of this RDD. + */ + def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = { + self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size), + (index, data) => Unit, Unit) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index bca6956a18..44ea573a7c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -18,14 +18,14 @@ package org.apache.spark.rdd import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext} -import org.apache.spark.storage.BlockManager +import org.apache.spark.storage.{BlockId, BlockManager} -private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition { +private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends Partition { val index = idx } private[spark] -class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String]) +class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { @transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 3311757189..ccaaecb85b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -85,7 +85,7 @@ private[spark] object CheckpointRDD extends Logging { val outputDir = new Path(path) val fs = outputDir.getFileSystem(env.hadoop.newConfiguration()) - val finalOutputName = splitIdToFile(ctx.splitId) + val finalOutputName = splitIdToFile(ctx.partitionId) val finalOutputPath = new Path(outputDir, finalOutputName) val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index d237797aa6..911a002884 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -21,7 +21,7 @@ import java.io.{ObjectOutputStream, IOException} import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext} +import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} import org.apache.spark.util.AppendOnlyMap @@ -125,12 +125,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: case ShuffleCoGroupSplitDep(shuffleId) => { // Read map outputs of shuffle val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach { + fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach { kv => getSeq(kv._1)(depNum) += kv._2 } } } - map.iterator + new InterruptibleIterator(context, map.iterator) } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index d3b3fffd40..fad042c7ae 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -27,54 +27,18 @@ import org.apache.hadoop.mapred.RecordReader import org.apache.hadoop.mapred.Reporter import org.apache.hadoop.util.ReflectionUtils -import org.apache.spark.{Logging, Partition, SerializableWritable, SparkContext, SparkEnv, - TaskContext} +import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.util.NextIterator import org.apache.hadoop.conf.{Configuration, Configurable} -/** - * An RDD that reads a file (or multiple files) from Hadoop (e.g. files in HDFS, the local file - * system, or S3). - * This accepts a general, broadcasted Hadoop Configuration because those tend to remain the same - * across multiple reads; the 'path' is the only variable that is different across new JobConfs - * created from the Configuration. - */ -class HadoopFileRDD[K, V]( - sc: SparkContext, - path: String, - broadcastedConf: Broadcast[SerializableWritable[Configuration]], - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], - valueClass: Class[V], - minSplits: Int) - extends HadoopRDD[K, V](sc, broadcastedConf, inputFormatClass, keyClass, valueClass, minSplits) { - - override def getJobConf(): JobConf = { - if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) { - // getJobConf() has been called previously, so there is already a local cache of the JobConf - // needed by this RDD. - return HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] - } else { - // Create a new JobConf, set the input file/directory paths to read from, and cache the - // JobConf (i.e., in a shared hash map in the slave's JVM process that's accessible through - // HadoopRDD.putCachedMetadata()), so that we only create one copy across multiple - // getJobConf() calls for this RDD in the local process. - // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. - val newJobConf = new JobConf(broadcastedConf.value.value) - FileInputFormat.setInputPaths(newJobConf, path) - HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) - return newJobConf - } - } -} /** * A Spark split class that wraps around a Hadoop InputSplit. */ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit) extends Partition { - + val inputSplit = new SerializableWritable[InputSplit](s) override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt @@ -83,11 +47,24 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp } /** - * A base class that provides core functionality for reading data partitions stored in Hadoop. + * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, + * sources in HBase, or S3). + * + * @param sc The SparkContext to associate the RDD with. + * @param broadCastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed + * variabe references an instance of JobConf, then that JobConf will be used for the Hadoop job. + * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration. + * @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that HadoopRDD + * creates. + * @param inputFormatClass Storage format of the data to be read. + * @param keyClass Class of the key associated with the inputFormatClass. + * @param valueClass Class of the value associated with the inputFormatClass. + * @param minSplits Minimum number of Hadoop Splits (HadoopRDD partitions) to generate. */ class HadoopRDD[K, V]( sc: SparkContext, broadcastedConf: Broadcast[SerializableWritable[Configuration]], + initLocalJobConfFuncOpt: Option[JobConf => Unit], inputFormatClass: Class[_ <: InputFormat[K, V]], keyClass: Class[K], valueClass: Class[V], @@ -105,6 +82,7 @@ class HadoopRDD[K, V]( sc, sc.broadcast(new SerializableWritable(conf)) .asInstanceOf[Broadcast[SerializableWritable[Configuration]]], + None /* initLocalJobConfFuncOpt */, inputFormatClass, keyClass, valueClass, @@ -130,6 +108,7 @@ class HadoopRDD[K, V]( // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. val newJobConf = new JobConf(broadcastedConf.value.value) + initLocalJobConfFuncOpt.map(f => f(newJobConf)) HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) return newJobConf } @@ -164,38 +143,41 @@ class HadoopRDD[K, V]( array } - override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] { - val split = theSplit.asInstanceOf[HadoopPartition] - logInfo("Input split: " + split.inputSplit) - var reader: RecordReader[K, V] = null - - val jobConf = getJobConf() - val inputFormat = getInputFormat(jobConf) - reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) - - // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback{ () => closeIfNeeded() } - - val key: K = reader.createKey() - val value: V = reader.createValue() - - override def getNext() = { - try { - finished = !reader.next(key, value) - } catch { - case eof: EOFException => - finished = true + override def compute(theSplit: Partition, context: TaskContext) = { + val iter = new NextIterator[(K, V)] { + val split = theSplit.asInstanceOf[HadoopPartition] + logInfo("Input split: " + split.inputSplit) + var reader: RecordReader[K, V] = null + + val jobConf = getJobConf() + val inputFormat = getInputFormat(jobConf) + reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) + + // Register an on-task-completion callback to close the input stream. + context.addOnCompleteCallback{ () => closeIfNeeded() } + + val key: K = reader.createKey() + val value: V = reader.createValue() + + override def getNext() = { + try { + finished = !reader.next(key, value) + } catch { + case eof: EOFException => + finished = true + } + (key, value) } - (key, value) - } - override def close() { - try { - reader.close() - } catch { - case e: Exception => logWarning("Exception in RecordReader.close()", e) + override def close() { + try { + reader.close() + } catch { + case e: Exception => logWarning("Exception in RecordReader.close()", e) + } } } + new InterruptibleIterator[(K, V)](context, iter) } override def getPreferredLocations(split: Partition): Seq[String] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala index 3ed8339010..aea08ff81b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala @@ -21,14 +21,14 @@ import org.apache.spark.{Partition, TaskContext} /** - * A variant of the MapPartitionsRDD that passes the partition index into the - * closure. This can be used to generate or collect partition specific - * information such as the number of tuples in a partition. + * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the + * TaskContext, the closure can either get access to the interruptible flag or get the index + * of the partition in the RDD. */ private[spark] -class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest]( +class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], - f: (Int, Iterator[T]) => Iterator[U], + f: (TaskContext, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean ) extends RDD[U](prev) { @@ -37,5 +37,5 @@ class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest]( override val partitioner = if (preservesPartitioning) prev.partitioner else None override def compute(split: Partition, context: TaskContext) = - f(split.index, firstParent[T].iterator(split, context)) + f(context, firstParent[T].iterator(split, context)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 7b3a89f7e0..2662d48c84 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ -import org.apache.spark.{Dependency, Logging, Partition, SerializableWritable, SparkContext, TaskContext} +import org.apache.spark.{InterruptibleIterator, Logging, Partition, SerializableWritable, SparkContext, TaskContext} private[spark] @@ -71,49 +71,52 @@ class NewHadoopRDD[K, V]( result } - override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] { - val split = theSplit.asInstanceOf[NewHadoopPartition] - logInfo("Input split: " + split.serializableHadoopSplit) - val conf = confBroadcast.value.value - val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) - val format = inputFormatClass.newInstance - if (format.isInstanceOf[Configurable]) { - format.asInstanceOf[Configurable].setConf(conf) - } - val reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - - // Register an on-task-completion callback to close the input stream. - context.addOnCompleteCallback(() => close()) - - var havePair = false - var finished = false - - override def hasNext: Boolean = { - if (!finished && !havePair) { - finished = !reader.nextKeyValue - havePair = !finished + override def compute(theSplit: Partition, context: TaskContext) = { + val iter = new Iterator[(K, V)] { + val split = theSplit.asInstanceOf[NewHadoopPartition] + logInfo("Input split: " + split.serializableHadoopSplit) + val conf = confBroadcast.value.value + val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val format = inputFormatClass.newInstance + if (format.isInstanceOf[Configurable]) { + format.asInstanceOf[Configurable].setConf(conf) + } + val reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + + // Register an on-task-completion callback to close the input stream. + context.addOnCompleteCallback(() => close()) + + var havePair = false + var finished = false + + override def hasNext: Boolean = { + if (!finished && !havePair) { + finished = !reader.nextKeyValue + havePair = !finished + } + !finished } - !finished - } - override def next: (K, V) = { - if (!hasNext) { - throw new java.util.NoSuchElementException("End of stream") + override def next(): (K, V) = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + havePair = false + (reader.getCurrentKey, reader.getCurrentValue) } - havePair = false - return (reader.getCurrentKey, reader.getCurrentValue) - } - private def close() { - try { - reader.close() - } catch { - case e: Exception => logWarning("Exception in RecordReader.close()", e) + private def close() { + try { + reader.close() + } catch { + case e: Exception => logWarning("Exception in RecordReader.close()", e) + } } } + new InterruptibleIterator(context, iter) } override def getPreferredLocations(split: Partition): Seq[String] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index a47c512275..93b78e1232 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -84,18 +84,24 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) } val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) if (self.partitioner == Some(partitioner)) { - self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) + self.mapPartitionsWithContext((context, iter) => { + new InterruptibleIterator(context, aggregator.combineValuesByKey(iter)) + }, preservesPartitioning = true) } else if (mapSideCombine) { val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner) .setSerializer(serializerClass) - partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true) + partitioned.mapPartitionsWithContext((context, iter) => { + new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter)) + }, preservesPartitioning = true) } else { // Don't apply map-side combiner. // A sanity check to make sure mergeCombiners is not defined. assert(mergeCombiners == null) val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass) - values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) + values.mapPartitionsWithContext((context, iter) => { + new InterruptibleIterator(context, aggregator.combineValuesByKey(iter)) + }, preservesPartitioning = true) } } @@ -564,7 +570,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" <split #> <attempt # = spark task #> */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber) + val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.partitionId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = outputFormatClass.newInstance val committer = format.getOutputCommitter(hadoopContext) @@ -663,7 +669,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt - writer.setup(context.stageId, context.splitId, attemptNumber) + writer.setup(context.stageId, context.partitionId, attemptNumber) writer.open() var count = 0 diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 6dbd4309aa..cd96250389 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -94,8 +94,9 @@ private[spark] class ParallelCollectionRDD[T: ClassManifest]( slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray } - override def compute(s: Partition, context: TaskContext) = - s.asInstanceOf[ParallelCollectionPartition[T]].iterator + override def compute(s: Partition, context: TaskContext) = { + new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator) + } override def getPreferredLocations(s: Partition): Seq[String] = { locationPrefs.getOrElse(s.index, Nil) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 1893627ee2..0355618e43 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -418,26 +418,39 @@ abstract class RDD[T: ClassManifest]( command: Seq[String], env: Map[String, String] = Map(), printPipeContext: (String => Unit) => Unit = null, - printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = + printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = { new PipedRDD(this, command, env, if (printPipeContext ne null) sc.clean(printPipeContext) else null, if (printRDDElement ne null) sc.clean(printRDDElement) else null) + } /** * Return a new RDD by applying a function to each partition of this RDD. */ - def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = + def mapPartitions[U: ClassManifest]( + f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning) + } /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. */ def mapPartitionsWithIndex[U: ClassManifest]( - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) + f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { + val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter) + new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning) + } + + /** + * Return a new RDD by applying a function to each partition of this RDD. This is a variant of + * mapPartitions that also passes the TaskContext into the closure. + */ + def mapPartitionsWithContext[U: ClassManifest]( + f: (TaskContext, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = { + new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning) + } /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index @@ -445,22 +458,23 @@ abstract class RDD[T: ClassManifest]( */ @deprecated("use mapPartitionsWithIndex", "0.7.0") def mapPartitionsWithSplit[U: ClassManifest]( - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = - new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) + f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { + mapPartitionsWithIndex(f, preservesPartitioning) + } /** * Maps f over this RDD, where f takes an additional parameter of type A. This * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false) - (f:(T, A) => U): RDD[U] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { - val a = constructA(index) - iter.map(t => f(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + def mapWith[A: ClassManifest, U: ClassManifest] + (constructA: Int => A, preservesPartitioning: Boolean = false) + (f: (T, A) => U): RDD[U] = { + def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { + val a = constructA(context.partitionId) + iter.map(t => f(t, a)) + } + new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) } /** @@ -468,13 +482,14 @@ abstract class RDD[T: ClassManifest]( * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false) - (f:(T, A) => Seq[U]): RDD[U] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { - val a = constructA(index) - iter.flatMap(t => f(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + def flatMapWith[A: ClassManifest, U: ClassManifest] + (constructA: Int => A, preservesPartitioning: Boolean = false) + (f: (T, A) => Seq[U]): RDD[U] = { + def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { + val a = constructA(context.partitionId) + iter.flatMap(t => f(t, a)) + } + new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) } /** @@ -482,13 +497,12 @@ abstract class RDD[T: ClassManifest]( * This additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def foreachWith[A: ClassManifest](constructA: Int => A) - (f:(T, A) => Unit) { - def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { - val a = constructA(index) - iter.map(t => {f(t, a); t}) - } - (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {}) + def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) { + def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { + val a = constructA(context.partitionId) + iter.map(t => {f(t, a); t}) + } + new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {}) } /** @@ -496,13 +510,12 @@ abstract class RDD[T: ClassManifest]( * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def filterWith[A: ClassManifest](constructA: Int => A) - (p:(T, A) => Boolean): RDD[T] = { - def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { - val a = constructA(index) - iter.filter(t => p(t, a)) - } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true) + def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = { + def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { + val a = constructA(context.partitionId) + iter.filter(t => p(t, a)) + } + new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true) } /** @@ -541,16 +554,14 @@ abstract class RDD[T: ClassManifest]( * Applies a function f to all elements of this RDD. */ def foreach(f: T => Unit) { - val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) + sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f)) } /** * Applies a function f to each partition of this RDD. */ def foreachPartition(f: Iterator[T] => Unit) { - val cleanF = sc.clean(f) - sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) + sc.runJob(this, (iter: Iterator[T]) => f(iter)) } /** @@ -675,6 +686,8 @@ abstract class RDD[T: ClassManifest]( */ def count(): Long = { sc.runJob(this, (iter: Iterator[T]) => { + // Use a while loop to count the number of elements rather than iter.size because + // iter.size uses a for loop, which is slightly slower in current version of Scala. var result = 0L while (iter.hasNext) { result += 1L diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 9537152335..a5d751a7bd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -56,7 +56,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest]( override def compute(split: Partition, context: TaskContext): Iterator[P] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId - SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context.taskMetrics, + SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, SparkEnv.get.serializerManager.get(serializerClass)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 8c1a29dfff..7af4d803e7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -108,7 +108,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM } case ShuffleCoGroupSplitDep(shuffleId) => { val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index, - context.taskMetrics, serializer) + context, serializer) iter.foreach(op) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4226617cfb..d84f5968df 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -28,8 +28,8 @@ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import org.apache.spark.storage.{BlockManager, BlockManagerMaster} -import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId} +import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of @@ -41,11 +41,11 @@ import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedH * locations to run each task on, based on the current cache status, and passes these to the * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are - * not caused by shuffie file loss are handled by the TaskScheduler, which will retry each task + * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * * THREADING: This class runs all its logic in a single thread executing the run() method, to which - * events are submitted using a synchonized queue (eventQueue). The public API methods, such as + * events are submitted using a synchronized queue (eventQueue). The public API methods, such as * runJob, taskEnded and executorLost, post events asynchronously to this queue. All other methods * should be private. */ @@ -55,20 +55,20 @@ class DAGScheduler( mapOutputTracker: MapOutputTracker, blockManagerMaster: BlockManagerMaster, env: SparkEnv) - extends TaskSchedulerListener with Logging { + extends Logging { def this(taskSched: TaskScheduler) { this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get) } - taskSched.setListener(this) + taskSched.setDAGScheduler(this) // Called by TaskScheduler to report task's starting. - override def taskStarted(task: Task[_], taskInfo: TaskInfo) { + def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventQueue.put(BeginEvent(task, taskInfo)) } // Called by TaskScheduler to report task completions or failures. - override def taskEnded( + def taskEnded( task: Task[_], reason: TaskEndReason, result: Any, @@ -79,17 +79,18 @@ class DAGScheduler( } // Called by TaskScheduler when an executor fails. - override def executorLost(execId: String) { + def executorLost(execId: String) { eventQueue.put(ExecutorLost(execId)) } // Called by TaskScheduler when a host is added - override def executorGained(execId: String, host: String) { + def executorGained(execId: String, host: String) { eventQueue.put(ExecutorGained(execId, host)) } - // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures. - override def taskSetFailed(taskSet: TaskSet, reason: String) { + // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or + // cancellation of the job itself. + def taskSetFailed(taskSet: TaskSet, reason: String) { eventQueue.put(TaskSetFailed(taskSet, reason)) } @@ -104,13 +105,15 @@ class DAGScheduler( private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent] - val nextJobId = new AtomicInteger(0) + private[scheduler] val nextJobId = new AtomicInteger(0) - val nextStageId = new AtomicInteger(0) + def numTotalJobs: Int = nextJobId.get() - val stageIdToStage = new TimeStampedHashMap[Int, Stage] + private val nextStageId = new AtomicInteger(0) - val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] + private val stageIdToStage = new TimeStampedHashMap[Int, Stage] + + private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] @@ -127,6 +130,7 @@ class DAGScheduler( // stray messages to detect. val failedEpoch = new HashMap[String, Long] + // stage id to the active job val idToActiveJob = new HashMap[Int, ActiveJob] val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done @@ -156,7 +160,7 @@ class DAGScheduler( private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = { if (!cacheLocs.contains(rdd.id)) { - val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray + val blockIds = rdd.partitions.indices.map(index=> RDDBlockId(rdd.id, index)).toArray[BlockId] val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster) cacheLocs(rdd.id) = blockIds.map { id => locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId)) @@ -261,32 +265,41 @@ class DAGScheduler( } /** - * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a - * JobWaiter whose getResult() method will return the result of the job when it is complete. - * - * The job is assumed to have at least one partition; zero partition jobs should be handled - * without a JobSubmitted event. + * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object + * can be used to block until the the job finishes executing or can be used to cancel the job. */ - private[scheduler] def prepareJob[T, U: ClassManifest]( - finalRdd: RDD[T], + def submitJob[T, U]( + rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: String, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null) - : (JobSubmitted, JobWaiter[U]) = + properties: Properties = null): JobWaiter[U] = { + // Check to make sure we are not launching a task on a partition that does not exist. + val maxPartitions = rdd.partitions.length + partitions.find(p => p >= maxPartitions).foreach { p => + throw new IllegalArgumentException( + "Attempting to access a non-existent partition: " + p + ". " + + "Total number of partitions: " + maxPartitions) + } + + val jobId = nextJobId.getAndIncrement() + if (partitions.size == 0) { + return new JobWaiter[U](this, jobId, 0, resultHandler) + } + assert(partitions.size > 0) - val waiter = new JobWaiter(partitions.size, resultHandler) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, - properties) - (toSubmit, waiter) + val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) + eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, + waiter, properties)) + waiter } def runJob[T, U: ClassManifest]( - finalRdd: RDD[T], + rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: String, @@ -294,21 +307,7 @@ class DAGScheduler( resultHandler: (Int, U) => Unit, properties: Properties = null) { - if (partitions.size == 0) { - return - } - - // Check to make sure we are not launching a task on a partition that does not exist. - val maxPartitions = finalRdd.partitions.length - partitions.find(p => p >= maxPartitions).foreach { p => - throw new IllegalArgumentException( - "Attempting to access a non-existent partition: " + p + ". " + - "Total number of partitions: " + maxPartitions) - } - - val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob( - finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties) - eventQueue.put(toSubmit) + val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) waiter.awaitResult() match { case JobSucceeded => {} case JobFailed(exception: Exception, _) => @@ -329,19 +328,40 @@ class DAGScheduler( val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray - eventQueue.put(JobSubmitted(rdd, func2, partitions, allowLocal = false, callSite, listener, properties)) + val jobId = nextJobId.getAndIncrement() + eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, + listener, properties)) listener.awaitResult() // Will throw an exception if the job fails } /** + * Cancel a job that is running or waiting in the queue. + */ + def cancelJob(jobId: Int) { + logInfo("Asked to cancel job " + jobId) + eventQueue.put(JobCancelled(jobId)) + } + + def cancelJobGroup(groupId: String) { + logInfo("Asked to cancel job group " + groupId) + eventQueue.put(JobGroupCancelled(groupId)) + } + + /** + * Cancel all jobs that are running or waiting in the queue. + */ + def cancelAllJobs() { + eventQueue.put(AllJobsCancelled) + } + + /** * Process one event retrieved from the event queue. * Returns true if we should stop the event loop. */ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { event match { - case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) => - val jobId = nextJobId.getAndIncrement() - val finalStage = newStage(finalRDD, None, jobId, Some(callSite)) + case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => + val finalStage = newStage(rdd, None, jobId, Some(callSite)) val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length + @@ -360,6 +380,29 @@ class DAGScheduler( submitStage(finalStage) } + case JobCancelled(jobId) => + // Cancel a job: find all the running stages that are linked to this job, and cancel them. + running.filter(_.jobId == jobId).foreach { stage => + taskSched.cancelTasks(stage.id) + } + + case JobGroupCancelled(groupId) => + // Cancel all jobs belonging to this job group. + // First finds all active jobs with this group id, and then kill stages for them. + val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + .map(_.jobId) + if (!jobIds.isEmpty) { + running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage => + taskSched.cancelTasks(stage.id) + } + } + + case AllJobsCancelled => + // Cancel all running jobs. + running.foreach { stage => + taskSched.cancelTasks(stage.id) + } + case ExecutorGained(execId, host) => handleExecutorGained(execId, host) @@ -578,6 +621,11 @@ class DAGScheduler( */ private def handleTaskCompletion(event: CompletionEvent) { val task = event.task + + if (!stageIdToStage.contains(task.stageId)) { + // Skip all the actions if the stage has been cancelled. + return + } val stage = stageIdToStage(task.stageId) def markStageAsFinished(stage: Stage) = { @@ -626,7 +674,7 @@ class DAGScheduler( if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) } else { - stage.addOutputLoc(smt.partition, status) + stage.addOutputLoc(smt.partitionId, status) } if (running.contains(stage) && pendingTasks(stage).isEmpty) { markStageAsFinished(stage) @@ -752,14 +800,14 @@ class DAGScheduler( /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set - * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. + * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. */ private def abortStage(failedStage: Stage, reason: String) { val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq failedStage.completionTime = Some(System.currentTimeMillis()) for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) - val error = new SparkException("Job failed: " + reason) + val error = new SparkException("Job aborted: " + reason) job.listener.jobFailed(error) listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage)))) idToActiveJob -= resultStage.jobId diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 10ff1b4376..a5769c6041 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -31,9 +31,10 @@ import org.apache.spark.executor.TaskMetrics * submitted) but there is a single "logic" thread that reads these events and takes decisions. * This greatly simplifies synchronization. */ -private[spark] sealed trait DAGSchedulerEvent +private[scheduler] sealed trait DAGSchedulerEvent -private[spark] case class JobSubmitted( +private[scheduler] case class JobSubmitted( + jobId: Int, finalRDD: RDD[_], func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], @@ -43,9 +44,16 @@ private[spark] case class JobSubmitted( properties: Properties = null) extends DAGSchedulerEvent -private[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent +private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent -private[spark] case class CompletionEvent( +private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent + +private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent + +private[scheduler] +case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent + +private[scheduler] case class CompletionEvent( task: Task[_], reason: TaskEndReason, result: Any, @@ -54,10 +62,12 @@ private[spark] case class CompletionEvent( taskMetrics: TaskMetrics) extends DAGSchedulerEvent -private[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent +private[scheduler] +case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent -private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent +private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent -private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent +private[scheduler] +case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent -private[spark] case object StopDAGScheduler extends DAGSchedulerEvent +private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 151514896f..7b5c0e29ad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -40,7 +40,7 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: Spar }) metricRegistry.register(MetricRegistry.name("job", "allJobs"), new Gauge[Int] { - override def getValue: Int = dagScheduler.nextJobId.get() + override def getValue: Int = dagScheduler.numTotalJobs }) metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] { diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 3628b1b078..19c0251690 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -24,56 +24,54 @@ import java.text.SimpleDateFormat import java.util.{Date, Properties}
import java.util.concurrent.LinkedBlockingQueue
-import scala.collection.mutable.{Map, HashMap, ListBuffer}
-import scala.io.Source
+import scala.collection.mutable.{HashMap, ListBuffer}
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.executor.TaskMetrics
-// Used to record runtime information for each job, including RDD graph
-// tasks' start/stop shuffle information and information from outside
-
+/**
+ * A logger class to record runtime information for jobs in Spark. This class outputs one log file
+ * per Spark job with information such as RDD graph, tasks start/stop, shuffle information.
+ *
+ * @param logDirName The base directory for the log files.
+ */
class JobLogger(val logDirName: String) extends SparkListener with Logging {
- private val logDir =
- if (System.getenv("SPARK_LOG_DIR") != null)
- System.getenv("SPARK_LOG_DIR")
- else
- "/tmp/spark"
+
+ private val logDir = Option(System.getenv("SPARK_LOG_DIR")).getOrElse("/tmp/spark")
+
private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
private val stageIDToJobID = new HashMap[Int, Int]
private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
-
+
createLogDir()
def this() = this(String.valueOf(System.currentTimeMillis()))
-
- def getLogDir = logDir
- def getJobIDtoPrintWriter = jobIDToPrintWriter
- def getStageIDToJobID = stageIDToJobID
- def getJobIDToStages = jobIDToStages
- def getEventQueue = eventQueue
-
+
+ // The following 5 functions are used only in testing.
+ private[scheduler] def getLogDir = logDir
+ private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
+ private[scheduler] def getStageIDToJobID = stageIDToJobID
+ private[scheduler] def getJobIDToStages = jobIDToStages
+ private[scheduler] def getEventQueue = eventQueue
+
// Create a folder for log files, the folder's name is the creation time of the jobLogger
protected def createLogDir() {
val dir = new File(logDir + "/" + logDirName + "/")
- if (dir.exists()) {
- return
- }
- if (dir.mkdirs() == false) {
- logError("create log directory error:" + logDir + "/" + logDirName + "/")
+ if (!dir.exists() && !dir.mkdirs()) {
+ logError("Error creating log directory: " + logDir + "/" + logDirName + "/")
}
}
// Create a log file for one job, the file name is the jobID
protected def createLogWriter(jobID: Int) {
- try{
+ try {
val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
jobIDToPrintWriter += (jobID -> fileWriter)
- } catch {
- case e: FileNotFoundException => e.printStackTrace()
- }
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
}
// Close log file, and clean the stage relationship in stageIDToJobID
@@ -118,10 +116,9 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
var rddList = new ListBuffer[RDD[_]]
rddList += rdd
- rdd.dependencies.foreach{ dep => dep match {
- case shufDep: ShuffleDependency[_,_] =>
- case _ => rddList ++= getRddsInStage(dep.rdd)
- }
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
}
rddList
}
@@ -161,29 +158,27 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
jobLogInfo(jobID, indentString(indent) + rddInfo, false)
- rdd.dependencies.foreach{ dep => dep match {
- case shufDep: ShuffleDependency[_,_] =>
- val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
- jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
- case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
- }
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
}
}
protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
- var stageInfo: String = ""
- if (stage.isShuffleMap) {
- stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
- stage.shuffleDep.get.shuffleId
- }else{
- stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ val stageInfo = if (stage.isShuffleMap) {
+ "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
+ } else {
+ "STAGE_ID=" + stage.id + " RESULT_STAGE"
}
if (stage.jobId == jobID) {
jobLogInfo(jobID, indentString(indent) + stageInfo, false)
recordRddInStageGraph(jobID, stage.rdd, indent)
stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
- } else
+ } else {
jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
+ }
}
// Record task metrics into job log files
@@ -193,39 +188,32 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging { " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
" EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
- val readMetrics =
- taskMetrics.shuffleReadMetrics match {
- case Some(metrics) =>
- " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
- " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
- " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
- " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
- " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
- " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
- " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
- case None => ""
- }
- val writeMetrics =
- taskMetrics.shuffleWriteMetrics match {
- case Some(metrics) =>
- " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
- case None => ""
- }
+ val readMetrics = taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics = taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
}
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
- stageLogInfo(
- stageSubmitted.stage.id,
- "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
- stageSubmitted.stage.id, stageSubmitted.taskSize))
+ stageLogInfo(stageSubmitted.stage.id, "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
+ stageSubmitted.stage.id, stageSubmitted.taskSize))
}
override def onStageCompleted(stageCompleted: StageCompleted) {
- stageLogInfo(
- stageCompleted.stageInfo.stage.id,
+ stageLogInfo(stageCompleted.stageInfo.stage.id,
"STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id))
-
}
override def onTaskStart(taskStart: SparkListenerTaskStart) { }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 200d881799..58f238d8cf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -17,48 +17,58 @@ package org.apache.spark.scheduler -import scala.collection.mutable.ArrayBuffer - /** * An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their * results to the given handler function. */ -private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit) +private[spark] class JobWaiter[T]( + dagScheduler: DAGScheduler, + jobId: Int, + totalTasks: Int, + resultHandler: (Int, T) => Unit) extends JobListener { private var finishedTasks = 0 - private var jobFinished = false // Is the job as a whole finished (succeeded or failed)? - private var jobResult: JobResult = null // If the job is finished, this will be its result + // Is the job as a whole finished (succeeded or failed)? + private var _jobFinished = totalTasks == 0 - override def taskSucceeded(index: Int, result: Any) { - synchronized { - if (jobFinished) { - throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") - } - resultHandler(index, result.asInstanceOf[T]) - finishedTasks += 1 - if (finishedTasks == totalTasks) { - jobFinished = true - jobResult = JobSucceeded - this.notifyAll() - } - } + def jobFinished = _jobFinished + + // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero + // partition RDDs), we set the jobResult directly to JobSucceeded. + private var jobResult: JobResult = if (jobFinished) JobSucceeded else null + + /** + * Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled + * asynchronously. After the low level scheduler cancels all the tasks belonging to this job, it + * will fail this job with a SparkException. + */ + def cancel() { + dagScheduler.cancelJob(jobId) } - override def jobFailed(exception: Exception) { - synchronized { - if (jobFinished) { - throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter") - } - jobFinished = true - jobResult = JobFailed(exception, None) + override def taskSucceeded(index: Int, result: Any): Unit = synchronized { + if (_jobFinished) { + throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") + } + resultHandler(index, result.asInstanceOf[T]) + finishedTasks += 1 + if (finishedTasks == totalTasks) { + _jobFinished = true + jobResult = JobSucceeded this.notifyAll() } } + override def jobFailed(exception: Exception): Unit = synchronized { + _jobFinished = true + jobResult = JobFailed(exception, None) + this.notifyAll() + } + def awaitResult(): JobResult = synchronized { - while (!jobFinished) { + while (!_jobFinished) { this.wait() } return jobResult diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 9eb8d48501..596f9adde9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -43,7 +43,10 @@ private[spark] class Pool( var runningTasks = 0 var priority = 0 - var stageId = 0 + + // A pool's stage id is used to break the tie in scheduling. + var stageId = -1 + var name = poolName var parent: Pool = null diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 6dd422bbf6..310ec62ca8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -38,17 +38,17 @@ private[spark] object ResultTask { synchronized { val old = serializedInfoCache.get(stageId).orNull if (old != null) { - return old + old } else { val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance + val ser = SparkEnv.get.closureSerializer.newInstance() val objOut = ser.serializeStream(new GZIPOutputStream(out)) objOut.writeObject(rdd) objOut.writeObject(func) objOut.close() val bytes = out.toByteArray serializedInfoCache.put(stageId, bytes) - return bytes + bytes } } } @@ -56,11 +56,11 @@ private[spark] object ResultTask { def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = { val loader = Thread.currentThread.getContextClassLoader val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance + val ser = SparkEnv.get.closureSerializer.newInstance() val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] - return (rdd, func) + (rdd, func) } def clearCache() { @@ -71,29 +71,37 @@ private[spark] object ResultTask { } +/** + * A task that sends back the output to the driver application. + * + * See [[org.apache.spark.scheduler.Task]] for more information. + * + * @param stageId id of the stage this task belongs to + * @param rdd input to func + * @param func a function to apply on a partition of the RDD + * @param _partitionId index of the number in the RDD + * @param locs preferred task execution locations for locality scheduling + * @param outputId index of the task in this job (a job can launch tasks on only a subset of the + * input RDD's partitions). + */ private[spark] class ResultTask[T, U]( stageId: Int, var rdd: RDD[T], var func: (TaskContext, Iterator[T]) => U, - var partition: Int, + _partitionId: Int, @transient locs: Seq[TaskLocation], var outputId: Int) - extends Task[U](stageId) with Externalizable { + extends Task[U](stageId, _partitionId) with Externalizable { def this() = this(0, null, null, 0, null, 0) - var split = if (rdd == null) { - null - } else { - rdd.partitions(partition) - } + var split = if (rdd == null) null else rdd.partitions(partitionId) @transient private val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } - override def run(attemptId: Long): U = { - val context = new TaskContext(stageId, partition, attemptId, runningLocally = false) + override def runTask(context: TaskContext): U = { metrics = Some(context.taskMetrics) try { func(context, rdd.iterator(split, context)) @@ -104,17 +112,17 @@ private[spark] class ResultTask[T, U]( override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ResultTask(" + stageId + ", " + partition + ")" + override def toString = "ResultTask(" + stageId + ", " + partitionId + ")" override def writeExternal(out: ObjectOutput) { RDDCheckpointData.synchronized { - split = rdd.partitions(partition) + split = rdd.partitions(partitionId) out.writeInt(stageId) val bytes = ResultTask.serializeInfo( stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) out.writeInt(bytes.length) out.write(bytes) - out.writeInt(partition) + out.writeInt(partitionId) out.writeInt(outputId) out.writeLong(epoch) out.writeObject(split) @@ -129,7 +137,7 @@ private[spark] class ResultTask[T, U]( val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) rdd = rdd_.asInstanceOf[RDD[T]] func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] - partition = in.readInt() + partitionId = in.readInt() outputId = in.readInt() epoch = in.readLong() split = in.readObject().asInstanceOf[Partition] diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index 4e25086ec9..356fe56bf3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -30,7 +30,10 @@ import scala.xml.XML * addTaskSetManager: build the leaf nodes(TaskSetManagers) */ private[spark] trait SchedulableBuilder { + def rootPool: Pool + def buildPools() + def addTaskSetManager(manager: Schedulable, properties: Properties) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 3b9d5679fb..24d97da6eb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -53,7 +53,7 @@ private[spark] object ShuffleMapTask { objOut.close() val bytes = out.toByteArray serializedInfoCache.put(stageId, bytes) - return bytes + bytes } } } @@ -66,7 +66,7 @@ private[spark] object ShuffleMapTask { val objIn = ser.deserializeStream(in) val rdd = objIn.readObject().asInstanceOf[RDD[_]] val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]] - return (rdd, dep) + (rdd, dep) } } @@ -75,7 +75,7 @@ private[spark] object ShuffleMapTask { val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) val objIn = new ObjectInputStream(in) val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap - return (HashMap(set.toSeq: _*)) + HashMap(set.toSeq: _*) } def clearCache() { @@ -85,13 +85,25 @@ private[spark] object ShuffleMapTask { } } +/** + * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner + * specified in the ShuffleDependency). + * + * See [[org.apache.spark.scheduler.Task]] for more information. + * + * @param stageId id of the stage this task belongs to + * @param rdd the final RDD in this stage + * @param dep the ShuffleDependency + * @param _partitionId index of the number in the RDD + * @param locs preferred task execution locations for locality scheduling + */ private[spark] class ShuffleMapTask( stageId: Int, var rdd: RDD[_], var dep: ShuffleDependency[_,_], - var partition: Int, + _partitionId: Int, @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId) + extends Task[MapStatus](stageId, _partitionId) with Externalizable with Logging { @@ -101,16 +113,16 @@ private[spark] class ShuffleMapTask( if (locs == null) Nil else locs.toSet.toSeq } - var split = if (rdd == null) null else rdd.partitions(partition) + var split = if (rdd == null) null else rdd.partitions(partitionId) override def writeExternal(out: ObjectOutput) { RDDCheckpointData.synchronized { - split = rdd.partitions(partition) + split = rdd.partitions(partitionId) out.writeInt(stageId) val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) out.writeInt(bytes.length) out.write(bytes) - out.writeInt(partition) + out.writeInt(partitionId) out.writeLong(epoch) out.writeObject(split) } @@ -124,16 +136,14 @@ private[spark] class ShuffleMapTask( val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) rdd = rdd_ dep = dep_ - partition = in.readInt() + partitionId = in.readInt() epoch = in.readLong() split = in.readObject().asInstanceOf[Partition] } - override def run(attemptId: Long): MapStatus = { + override def runTask(context: TaskContext): MapStatus = { val numOutputSplits = dep.partitioner.numPartitions - - val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false) - metrics = Some(taskContext.taskMetrics) + metrics = Some(context.taskMetrics) val blockManager = SparkEnv.get.blockManager var shuffle: ShuffleBlocks = null @@ -143,10 +153,10 @@ private[spark] class ShuffleMapTask( // Obtain all the block writers for shuffle blocks. val ser = SparkEnv.get.serializerManager.get(dep.serializerClass) shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser) - buckets = shuffle.acquireWriters(partition) + buckets = shuffle.acquireWriters(partitionId) // Write the map output to its associated buckets. - for (elem <- rdd.iterator(split, taskContext)) { + for (elem <- rdd.iterator(split, context)) { val pair = elem.asInstanceOf[Product2[Any, Any]] val bucketId = dep.partitioner.getPartition(pair._1) buckets.writers(bucketId).write(pair) @@ -154,20 +164,22 @@ private[spark] class ShuffleMapTask( // Commit the writes. Get the size of each bucket block (total block size). var totalBytes = 0L + var totalTime = 0L val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter => writer.commit() - writer.close() - val size = writer.size() + val size = writer.fileSegment().length totalBytes += size + totalTime += writer.timeWriting() MapOutputTracker.compressSize(size) } // Update shuffle metrics. val shuffleMetrics = new ShuffleWriteMetrics shuffleMetrics.shuffleBytesWritten = totalBytes + shuffleMetrics.shuffleWriteTime = totalTime metrics.get.shuffleWriteMetrics = Some(shuffleMetrics) - return new MapStatus(blockManager.blockManagerId, compressedSizes) + new MapStatus(blockManager.blockManagerId, compressedSizes) } catch { case e: Exception => // If there is an exception from running the task, revert the partial writes // and throw the exception upstream to Spark. @@ -178,14 +190,15 @@ private[spark] class ShuffleMapTask( } finally { // Release the writers back to the shuffle block manager. if (shuffle != null && buckets != null) { + buckets.writers.foreach(_.close()) shuffle.releaseWriters(buckets) } // Execute the callbacks on task completion. - taskContext.executeOnCompleteCallbacks() + context.executeOnCompleteCallbacks() } } override def preferredLocations: Seq[TaskLocation] = preferredLocs - override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition) + override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partitionId) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 62b521ad45..466baf9913 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -54,7 +54,7 @@ trait SparkListener { /** * Called when a task starts */ - def onTaskStart(taskEnd: SparkListenerTaskStart) { } + def onTaskStart(taskStart: SparkListenerTaskStart) { } /** * Called when a task ends diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 598d91752a..69b42e86ea 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -17,25 +17,74 @@ package org.apache.spark.scheduler -import org.apache.spark.serializer.SerializerInstance import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream -import org.apache.spark.util.ByteBufferInputStream + import scala.collection.mutable.HashMap + +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream + +import org.apache.spark.TaskContext import org.apache.spark.executor.TaskMetrics +import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.util.ByteBufferInputStream + /** - * A task to execute on a worker node. + * A unit of execution. We have two kinds of Task's in Spark: + * - [[org.apache.spark.scheduler.ShuffleMapTask]] + * - [[org.apache.spark.scheduler.ResultTask]] + * + * A Spark job consists of one or more stages. The very last stage in a job consists of multiple + * ResultTask's, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task + * and sends the task output back to the driver application. A ShuffleMapTask executes the task + * and divides the task output to multiple buckets (based on the task's partitioner). + * + * @param stageId id of the stage this task belongs to + * @param partitionId index of the number in the RDD */ -private[spark] abstract class Task[T](val stageId: Int) extends Serializable { - def run(attemptId: Long): T +private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { + + final def run(attemptId: Long): T = { + context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) + if (_killed) { + kill() + } + runTask(context) + } + + def runTask(context: TaskContext): T + def preferredLocations: Seq[TaskLocation] = Nil - var epoch: Long = -1 // Map output tracker epoch. Will be set by TaskScheduler. + // Map output tracker epoch. Will be set by TaskScheduler. + var epoch: Long = -1 var metrics: Option[TaskMetrics] = None + // Task context, to be initialized in run(). + @transient protected var context: TaskContext = _ + + // A flag to indicate whether the task is killed. This is used in case context is not yet + // initialized when kill() is invoked. + @volatile @transient private var _killed = false + + /** + * Whether the task has been killed. + */ + def killed: Boolean = _killed + + /** + * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark + * code and user code to properly handle the flag. This function should be idempotent so it can + * be called multiple times. + */ + def kill() { + _killed = true + if (context != null) { + context.interrupted = true + } + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index db3954a9d3..7e468d0d67 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -24,13 +24,14 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.{SparkEnv} import java.nio.ByteBuffer import org.apache.spark.util.Utils +import org.apache.spark.storage.BlockId // Task result. Also contains updates to accumulator variables. private[spark] sealed trait TaskResult[T] /** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */ private[spark] -case class IndirectTaskResult[T](val blockId: String) extends TaskResult[T] with Serializable +case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable /** A TaskResult that contains the task's return value and accumulator updates. */ private[spark] diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 7c2a9f03d7..10e0478108 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -24,8 +24,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode * Each TaskScheduler schedulers task for a single SparkContext. * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage, * and are responsible for sending the tasks to the cluster, running them, retrying if there - * are failures, and mitigating stragglers. They return events to the DAGScheduler through - * the TaskSchedulerListener interface. + * are failures, and mitigating stragglers. They return events to the DAGScheduler. */ private[spark] trait TaskScheduler { @@ -45,8 +44,11 @@ private[spark] trait TaskScheduler { // Submit a sequence of tasks to run. def submitTasks(taskSet: TaskSet): Unit - // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called. - def setListener(listener: TaskSchedulerListener): Unit + // Cancel a stage. + def cancelTasks(stageId: Int) + + // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. + def setDAGScheduler(dagScheduler: DAGScheduler): Unit // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. def defaultParallelism(): Int diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala deleted file mode 100644 index 593fa9fb93..0000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler - -import scala.collection.mutable.Map - -import org.apache.spark.TaskEndReason -import org.apache.spark.executor.TaskMetrics - -/** - * Interface for getting events back from the TaskScheduler. - */ -private[spark] trait TaskSchedulerListener { - // A task has started. - def taskStarted(task: Task[_], taskInfo: TaskInfo) - - // A task has finished or failed. - def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit - - // A node was added to the cluster. - def executorGained(execId: String, host: String): Unit - - // A node was lost from the cluster. - def executorLost(execId: String): Unit - - // The TaskScheduler wants to abort an entire task set. - def taskSetFailed(taskSet: TaskSet, reason: String): Unit -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index c3ad325156..03bf760837 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -31,5 +31,9 @@ private[spark] class TaskSet( val properties: Properties) { val id: String = stageId + "." + attempt + def kill() { + tasks.foreach(_.kill()) + } + override def toString: String = "TaskSet " + id } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index 1a844b7e7e..4ea8bf8853 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -17,7 +17,6 @@ package org.apache.spark.scheduler.cluster -import java.lang.{Boolean => JBoolean} import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong import java.util.{TimerTask, Timer} @@ -79,14 +78,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) private val executorIdToHost = new HashMap[String, String] - // JAR server, if any JARs were added by the user to the SparkContext - var jarServer: HttpServer = null - - // URIs of JARs to pass to executor - var jarUris: String = "" - // Listener object to pass upcalls into - var listener: TaskSchedulerListener = null + var dagScheduler: DAGScheduler = null var backend: SchedulerBackend = null @@ -101,8 +94,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // This is a var so that we can reset it for testing purposes. private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) - override def setListener(listener: TaskSchedulerListener) { - this.listener = listener + override def setDAGScheduler(dagScheduler: DAGScheduler) { + this.dagScheduler = dagScheduler } def initialize(context: SchedulerBackend) { @@ -171,8 +164,31 @@ private[spark] class ClusterScheduler(val sc: SparkContext) backend.reviveOffers() } - def taskSetFinished(manager: TaskSetManager) { - this.synchronized { + override def cancelTasks(stageId: Int): Unit = synchronized { + logInfo("Cancelling stage " + stageId) + activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the stage. + val taskIds = taskSetTaskIds(tsm.taskSet.id) + if (taskIds.size > 0) { + taskIds.foreach { tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId) + } + } + tsm.error("Stage %d was cancelled".format(stageId)) + } + } + + def taskSetFinished(manager: TaskSetManager): Unit = synchronized { + // Check to see if the given task set has been removed. This is possible in the case of + // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has + // more than one running tasks). + if (activeTaskSets.contains(manager.taskSet.id)) { activeTaskSets -= manager.taskSet.id manager.parent.removeSchedulable(manager) logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) @@ -281,7 +297,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } // Update the DAGScheduler without holding a lock on this, since that can deadlock if (failedExecutor != None) { - listener.executorLost(failedExecutor.get) + dagScheduler.executorLost(failedExecutor.get) backend.reviveOffers() } if (taskFailed) { @@ -334,9 +350,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext) if (backend != null) { backend.stop() } - if (jarServer != null) { - jarServer.stop() - } if (taskResultGetter != null) { taskResultGetter.stop() } @@ -384,9 +397,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext) logError("Lost an executor " + executorId + " (already removed): " + reason) } } - // Call listener.executorLost without holding the lock on this to prevent deadlock + // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock if (failedExecutor != None) { - listener.executorLost(failedExecutor.get) + dagScheduler.executorLost(failedExecutor.get) backend.reviveOffers() } } @@ -405,7 +418,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } def executorGained(execId: String, host: String) { - listener.executorGained(execId, host) + dagScheduler.executorGained(execId, host) } def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala index 936167c13f..29093e3b4f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -17,18 +17,16 @@ package org.apache.spark.scheduler.cluster -import java.nio.ByteBuffer -import java.util.{Arrays, NoSuchElementException} +import java.util.Arrays import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet import scala.math.max import scala.math.min -import scala.Some import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv, - SparkException, Success, TaskEndReason, TaskResultLost, TaskState} + Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler._ import org.apache.spark.util.{SystemClock, Clock} @@ -417,11 +415,11 @@ private[spark] class ClusterTaskSetManager( } private def taskStarted(task: Task[_], info: TaskInfo) { - sched.listener.taskStarted(task, info) + sched.dagScheduler.taskStarted(task, info) } /** - * Marks the task as successful and notifies the listener that a task has ended. + * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = { val info = taskInfos(tid) @@ -431,7 +429,7 @@ private[spark] class ClusterTaskSetManager( if (!successful(index)) { logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format( tid, info.duration, info.host, tasksSuccessful, numTasks)) - sched.listener.taskEnded( + sched.dagScheduler.taskEnded( tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) // Mark successful and stop if all the tasks have succeeded. @@ -447,7 +445,8 @@ private[spark] class ClusterTaskSetManager( } /** - * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the listener. + * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the + * DAG Scheduler. */ def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) { val info = taskInfos(tid) @@ -458,54 +457,57 @@ private[spark] class ClusterTaskSetManager( val index = info.index info.markFailed() if (!successful(index)) { - logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) + logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) copiesRunning(index) -= 1 // Check if the problem is a map output fetch failure. In that case, this // task will never succeed on any node, so tell the scheduler about it. reason.foreach { - _ match { - case fetchFailed: FetchFailed => - logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) - sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) - successful(index) = true - tasksSuccessful += 1 - sched.taskSetFinished(this) - removeAllRunningTasks() - return - - case ef: ExceptionFailure => - sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) - val key = ef.description - val now = clock.getTime() - val (printFull, dupCount) = { - if (recentExceptions.contains(key)) { - val (dupCount, printTime) = recentExceptions(key) - if (now - printTime > EXCEPTION_PRINT_INTERVAL) { - recentExceptions(key) = (0, now) - (true, 0) - } else { - recentExceptions(key) = (dupCount + 1, printTime) - (false, dupCount + 1) - } - } else { + case fetchFailed: FetchFailed => + logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress) + sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null) + successful(index) = true + tasksSuccessful += 1 + sched.taskSetFinished(this) + removeAllRunningTasks() + return + + case TaskKilled => + logWarning("Task %d was killed.".format(tid)) + sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null) + return + + case ef: ExceptionFailure => + sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) + val key = ef.description + val now = clock.getTime() + val (printFull, dupCount) = { + if (recentExceptions.contains(key)) { + val (dupCount, printTime) = recentExceptions(key) + if (now - printTime > EXCEPTION_PRINT_INTERVAL) { recentExceptions(key) = (0, now) (true, 0) + } else { + recentExceptions(key) = (dupCount + 1, printTime) + (false, dupCount + 1) } - } - if (printFull) { - val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s\n%s".format( - ef.className, ef.description, locs.mkString("\n"))) } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + recentExceptions(key) = (0, now) + (true, 0) } + } + if (printFull) { + val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logWarning("Loss was due to %s\n%s\n%s".format( + ef.className, ef.description, locs.mkString("\n"))) + } else { + logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + } - case TaskResultLost => - logInfo("Lost result for TID %s on host %s".format(tid, info.host)) - sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null) + case TaskResultLost => + logWarning("Lost result for TID %s on host %s".format(tid, info.host)) + sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null) - case _ => {} - } + case _ => {} } // On non-fetch failures, re-enqueue the task as pending for a max number of retries addPendingTask(index) @@ -532,7 +534,7 @@ private[spark] class ClusterTaskSetManager( failed = true causeOfFailure = message // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.listener.taskSetFailed(taskSet, message) + sched.dagScheduler.taskSetFailed(taskSet, message) removeAllRunningTasks() sched.taskSetFinished(this) } @@ -605,7 +607,7 @@ private[spark] class ClusterTaskSetManager( addPendingTask(index) // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our // stage finishes when a total of tasks.size tasks finish. - sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null) + sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index c0b836bf1a..a8230ec6bc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -24,26 +24,28 @@ import org.apache.spark.scheduler.TaskDescription import org.apache.spark.util.{Utils, SerializableBuffer} -private[spark] sealed trait StandaloneClusterMessage extends Serializable +private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable -private[spark] object StandaloneClusterMessages { +private[spark] object CoarseGrainedClusterMessages { // Driver to executors - case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage + case class LaunchTask(task: TaskDescription) extends CoarseGrainedClusterMessage + + case class KillTask(taskId: Long, executor: String) extends CoarseGrainedClusterMessage case class RegisteredExecutor(sparkProperties: Seq[(String, String)]) - extends StandaloneClusterMessage + extends CoarseGrainedClusterMessage - case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage + case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage // Executors to driver case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) - extends StandaloneClusterMessage { + extends CoarseGrainedClusterMessage { Utils.checkHostPort(hostPort, "Expected host port") } case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, - data: SerializableBuffer) extends StandaloneClusterMessage + data: SerializableBuffer) extends CoarseGrainedClusterMessage object StatusUpdate { /** Alternate factory method that takes a ByteBuffer directly for the data field */ @@ -54,10 +56,10 @@ private[spark] object StandaloneClusterMessages { } // Internal messages in driver - case object ReviveOffers extends StandaloneClusterMessage + case object ReviveOffers extends CoarseGrainedClusterMessage - case object StopDriver extends StandaloneClusterMessage + case object StopDriver extends CoarseGrainedClusterMessage - case class RemoveExecutor(executorId: String, reason: String) extends StandaloneClusterMessage + case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index f3aeea43d5..c0f1c6dbad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -30,16 +30,19 @@ import akka.util.duration._ import org.apache.spark.{SparkException, Logging, TaskState} import org.apache.spark.scheduler.TaskDescription -import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._ +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.Utils /** - * A standalone scheduler backend, which waits for standalone executors to connect to it through - * Akka. These may be executed in a variety of ways, such as Mesos tasks for the coarse-grained - * Mesos mode or standalone processes for Spark's standalone deploy mode (spark.deploy.*). + * A scheduler backend that waits for coarse grained executors to connect to it through Akka. + * This backend holds onto each executor for the duration of the Spark job rather than relinquishing + * executors whenever a task is done and asking the scheduler to launch a new executor for + * each new task. Executors may be launched in a variety of ways, such as Mesos tasks for the + * coarse-grained Mesos mode or standalone processes for Spark's standalone deploy mode + * (spark.deploy.*). */ private[spark] -class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) extends SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed @@ -91,6 +94,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor case ReviveOffers => makeOffers() + case KillTask(taskId, executorId) => + executorActor(executorId) ! KillTask(taskId, executorId) + case StopDriver => sender ! true context.stop(self) @@ -159,7 +165,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } } driverActor = actorSystem.actorOf( - Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME) + Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME) } private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") @@ -180,6 +186,10 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor driverActor ! ReviveOffers } + override def killTask(taskId: Long, executorId: String) { + driverActor ! KillTask(taskId, executorId) + } + override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism")) .map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2)) @@ -195,6 +205,6 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor } } -private[spark] object StandaloneSchedulerBackend { - val ACTOR_NAME = "StandaloneScheduler" +private[spark] object CoarseGrainedSchedulerBackend { + val ACTOR_NAME = "CoarseGrainedScheduler" } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala index d57eb3276f..5367218faa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import org.apache.spark.{SparkContext} +import org.apache.spark.SparkContext /** * A backend interface for cluster scheduling systems that allows plugging in different ones under @@ -30,8 +30,8 @@ private[spark] trait SchedulerBackend { def reviveOffers(): Unit def defaultParallelism(): Int + def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException + // Memory used by each executor (in megabytes) protected val executorMemory: Int = SparkContext.executorMemoryRequested - - // TODO: Probably want to add a killTask too } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index cb88159b8d..cefa970bb9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -28,7 +28,7 @@ private[spark] class SparkDeploySchedulerBackend( sc: SparkContext, masters: Array[String], appName: String) - extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) with ClientListener with Logging { @@ -44,10 +44,10 @@ private[spark] class SparkDeploySchedulerBackend( // The endpoint for executors to talk to us val driverUrl = "akka://spark@%s:%s/user/%s".format( System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), - StandaloneSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ACTOR_NAME) val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") val command = Command( - "org.apache.spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs) + "org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse(null) val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome, "http://" + sc.ui.appUIAddress) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala index feec8ecfe4..4312c46cc1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala @@ -24,33 +24,16 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult} import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.util.Utils /** * Runs a thread pool that deserializes and remotely fetches (if necessary) task results. */ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler) extends Logging { - private val MIN_THREADS = System.getProperty("spark.resultGetter.minThreads", "4").toInt - private val MAX_THREADS = System.getProperty("spark.resultGetter.maxThreads", "4").toInt - private val getTaskResultExecutor = new ThreadPoolExecutor( - MIN_THREADS, - MAX_THREADS, - 0L, - TimeUnit.SECONDS, - new LinkedBlockingDeque[Runnable], - new ResultResolverThreadFactory) - - class ResultResolverThreadFactory extends ThreadFactory { - private var counter = 0 - private var PREFIX = "Result resolver thread" - - override def newThread(r: Runnable): Thread = { - val thread = new Thread(r, "%s-%s".format(PREFIX, counter)) - counter += 1 - thread.setDaemon(true) - return thread - } - } + private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt + private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( + THREADS, "Result resolver thread") protected val serializer = new ThreadLocal[SerializerInstance] { override def initialValue(): SerializerInstance = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 8f2eef9a53..300fe693f1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -30,13 +30,14 @@ import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.spark.{SparkException, Logging, SparkContext, TaskState} -import org.apache.spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend} +import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend} /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds * onto each Mesos node for the duration of the Spark job instead of relinquishing cores whenever * a task is done. It launches Spark tasks within the coarse-grained Mesos tasks using the - * StandaloneBackend mechanism. This class is useful for lower and more predictable latency. + * CoarseGrainedSchedulerBackend mechanism. This class is useful for lower and more predictable + * latency. * * Unfortunately this has a bit of duplication from MesosSchedulerBackend, but it seems hard to * remove this. @@ -46,7 +47,7 @@ private[spark] class CoarseMesosSchedulerBackend( sc: SparkContext, master: String, appName: String) - extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) with MScheduler with Logging { @@ -122,20 +123,20 @@ private[spark] class CoarseMesosSchedulerBackend( val driverUrl = "akka://spark@%s:%s/user/%s".format( System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), - StandaloneSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ACTOR_NAME) val uri = System.getProperty("spark.executor.uri") if (uri == null) { val runScript = new File(sparkHome, "spark-class").getCanonicalPath command.setValue( - "\"%s\" org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format( + "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format( runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.split('/').last.split('.').head command.setValue( - "cd %s*; ./spark-class org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format( - basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) + "cd %s*; ./spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d" + .format(basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } return command.build() diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala index 4d1bb1c639..2699f0b33e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -17,23 +17,19 @@ package org.apache.spark.scheduler.local -import java.io.File -import java.lang.management.ManagementFactory -import java.util.concurrent.atomic.AtomicInteger import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} + +import akka.actor._ import org.apache.spark._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.executor.ExecutorURLClassLoader +import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode -import akka.actor._ -import org.apache.spark.util.Utils + /** * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally @@ -41,52 +37,57 @@ import org.apache.spark.util.Utils * testing fault recovery. */ -private[spark] +private[local] case class LocalReviveOffers() -private[spark] +private[local] case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) +private[local] +case class KillTask(taskId: Long) + private[spark] -class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging { +class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int) + extends Actor with Logging { + + val executor = new Executor("localhost", "localhost", Seq.empty, isLocal = true) def receive = { case LocalReviveOffers => launchTask(localScheduler.resourceOffer(freeCores)) + case LocalStatusUpdate(taskId, state, serializeData) => - freeCores += 1 - localScheduler.statusUpdate(taskId, state, serializeData) - launchTask(localScheduler.resourceOffer(freeCores)) + if (TaskState.isFinished(state)) { + freeCores += 1 + launchTask(localScheduler.resourceOffer(freeCores)) + } + + case KillTask(taskId) => + executor.killTask(taskId) } - def launchTask(tasks : Seq[TaskDescription]) { + private def launchTask(tasks: Seq[TaskDescription]) { for (task <- tasks) { freeCores -= 1 - localScheduler.threadPool.submit(new Runnable { - def run() { - localScheduler.runTask(task.taskId, task.serializedTask) - } - }) + executor.launchTask(localScheduler, task.taskId, task.serializedTask) } } } private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) extends TaskScheduler + with ExecutorBackend with Logging { - var attemptId = new AtomicInteger(0) - var threadPool = Utils.newDaemonFixedThreadPool(threads) val env = SparkEnv.get - var listener: TaskSchedulerListener = null + val attemptId = new AtomicInteger + var dagScheduler: DAGScheduler = null // Application dependencies (added through SparkContext) that we've fetched so far on this node. // Each map holds the master's timestamp for the version of that file or JAR we got. val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() val currentJars: HashMap[String, Long] = new HashMap[String, Long]() - val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader) - var schedulableBuilder: SchedulableBuilder = null var rootPool: Pool = null val schedulingMode: SchedulingMode = SchedulingMode.withName( @@ -113,8 +114,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test") } - override def setListener(listener: TaskSchedulerListener) { - this.listener = listener + override def setDAGScheduler(dagScheduler: DAGScheduler) { + this.dagScheduler = dagScheduler } override def submitTasks(taskSet: TaskSet) { @@ -127,6 +128,26 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } } + override def cancelTasks(stageId: Int): Unit = synchronized { + logInfo("Cancelling stage " + stageId) + logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId)) + activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task and then abort + // the stage. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the stage. + val taskIds = taskSetTaskIds(tsm.taskSet.id) + if (taskIds.size > 0) { + taskIds.foreach { tid => + localActor ! KillTask(tid) + } + } + tsm.error("Stage %d was cancelled".format(stageId)) + } + } + def resourceOffer(freeCores: Int): Seq[TaskDescription] = { synchronized { var freeCpuCores = freeCores @@ -166,107 +187,32 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } } - def runTask(taskId: Long, bytes: ByteBuffer) { - logInfo("Running " + taskId) - val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL) - // Set the Spark execution environment for the worker thread - SparkEnv.set(env) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objectSer = SparkEnv.get.serializer.newInstance() - var attemptedTask: Option[Task[_]] = None - val start = System.currentTimeMillis() - var taskStart: Long = 0 - def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum - val startGCTime = getTotalGCTime - - try { - Accumulators.clear() - Thread.currentThread().setContextClassLoader(classLoader) - - // Serialize and deserialize the task so that accumulators are changed to thread-local ones; - // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes) - updateDependencies(taskFiles, taskJars) // Download any files added with addFile - val deserializedTask = ser.deserialize[Task[_]]( - taskBytes, Thread.currentThread.getContextClassLoader) - attemptedTask = Some(deserializedTask) - val deserTime = System.currentTimeMillis() - start - taskStart = System.currentTimeMillis() - - // Run it - val result: Any = deserializedTask.run(taskId) - - // Serialize and deserialize the result to emulate what the Mesos - // executor does. This is useful to catch serialization errors early - // on in development (so when users move their local Spark programs - // to the cluster, they don't get surprised by serialization errors). - val serResult = objectSer.serialize(result) - deserializedTask.metrics.get.resultSize = serResult.limit() - val resultToReturn = objectSer.deserialize[Any](serResult) - val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]]( - ser.serialize(Accumulators.values)) - val serviceTime = System.currentTimeMillis() - taskStart - logInfo("Finished " + taskId) - deserializedTask.metrics.get.executorRunTime = serviceTime.toInt - deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime - deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt - val taskResult = new DirectTaskResult( - result, accumUpdates, deserializedTask.metrics.getOrElse(null)) - val serializedResult = ser.serialize(taskResult) - localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult) - } catch { - case t: Throwable => { - val serviceTime = System.currentTimeMillis() - taskStart - val metrics = attemptedTask.flatMap(t => t.metrics) - for (m <- metrics) { - m.executorRunTime = serviceTime.toInt - m.jvmGCTime = getTotalGCTime - startGCTime - } - val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics) - localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure)) - } - } - } - - /** - * Download any missing dependencies if we receive a new set of files and JARs from the - * SparkContext. Also adds any new JARs we fetched to the class loader. - */ - private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - synchronized { - // Fetch missing dependencies - for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentFiles(name) = timestamp - } - - for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { - logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) - currentJars(name) = timestamp - // Add it to our class loader - val localName = name.split("/").last - val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL - if (!classLoader.getURLs.contains(url)) { - logInfo("Adding " + url + " to class loader") - classLoader.addURL(url) + override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { + if (TaskState.isFinished(state)) { + synchronized { + taskIdToTaskSetId.get(taskId) match { + case Some(taskSetId) => + val taskSetManager = activeTaskSets(taskSetId) + taskSetTaskIds(taskSetId) -= taskId + + state match { + case TaskState.FINISHED => + taskSetManager.taskEnded(taskId, state, serializedData) + case TaskState.FAILED => + taskSetManager.taskFailed(taskId, state, serializedData) + case TaskState.KILLED => + taskSetManager.error("Task %d was killed".format(taskId)) + case _ => {} + } + case None => + logInfo("Ignoring update from TID " + taskId + " because its task set is gone") } } + localActor ! LocalStatusUpdate(taskId, state, serializedData) } } - def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) { - synchronized { - val taskSetId = taskIdToTaskSetId(taskId) - val taskSetManager = activeTaskSets(taskSetId) - taskSetTaskIds(taskSetId) -= taskId - taskSetManager.statusUpdate(taskId, state, serializedData) - } - } - - override def stop() { - threadPool.shutdownNow() + override def stop() { } override def defaultParallelism() = threads diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala index c2e2399ccb..55f8313e87 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala @@ -132,19 +132,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas return None } - def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { - SparkEnv.set(env) - state match { - case TaskState.FINISHED => - taskEnded(tid, state, serializedData) - case TaskState.FAILED => - taskFailed(tid, state, serializedData) - case _ => {} - } - } - def taskStarted(task: Task[_], info: TaskInfo) { - sched.listener.taskStarted(task, info) + sched.dagScheduler.taskStarted(task, info) } def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { @@ -159,7 +148,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas } } result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics) + sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info, + result.metrics) numFinished += 1 decreaseRunningTasks(1) finished(index) = true @@ -176,7 +166,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas decreaseRunningTasks(1) val reason: ExceptionFailure = ser.deserialize[ExceptionFailure]( serializedData, getClass.getClassLoader) - sched.listener.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null)) + sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null)) if (!finished(index)) { copiesRunning(index) -= 1 numFailures(index) += 1 @@ -187,7 +177,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format( taskSet.id, index, 4, reason.description) decreaseRunningTasks(runningTasks) - sched.listener.taskSetFailed(taskSet, errorMessage) + sched.dagScheduler.taskSetFailed(taskSet, errorMessage) // need to delete failed Taskset from schedule queue sched.taskSetFinished(this) } @@ -195,5 +185,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas } override def error(message: String) { + sched.dagScheduler.taskSetFailed(taskSet, message) + sched.taskSetFinished(this) } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index e936b1cfed..55b25f145a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -26,9 +26,8 @@ import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar} import org.apache.spark.{SerializableWritable, Logging} -import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock, StorageLevel} - import org.apache.spark.broadcast.HttpBroadcast +import org.apache.spark.storage.{GetBlock,GotBlock, PutBlock, StorageLevel, TestBlockId} /** * A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]]. @@ -43,13 +42,14 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging val kryo = instantiator.newKryo() val classLoader = Thread.currentThread.getContextClassLoader + val blockId = TestBlockId("1") // Register some commonly used classes val toRegister: Seq[AnyRef] = Seq( ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY, - PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY), - GotBlock("1", ByteBuffer.allocate(1)), - GetBlock("1"), + PutBlock(blockId, ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY), + GotBlock(blockId, ByteBuffer.allocate(1)), + GetBlock(blockId), 1 to 10, 1 until 10, 1L to 10L, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockException.scala b/core/src/main/scala/org/apache/spark/storage/BlockException.scala index 290dbce4f5..0d0a2dadc7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockException.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockException.scala @@ -18,5 +18,5 @@ package org.apache.spark.storage private[spark] -case class BlockException(blockId: String, message: String) extends Exception(message) +case class BlockException(blockId: BlockId, message: String) extends Exception(message) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 3aeda3879d..e51c5b30a3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -47,7 +47,7 @@ import org.apache.spark.util.Utils */ private[storage] -trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])] +trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging with BlockFetchTracker { def initialize() } @@ -57,20 +57,20 @@ private[storage] object BlockFetcherIterator { // A request to fetch one or more blocks, complete with their sizes - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) { + class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { val size = blocks.map(_._2).sum } // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize // the block (since we want all deserializaton to happen in the calling thread); can also // represent a fetch failure if size == -1. - class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) { + class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { def failed: Boolean = size == -1 } class BasicBlockFetcherIterator( private val blockManager: BlockManager, - val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer) extends BlockFetcherIterator { @@ -92,12 +92,12 @@ object BlockFetcherIterator { // This represents the number of local blocks, also counting zero-sized blocks private var numLocal = 0 // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks - protected val localBlocksToFetch = new ArrayBuffer[String]() + protected val localBlocksToFetch = new ArrayBuffer[BlockId]() // This represents the number of remote blocks, also counting zero-sized blocks private var numRemote = 0 // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks - protected val remoteBlocksToFetch = new HashSet[String]() + protected val remoteBlocksToFetch = new HashSet[BlockId]() // A queue to hold our results. protected val results = new LinkedBlockingQueue[FetchResult] @@ -167,7 +167,7 @@ object BlockFetcherIterator { logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize) val iterator = blockInfos.iterator var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(String, Long)] + var curBlocks = new ArrayBuffer[(BlockId, Long)] while (iterator.hasNext) { val (blockId, size) = iterator.next() // Skip empty blocks @@ -183,7 +183,7 @@ object BlockFetcherIterator { // Add this FetchRequest remoteRequests += new FetchRequest(address, curBlocks) curRequestSize = 0 - curBlocks = new ArrayBuffer[(String, Long)] + curBlocks = new ArrayBuffer[(BlockId, Long)] } } // Add in the final request @@ -241,7 +241,7 @@ object BlockFetcherIterator { override def hasNext: Boolean = resultsGotten < _numBlocksToFetch - override def next(): (String, Option[Iterator[Any]]) = { + override def next(): (BlockId, Option[Iterator[Any]]) = { resultsGotten += 1 val startFetchWait = System.currentTimeMillis() val result = results.take() @@ -267,7 +267,7 @@ object BlockFetcherIterator { class NettyBlockFetcherIterator( blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], + blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer) extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) { @@ -303,7 +303,7 @@ object BlockFetcherIterator { override protected def sendRequest(req: FetchRequest) { - def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) { + def putResult(blockId: BlockId, blockSize: Long, blockData: ByteBuf) { val fetchResult = new FetchResult(blockId, blockSize, () => dataDeserialize(blockId, blockData.nioBuffer, serializer)) results.put(fetchResult) @@ -337,7 +337,7 @@ object BlockFetcherIterator { logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") } - override def next(): (String, Option[Iterator[Any]]) = { + override def next(): (BlockId, Option[Iterator[Any]]) = { resultsGotten += 1 val result = results.take() // If all the results has been retrieved, copiers will exit automatically diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala new file mode 100644 index 0000000000..7156d855d8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +/** + * Identifies a particular Block of data, usually associated with a single file. + * A Block can be uniquely identified by its filename, but each type of Block has a different + * set of keys which produce its unique name. + * + * If your BlockId should be serializable, be sure to add it to the BlockId.fromString() method. + */ +private[spark] sealed abstract class BlockId { + /** A globally unique identifier for this Block. Can be used for ser/de. */ + def name: String + + // convenience methods + def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None + def isRDD = isInstanceOf[RDDBlockId] + def isShuffle = isInstanceOf[ShuffleBlockId] + def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId] + + override def toString = name + override def hashCode = name.hashCode + override def equals(other: Any): Boolean = other match { + case o: BlockId => getClass == o.getClass && name.equals(o.name) + case _ => false + } +} + +private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { + def name = "rdd_" + rddId + "_" + splitIndex +} + +private[spark] +case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { + def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId +} + +private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId { + def name = "broadcast_" + broadcastId +} + +private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId { + def name = broadcastId.name + "_" + hType +} + +private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId { + def name = "taskresult_" + taskId +} + +private[spark] case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId { + def name = "input-" + streamId + "-" + uniqueId +} + +// Intended only for testing purposes +private[spark] case class TestBlockId(id: String) extends BlockId { + def name = "test_" + id +} + +private[spark] object BlockId { + val RDD = "rdd_([0-9]+)_([0-9]+)".r + val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r + val BROADCAST = "broadcast_([0-9]+)".r + val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r + val TASKRESULT = "taskresult_([0-9]+)".r + val STREAM = "input-([0-9]+)-([0-9]+)".r + val TEST = "test_(.*)".r + + /** Converts a BlockId "name" String back into a BlockId. */ + def apply(id: String) = id match { + case RDD(rddId, splitIndex) => + RDDBlockId(rddId.toInt, splitIndex.toInt) + case SHUFFLE(shuffleId, mapId, reduceId) => + ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) + case BROADCAST(broadcastId) => + BroadcastBlockId(broadcastId.toLong) + case BROADCAST_HELPER(broadcastId, hType) => + BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType) + case TASKRESULT(taskId) => + TaskResultBlockId(taskId.toLong) + case STREAM(streamId, uniqueId) => + StreamBlockId(streamId.toInt, uniqueId.toLong) + case TEST(value) => + TestBlockId(value) + case _ => + throw new IllegalStateException("Unrecognized BlockId: " + id) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 2322922f75..e6329cbd47 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -20,14 +20,15 @@ package org.apache.spark.storage import java.io.{InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} -import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet} +import scala.collection.mutable.{HashMap, ArrayBuffer} +import scala.util.Random import akka.actor.{ActorSystem, Cancellable, Props} import akka.dispatch.{Await, Future} import akka.util.Duration import akka.util.duration._ -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream +import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec @@ -37,7 +38,6 @@ import org.apache.spark.util._ import sun.nio.ch.DirectBuffer - private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, @@ -102,18 +102,19 @@ private[spark] class BlockManager( } val shuffleBlockManager = new ShuffleBlockManager(this) + val diskBlockManager = new DiskBlockManager( + System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) - private val blockInfo = new TimeStampedHashMap[String, BlockInfo] + private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) - private[storage] val diskStore: DiskStore = - new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) + private[storage] val diskStore = new DiskStore(this, diskBlockManager) // If we use Netty for shuffle, start a new Netty-based shuffle sender service. private val nettyPort: Int = { val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt - if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0 + if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0 } val connectionManager = new ConnectionManager(0) @@ -249,7 +250,7 @@ private[spark] class BlockManager( /** * Get storage level of local block. If no info exists for the block, then returns null. */ - def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull + def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull /** * Tell the master about the current storage status of a block. This will send a block update @@ -259,7 +260,7 @@ private[spark] class BlockManager( * droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid). * This ensures that update in master will compensate for the increase in memory on slave. */ - def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) { + def reportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L) { val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize) if (needReregister) { logInfo("Got told to reregister updating block " + blockId) @@ -270,11 +271,11 @@ private[spark] class BlockManager( } /** - * Actually send a UpdateBlockInfo message. Returns the mater's response, + * Actually send a UpdateBlockInfo message. Returns the master's response, * which will be true if the block was successfully recorded and false if * the slave needs to re-register. */ - private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = { + private def tryToReportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = { val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized { info.level match { case null => @@ -299,7 +300,7 @@ private[spark] class BlockManager( /** * Get locations of an array of blocks. */ - def getLocationBlockIds(blockIds: Array[String]): Array[Seq[BlockManagerId]] = { + def getLocationBlockIds(blockIds: Array[BlockId]): Array[Seq[BlockManagerId]] = { val startTimeMs = System.currentTimeMillis val locations = master.getLocations(blockIds).toArray logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) @@ -311,7 +312,7 @@ private[spark] class BlockManager( * shuffle blocks. It is safe to do so without a lock on block info since disk store * never deletes (recent) items. */ - def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { + def getLocalFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { diskStore.getValues(blockId, serializer).orElse( sys.error("Block " + blockId + " not found on disk, though it should be")) } @@ -319,94 +320,19 @@ private[spark] class BlockManager( /** * Get block from local block manager. */ - def getLocal(blockId: String): Option[Iterator[Any]] = { + def getLocal(blockId: BlockId): Option[Iterator[Any]] = { logDebug("Getting local block " + blockId) - val info = blockInfo.get(blockId).orNull - if (info != null) { - info.synchronized { - - // In the another thread is writing the block, wait for it to become ready. - if (!info.waitForReady()) { - // If we get here, the block write failed. - logWarning("Block " + blockId + " was marked as failure.") - return None - } - - val level = info.level - logDebug("Level for block " + blockId + " is " + level) - - // Look for the block in memory - if (level.useMemory) { - logDebug("Getting block " + blockId + " from memory") - memoryStore.getValues(blockId) match { - case Some(iterator) => - return Some(iterator) - case None => - logDebug("Block " + blockId + " not found in memory") - } - } - - // Look for block on disk, potentially loading it back into memory if required - if (level.useDisk) { - logDebug("Getting block " + blockId + " from disk") - if (level.useMemory && level.deserialized) { - diskStore.getValues(blockId) match { - case Some(iterator) => - // Put the block back in memory before returning it - // TODO: Consider creating a putValues that also takes in a iterator ? - val elements = new ArrayBuffer[Any] - elements ++= iterator - memoryStore.putValues(blockId, elements, level, true).data match { - case Left(iterator2) => - return Some(iterator2) - case _ => - throw new Exception("Memory store did not return back an iterator") - } - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } else if (level.useMemory && !level.deserialized) { - // Read it as a byte buffer into memory first, then return it - diskStore.getBytes(blockId) match { - case Some(bytes) => - // Put a copy of the block back in memory before returning it. Note that we can't - // put the ByteBuffer returned by the disk store as that's a memory-mapped file. - // The use of rewind assumes this. - assert (0 == bytes.position()) - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) - memoryStore.putBytes(blockId, copyForMemory, level) - bytes.rewind() - return Some(dataDeserialize(blockId, bytes)) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } else { - diskStore.getValues(blockId) match { - case Some(iterator) => - return Some(iterator) - case None => - throw new Exception("Block " + blockId + " not found on disk, though it should be") - } - } - } - } - } else { - logDebug("Block " + blockId + " not registered locally") - } - return None + doGetLocal(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]] } /** * Get block from the local block manager as serialized bytes. */ - def getLocalBytes(blockId: String): Option[ByteBuffer] = { - // TODO: This whole thing is very similar to getLocal; we need to refactor it somehow + def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = { logDebug("Getting local block " + blockId + " as bytes") - // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work - if (ShuffleBlockManager.isShuffle(blockId)) { + if (blockId.isShuffle) { return diskStore.getBytes(blockId) match { case Some(bytes) => Some(bytes) @@ -414,12 +340,15 @@ private[spark] class BlockManager( throw new Exception("Block " + blockId + " not found on disk, though it should be") } } + doGetLocal(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]] + } + private def doGetLocal(blockId: BlockId, asValues: Boolean): Option[Any] = { val info = blockInfo.get(blockId).orNull if (info != null) { info.synchronized { - // In the another thread is writing the block, wait for it to become ready. + // If another thread is writing the block, wait for it to become ready. if (!info.waitForReady()) { // If we get here, the block write failed. logWarning("Block " + blockId + " was marked as failure.") @@ -432,62 +361,104 @@ private[spark] class BlockManager( // Look for the block in memory if (level.useMemory) { logDebug("Getting block " + blockId + " from memory") - memoryStore.getBytes(blockId) match { - case Some(bytes) => - return Some(bytes) + val result = if (asValues) { + memoryStore.getValues(blockId) + } else { + memoryStore.getBytes(blockId) + } + result match { + case Some(values) => + return Some(values) case None => logDebug("Block " + blockId + " not found in memory") } } - // Look for block on disk + // Look for block on disk, potentially storing it back into memory if required: if (level.useDisk) { - // Read it as a byte buffer into memory first, then return it - diskStore.getBytes(blockId) match { - case Some(bytes) => - assert (0 == bytes.position()) - if (level.useMemory) { - if (level.deserialized) { - memoryStore.putBytes(blockId, bytes, level) - } else { - // The memory store will hang onto the ByteBuffer, so give it a copy instead of - // the memory-mapped file buffer we got from the disk store - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) - memoryStore.putBytes(blockId, copyForMemory, level) - } - } - bytes.rewind() - return Some(bytes) + logDebug("Getting block " + blockId + " from disk") + val bytes: ByteBuffer = diskStore.getBytes(blockId) match { + case Some(bytes) => bytes case None => throw new Exception("Block " + blockId + " not found on disk, though it should be") } + assert (0 == bytes.position()) + + if (!level.useMemory) { + // If the block shouldn't be stored in memory, we can just return it: + if (asValues) { + return Some(dataDeserialize(blockId, bytes)) + } else { + return Some(bytes) + } + } else { + // Otherwise, we also have to store something in the memory store: + if (!level.deserialized || !asValues) { + // We'll store the bytes in memory if the block's storage level includes + // "memory serialized", or if it should be cached as objects in memory + // but we only requested its serialized bytes: + val copyForMemory = ByteBuffer.allocate(bytes.limit) + copyForMemory.put(bytes) + memoryStore.putBytes(blockId, copyForMemory, level) + bytes.rewind() + } + if (!asValues) { + return Some(bytes) + } else { + val values = dataDeserialize(blockId, bytes) + if (level.deserialized) { + // Cache the values before returning them: + // TODO: Consider creating a putValues that also takes in a iterator? + val valuesBuffer = new ArrayBuffer[Any] + valuesBuffer ++= values + memoryStore.putValues(blockId, valuesBuffer, level, true).data match { + case Left(values2) => + return Some(values2) + case _ => + throw new Exception("Memory store did not return back an iterator") + } + } else { + return Some(values) + } + } + } } } } else { logDebug("Block " + blockId + " not registered locally") } - return None + None } /** * Get block from remote block managers. */ - def getRemote(blockId: String): Option[Iterator[Any]] = { - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } + def getRemote(blockId: BlockId): Option[Iterator[Any]] = { logDebug("Getting remote block " + blockId) - // Get locations of block - val locations = master.getLocations(blockId) + doGetRemote(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]] + } - // Get block from remote locations + /** + * Get block from remote block managers as serialized bytes. + */ + def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = { + logDebug("Getting remote block " + blockId + " as bytes") + doGetRemote(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]] + } + + private def doGetRemote(blockId: BlockId, asValues: Boolean): Option[Any] = { + require(blockId != null, "BlockId is null") + val locations = Random.shuffle(master.getLocations(blockId)) for (loc <- locations) { logDebug("Getting remote block " + blockId + " from " + loc) val data = BlockManagerWorker.syncGetBlock( GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) if (data != null) { - return Some(dataDeserialize(blockId, data)) + if (asValues) { + return Some(dataDeserialize(blockId, data)) + } else { + return Some(data) + } } logDebug("The value of block " + blockId + " is null") } @@ -496,34 +467,9 @@ private[spark] class BlockManager( } /** - * Get block from remote block managers as serialized bytes. - */ - def getRemoteBytes(blockId: String): Option[ByteBuffer] = { - // TODO: As with getLocalBytes, this is very similar to getRemote and perhaps should be - // refactored. - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - logDebug("Getting remote block " + blockId + " as bytes") - - val locations = master.getLocations(blockId) - for (loc <- locations) { - logDebug("Getting remote block " + blockId + " from " + loc) - val data = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) - if (data != null) { - return Some(data) - } - logDebug("The value of block " + blockId + " is null") - } - logDebug("Block " + blockId + " not found") - return None - } - - /** * Get a block from the block manager (either local or remote). */ - def get(blockId: String): Option[Iterator[Any]] = { + def get(blockId: BlockId): Option[Iterator[Any]] = { val local = getLocal(blockId) if (local.isDefined) { logInfo("Found block %s locally".format(blockId)) @@ -544,7 +490,7 @@ private[spark] class BlockManager( * so that we can control the maxMegabytesInFlight for the fetch. */ def getMultiple( - blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer) + blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer) : BlockFetcherIterator = { val iter = @@ -558,7 +504,7 @@ private[spark] class BlockManager( iter } - def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) + def put(blockId: BlockId, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean) : Long = { val elements = new ArrayBuffer[Any] elements ++= values @@ -567,16 +513,20 @@ private[spark] class BlockManager( /** * A short circuited method to get a block writer that can write data directly to disk. + * The Block will be appended to the File specified by filename. * This is currently used for writing shuffle files out. Callers should handle error * cases. */ - def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) + def getDiskWriter(blockId: BlockId, filename: String, serializer: Serializer, bufferSize: Int) : BlockObjectWriter = { - val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize) + val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) + val file = diskBlockManager.createBlockFile(blockId, filename, allowAppending = true) + val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream) writer.registerCloseEventHandler(() => { + diskBlockManager.mapBlockToFileSegment(blockId, writer.fileSegment()) val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false) blockInfo.put(blockId, myInfo) - myInfo.markReady(writer.size()) + myInfo.markReady(writer.fileSegment().length) }) writer } @@ -584,18 +534,25 @@ private[spark] class BlockManager( /** * Put a new block of values to the block manager. Returns its (estimated) size in bytes. */ - def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, - tellMaster: Boolean = true) : Long = { + def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, + tellMaster: Boolean = true) : Long = { + require(values != null, "Values is null") + doPut(blockId, Left(values), level, tellMaster) + } - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - if (values == null) { - throw new IllegalArgumentException("Values is null") - } - if (level == null || !level.isValid) { - throw new IllegalArgumentException("Storage level is null or invalid") - } + /** + * Put a new block of serialized bytes to the block manager. + */ + def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel, + tellMaster: Boolean = true) { + require(bytes != null, "Bytes is null") + doPut(blockId, Right(bytes), level, tellMaster) + } + + private def doPut(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer], + level: StorageLevel, tellMaster: Boolean = true): Long = { + require(blockId != null, "BlockId is null") + require(level != null && level.isValid, "StorageLevel is null or invalid") // Remember the block's storage level so that we can correctly drop it to disk if it needs // to be dropped right after it got put into memory. Note, however, that other threads will @@ -611,7 +568,8 @@ private[spark] class BlockManager( return oldBlockOpt.get.size } - // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? + // TODO: So the block info exists - but previous attempt to load it (?) failed. + // What do we do now ? Retry on it ? oldBlockOpt.get } else { tinfo @@ -620,10 +578,10 @@ private[spark] class BlockManager( val startTimeMs = System.currentTimeMillis - // If we need to replicate the data, we'll want access to the values, but because our - // put will read the whole iterator, there will be no values left. For the case where - // the put serializes data, we'll remember the bytes, above; but for the case where it - // doesn't, such as deserialized storage, let's rely on the put returning an Iterator. + // If we're storing values and we need to replicate the data, we'll want access to the values, + // but because our put will read the whole iterator, there will be no values left. For the + // case where the put serializes data, we'll remember the bytes, above; but for the case where + // it doesn't, such as deserialized storage, let's rely on the put returning an Iterator. var valuesAfterPut: Iterator[Any] = null // Ditto for the bytes after the put @@ -632,30 +590,51 @@ private[spark] class BlockManager( // Size of the block in bytes (to return to caller) var size = 0L + // If we're storing bytes, then initiate the replication before storing them locally. + // This is faster as data is already serialized and ready to send. + val replicationFuture = if (data.isRight && level.replication > 1) { + val bufferView = data.right.get.duplicate() // Doesn't copy the bytes, just creates a wrapper + Future { + replicate(blockId, bufferView, level) + } + } else { + null + } + myInfo.synchronized { logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) + " to get into synchronized block") var marked = false try { - if (level.useMemory) { - // Save it just to memory first, even if it also has useDisk set to true; we will later - // drop it to disk if the memory store can't hold it. - val res = memoryStore.putValues(blockId, values, level, true) - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case Left(newIterator) => valuesAfterPut = newIterator + data match { + case Left(values) => { + if (level.useMemory) { + // Save it just to memory first, even if it also has useDisk set to true; we will + // drop it to disk later if the memory store can't hold it. + val res = memoryStore.putValues(blockId, values, level, true) + size = res.size + res.data match { + case Right(newBytes) => bytesAfterPut = newBytes + case Left(newIterator) => valuesAfterPut = newIterator + } + } else { + // Save directly to disk. + // Don't get back the bytes unless we replicate them. + val askForBytes = level.replication > 1 + val res = diskStore.putValues(blockId, values, level, askForBytes) + size = res.size + res.data match { + case Right(newBytes) => bytesAfterPut = newBytes + case _ => + } + } } - } else { - // Save directly to disk. - // Don't get back the bytes unless we replicate them. - val askForBytes = level.replication > 1 - val res = diskStore.putValues(blockId, values, level, askForBytes) - size = res.size - res.data match { - case Right(newBytes) => bytesAfterPut = newBytes - case _ => + case Right(bytes) => { + bytes.rewind() + // Store it only in memory at first, even if useDisk is also set to true + (if (level.useMemory) memoryStore else diskStore).putBytes(blockId, bytes, level) + size = bytes.limit } } @@ -680,132 +659,46 @@ private[spark] class BlockManager( } logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) - // Replicate block if required + // Either we're storing bytes and we asynchronously started replication, or we're storing + // values and need to serialize and replicate them now: if (level.replication > 1) { - val remoteStartTime = System.currentTimeMillis - // Serialize the block if not already done - if (bytesAfterPut == null) { - if (valuesAfterPut == null) { - throw new SparkException( - "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") - } - bytesAfterPut = dataSerialize(blockId, valuesAfterPut) - } - replicate(blockId, bytesAfterPut, level) - logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime)) - } - BlockManager.dispose(bytesAfterPut) - - return size - } - - - /** - * Put a new block of serialized bytes to the block manager. - */ - def putBytes( - blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) { - - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - if (bytes == null) { - throw new IllegalArgumentException("Bytes is null") - } - if (level == null || !level.isValid) { - throw new IllegalArgumentException("Storage level is null or invalid") - } - - // Remember the block's storage level so that we can correctly drop it to disk if it needs - // to be dropped right after it got put into memory. Note, however, that other threads will - // not be able to get() this block until we call markReady on its BlockInfo. - val myInfo = { - val tinfo = new BlockInfo(level, tellMaster) - // Do atomically ! - val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) - - if (oldBlockOpt.isDefined) { - if (oldBlockOpt.get.waitForReady()) { - logWarning("Block " + blockId + " already exists on this machine; not re-adding it") - return - } - - // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ? - oldBlockOpt.get - } else { - tinfo - } - } - - val startTimeMs = System.currentTimeMillis - - // Initiate the replication before storing it locally. This is faster as - // data is already serialized and ready for sending - val replicationFuture = if (level.replication > 1) { - val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper - Future { - replicate(blockId, bufferView, level) - } - } else { - null - } - - myInfo.synchronized { - logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) - + " to get into synchronized block") - - var marked = false - try { - if (level.useMemory) { - // Store it only in memory at first, even if useDisk is also set to true - bytes.rewind() - memoryStore.putBytes(blockId, bytes, level) - } else { - bytes.rewind() - diskStore.putBytes(blockId, bytes, level) - } - - // assert (0 == bytes.position(), "" + bytes) - - // Now that the block is in either the memory or disk store, let other threads read it, - // and tell the master about it. - marked = true - myInfo.markReady(bytes.limit) - if (tellMaster) { - reportBlockStatus(blockId, myInfo) - } - } finally { - // If we failed at putting the block to memory/disk, notify other possible readers - // that it has failed, and then remove it from the block info map. - if (! marked) { - // Note that the remove must happen before markFailure otherwise another thread - // could've inserted a new BlockInfo before we remove it. - blockInfo.remove(blockId) - myInfo.markFailure() - logWarning("Putting block " + blockId + " failed") + data match { + case Right(bytes) => Await.ready(replicationFuture, Duration.Inf) + case Left(values) => { + val remoteStartTime = System.currentTimeMillis + // Serialize the block if not already done + if (bytesAfterPut == null) { + if (valuesAfterPut == null) { + throw new SparkException( + "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") + } + bytesAfterPut = dataSerialize(blockId, valuesAfterPut) + } + replicate(blockId, bytesAfterPut, level) + logDebug("Put block " + blockId + " remotely took " + + Utils.getUsedTimeMs(remoteStartTime)) } } } - // If replication had started, then wait for it to finish - if (level.replication > 1) { - Await.ready(replicationFuture, Duration.Inf) - } + BlockManager.dispose(bytesAfterPut) if (level.replication > 1) { - logDebug("PutBytes for block " + blockId + " with replication took " + + logDebug("Put for block " + blockId + " with replication took " + Utils.getUsedTimeMs(startTimeMs)) } else { - logDebug("PutBytes for block " + blockId + " without replication took " + + logDebug("Put for block " + blockId + " without replication took " + Utils.getUsedTimeMs(startTimeMs)) } + + size } /** * Replicate block to another node. */ var cachedPeers: Seq[BlockManagerId] = null - private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { + private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel) { val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) if (cachedPeers == null) { cachedPeers = master.getPeers(blockManagerId, level.replication - 1) @@ -828,14 +721,14 @@ private[spark] class BlockManager( /** * Read a block consisting of a single object. */ - def getSingle(blockId: String): Option[Any] = { + def getSingle(blockId: BlockId): Option[Any] = { get(blockId).map(_.next()) } /** * Write a block consisting of a single object. */ - def putSingle(blockId: String, value: Any, level: StorageLevel, tellMaster: Boolean = true) { + def putSingle(blockId: BlockId, value: Any, level: StorageLevel, tellMaster: Boolean = true) { put(blockId, Iterator(value), level, tellMaster) } @@ -843,7 +736,7 @@ private[spark] class BlockManager( * Drop a block from memory, possibly putting it on disk if applicable. Called when the memory * store reaches its limit and needs to free up space. */ - def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) { + def dropFromMemory(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer]) { logInfo("Dropping block " + blockId + " from memory") val info = blockInfo.get(blockId).orNull if (info != null) { @@ -892,16 +785,15 @@ private[spark] class BlockManager( // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps // from RDD.id to blocks. logInfo("Removing RDD " + rddId) - val rddPrefix = "rdd_" + rddId + "_" - val blocksToRemove = blockInfo.filter(_._1.startsWith(rddPrefix)).map(_._1) - blocksToRemove.foreach(blockId => removeBlock(blockId, false)) + val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) + blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false)) blocksToRemove.size } /** * Remove a block from both memory and disk. */ - def removeBlock(blockId: String, tellMaster: Boolean = true) { + def removeBlock(blockId: BlockId, tellMaster: Boolean = true) { logInfo("Removing block " + blockId) val info = blockInfo.get(blockId).orNull if (info != null) info.synchronized { @@ -924,34 +816,20 @@ private[spark] class BlockManager( private def dropOldNonBroadcastBlocks(cleanupTime: Long) { logInfo("Dropping non broadcast blocks older than " + cleanupTime) - val iterator = blockInfo.internalMap.entrySet().iterator() - while (iterator.hasNext) { - val entry = iterator.next() - val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) - if (time < cleanupTime && ! BlockManager.isBroadcastBlock(id) ) { - info.synchronized { - val level = info.level - if (level.useMemory) { - memoryStore.remove(id) - } - if (level.useDisk) { - diskStore.remove(id) - } - iterator.remove() - logInfo("Dropped block " + id) - } - reportBlockStatus(id, info) - } - } + dropOldBlocks(cleanupTime, !_.isBroadcast) } private def dropOldBroadcastBlocks(cleanupTime: Long) { logInfo("Dropping broadcast blocks older than " + cleanupTime) + dropOldBlocks(cleanupTime, _.isBroadcast) + } + + private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) { val iterator = blockInfo.internalMap.entrySet().iterator() while (iterator.hasNext) { val entry = iterator.next() val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) - if (time < cleanupTime && BlockManager.isBroadcastBlock(id) ) { + if (time < cleanupTime && shouldDrop(id)) { info.synchronized { val level = info.level if (level.useMemory) { @@ -968,39 +846,45 @@ private[spark] class BlockManager( } } - def shouldCompress(blockId: String): Boolean = { - if (ShuffleBlockManager.isShuffle(blockId)) { - compressShuffle - } else if (BlockManager.isBroadcastBlock(blockId)) { - compressBroadcast - } else if (blockId.startsWith("rdd_")) { - compressRdds - } else { - false // Won't happen in a real cluster, but it can in tests - } + def shouldCompress(blockId: BlockId): Boolean = blockId match { + case ShuffleBlockId(_, _, _) => compressShuffle + case BroadcastBlockId(_) => compressBroadcast + case RDDBlockId(_, _) => compressRdds + case _ => false } /** * Wrap an output stream for compression if block compression is enabled for its block type */ - def wrapForCompression(blockId: String, s: OutputStream): OutputStream = { + def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s } /** * Wrap an input stream for compression if block compression is enabled for its block type */ - def wrapForCompression(blockId: String, s: InputStream): InputStream = { + def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s } + /** Serializes into a stream. */ + def dataSerializeStream( + blockId: BlockId, + outputStream: OutputStream, + values: Iterator[Any], + serializer: Serializer = defaultSerializer) { + val byteStream = new FastBufferedOutputStream(outputStream) + val ser = serializer.newInstance() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() + } + + /** Serializes into a byte buffer. */ def dataSerialize( - blockId: String, + blockId: BlockId, values: Iterator[Any], serializer: Serializer = defaultSerializer): ByteBuffer = { val byteStream = new FastByteArrayOutputStream(4096) - val ser = serializer.newInstance() - ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() + dataSerializeStream(blockId, byteStream, values, serializer) byteStream.trim() ByteBuffer.wrap(byteStream.array) } @@ -1010,7 +894,7 @@ private[spark] class BlockManager( * the iterator is reached. */ def dataDeserialize( - blockId: String, + blockId: BlockId, bytes: ByteBuffer, serializer: Serializer = defaultSerializer): Iterator[Any] = { bytes.rewind() @@ -1065,10 +949,10 @@ private[spark] object BlockManager extends Logging { } def blockIdsToBlockManagers( - blockIds: Array[String], + blockIds: Array[BlockId], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null) - : Map[String, Seq[BlockManagerId]] = + : Map[BlockId, Seq[BlockManagerId]] = { // env == null and blockManagerMaster != null is used in tests assert (env != null || blockManagerMaster != null) @@ -1078,7 +962,7 @@ private[spark] object BlockManager extends Logging { blockManagerMaster.getLocations(blockIds) } - val blockManagers = new HashMap[String, Seq[BlockManagerId]] + val blockManagers = new HashMap[BlockId, Seq[BlockManagerId]] for (i <- 0 until blockIds.length) { blockManagers(blockIds(i)) = blockLocations(i) } @@ -1086,25 +970,21 @@ private[spark] object BlockManager extends Logging { } def blockIdsToExecutorIds( - blockIds: Array[String], + blockIds: Array[BlockId], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null) - : Map[String, Seq[String]] = + : Map[BlockId, Seq[String]] = { blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId)) } def blockIdsToHosts( - blockIds: Array[String], + blockIds: Array[BlockId], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null) - : Map[String, Seq[String]] = + : Map[BlockId, Seq[String]] = { blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host)) } - - def isBroadcastBlock(blockId: String): Boolean = null != blockId && blockId.startsWith("broadcast_") - - def toBroadcastId(id: Long): String = "broadcast_" + id } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index cf463d6ffc..94038649b3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -60,7 +60,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi def updateBlockInfo( blockManagerId: BlockManagerId, - blockId: String, + blockId: BlockId, storageLevel: StorageLevel, memSize: Long, diskSize: Long): Boolean = { @@ -71,12 +71,12 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi } /** Get locations of the blockId from the driver */ - def getLocations(blockId: String): Seq[BlockManagerId] = { + def getLocations(blockId: BlockId): Seq[BlockManagerId] = { askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId)) } /** Get locations of multiple blockIds from the driver */ - def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } @@ -94,7 +94,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi * Remove a block from the slaves that have it. This can only be used to remove * blocks that the driver knows about. */ - def removeBlock(blockId: String) { + def removeBlock(blockId: BlockId) { askDriverWithReply(RemoveBlock(blockId)) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index c7b23ab094..f8cf14b503 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -48,7 +48,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId] // Mapping from block id to the set of block managers that have the block. - private val blockLocations = new JHashMap[String, mutable.HashSet[BlockManagerId]] + private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] val akkaTimeout = Duration.create( System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") @@ -129,10 +129,9 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { // First remove the metadata for the given RDD, and then asynchronously remove the blocks // from the slaves. - val prefix = "rdd_" + rddId + "_" // Find all blocks for the given RDD, remove the block from both blockLocations and // the blockManagerInfo that is tracking the blocks. - val blocks = blockLocations.keySet().filter(_.startsWith(prefix)) + val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) blocks.foreach { blockId => val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) @@ -198,7 +197,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { // Remove a block from the slaves that have it. This can only be used to remove // blocks that the master knows about. - private def removeBlockFromWorkers(blockId: String) { + private def removeBlockFromWorkers(blockId: BlockId) { val locations = blockLocations.get(blockId) if (locations != null) { locations.foreach { blockManagerId: BlockManagerId => @@ -228,9 +227,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { } private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { - if (id.executorId == "<driver>" && !isLocal) { - // Got a register message from the master node; don't register it - } else if (!blockManagerInfo.contains(id)) { + if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { case Some(manager) => // A block manager of the same executor already exists. @@ -247,7 +244,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { private def updateBlockInfo( blockManagerId: BlockManagerId, - blockId: String, + blockId: BlockId, storageLevel: StorageLevel, memSize: Long, diskSize: Long) { @@ -292,11 +289,11 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { sender ! true } - private def getLocations(blockId: String): Seq[BlockManagerId] = { + private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty } - private def getLocationsMultipleBlockIds(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { + private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { blockIds.map(blockId => getLocations(blockId)) } @@ -330,7 +327,7 @@ object BlockManagerMasterActor { private var _remainingMem: Long = maxMem // Mapping from block id to its status. - private val _blocks = new JHashMap[String, BlockStatus] + private val _blocks = new JHashMap[BlockId, BlockStatus] logInfo("Registering block manager %s with %s RAM".format( blockManagerId.hostPort, Utils.bytesToString(maxMem))) @@ -339,7 +336,7 @@ object BlockManagerMasterActor { _lastSeenMs = System.currentTimeMillis() } - def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long, + def updateBlockInfo(blockId: BlockId, storageLevel: StorageLevel, memSize: Long, diskSize: Long) { updateLastSeenMs() @@ -383,7 +380,7 @@ object BlockManagerMasterActor { } } - def removeBlock(blockId: String) { + def removeBlock(blockId: BlockId) { if (_blocks.containsKey(blockId)) { _remainingMem += _blocks.get(blockId).memSize _blocks.remove(blockId) @@ -394,7 +391,7 @@ object BlockManagerMasterActor { def lastSeenMs: Long = _lastSeenMs - def blocks: JHashMap[String, BlockStatus] = _blocks + def blocks: JHashMap[BlockId, BlockStatus] = _blocks override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 24333a179c..45f51da288 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -30,7 +30,7 @@ private[storage] object BlockManagerMessages { // Remove a block from the slaves that have it. This can only be used to remove // blocks that the master knows about. - case class RemoveBlock(blockId: String) extends ToBlockManagerSlave + case class RemoveBlock(blockId: BlockId) extends ToBlockManagerSlave // Remove all blocks belonging to a specific RDD. case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave @@ -51,7 +51,7 @@ private[storage] object BlockManagerMessages { class UpdateBlockInfo( var blockManagerId: BlockManagerId, - var blockId: String, + var blockId: BlockId, var storageLevel: StorageLevel, var memSize: Long, var diskSize: Long) @@ -62,7 +62,7 @@ private[storage] object BlockManagerMessages { override def writeExternal(out: ObjectOutput) { blockManagerId.writeExternal(out) - out.writeUTF(blockId) + out.writeUTF(blockId.name) storageLevel.writeExternal(out) out.writeLong(memSize) out.writeLong(diskSize) @@ -70,7 +70,7 @@ private[storage] object BlockManagerMessages { override def readExternal(in: ObjectInput) { blockManagerId = BlockManagerId(in) - blockId = in.readUTF() + blockId = BlockId(in.readUTF()) storageLevel = StorageLevel(in) memSize = in.readLong() diskSize = in.readLong() @@ -79,7 +79,7 @@ private[storage] object BlockManagerMessages { object UpdateBlockInfo { def apply(blockManagerId: BlockManagerId, - blockId: String, + blockId: BlockId, storageLevel: StorageLevel, memSize: Long, diskSize: Long): UpdateBlockInfo = { @@ -87,14 +87,14 @@ private[storage] object BlockManagerMessages { } // For pattern-matching - def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { + def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, BlockId, StorageLevel, Long, Long)] = { Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) } } - case class GetLocations(blockId: String) extends ToBlockManagerMaster + case class GetLocations(blockId: BlockId) extends ToBlockManagerMaster - case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster + case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 951503019f..3a65e55733 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -26,6 +26,7 @@ import org.apache.spark.storage.BlockManagerMessages._ * An actor to take commands from the master to execute options. For example, * this is used to remove blocks from the slave's BlockManager. */ +private[storage] class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor { override def receive = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala index 678c38203c..0c66addf9d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala @@ -77,7 +77,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends } } - private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) { + private def putBlock(id: BlockId, bytes: ByteBuffer, level: StorageLevel) { val startTimeMs = System.currentTimeMillis() logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes) blockManager.putBytes(id, bytes, level) @@ -85,7 +85,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends + " with data size: " + bytes.limit) } - private def getBlock(id: String): ByteBuffer = { + private def getBlock(id: BlockId): ByteBuffer = { val startTimeMs = System.currentTimeMillis() logDebug("GetBlock " + id + " started from " + startTimeMs) val buffer = blockManager.getLocalBytes(id) match { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala index d8fa6a91d1..80dcb5a207 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.network._ -private[spark] case class GetBlock(id: String) -private[spark] case class GotBlock(id: String, data: ByteBuffer) -private[spark] case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel) +private[spark] case class GetBlock(id: BlockId) +private[spark] case class GotBlock(id: BlockId, data: ByteBuffer) +private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel) private[spark] class BlockMessage() { // Un-initialized: typ = 0 @@ -34,7 +34,7 @@ private[spark] class BlockMessage() { // GotBlock: typ = 2 // PutBlock: typ = 3 private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED - private var id: String = null + private var id: BlockId = null private var data: ByteBuffer = null private var level: StorageLevel = null @@ -74,7 +74,7 @@ private[spark] class BlockMessage() { for (i <- 1 to idLength) { idBuilder += buffer.getChar() } - id = idBuilder.toString() + id = BlockId(idBuilder.toString) if (typ == BlockMessage.TYPE_PUT_BLOCK) { @@ -109,28 +109,17 @@ private[spark] class BlockMessage() { set(buffer) } - def getType: Int = { - return typ - } - - def getId: String = { - return id - } - - def getData: ByteBuffer = { - return data - } - - def getLevel: StorageLevel = { - return level - } + def getType: Int = typ + def getId: BlockId = id + def getData: ByteBuffer = data + def getLevel: StorageLevel = level def toBufferMessage: BufferMessage = { val startTime = System.currentTimeMillis val buffers = new ArrayBuffer[ByteBuffer]() - var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2) - buffer.putInt(typ).putInt(id.length()) - id.foreach((x: Char) => buffer.putChar(x)) + var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2) + buffer.putInt(typ).putInt(id.name.length) + id.name.foreach((x: Char) => buffer.putChar(x)) buffer.flip() buffers += buffer @@ -212,7 +201,8 @@ private[spark] object BlockMessage { def main(args: Array[String]) { val B = new BlockMessage() - B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2)) + val blockId = TestBlockId("ABC") + B.set(new PutBlock(blockId, ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2)) val bMsg = B.toBufferMessage val C = new BlockMessage() C.set(bMsg) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala index 0aaf846b5b..6ce9127c74 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala @@ -111,14 +111,15 @@ private[spark] object BlockMessageArray { } def main(args: Array[String]) { - val blockMessages = + val blockMessages = (0 until 10).map { i => if (i % 2 == 0) { val buffer = ByteBuffer.allocate(100) buffer.clear - BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY_SER)) + BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer, + StorageLevel.MEMORY_ONLY_SER)) } else { - BlockMessage.fromGetBlock(GetBlock(i.toString)) + BlockMessage.fromGetBlock(GetBlock(TestBlockId(i.toString))) } } val blockMessageArray = new BlockMessageArray(blockMessages) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 39f103297f..32d2dd0694 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -17,6 +17,13 @@ package org.apache.spark.storage +import java.io.{FileOutputStream, File, OutputStream} +import java.nio.channels.FileChannel + +import it.unimi.dsi.fastutil.io.FastBufferedOutputStream + +import org.apache.spark.Logging +import org.apache.spark.serializer.{SerializationStream, Serializer} /** * An interface for writing JVM objects to some underlying storage. This interface allows @@ -25,7 +32,7 @@ package org.apache.spark.storage * * This interface does not support concurrent writes. */ -abstract class BlockObjectWriter(val blockId: String) { +abstract class BlockObjectWriter(val blockId: BlockId) { var closeEventHandler: () => Unit = _ @@ -59,7 +66,129 @@ abstract class BlockObjectWriter(val blockId: String) { def write(value: Any) /** - * Size of the valid writes, in bytes. + * Returns the file segment of committed data that this Writer has written. + */ + def fileSegment(): FileSegment + + /** + * Cumulative time spent performing blocking writes, in ns. */ - def size(): Long + def timeWriting(): Long +} + +/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */ +class DiskBlockObjectWriter( + blockId: BlockId, + file: File, + serializer: Serializer, + bufferSize: Int, + compressStream: OutputStream => OutputStream) + extends BlockObjectWriter(blockId) + with Logging +{ + + /** Intercepts write calls and tracks total time spent writing. Not thread safe. */ + private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream { + def timeWriting = _timeWriting + private var _timeWriting = 0L + + private def callWithTiming(f: => Unit) = { + val start = System.nanoTime() + f + _timeWriting += (System.nanoTime() - start) + } + + def write(i: Int): Unit = callWithTiming(out.write(i)) + override def write(b: Array[Byte]) = callWithTiming(out.write(b)) + override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len)) + } + + private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean + + /** The file channel, used for repositioning / truncating the file. */ + private var channel: FileChannel = null + private var bs: OutputStream = null + private var fos: FileOutputStream = null + private var ts: TimeTrackingOutputStream = null + private var objOut: SerializationStream = null + private var initialPosition = 0L + private var lastValidPosition = 0L + private var initialized = false + private var _timeWriting = 0L + + override def open(): BlockObjectWriter = { + fos = new FileOutputStream(file, true) + ts = new TimeTrackingOutputStream(fos) + channel = fos.getChannel() + initialPosition = channel.position + lastValidPosition = initialPosition + bs = compressStream(new FastBufferedOutputStream(ts, bufferSize)) + objOut = serializer.newInstance().serializeStream(bs) + initialized = true + this + } + + override def close() { + if (initialized) { + if (syncWrites) { + // Force outstanding writes to disk and track how long it takes + objOut.flush() + val start = System.nanoTime() + fos.getFD.sync() + _timeWriting += System.nanoTime() - start + } + objOut.close() + + _timeWriting += ts.timeWriting + + channel = null + bs = null + fos = null + ts = null + objOut = null + } + // Invoke the close callback handler. + super.close() + } + + override def isOpen: Boolean = objOut != null + + override def commit(): Long = { + if (initialized) { + // NOTE: Flush the serializer first and then the compressed/buffered output stream + objOut.flush() + bs.flush() + val prevPos = lastValidPosition + lastValidPosition = channel.position() + lastValidPosition - prevPos + } else { + // lastValidPosition is zero if stream is uninitialized + lastValidPosition + } + } + + override def revertPartialWrites() { + if (initialized) { + // Discard current writes. We do this by flushing the outstanding writes and + // truncate the file to the last valid position. + objOut.flush() + bs.flush() + channel.truncate(lastValidPosition) + } + } + + override def write(value: Any) { + if (!initialized) { + open() + } + objOut.writeObject(value) + } + + override def fileSegment(): FileSegment = { + val bytesWritten = lastValidPosition - initialPosition + new FileSegment(file, initialPosition, bytesWritten) + } + + // Only valid if called after close() + override def timeWriting() = _timeWriting } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala index fa834371f4..ea42656240 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala @@ -27,7 +27,7 @@ import org.apache.spark.Logging */ private[spark] abstract class BlockStore(val blockManager: BlockManager) extends Logging { - def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) + def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) /** * Put in a block and, possibly, also return its content as either bytes or another Iterator. @@ -36,26 +36,26 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging { * @return a PutResult that contains the size of the data, as well as the values put if * returnValues is true (if not, the result's data field can be null) */ - def putValues(blockId: String, values: ArrayBuffer[Any], level: StorageLevel, + def putValues(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, returnValues: Boolean) : PutResult /** * Return the size of a block in bytes. */ - def getSize(blockId: String): Long + def getSize(blockId: BlockId): Long - def getBytes(blockId: String): Option[ByteBuffer] + def getBytes(blockId: BlockId): Option[ByteBuffer] - def getValues(blockId: String): Option[Iterator[Any]] + def getValues(blockId: BlockId): Option[Iterator[Any]] /** * Remove a block, if it exists. * @param blockId the block to remove. * @return True if the block was found and removed, False otherwise. */ - def remove(blockId: String): Boolean + def remove(blockId: BlockId): Boolean - def contains(blockId: String): Boolean + def contains(blockId: BlockId): Boolean def clear() { } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala new file mode 100644 index 0000000000..bcb58ad946 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.File +import java.text.SimpleDateFormat +import java.util.{Date, Random} +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.Logging +import org.apache.spark.executor.ExecutorExitCode +import org.apache.spark.network.netty.{PathResolver, ShuffleSender} +import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} + +/** + * Creates and maintains the logical mapping between logical blocks and physical on-disk + * locations. By default, one block is mapped to one file with a name given by its BlockId. + * However, it is also possible to have a block map to only a segment of a file, by calling + * mapBlockToFileSegment(). + * + * @param rootDirs The directories to use for storing block files. Data will be hashed among these. + */ +private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver with Logging { + + private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 + private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt + + // Create one local directory for each path mentioned in spark.local.dir; then, inside this + // directory, create multiple subdirectories that we will hash files into, in order to avoid + // having really large inodes at the top level. + private val localDirs: Array[File] = createLocalDirs() + private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) + private var shuffleSender : ShuffleSender = null + + // Stores only Blocks which have been specifically mapped to segments of files + // (rather than the default, which maps a Block to a whole file). + // This keeps our bookkeeping down, since the file system itself tracks the standalone Blocks. + private val blockToFileSegmentMap = new TimeStampedHashMap[BlockId, FileSegment] + + val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DISK_BLOCK_MANAGER, this.cleanup) + + addShutdownHook() + + /** + * Creates a logical mapping from the given BlockId to a segment of a file. + * This will cause any accesses of the logical BlockId to be directed to the specified + * physical location. + */ + def mapBlockToFileSegment(blockId: BlockId, fileSegment: FileSegment) { + blockToFileSegmentMap.put(blockId, fileSegment) + } + + /** + * Returns the phyiscal file segment in which the given BlockId is located. + * If the BlockId has been mapped to a specific FileSegment, that will be returned. + * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly. + */ + def getBlockLocation(blockId: BlockId): FileSegment = { + if (blockToFileSegmentMap.internalMap.containsKey(blockId)) { + blockToFileSegmentMap.get(blockId).get + } else { + val file = getFile(blockId.name) + new FileSegment(file, 0, file.length()) + } + } + + /** + * Simply returns a File to place the given Block into. This does not physically create the file. + * If filename is given, that file will be used. Otherwise, we will use the BlockId to get + * a unique filename. + */ + def createBlockFile(blockId: BlockId, filename: String = "", allowAppending: Boolean): File = { + val actualFilename = if (filename == "") blockId.name else filename + val file = getFile(actualFilename) + if (!allowAppending && file.exists()) { + throw new IllegalStateException( + "Attempted to create file that already exists: " + actualFilename) + } + file + } + + private def getFile(filename: String): File = { + // Figure out which local directory it hashes to, and which subdirectory in that + val hash = Utils.nonNegativeHash(filename) + val dirId = hash % localDirs.length + val subDirId = (hash / localDirs.length) % subDirsPerLocalDir + + // Create the subdirectory if it doesn't already exist + var subDir = subDirs(dirId)(subDirId) + if (subDir == null) { + subDir = subDirs(dirId).synchronized { + val old = subDirs(dirId)(subDirId) + if (old != null) { + old + } else { + val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) + newDir.mkdir() + subDirs(dirId)(subDirId) = newDir + newDir + } + } + } + + new File(subDir, filename) + } + + private def createLocalDirs(): Array[File] = { + logDebug("Creating local directories at root dirs '" + rootDirs + "'") + val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + rootDirs.split(",").map { rootDir => + var foundLocalDir = false + var localDir: File = null + var localDirId: String = null + var tries = 0 + val rand = new Random() + while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) { + tries += 1 + try { + localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) + localDir = new File(rootDir, "spark-local-" + localDirId) + if (!localDir.exists) { + foundLocalDir = localDir.mkdirs() + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e) + } + } + if (!foundLocalDir) { + logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + + " attempts to create local dir in " + rootDir) + System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) + } + logInfo("Created local directory at " + localDir) + localDir + } + } + + private def cleanup(cleanupTime: Long) { + blockToFileSegmentMap.clearOldValues(cleanupTime) + } + + private def addShutdownHook() { + localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir)) + Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { + override def run() { + logDebug("Shutdown hook called") + localDirs.foreach { localDir => + try { + if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + } catch { + case t: Throwable => + logError("Exception while deleting local spark dir: " + localDir, t) + } + } + + if (shuffleSender != null) { + shuffleSender.stop() + } + } + }) + } + + private[storage] def startShuffleBlockSender(port: Int): Int = { + shuffleSender = new ShuffleSender(port, this) + logInfo("Created ShuffleSender binding to port : " + shuffleSender.port) + shuffleSender.port + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 63447baf8c..a3c496f9e0 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -17,153 +17,46 @@ package org.apache.spark.storage -import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile} +import java.io.{FileOutputStream, RandomAccessFile} import java.nio.ByteBuffer -import java.nio.channels.FileChannel import java.nio.channels.FileChannel.MapMode -import java.util.{Random, Date} -import java.text.SimpleDateFormat import scala.collection.mutable.ArrayBuffer -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream - -import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.serializer.{Serializer, SerializationStream} import org.apache.spark.Logging -import org.apache.spark.network.netty.ShuffleSender -import org.apache.spark.network.netty.PathResolver +import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils /** * Stores BlockManager blocks on disk. */ -private class DiskStore(blockManager: BlockManager, rootDirs: String) +private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager) extends BlockStore(blockManager) with Logging { - class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int) - extends BlockObjectWriter(blockId) { - - private val f: File = createFile(blockId /*, allowAppendExisting */) - - // The file channel, used for repositioning / truncating the file. - private var channel: FileChannel = null - private var bs: OutputStream = null - private var objOut: SerializationStream = null - private var lastValidPosition = 0L - private var initialized = false - - override def open(): DiskBlockObjectWriter = { - val fos = new FileOutputStream(f, true) - channel = fos.getChannel() - bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize)) - objOut = serializer.newInstance().serializeStream(bs) - initialized = true - this - } - - override def close() { - if (initialized) { - objOut.close() - channel = null - bs = null - objOut = null - } - // Invoke the close callback handler. - super.close() - } - - override def isOpen: Boolean = objOut != null - - // Flush the partial writes, and set valid length to be the length of the entire file. - // Return the number of bytes written for this commit. - override def commit(): Long = { - if (initialized) { - // NOTE: Flush the serializer first and then the compressed/buffered output stream - objOut.flush() - bs.flush() - val prevPos = lastValidPosition - lastValidPosition = channel.position() - lastValidPosition - prevPos - } else { - // lastValidPosition is zero if stream is uninitialized - lastValidPosition - } - } - - override def revertPartialWrites() { - if (initialized) { - // Discard current writes. We do this by flushing the outstanding writes and - // truncate the file to the last valid position. - objOut.flush() - bs.flush() - channel.truncate(lastValidPosition) - } - } - - override def write(value: Any) { - if (!initialized) { - open() - } - objOut.writeObject(value) - } - - override def size(): Long = lastValidPosition - } - - private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt - - private var shuffleSender : ShuffleSender = null - // Create one local directory for each path mentioned in spark.local.dir; then, inside this - // directory, create multiple subdirectories that we will hash files into, in order to avoid - // having really large inodes at the top level. - private val localDirs: Array[File] = createLocalDirs() - private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) - - addShutdownHook() - - def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int) - : BlockObjectWriter = { - new DiskBlockObjectWriter(blockId, serializer, bufferSize) + override def getSize(blockId: BlockId): Long = { + diskManager.getBlockLocation(blockId).length } - override def getSize(blockId: String): Long = { - getFile(blockId).length() - } - - override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) { // So that we do not modify the input offsets ! // duplicate does not copy buffer, so inexpensive val bytes = _bytes.duplicate() logDebug("Attempting to put block " + blockId) val startTime = System.currentTimeMillis - val file = createFile(blockId) - val channel = new RandomAccessFile(file, "rw").getChannel() + val file = diskManager.createBlockFile(blockId, allowAppending = false) + val channel = new FileOutputStream(file).getChannel() while (bytes.remaining > 0) { channel.write(bytes) } channel.close() val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file on disk in %d ms".format( - blockId, Utils.bytesToString(bytes.limit), (finishTime - startTime))) - } - - private def getFileBytes(file: File): ByteBuffer = { - val length = file.length() - val channel = new RandomAccessFile(file, "r").getChannel() - val buffer = try { - channel.map(MapMode.READ_ONLY, 0, length) - } finally { - channel.close() - } - - buffer + file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime))) } override def putValues( - blockId: String, + blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, returnValues: Boolean) @@ -171,159 +64,62 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) logDebug("Attempting to write values for block " + blockId) val startTime = System.currentTimeMillis - val file = createFile(blockId) - val fileOut = blockManager.wrapForCompression(blockId, - new FastBufferedOutputStream(new FileOutputStream(file))) - val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut) - objOut.writeAll(values.iterator) - objOut.close() - val length = file.length() + val file = diskManager.createBlockFile(blockId, allowAppending = false) + val outputStream = new FileOutputStream(file) + blockManager.dataSerializeStream(blockId, outputStream, values.iterator) + val length = file.length val timeTaken = System.currentTimeMillis - startTime logDebug("Block %s stored as %s file on disk in %d ms".format( - blockId, Utils.bytesToString(length), timeTaken)) + file.getName, Utils.bytesToString(length), timeTaken)) if (returnValues) { // Return a byte buffer for the contents of the file - val buffer = getFileBytes(file) + val buffer = getBytes(blockId).get PutResult(length, Right(buffer)) } else { PutResult(length, null) } } - override def getBytes(blockId: String): Option[ByteBuffer] = { - val file = getFile(blockId) - val bytes = getFileBytes(file) - Some(bytes) + override def getBytes(blockId: BlockId): Option[ByteBuffer] = { + val segment = diskManager.getBlockLocation(blockId) + val channel = new RandomAccessFile(segment.file, "r").getChannel() + val buffer = try { + channel.map(MapMode.READ_ONLY, segment.offset, segment.length) + } finally { + channel.close() + } + Some(buffer) } - override def getValues(blockId: String): Option[Iterator[Any]] = { - getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes)) + override def getValues(blockId: BlockId): Option[Iterator[Any]] = { + getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) } /** * A version of getValues that allows a custom serializer. This is used as part of the * shuffle short-circuit code. */ - def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = { + def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer)) } - override def remove(blockId: String): Boolean = { - val file = getFile(blockId) - if (file.exists()) { + override def remove(blockId: BlockId): Boolean = { + val fileSegment = diskManager.getBlockLocation(blockId) + val file = fileSegment.file + if (file.exists() && file.length() == fileSegment.length) { file.delete() } else { - false - } - } - - override def contains(blockId: String): Boolean = { - getFile(blockId).exists() - } - - private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = { - val file = getFile(blockId) - if (!allowAppendExisting && file.exists()) { - // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task - // was rescheduled on the same machine as the old task. - logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting") - file.delete() - } - file - } - - private def getFile(blockId: String): File = { - logDebug("Getting file for block " + blockId) - - // Figure out which local directory it hashes to, and which subdirectory in that - val hash = Utils.nonNegativeHash(blockId) - val dirId = hash % localDirs.length - val subDirId = (hash / localDirs.length) % subDirsPerLocalDir - - // Create the subdirectory if it doesn't already exist - var subDir = subDirs(dirId)(subDirId) - if (subDir == null) { - subDir = subDirs(dirId).synchronized { - val old = subDirs(dirId)(subDirId) - if (old != null) { - old - } else { - val newDir = new File(localDirs(dirId), "%02x".format(subDirId)) - newDir.mkdir() - subDirs(dirId)(subDirId) = newDir - newDir - } - } - } - - new File(subDir, blockId) - } - - private def createLocalDirs(): Array[File] = { - logDebug("Creating local directories at root dirs '" + rootDirs + "'") - val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - rootDirs.split(",").map { rootDir => - var foundLocalDir = false - var localDir: File = null - var localDirId: String = null - var tries = 0 - val rand = new Random() - while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) { - tries += 1 - try { - localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) - localDir = new File(rootDir, "spark-local-" + localDirId) - if (!localDir.exists) { - foundLocalDir = localDir.mkdirs() - } - } catch { - case e: Exception => - logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e) - } + if (fileSegment.length < file.length()) { + logWarning("Could not delete block associated with only a part of a file: " + blockId) } - if (!foundLocalDir) { - logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + - " attempts to create local dir in " + rootDir) - System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) - } - logInfo("Created local directory at " + localDir) - localDir + false } } - private def addShutdownHook() { - localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir)) - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { - override def run() { - logDebug("Shutdown hook called") - localDirs.foreach { localDir => - try { - if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) - } catch { - case t: Throwable => - logError("Exception while deleting local spark dir: " + localDir, t) - } - } - if (shuffleSender != null) { - shuffleSender.stop - } - } - }) - } - - private[storage] def startShuffleBlockSender(port: Int): Int = { - val pResolver = new PathResolver { - override def getAbsolutePath(blockId: String): String = { - if (!blockId.startsWith("shuffle_")) { - return null - } - DiskStore.this.getFile(blockId).getAbsolutePath() - } - } - shuffleSender = new ShuffleSender(port, pResolver) - logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port) - shuffleSender.port + override def contains(blockId: BlockId): Boolean = { + val file = diskManager.getBlockLocation(blockId).file + file.exists() } } diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala new file mode 100644 index 0000000000..555486830a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.File + +/** + * References a particular segment of a file (potentially the entire file), + * based off an offset and a length. + */ +private[spark] class FileSegment(val file: File, val offset: Long, val length : Long) { + override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) +} diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 77a39c71ed..05f676c6e2 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -32,7 +32,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) case class Entry(value: Any, size: Long, deserialized: Boolean) - private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true) + private val entries = new LinkedHashMap[BlockId, Entry](32, 0.75f, true) @volatile private var currentMemory = 0L // Object used to ensure that only one thread is putting blocks and if necessary, dropping // blocks from the memory store. @@ -42,13 +42,13 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) def freeMemory: Long = maxMemory - currentMemory - override def getSize(blockId: String): Long = { + override def getSize(blockId: BlockId): Long = { entries.synchronized { entries.get(blockId).size } } - override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) { + override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) { // Work on a duplicate - since the original input might be used elsewhere. val bytes = _bytes.duplicate() bytes.rewind() @@ -64,7 +64,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } override def putValues( - blockId: String, + blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel, returnValues: Boolean) @@ -81,7 +81,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def getBytes(blockId: String): Option[ByteBuffer] = { + override def getBytes(blockId: BlockId): Option[ByteBuffer] = { val entry = entries.synchronized { entries.get(blockId) } @@ -94,7 +94,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def getValues(blockId: String): Option[Iterator[Any]] = { + override def getValues(blockId: BlockId): Option[Iterator[Any]] = { val entry = entries.synchronized { entries.get(blockId) } @@ -108,7 +108,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def remove(blockId: String): Boolean = { + override def remove(blockId: BlockId): Boolean = { entries.synchronized { val entry = entries.remove(blockId) if (entry != null) { @@ -131,14 +131,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } /** - * Return the RDD ID that a given block ID is from, or null if it is not an RDD block. + * Return the RDD ID that a given block ID is from, or None if it is not an RDD block. */ - private def getRddId(blockId: String): String = { - if (blockId.startsWith("rdd_")) { - blockId.split('_')(1) - } else { - null - } + private def getRddId(blockId: BlockId): Option[Int] = { + blockId.asRDDId.map(_.rddId) } /** @@ -151,7 +147,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * blocks to free memory for one block, another thread may use up the freed space for * another block. */ - private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = { + private def tryToPut(blockId: BlockId, value: Any, size: Long, deserialized: Boolean): Boolean = { // TODO: Its possible to optimize the locking by locking entries only when selecting blocks // to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been // released, it must be ensured that those to-be-dropped blocks are not double counted for @@ -195,7 +191,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Assumes that a lock is held by the caller to ensure only one thread is dropping blocks. * Otherwise, the freed space may fill up before the caller puts in their new value. */ - private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = { + private def ensureFreeSpace(blockIdToAdd: BlockId, space: Long): Boolean = { logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( space, currentMemory, maxMemory)) @@ -207,7 +203,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) if (maxMemory - currentMemory < space) { val rddToAdd = getRddId(blockIdToAdd) - val selectedBlocks = new ArrayBuffer[String]() + val selectedBlocks = new ArrayBuffer[BlockId]() var selectedMemory = 0L // This is synchronized to ensure that the set of entries is not changed @@ -218,7 +214,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) { val pair = iterator.next() val blockId = pair.getKey - if (rddToAdd != null && rddToAdd == getRddId(blockId)) { + if (rddToAdd != None && rddToAdd == getRddId(blockId)) { logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " + "block from the same RDD") return false @@ -252,7 +248,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) return true } - override def contains(blockId: String): Boolean = { + override def contains(blockId: BlockId): Boolean = { entries.synchronized { entries.containsKey(blockId) } } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 9da11efb57..229178c095 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -17,12 +17,13 @@ package org.apache.spark.storage -import org.apache.spark.serializer.Serializer +import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicInteger +import org.apache.spark.serializer.Serializer private[spark] -class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter]) - +class ShuffleWriterGroup(val id: Int, val fileId: Int, val writers: Array[BlockObjectWriter]) private[spark] trait ShuffleBlocks { @@ -30,38 +31,61 @@ trait ShuffleBlocks { def releaseWriters(group: ShuffleWriterGroup) } - +/** + * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one writer + * per reducer. + * + * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle + * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer + * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle files, + * it releases them for another task. + * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple: + * - shuffleId: The unique id given to the entire shuffle stage. + * - bucketId: The id of the output partition (i.e., reducer id) + * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a + * time owns a particular fileId, and this id is returned to a pool when the task finishes. + */ private[spark] class ShuffleBlockManager(blockManager: BlockManager) { + // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. + // TODO: Remove this once the shuffle file consolidation feature is stable. + val consolidateShuffleFiles = + System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean - def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = { + var nextFileId = new AtomicInteger(0) + val unusedFileIds = new ConcurrentLinkedQueue[java.lang.Integer]() + + def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) = { new ShuffleBlocks { // Get a group of writers for a map task. override def acquireWriters(mapId: Int): ShuffleWriterGroup = { val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 + val fileId = getUnusedFileId() val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId => - val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId) - blockManager.getDiskBlockWriter(blockId, serializer, bufferSize) + val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) + val filename = physicalFileName(shuffleId, bucketId, fileId) + blockManager.getDiskWriter(blockId, filename, serializer, bufferSize) } - new ShuffleWriterGroup(mapId, writers) + new ShuffleWriterGroup(mapId, fileId, writers) } - override def releaseWriters(group: ShuffleWriterGroup) = { - // Nothing really to release here. + override def releaseWriters(group: ShuffleWriterGroup) { + recycleFileId(group.fileId) } } } -} - -private[spark] -object ShuffleBlockManager { + private def getUnusedFileId(): Int = { + val fileId = unusedFileIds.poll() + if (fileId == null) nextFileId.getAndIncrement() else fileId + } - // Returns the block id for a given shuffle block. - def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = { - "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId + private def recycleFileId(fileId: Int) { + if (!consolidateShuffleFiles) { return } // ensures we always generate new file id + unusedFileIds.add(fileId) } - // Returns true if the block is a shuffle block. - def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_") + private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { + "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) + } } diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala new file mode 100644 index 0000000000..1b074e5ec7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala @@ -0,0 +1,84 @@ +package org.apache.spark.storage + +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{CountDownLatch, Executors} + +import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.SparkContext +import org.apache.spark.util.Utils + +/** Utility for micro-benchmarking shuffle write performance. + * + * Writes simulated shuffle output from several threads and records the observed throughput*/ +object StoragePerfTester { + def main(args: Array[String]) = { + /** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */ + val dataSizeMb = Utils.memoryStringToMb(sys.env.getOrElse("OUTPUT_DATA", "1g")) + + /** Number of map tasks. All tasks execute concurrently. */ + val numMaps = sys.env.get("NUM_MAPS").map(_.toInt).getOrElse(8) + + /** Number of reduce splits for each map task. */ + val numOutputSplits = sys.env.get("NUM_REDUCERS").map(_.toInt).getOrElse(500) + + val recordLength = 1000 // ~1KB records + val totalRecords = dataSizeMb * 1000 + val recordsPerMap = totalRecords / numMaps + + val writeData = "1" * recordLength + val executor = Executors.newFixedThreadPool(numMaps) + + System.setProperty("spark.shuffle.compress", "false") + System.setProperty("spark.shuffle.sync", "true") + + // This is only used to instantiate a BlockManager. All thread scheduling is done manually. + val sc = new SparkContext("local[4]", "Write Tester") + val blockManager = sc.env.blockManager + + def writeOutputBytes(mapId: Int, total: AtomicLong) = { + val shuffle = blockManager.shuffleBlockManager.forShuffle(1, numOutputSplits, + new KryoSerializer()) + val buckets = shuffle.acquireWriters(mapId) + for (i <- 1 to recordsPerMap) { + buckets.writers(i % numOutputSplits).write(writeData) + } + buckets.writers.map {w => + w.commit() + total.addAndGet(w.fileSegment().length) + w.close() + } + + shuffle.releaseWriters(buckets) + } + + val start = System.currentTimeMillis() + val latch = new CountDownLatch(numMaps) + val totalBytes = new AtomicLong() + for (task <- 1 to numMaps) { + executor.submit(new Runnable() { + override def run() = { + try { + writeOutputBytes(task, totalBytes) + latch.countDown() + } catch { + case e: Exception => + println("Exception in child thread: " + e + " " + e.getMessage) + System.exit(1) + } + } + }) + } + latch.await() + val end = System.currentTimeMillis() + val time = (end - start) / 1000.0 + val bytesPerSecond = totalBytes.get() / time + val bytesPerFile = (totalBytes.get() / (numOutputSplits * numMaps.toDouble)).toLong + + System.err.println("files_total\t\t%s".format(numMaps * numOutputSplits)) + System.err.println("bytes_per_file\t\t%s".format(Utils.bytesToString(bytesPerFile))) + System.err.println("agg_throughput\t\t%s/s".format(Utils.bytesToString(bytesPerSecond.toLong))) + + executor.shutdown() + sc.stop() + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 2bb7715696..1720007e4e 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -23,20 +23,24 @@ import org.apache.spark.util.Utils private[spark] case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long, - blocks: Map[String, BlockStatus]) { + blocks: Map[BlockId, BlockStatus]) { - def memUsed(blockPrefix: String = "") = { - blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize). - reduceOption(_+_).getOrElse(0l) - } + def memUsed() = blocks.values.map(_.memSize).reduceOption(_+_).getOrElse(0L) - def diskUsed(blockPrefix: String = "") = { - blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize). - reduceOption(_+_).getOrElse(0l) - } + def memUsedByRDD(rddId: Int) = + rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_+_).getOrElse(0L) + + def diskUsed() = blocks.values.map(_.diskSize).reduceOption(_+_).getOrElse(0L) + + def diskUsedByRDD(rddId: Int) = + rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_+_).getOrElse(0L) def memRemaining : Long = maxMem - memUsed() + def rddBlocks = blocks.flatMap { + case (rdd: RDDBlockId, status) => Some(rdd, status) + case _ => None + } } case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel, @@ -60,7 +64,7 @@ object StorageUtils { /* Returns RDD-level information, compiled from a list of StorageStatus objects */ def rddInfoFromStorageStatus(storageStatusList: Seq[StorageStatus], sc: SparkContext) : Array[RDDInfo] = { - rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc) + rddInfoFromBlockStatusList(storageStatusList.flatMap(_.rddBlocks).toMap[RDDBlockId, BlockStatus], sc) } /* Returns a map of blocks to their locations, compiled from a list of StorageStatus objects */ @@ -71,26 +75,21 @@ object StorageUtils { } /* Given a list of BlockStatus objets, returns information for each RDD */ - def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus], + def rddInfoFromBlockStatusList(infos: Map[RDDBlockId, BlockStatus], sc: SparkContext) : Array[RDDInfo] = { // Group by rddId, ignore the partition name - val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) => - k.substring(0,k.lastIndexOf('_')) - }.mapValues(_.values.toArray) + val groupedRddBlocks = infos.groupBy { case(k, v) => k.rddId }.mapValues(_.values.toArray) // For each RDD, generate an RDDInfo object - val rddInfos = groupedRddBlocks.map { case (rddKey, rddBlocks) => + val rddInfos = groupedRddBlocks.map { case (rddId, rddBlocks) => // Add up memory and disk sizes val memSize = rddBlocks.map(_.memSize).reduce(_ + _) val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _) - // Find the id of the RDD, e.g. rdd_1 => 1 - val rddId = rddKey.split("_").last.toInt - // Get the friendly name and storage level for the RDD, if available sc.persistentRdds.get(rddId).map { r => - val rddName = Option(r.name).getOrElse(rddKey) + val rddName = Option(r.name).getOrElse(rddId.toString) val rddStorageLevel = r.getStorageLevel RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize) } @@ -101,16 +100,14 @@ object StorageUtils { rddInfos } - /* Removes all BlockStatus object that are not part of a block prefix */ - def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus], - prefix: String) : Array[StorageStatus] = { + /* Filters storage status by a given RDD id. */ + def filterStorageStatusByRDD(storageStatusList: Array[StorageStatus], rddId: Int) + : Array[StorageStatus] = { storageStatusList.map { status => - val newBlocks = status.blocks.filterKeys(_.startsWith(prefix)) + val newBlocks = status.rddBlocks.filterKeys(_.rddId == rddId).toMap[BlockId, BlockStatus] //val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _) StorageStatus(status.blockManagerId, status.maxMem, newBlocks) } - } - } diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index f2ae8dd97d..860e680576 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -36,11 +36,11 @@ private[spark] object ThreadingTest { val numBlocksPerProducer = 20000 private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread { - val queue = new ArrayBlockingQueue[(String, Seq[Int])](100) + val queue = new ArrayBlockingQueue[(BlockId, Seq[Int])](100) override def run() { for (i <- 1 to numBlocksPerProducer) { - val blockId = "b-" + id + "-" + i + val blockId = TestBlockId("b-" + id + "-" + i) val blockSize = Random.nextInt(1000) val block = (1 to blockSize).map(_ => Random.nextInt()) val level = randomLevel() @@ -64,7 +64,7 @@ private[spark] object ThreadingTest { private[spark] class ConsumerThread( manager: BlockManager, - queue: ArrayBlockingQueue[(String, Seq[Int])] + queue: ArrayBlockingQueue[(BlockId, Seq[Int])] ) extends Thread { var numBlockConsumed = 0 diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 163a3746ea..b7c81d091c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -86,7 +86,7 @@ private[spark] class StagePage(parent: JobProgressUI) { Seq("Task ID", "Status", "Locality Level", "Executor", "Launch Time", "Duration") ++ Seq("GC Time") ++ {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ - {if (hasShuffleWrite) Seq("Shuffle Write") else Nil} ++ + {if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++ Seq("Errors") val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite), tasks) @@ -169,6 +169,8 @@ private[spark] class StagePage(parent: JobProgressUI) { Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")}</td> }} {if (shuffleWrite) { + <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => + parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}</td> <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")}</td> }} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 43c1257677..b83cd54f3c 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.storage.{StorageStatus, StorageUtils} +import org.apache.spark.storage.{BlockId, StorageStatus, StorageUtils} import org.apache.spark.storage.BlockManagerMasterActor.BlockStatus import org.apache.spark.ui.UIUtils._ import org.apache.spark.ui.Page._ @@ -33,21 +33,20 @@ private[spark] class RDDPage(parent: BlockManagerUI) { val sc = parent.sc def render(request: HttpServletRequest): Seq[Node] = { - val id = request.getParameter("id") - val prefix = "rdd_" + id.toString + val id = request.getParameter("id").toInt val storageStatusList = sc.getExecutorStorageStatus - val filteredStorageStatusList = StorageUtils. - filterStorageStatusByPrefix(storageStatusList, prefix) + val filteredStorageStatusList = StorageUtils.filterStorageStatusByRDD(storageStatusList, id) val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head val workerHeaders = Seq("Host", "Memory Usage", "Disk Usage") - val workers = filteredStorageStatusList.map((prefix, _)) + val workers = filteredStorageStatusList.map((id, _)) val workerTable = listingTable(workerHeaders, workerRow, workers) val blockHeaders = Seq("Block Name", "Storage Level", "Size in Memory", "Size on Disk", "Executors") - val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1) + val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray. + sortWith(_._1.name < _._1.name) val blockLocations = StorageUtils.blockLocationsFromStorageStatus(filteredStorageStatusList) val blocks = blockStatuses.map { case(id, status) => (id, status, blockLocations.get(id).getOrElse(Seq("UNKNOWN"))) @@ -99,7 +98,7 @@ private[spark] class RDDPage(parent: BlockManagerUI) { headerSparkPage(content, parent.sc, "RDD Storage Info for " + rddInfo.name, Storage) } - def blockRow(row: (String, BlockStatus, Seq[String])): Seq[Node] = { + def blockRow(row: (BlockId, BlockStatus, Seq[String])): Seq[Node] = { val (id, block, locations) = row <tr> <td>{id}</td> @@ -118,15 +117,15 @@ private[spark] class RDDPage(parent: BlockManagerUI) { </tr> } - def workerRow(worker: (String, StorageStatus)): Seq[Node] = { - val (prefix, status) = worker + def workerRow(worker: (Int, StorageStatus)): Seq[Node] = { + val (rddId, status) = worker <tr> <td>{status.blockManagerId.host + ":" + status.blockManagerId.port}</td> <td> - {Utils.bytesToString(status.memUsed(prefix))} + {Utils.bytesToString(status.memUsedByRDD(rddId))} ({Utils.bytesToString(status.memRemaining)} Remaining) </td> - <td>{Utils.bytesToString(status.diskUsed(prefix))}</td> + <td>{Utils.bytesToString(status.diskUsedByRDD(rddId))}</td> </tr> } } diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 0ce1394c77..3f963727d9 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -56,9 +56,10 @@ class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, clea } object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext", "HttpBroadcast", "DagScheduler", "ResultTask", - "ShuffleMapTask", "BlockManager", "BroadcastVars") { + "ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") { - val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, SHUFFLE_MAP_TASK, BLOCK_MANAGER, BROADCAST_VARS = Value + val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, + SHUFFLE_MAP_TASK, BLOCK_MANAGER, DISK_BLOCK_MANAGER, BROADCAST_VARS = Value type MetadataCleanerType = Value diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index f384875cc9..a3b3968c5e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -447,14 +447,17 @@ private[spark] object Utils extends Logging { hostPortParseResults.get(hostPort) } - private[spark] val daemonThreadFactory: ThreadFactory = - new ThreadFactoryBuilder().setDaemon(true).build() + private val daemonThreadFactoryBuilder: ThreadFactoryBuilder = + new ThreadFactoryBuilder().setDaemon(true) /** - * Wrapper over newCachedThreadPool. + * Wrapper over newCachedThreadPool. Thread names are formatted as prefix-ID, where ID is a + * unique, sequentially assigned integer. */ - def newDaemonCachedThreadPool(): ThreadPoolExecutor = - Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] + def newDaemonCachedThreadPool(prefix: String): ThreadPoolExecutor = { + val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build() + Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] + } /** * Return the string to tell how long has passed in seconds. The passing parameter should be in @@ -465,10 +468,13 @@ private[spark] object Utils extends Logging { } /** - * Wrapper over newFixedThreadPool. + * Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a + * unique, sequentially assigned integer. */ - def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor = - Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor] + def newDaemonFixedThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = { + val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build() + Executors.newFixedThreadPool(nThreads, threadFactory).asInstanceOf[ThreadPoolExecutor] + } private def listFilesSafely(file: File): Seq[File] = { val files = file.listFiles() diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index b3a53d928b..e022accee6 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -20,8 +20,42 @@ package org.apache.spark import org.scalatest.FunSuite class BroadcastSuite extends FunSuite with LocalSparkContext { - - test("basic broadcast") { + + override def afterEach() { + super.afterEach() + System.clearProperty("spark.broadcast.factory") + } + + test("Using HttpBroadcast locally") { + System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") + sc = new SparkContext("local", "test") + val list = List(1, 2, 3, 4) + val listBroadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === Set((1, 10), (2, 10))) + } + + test("Accessing HttpBroadcast variables from multiple threads") { + System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") + sc = new SparkContext("local[10]", "test") + val list = List(1, 2, 3, 4) + val listBroadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + } + + test("Accessing HttpBroadcast variables in a local cluster") { + System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") + val numSlaves = 4 + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") + val list = List(1, 2, 3, 4) + val listBroadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + + test("Using TorrentBroadcast locally") { + System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") sc = new SparkContext("local", "test") val list = List(1, 2, 3, 4) val listBroadcast = sc.broadcast(list) @@ -29,11 +63,23 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { assert(results.collect.toSet === Set((1, 10), (2, 10))) } - test("broadcast variables accessed in multiple threads") { + test("Accessing TorrentBroadcast variables from multiple threads") { + System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") sc = new SparkContext("local[10]", "test") val list = List(1, 2, 3, 4) val listBroadcast = sc.broadcast(list) val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) } + + test("Accessing TorrentBroadcast variables in a local cluster") { + System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") + val numSlaves = 4 + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") + val list = List(1, 2, 3, 4) + val listBroadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 3a7171c488..ea936e815b 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -23,7 +23,7 @@ import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.mock.EasyMockSugar import org.apache.spark.rdd.RDD -import org.apache.spark.storage.{BlockManager, StorageLevel} +import org.apache.spark.storage.{BlockManager, RDDBlockId, StorageLevel} // TODO: Test the CacheManager's thread-safety aspects class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar { @@ -52,13 +52,14 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar test("get uncached rdd") { expecting { - blockManager.get("rdd_0_0").andReturn(None) - blockManager.put("rdd_0_0", ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY, true). - andReturn(0) + blockManager.get(RDDBlockId(0, 0)).andReturn(None) + blockManager.put(RDDBlockId(0, 0), ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY, + true).andReturn(0) } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, runningLocally = false, null) + val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false, + taskMetrics = null) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } @@ -66,11 +67,12 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar test("get cached rdd") { expecting { - blockManager.get("rdd_0_0").andReturn(Some(ArrayBuffer(5, 6, 7).iterator)) + blockManager.get(RDDBlockId(0, 0)).andReturn(Some(ArrayBuffer(5, 6, 7).iterator)) } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, runningLocally = false, null) + val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false, + taskMetrics = null) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -79,11 +81,12 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar test("get uncached local rdd") { expecting { // Local computation should not persist the resulting value, so don't expect a put(). - blockManager.get("rdd_0_0").andReturn(None) + blockManager.get(RDDBlockId(0, 0)).andReturn(None) } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, runningLocally = true, null) + val context = new TaskContext(0, 0, 0, runningLocally = true, interrupted = false, + taskMetrics = null) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index d9103aebb7..f26c44d3e7 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.FunSuite import java.io.File import org.apache.spark.rdd._ import org.apache.spark.SparkContext._ -import storage.StorageLevel +import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { @@ -62,8 +62,8 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { testCheckpointing(_.sample(false, 0.5, 0)) testCheckpointing(_.glom()) testCheckpointing(_.mapPartitions(_.map(_.toString))) - testCheckpointing(r => new MapPartitionsWithIndexRDD(r, - (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false )) + testCheckpointing(r => new MapPartitionsWithContextRDD(r, + (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false )) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) testCheckpointing(_.pipe(Seq("cat"))) @@ -83,7 +83,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } test("BlockRDD") { - val blockId = "id" + val blockId = TestBlockId("id") val blockManager = SparkEnv.get.blockManager blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY) val blockRDD = new BlockRDD[String](sc, Array(blockId)) @@ -191,7 +191,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { } test("CheckpointRDD with zero partitions") { - val rdd = new BlockRDD[Int](sc, Array[String]()) + val rdd = new BlockRDD[Int](sc, Array[BlockId]()) assert(rdd.partitions.size === 0) assert(rdd.isCheckpointed === false) rdd.checkpoint() diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index cd2bf9a8ff..480bac84f3 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -18,24 +18,14 @@ package org.apache.spark import network.ConnectionManagerId -import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Timeouts._ +import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers -import org.scalatest.prop.Checkers import org.scalatest.time.{Span, Millis} -import org.scalacheck.Arbitrary._ -import org.scalacheck.Gen -import org.scalacheck.Prop._ -import org.eclipse.jetty.server.{Server, Request, Handler} - -import com.google.common.io.Files - -import scala.collection.mutable.ArrayBuffer import SparkContext._ -import storage.{GetBlock, BlockManagerWorker, StorageLevel} -import ui.JettyUtils +import org.apache.spark.storage.{BlockManagerWorker, GetBlock, RDDBlockId, StorageLevel} class NotSerializableClass @@ -193,7 +183,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter // Get all the locations of the first partition and try to fetch the partitions // from those locations. - val blockIds = data.partitions.indices.map(index => "rdd_%d_%d".format(data.id, index)).toArray + val blockIds = data.partitions.indices.map(index => RDDBlockId(data.id, index)).toArray val blockId = blockIds(0) val blockManager = SparkEnv.get.blockManager blockManager.master.getLocations(blockId).foreach(id => { diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 591c1d498d..7b0bb89ab2 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable { @Test public void iterator() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0, false, null); + TaskContext context = new TaskContext(0, 0, 0, false, false, null); Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue()); } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala new file mode 100644 index 0000000000..d8a0e983b2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.concurrent.Semaphore + +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import scala.concurrent.future +import scala.concurrent.ExecutionContext.Implicits.global + +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.matchers.ShouldMatchers + +import org.apache.spark.SparkContext._ +import org.apache.spark.scheduler.{SparkListenerTaskStart, SparkListener} + + +/** + * Test suite for cancelling running jobs. We run the cancellation tasks for single job action + * (e.g. count) as well as multi-job action (e.g. take). We test the local and cluster schedulers + * in both FIFO and fair scheduling modes. + */ +class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAfter + with LocalSparkContext { + + override def afterEach() { + super.afterEach() + resetSparkContext() + System.clearProperty("spark.scheduler.mode") + } + + test("local mode, FIFO scheduler") { + System.setProperty("spark.scheduler.mode", "FIFO") + sc = new SparkContext("local[2]", "test") + testCount() + testTake() + // Make sure we can still launch tasks. + assert(sc.parallelize(1 to 10, 2).count === 10) + } + + test("local mode, fair scheduler") { + System.setProperty("spark.scheduler.mode", "FAIR") + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + System.setProperty("spark.scheduler.allocation.file", xmlPath) + sc = new SparkContext("local[2]", "test") + testCount() + testTake() + // Make sure we can still launch tasks. + assert(sc.parallelize(1 to 10, 2).count === 10) + } + + test("cluster mode, FIFO scheduler") { + System.setProperty("spark.scheduler.mode", "FIFO") + sc = new SparkContext("local-cluster[2,1,512]", "test") + testCount() + testTake() + // Make sure we can still launch tasks. + assert(sc.parallelize(1 to 10, 2).count === 10) + } + + test("cluster mode, fair scheduler") { + System.setProperty("spark.scheduler.mode", "FAIR") + val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() + System.setProperty("spark.scheduler.allocation.file", xmlPath) + sc = new SparkContext("local-cluster[2,1,512]", "test") + testCount() + testTake() + // Make sure we can still launch tasks. + assert(sc.parallelize(1 to 10, 2).count === 10) + } + + test("job group") { + sc = new SparkContext("local[2]", "test") + + // Add a listener to release the semaphore once any tasks are launched. + val sem = new Semaphore(0) + sc.dagScheduler.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem.release() + } + }) + + // jobA is the one to be cancelled. + val jobA = future { + sc.setJobGroup("jobA", "this is a job to be cancelled") + sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count() + } + + sc.clearJobGroup() + val jobB = sc.parallelize(1 to 100, 2).countAsync() + + // Block until both tasks of job A have started and cancel job A. + sem.acquire(2) + sc.cancelJobGroup("jobA") + val e = intercept[SparkException] { Await.result(jobA, Duration.Inf) } + assert(e.getMessage contains "cancel") + + // Once A is cancelled, job B should finish fairly quickly. + assert(jobB.get() === 100) + } + + test("two jobs sharing the same stage") { + // sem1: make sure cancel is issued after some tasks are launched + // sem2: make sure the first stage is not finished until cancel is issued + val sem1 = new Semaphore(0) + val sem2 = new Semaphore(0) + + sc = new SparkContext("local[2]", "test") + sc.dagScheduler.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem1.release() + } + }) + + // Create two actions that would share the some stages. + val rdd = sc.parallelize(1 to 10, 2).map { i => + sem2.acquire() + (i, i) + }.reduceByKey(_+_) + val f1 = rdd.collectAsync() + val f2 = rdd.countAsync() + + // Kill one of the action. + future { + sem1.acquire() + f1.cancel() + sem2.release(10) + } + + // Expect both to fail now. + // TODO: update this test when we change Spark so cancelling f1 wouldn't affect f2. + intercept[SparkException] { f1.get() } + intercept[SparkException] { f2.get() } + } + + def testCount() { + // Cancel before launching any tasks + { + val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync() + future { f.cancel() } + val e = intercept[SparkException] { f.get() } + assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) + } + + // Cancel after some tasks have been launched + { + // Add a listener to release the semaphore once any tasks are launched. + val sem = new Semaphore(0) + sc.dagScheduler.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem.release() + } + }) + + val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync() + future { + // Wait until some tasks were launched before we cancel the job. + sem.acquire() + f.cancel() + } + val e = intercept[SparkException] { f.get() } + assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) + } + } + + def testTake() { + // Cancel before launching any tasks + { + val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000) + future { f.cancel() } + val e = intercept[SparkException] { f.get() } + assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) + } + + // Cancel after some tasks have been launched + { + // Add a listener to release the semaphore once any tasks are launched. + val sem = new Semaphore(0) + sc.dagScheduler.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem.release() + } + }) + val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000) + future { + sem.acquire() + f.cancel() + } + val e = intercept[SparkException] { f.get() } + assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala new file mode 100644 index 0000000000..da032b17d9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.util.concurrent.Semaphore + +import scala.concurrent.ExecutionContext.Implicits.global + +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkContext._ +import org.apache.spark.{SparkContext, SparkException, LocalSparkContext} + + +class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll with Timeouts { + + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local[2]", "test") + } + + override def afterAll() { + LocalSparkContext.stop(sc) + sc = null + } + + lazy val zeroPartRdd = new EmptyRDD[Int](sc) + + test("countAsync") { + assert(zeroPartRdd.countAsync().get() === 0) + assert(sc.parallelize(1 to 10000, 5).countAsync().get() === 10000) + } + + test("collectAsync") { + assert(zeroPartRdd.collectAsync().get() === Seq.empty) + + val collected = sc.parallelize(1 to 1000, 3).collectAsync().get() + assert(collected === (1 to 1000)) + } + + test("foreachAsync") { + zeroPartRdd.foreachAsync(i => Unit).get() + + val accum = sc.accumulator(0) + sc.parallelize(1 to 1000, 3).foreachAsync { i => + accum += 1 + }.get() + assert(accum.value === 1000) + } + + test("foreachPartitionAsync") { + zeroPartRdd.foreachPartitionAsync(iter => Unit).get() + + val accum = sc.accumulator(0) + sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter => + accum += 1 + }.get() + assert(accum.value === 9) + } + + test("takeAsync") { + def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) { + val expected = input.take(num) + val saw = rdd.takeAsync(num).get() + assert(saw == expected, "incorrect result for rdd with %d partitions (expected %s, saw %s)" + .format(rdd.partitions.size, expected, saw)) + } + val input = Range(1, 1000) + + var rdd = sc.parallelize(input, 1) + for (num <- Seq(0, 1, 999, 1000)) { + testTake(rdd, input, num) + } + + rdd = sc.parallelize(input, 2) + for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) { + testTake(rdd, input, num) + } + + rdd = sc.parallelize(input, 100) + for (num <- Seq(0, 1, 500, 501, 999, 1000)) { + testTake(rdd, input, num) + } + + rdd = sc.parallelize(input, 1000) + for (num <- Seq(0, 1, 3, 999, 1000)) { + testTake(rdd, input, num) + } + } + + /** + * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case + * of a successful job execution. + */ + test("async success handling") { + val f = sc.parallelize(1 to 10, 2).countAsync() + + // Use a semaphore to make sure onSuccess and onComplete's success path will be called. + // If not, the test will hang. + val sem = new Semaphore(0) + + f.onComplete { + case scala.util.Success(res) => + sem.release() + case scala.util.Failure(e) => + info("Should not have reached this code path (onComplete matching Failure)") + throw new Exception("Task should succeed") + } + f.onSuccess { case a: Any => + sem.release() + } + f.onFailure { case t => + info("Should not have reached this code path (onFailure)") + throw new Exception("Task should succeed") + } + assert(f.get() === 10) + + failAfter(10 seconds) { + sem.acquire(2) + } + } + + /** + * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case + * of a failed job execution. + */ + test("async failure handling") { + val f = sc.parallelize(1 to 10, 2).map { i => + throw new Exception("intentional"); i + }.countAsync() + + // Use a semaphore to make sure onFailure and onComplete's failure path will be called. + // If not, the test will hang. + val sem = new Semaphore(0) + + f.onComplete { + case scala.util.Success(res) => + info("Should not have reached this code path (onComplete matching Success)") + throw new Exception("Task should fail") + case scala.util.Failure(e) => + sem.release() + } + f.onSuccess { case a: Any => + info("Should not have reached this code path (onSuccess)") + throw new Exception("Task should fail") + } + f.onFailure { case t => + sem.release() + } + intercept[SparkException] { + f.get() + } + + failAfter(10 seconds) { + sem.acquire(2) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 31f97fc139..57d3382ed0 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -106,7 +106,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { } } visit(sums) - assert(deps.size === 2) // ShuffledRDD, ParallelCollection + assert(deps.size === 2) // ShuffledRDD, ParallelCollection. } test("join") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 2f933246b0..2a2f828be6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -24,15 +24,14 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.LocalSparkContext import org.apache.spark.MapOutputTracker -import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext import org.apache.spark.Partition import org.apache.spark.TaskContext import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency} import org.apache.spark.{FetchFailed, Success, TaskEndReason} -import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster} - +import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode +import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} /** * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler @@ -60,7 +59,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) taskSets += taskSet } - override def setListener(listener: TaskSchedulerListener) = {} + override def cancelTasks(stageId: Int) {} + override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 } @@ -75,15 +75,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] // stub out BlockManagerMaster.getLocations to use our cacheLocations val blockManagerMaster = new BlockManagerMaster(null) { - override def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = { - blockIds.map { name => - val pieces = name.split("_") - if (pieces(0) == "rdd") { - val key = pieces(1).toInt -> pieces(2).toInt - cacheLocations.getOrElse(key, Seq()) - } else { - Seq() - } + override def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { + blockIds.map { + _.asRDDId.map(id => (id.rddId -> id.splitIndex)).flatMap(key => cacheLocations.get(key)). + getOrElse(Seq()) }.toSeq } override def removeExecutor(execId: String) { @@ -186,7 +181,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, allowLocal: Boolean = false, listener: JobListener = listener) { - runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener)) + val jobId = scheduler.nextJobId.getAndIncrement() + runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, null, listener)) } /** Sends TaskSetFailed to the scheduler. */ @@ -220,7 +216,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont override def getPreferredLocations(split: Partition) = Nil override def toString = "DAGSchedulerSuite Local RDD" } - runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener)) + val jobId = scheduler.nextJobId.getAndIncrement() + runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, listener)) assert(results === Map(0 -> 42)) } @@ -247,7 +244,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont test("trivial job failure") { submit(makeRdd(1, Nil), Array(0)) failed(taskSets(0), "some failure") - assert(failure.getMessage === "Job failed: some failure") + assert(failure.getMessage === "Job aborted: some failure") } test("run trivial shuffle") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala index 80d0c5a5e9..b97f2b19b5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala @@ -28,6 +28,30 @@ import org.apache.spark.executor.TaskMetrics import java.nio.ByteBuffer import org.apache.spark.util.{Utils, FakeClock} +class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler(taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) { + taskScheduler.startedTasks += taskInfo.index + } + + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: mutable.Map[Long, Any], + taskInfo: TaskInfo, + taskMetrics: TaskMetrics) { + taskScheduler.endedTasks(taskInfo.index) = reason + } + + override def executorGained(execId: String, host: String) {} + + override def executorLost(execId: String) {} + + override def taskSetFailed(taskSet: TaskSet, reason: String) { + taskScheduler.taskSetsFailed += taskSet.id + } +} + /** * A mock ClusterScheduler implementation that just remembers information about tasks started and * feedback received from the TaskSetManagers. Note that it's important to initialize this with @@ -44,30 +68,7 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* val executors = new mutable.HashMap[String, String] ++ liveExecutors - listener = new TaskSchedulerListener { - def taskStarted(task: Task[_], taskInfo: TaskInfo) { - startedTasks += taskInfo.index - } - - def taskEnded( - task: Task[_], - reason: TaskEndReason, - result: Any, - accumUpdates: mutable.Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics) - { - endedTasks(taskInfo.index) = reason - } - - def executorGained(execId: String, host: String) {} - - def executorLost(execId: String) {} - - def taskSetFailed(taskSet: TaskSet, reason: String) { - taskSetsFailed += taskSet.id - } - } + dagScheduler = new FakeDAGScheduler(this) def removeExecutor(execId: String): Unit = executors -= execId diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala index 2f12aaed18..0f01515179 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala @@ -17,10 +17,11 @@ package org.apache.spark.scheduler.cluster +import org.apache.spark.TaskContext import org.apache.spark.scheduler.{TaskLocation, Task} -class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId) { - override def run(attemptId: Long): Int = 0 +class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) { + override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala index 119ba30090..ee150a3107 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala @@ -23,6 +23,7 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv} import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult} +import org.apache.spark.storage.TaskResultBlockId /** * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter. @@ -85,7 +86,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) assert(result === 1.to(akkaFrameSize).toArray) - val RESULT_BLOCK_ID = "taskresult_0" + val RESULT_BLOCK_ID = TaskResultBlockId(0) assert(sc.env.blockManager.master.getLocations(RESULT_BLOCK_ID).size === 0, "Expect result to be removed from the block manager.") } diff --git a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala index af76c843e8..1e676c1719 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.scheduler.local -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - -import org.apache.spark._ -import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.{ConcurrentMap, HashMap} import java.util.concurrent.Semaphore import java.util.concurrent.CountDownLatch -import java.util.Properties + +import scala.collection.mutable.HashMap + +import org.scalatest.{BeforeAndAfterEach, FunSuite} + +import org.apache.spark._ + class Lock() { var finished = false @@ -63,7 +61,12 @@ object TaskThreadInfo { * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue, * thus it will be scheduled later when cluster has free cpu cores. */ -class LocalSchedulerSuite extends FunSuite with LocalSparkContext { +class LocalSchedulerSuite extends FunSuite with LocalSparkContext with BeforeAndAfterEach { + + override def afterEach() { + super.afterEach() + System.clearProperty("spark.scheduler.mode") + } def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) { @@ -148,12 +151,13 @@ class LocalSchedulerSuite extends FunSuite with LocalSparkContext { } test("Local fair scheduler end-to-end test") { - sc = new SparkContext("local[8]", "LocalSchedulerSuite") - val sem = new Semaphore(0) System.setProperty("spark.scheduler.mode", "FAIR") val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() System.setProperty("spark.scheduler.allocation.file", xmlPath) + sc = new SparkContext("local[8]", "LocalSchedulerSuite") + val sem = new Semaphore(0) + createThread(10,"1",sc,sem) TaskThreadInfo.threadToStarted(10).await() createThread(20,"2",sc,sem) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala new file mode 100644 index 0000000000..cb76275e39 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.scalatest.FunSuite + +class BlockIdSuite extends FunSuite { + def assertSame(id1: BlockId, id2: BlockId) { + assert(id1.name === id2.name) + assert(id1.hashCode === id2.hashCode) + assert(id1 === id2) + } + + def assertDifferent(id1: BlockId, id2: BlockId) { + assert(id1.name != id2.name) + assert(id1.hashCode != id2.hashCode) + assert(id1 != id2) + } + + test("test-bad-deserialization") { + try { + // Try to deserialize an invalid block id. + BlockId("myblock") + fail() + } catch { + case e: IllegalStateException => // OK + case _ => fail() + } + } + + test("rdd") { + val id = RDDBlockId(1, 2) + assertSame(id, RDDBlockId(1, 2)) + assertDifferent(id, RDDBlockId(1, 1)) + assert(id.name === "rdd_1_2") + assert(id.asRDDId.get.rddId === 1) + assert(id.asRDDId.get.splitIndex === 2) + assert(id.isRDD) + assertSame(id, BlockId(id.toString)) + } + + test("shuffle") { + val id = ShuffleBlockId(1, 2, 3) + assertSame(id, ShuffleBlockId(1, 2, 3)) + assertDifferent(id, ShuffleBlockId(3, 2, 3)) + assert(id.name === "shuffle_1_2_3") + assert(id.asRDDId === None) + assert(id.shuffleId === 1) + assert(id.mapId === 2) + assert(id.reduceId === 3) + assert(id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + + test("broadcast") { + val id = BroadcastBlockId(42) + assertSame(id, BroadcastBlockId(42)) + assertDifferent(id, BroadcastBlockId(123)) + assert(id.name === "broadcast_42") + assert(id.asRDDId === None) + assert(id.broadcastId === 42) + assert(id.isBroadcast) + assertSame(id, BlockId(id.toString)) + } + + test("taskresult") { + val id = TaskResultBlockId(60) + assertSame(id, TaskResultBlockId(60)) + assertDifferent(id, TaskResultBlockId(61)) + assert(id.name === "taskresult_60") + assert(id.asRDDId === None) + assert(id.taskId === 60) + assert(!id.isRDD) + assertSame(id, BlockId(id.toString)) + } + + test("stream") { + val id = StreamBlockId(1, 100) + assertSame(id, StreamBlockId(1, 100)) + assertDifferent(id, StreamBlockId(2, 101)) + assert(id.name === "input-1-100") + assert(id.asRDDId === None) + assert(id.streamId === 1) + assert(id.uniqueId === 100) + assert(!id.isBroadcast) + assertSame(id, BlockId(id.toString)) + } + + test("test") { + val id = TestBlockId("abc") + assertSame(id, TestBlockId("abc")) + assertDifferent(id, TestBlockId("ab")) + assert(id.name === "test_abc") + assert(id.asRDDId === None) + assert(id.id === "abc") + assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 038a9acb85..484a654108 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -32,7 +32,6 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.util.{SizeEstimator, Utils, AkkaUtils, ByteBufferInputStream} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} - class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester { var store: BlockManager = null var store2: BlockManager = null @@ -46,6 +45,10 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT System.setProperty("spark.kryoserializer.buffer.mb", "1") val serializer = new KryoSerializer + // Implicitly convert strings to BlockIds for test clarity. + implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) + before { val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0) this.actorSystem = actorSystem @@ -229,31 +232,31 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) // Putting a1, a2 and a3 in memory. - store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY) store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY) master.removeRdd(0, blocking = false) eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("rdd_0_0") should be (None) - master.getLocations("rdd_0_0") should have size 0 + store.getSingle(rdd(0, 0)) should be (None) + master.getLocations(rdd(0, 0)) should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - store.getSingle("rdd_0_1") should be (None) - master.getLocations("rdd_0_1") should have size 0 + store.getSingle(rdd(0, 1)) should be (None) + master.getLocations(rdd(0, 1)) should have size 0 } eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { store.getSingle("nonrddblock") should not be (None) master.getLocations("nonrddblock") should have size (1) } - store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY) master.removeRdd(0, blocking = true) - store.getSingle("rdd_0_0") should be (None) - master.getLocations("rdd_0_0") should have size 0 - store.getSingle("rdd_0_1") should be (None) - master.getLocations("rdd_0_1") should have size 0 + store.getSingle(rdd(0, 0)) should be (None) + master.getLocations(rdd(0, 0)) should have size 0 + store.getSingle(rdd(0, 1)) should be (None) + master.getLocations(rdd(0, 1)) should have size 0 } test("reregistration on heart beat") { @@ -372,41 +375,41 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) - store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 1), a1, StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 2), a2, StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 3), a3, StorageLevel.MEMORY_ONLY) // Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2 // from the same RDD - assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") - assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store") - assert(store.getSingle("rdd_0_1") != None, "rdd_0_1 was not in store") + assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingle(rdd(0, 2)) != None, "rdd_0_2 was not in store") + assert(store.getSingle(rdd(0, 1)) != None, "rdd_0_1 was not in store") // Check that rdd_0_3 doesn't replace them even after further accesses - assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") - assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") - assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store") + assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") + assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store") } test("in-memory LRU for partitions of multiple RDDs") { store = new BlockManager("<driver>", actorSystem, master, serializer, 1200) - store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) // At this point rdd_1_1 should've replaced rdd_0_1 - assert(store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was not in store") - assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store") - assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store") + assert(store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was not in store") + assert(!store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was in store") + assert(store.memoryStore.contains(rdd(0, 2)), "rdd_0_2 was not in store") // Do a get() on rdd_0_2 so that it is the most recently used item - assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store") + assert(store.getSingle(rdd(0, 2)) != None, "rdd_0_2 was not in store") // Put in more partitions from RDD 0; they should replace rdd_1_1 - store.putSingle("rdd_0_3", new Array[Byte](400), StorageLevel.MEMORY_ONLY) - store.putSingle("rdd_0_4", new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 3), new Array[Byte](400), StorageLevel.MEMORY_ONLY) + store.putSingle(rdd(0, 4), new Array[Byte](400), StorageLevel.MEMORY_ONLY) // Now rdd_1_1 should be dropped to add rdd_0_3, but then rdd_0_2 should *not* be dropped // when we try to add rdd_0_4. - assert(!store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was in store") - assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store") - assert(!store.memoryStore.contains("rdd_0_4"), "rdd_0_4 was in store") - assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store") - assert(store.memoryStore.contains("rdd_0_3"), "rdd_0_3 was not in store") + assert(!store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was in store") + assert(!store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was in store") + assert(!store.memoryStore.contains(rdd(0, 4)), "rdd_0_4 was in store") + assert(store.memoryStore.contains(rdd(0, 2)), "rdd_0_2 was not in store") + assert(store.memoryStore.contains(rdd(0, 3)), "rdd_0_3 was not in store") } test("on-disk storage") { @@ -590,43 +593,46 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT try { System.setProperty("spark.shuffle.compress", "true") store = new BlockManager("exec1", actorSystem, master, serializer, 2000) - store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed") + store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100, + "shuffle_0_0_0 was not compressed") store.stop() store = null System.setProperty("spark.shuffle.compress", "false") store = new BlockManager("exec2", actorSystem, master, serializer, 2000) - store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed") + store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000, + "shuffle_0_0_0 was compressed") store.stop() store = null System.setProperty("spark.broadcast.compress", "true") store = new BlockManager("exec3", actorSystem, master, serializer, 2000) - store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed") + store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100, + "broadcast_0 was not compressed") store.stop() store = null System.setProperty("spark.broadcast.compress", "false") store = new BlockManager("exec4", actorSystem, master, serializer, 2000) - store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed") + store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed") store.stop() store = null System.setProperty("spark.rdd.compress", "true") store = new BlockManager("exec5", actorSystem, master, serializer, 2000) - store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed") + store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed") store.stop() store = null System.setProperty("spark.rdd.compress", "false") store = new BlockManager("exec6", actorSystem, master, serializer, 2000) - store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) - assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed") + store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) + assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed") store.stop() store = null diff --git a/docs/configuration.md b/docs/configuration.md index 7940d41a27..97183bafdb 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -149,7 +149,7 @@ Apart from these, the following properties are also available, and may be useful <td>spark.io.compression.codec</td> <td>org.apache.spark.io.<br />LZFCompressionCodec</td> <td> - The compression codec class to use for various compressions. By default, Spark provides two + The codec used to compress internal data such as RDD partitions and shuffle outputs. By default, Spark provides two codecs: <code>org.apache.spark.io.LZFCompressionCodec</code> and <code>org.apache.spark.io.SnappyCompressionCodec</code>. </td> </tr> @@ -319,6 +319,14 @@ Apart from these, the following properties are also available, and may be useful Should be greater than or equal to 1. Number of allowed retries = this value - 1. </td> </tr> +<tr> + <td>spark.broadcast.blockSize</td> + <td>4096</td> + <td> + Size of each piece of a block in kilobytes for <code>TorrentBroadcastFactory</code>. + Too large a value decreases parallelism during broadcast (makes it slower); however, if it is too small, <code>BlockManager</code> might take a performance hit. + </td> +</tr> </table> diff --git a/examples/pom.xml b/examples/pom.xml index b97e6af288..aee371fbc7 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -153,6 +153,14 @@ <groupId>org.apache.cassandra.deps</groupId> <artifactId>avro</artifactId> </exclusion> + <exclusion> + <groupId>org.sonatype.sisu.inject</groupId> + <artifactId>*</artifactId> + </exclusion> + <exclusion> + <groupId>org.xerial.snappy</groupId> + <artifactId>*</artifactId> + </exclusion> </exclusions> </dependency> </dependencies> diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 868ff81f67..529709c2f9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -22,12 +22,19 @@ import org.apache.spark.SparkContext object BroadcastTest { def main(args: Array[String]) { if (args.length == 0) { - System.err.println("Usage: BroadcastTest <master> [<slices>] [numElem]") + System.err.println("Usage: BroadcastTest <master> [slices] [numElem] [broadcastAlgo] [blockSize]") System.exit(1) } - val sc = new SparkContext(args(0), "Broadcast Test", + val bcName = if (args.length > 3) args(3) else "Http" + val blockSize = if (args.length > 4) args(4) else "4096" + + System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast." + bcName + "BroadcastFactory") + System.setProperty("spark.broadcast.blockSize", blockSize) + + val sc = new SparkContext(args(0), "Broadcast Test 2", System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR"))) + val slices = if (args.length > 1) args(1).toInt else 2 val num = if (args.length > 2) args(2).toInt else 1000000 @@ -36,13 +43,15 @@ object BroadcastTest { arr1(i) = i } - for (i <- 0 until 2) { + for (i <- 0 until 3) { println("Iteration " + i) println("===========") + val startTime = System.nanoTime val barr1 = sc.broadcast(arr1) sc.parallelize(1 to 10, slices).foreach { i => println(barr1.value.size) } + println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6)) } System.exit(0) diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala index 884d6d6f34..de70c50473 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala @@ -17,17 +17,19 @@ package org.apache.spark.streaming.examples.clickstream -import java.net.{InetAddress,ServerSocket,Socket,SocketException} -import java.io.{InputStreamReader, BufferedReader, PrintWriter} +import java.net.ServerSocket +import java.io.PrintWriter import util.Random /** Represents a page view on a website with associated dimension data.*/ -class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int) { +class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int) + extends Serializable { override def toString() : String = { "%s\t%s\t%s\t%s\n".format(url, status, zipCode, userID) } } -object PageView { + +object PageView extends Serializable { def fromString(in : String) : PageView = { val parts = in.split("\t") new PageView(parts(0), parts(1).toInt, parts(2).toInt, parts(3).toInt) @@ -39,6 +41,9 @@ object PageView { * This should be used in tandem with PageViewStream.scala. Example: * $ ./run-example spark.streaming.examples.clickstream.PageViewGenerator 44444 10 * $ ./run-example spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444 + * + * When running this, you may want to set the root logging level to ERROR in + * conf/log4j.properties to reduce the verbosity of the output. * */ object PageViewGenerator { val pages = Map("http://foo.com/" -> .7, @@ -21,7 +21,7 @@ <parent> <groupId>org.apache</groupId> <artifactId>apache</artifactId> - <version>11</version> + <version>13</version> </parent> <groupId>org.apache.spark</groupId> <artifactId>spark-parent</artifactId> @@ -61,6 +61,29 @@ <maven>3.0.0</maven> </prerequisites> + <mailingLists> + <mailingList> + <name>Dev Mailing List</name> + <post>dev@spark.incubator.apache.org</post> + <subscribe>dev-subscribe@spark.incubator.apache.org</subscribe> + <unsubscribe>dev-unsubscribe@spark.incubator.apache.org</unsubscribe> + </mailingList> + + <mailingList> + <name>User Mailing List</name> + <post>user@spark.incubator.apache.org</post> + <subscribe>user-subscribe@spark.incubator.apache.org</subscribe> + <unsubscribe>user-unsubscribe@spark.incubator.apache.org</unsubscribe> + </mailingList> + + <mailingList> + <name>Commits Mailing List</name> + <post>commits@spark.incubator.apache.org</post> + <subscribe>commits-subscribe@spark.incubator.apache.org</subscribe> + <unsubscribe>commits-unsubscribe@spark.incubator.apache.org</unsubscribe> + </mailingList> + </mailingLists> + <modules> <module>core</module> <module>bagel</module> @@ -227,16 +250,34 @@ <groupId>com.typesafe.akka</groupId> <artifactId>akka-actor</artifactId> <version>${akka.version}</version> + <exclusions> + <exclusion> + <groupId>org.jboss.netty</groupId> + <artifactId>netty</artifactId> + </exclusion> + </exclusions> </dependency> <dependency> <groupId>com.typesafe.akka</groupId> <artifactId>akka-remote</artifactId> <version>${akka.version}</version> + <exclusions> + <exclusion> + <groupId>org.jboss.netty</groupId> + <artifactId>netty</artifactId> + </exclusion> + </exclusions> </dependency> <dependency> <groupId>com.typesafe.akka</groupId> <artifactId>akka-slf4j</artifactId> <version>${akka.version}</version> + <exclusions> + <exclusion> + <groupId>org.jboss.netty</groupId> + <artifactId>netty</artifactId> + </exclusion> + </exclusions> </dependency> <dependency> <groupId>it.unimi.dsi</groupId> @@ -371,19 +412,11 @@ </exclusion> <exclusion> <groupId>org.codehaus.jackson</groupId> - <artifactId>jackson-core-asl</artifactId> - </exclusion> - <exclusion> - <groupId>org.codehaus.jackson</groupId> - <artifactId>jackson-mapper-asl</artifactId> + <artifactId>*</artifactId> </exclusion> <exclusion> - <groupId>org.codehaus.jackson</groupId> - <artifactId>jackson-jaxrs</artifactId> - </exclusion> - <exclusion> - <groupId>org.codehaus.jackson</groupId> - <artifactId>jackson-xc</artifactId> + <groupId>org.sonatype.sisu.inject</groupId> + <artifactId>*</artifactId> </exclusion> </exclusions> </dependency> @@ -407,19 +440,11 @@ </exclusion> <exclusion> <groupId>org.codehaus.jackson</groupId> - <artifactId>jackson-core-asl</artifactId> - </exclusion> - <exclusion> - <groupId>org.codehaus.jackson</groupId> - <artifactId>jackson-mapper-asl</artifactId> + <artifactId>*</artifactId> </exclusion> <exclusion> - <groupId>org.codehaus.jackson</groupId> - <artifactId>jackson-jaxrs</artifactId> - </exclusion> - <exclusion> - <groupId>org.codehaus.jackson</groupId> - <artifactId>jackson-xc</artifactId> + <groupId>org.sonatype.sisu.inject</groupId> + <artifactId>*</artifactId> </exclusion> </exclusions> </dependency> @@ -438,19 +463,11 @@ </exclusion> <exclusion> <groupId>org.codehaus.jackson</groupId> - <artifactId>jackson-core-asl</artifactId> - </exclusion> - <exclusion> - <groupId>org.codehaus.jackson</groupId> - <artifactId>jackson-mapper-asl</artifactId> - </exclusion> - <exclusion> - <groupId>org.codehaus.jackson</groupId> - <artifactId>jackson-jaxrs</artifactId> + <artifactId>*</artifactId> </exclusion> <exclusion> - <groupId>org.codehaus.jackson</groupId> - <artifactId>jackson-xc</artifactId> + <groupId>org.sonatype.sisu.inject</groupId> + <artifactId>*</artifactId> </exclusion> </exclusions> </dependency> @@ -469,19 +486,11 @@ </exclusion> <exclusion> <groupId>org.codehaus.jackson</groupId> - <artifactId>jackson-core-asl</artifactId> - </exclusion> - <exclusion> - <groupId>org.codehaus.jackson</groupId> - <artifactId>jackson-mapper-asl</artifactId> - </exclusion> - <exclusion> - <groupId>org.codehaus.jackson</groupId> - <artifactId>jackson-jaxrs</artifactId> + <artifactId>*</artifactId> </exclusion> <exclusion> - <groupId>org.codehaus.jackson</groupId> - <artifactId>jackson-xc</artifactId> + <groupId>org.sonatype.sisu.inject</groupId> + <artifactId>*</artifactId> </exclusion> </exclusions> </dependency> diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b14970942b..17f480e3f0 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -60,6 +60,8 @@ object SparkBuild extends Build { lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings) .dependsOn(core, bagel, mllib, repl, streaming) dependsOn(maybeYarn: _*) + lazy val assembleDeps = TaskKey[Unit]("assemble-deps", "Build assembly of dependencies and packages Spark projects") + // A configuration to set an alternative publishLocalConfiguration lazy val MavenCompile = config("m2r") extend(Compile) lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") @@ -74,8 +76,11 @@ object SparkBuild extends Build { // Conditionally include the yarn sub-project lazy val maybeYarn = if(isYarnEnabled) Seq[ClasspathDependency](yarn) else Seq[ClasspathDependency]() lazy val maybeYarnRef = if(isYarnEnabled) Seq[ProjectReference](yarn) else Seq[ProjectReference]() - lazy val allProjects = Seq[ProjectReference]( - core, repl, examples, bagel, streaming, mllib, tools, assemblyProj) ++ maybeYarnRef + + // Everything except assembly, tools and examples belong to packageProjects + lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib) ++ maybeYarnRef + + lazy val allProjects = packageProjects ++ Seq[ProjectReference](examples, tools, assemblyProj) def sharedSettings = Defaults.defaultSettings ++ Seq( organization := "org.apache.spark", @@ -306,7 +311,9 @@ object SparkBuild extends Build { def assemblyProjSettings = sharedSettings ++ Seq( name := "spark-assembly", - jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" } + assembleDeps in Compile <<= (packageProjects.map(packageBin in Compile in _) ++ Seq(packageDependency in Compile)).dependOn, + jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" }, + jarName in packageDependency <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + "-deps.jar" } ) ++ assemblySettings ++ extraAssemblySettings def extraAssemblySettings() = Seq( @@ -314,6 +321,7 @@ object SparkBuild extends Build { mergeStrategy in assembly := { case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard + case "log4j.properties" => MergeStrategy.discard case "META-INF/services/org.apache.hadoop.fs.FileSystem" => MergeStrategy.concat case "reference.conf" => MergeStrategy.concat case _ => MergeStrategy.first diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index d367f91967..da3d96689a 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -42,6 +42,13 @@ >>> a.value 13 +>>> b = sc.accumulator(0) +>>> def g(x): +... b.add(x) +>>> rdd.foreach(g) +>>> b.value +6 + >>> from pyspark.accumulators import AccumulatorParam >>> class VectorAccumulatorParam(AccumulatorParam): ... def zero(self, value): @@ -139,9 +146,13 @@ class Accumulator(object): raise Exception("Accumulator.value cannot be accessed inside tasks") self._value = value + def add(self, term): + """Adds a term to this accumulator's value""" + self._value = self.accum_param.addInPlace(self._value, term) + def __iadd__(self, term): """The += operator; adds a term to this accumulator's value""" - self._value = self.accum_param.addInPlace(self._value, term) + self.add(term) return self def __str__(self): diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 36f54a22cf..48a8fa9328 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -845,7 +845,14 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master: val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')) .getOrElse(new Array[String](0)) .map(new java.io.File(_).getAbsolutePath) - sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars) + try { + sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars) + } catch { + case e: Exception => + e.printStackTrace() + echo("Failed to create SparkContext, exiting...") + sys.exit(1) + } sparkContext } diff --git a/streaming/pom.xml b/streaming/pom.xml index 14c043175d..339fcd2a39 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -84,12 +84,22 @@ <groupId>org.jboss.netty</groupId> <artifactId>netty</artifactId> </exclusion> + <exclusion> + <groupId>org.xerial.snappy</groupId> + <artifactId>*</artifactId> + </exclusion> </exclusions> </dependency> <dependency> <groupId>org.twitter4j</groupId> <artifactId>twitter4j-stream</artifactId> <version>3.0.3</version> + <exclusions> + <exclusion> + <groupId>org.jboss.netty</groupId> + <artifactId>netty</artifactId> + </exclusion> + </exclusions> </dependency> <dependency> <groupId>org.scala-lang</groupId> @@ -99,6 +109,12 @@ <groupId>com.typesafe.akka</groupId> <artifactId>akka-zeromq</artifactId> <version>2.0.3</version> + <exclusions> + <exclusion> + <groupId>org.jboss.netty</groupId> + <artifactId>netty</artifactId> + </exclusion> + </exclusions> </dependency> <dependency> <groupId>org.scalatest</groupId> diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 2d8f072624..bb9febad38 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.Logging import org.apache.spark.io.CompressionCodec +import org.apache.spark.util.MetadataCleaner private[streaming] @@ -40,6 +41,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val checkpointDir = ssc.checkpointDir val checkpointDuration = ssc.checkpointDuration val pendingTimes = ssc.scheduler.jobManager.getPendingTimes() + val delaySeconds = MetadataCleaner.getDelaySeconds def validate() { assert(master != null, "Checkpoint.master is null") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala index aae79a4e6f..b97fb7e6e3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala @@ -30,10 +30,11 @@ import akka.actor._ import akka.pattern.ask import akka.util.duration._ import akka.dispatch._ +import org.apache.spark.storage.BlockId private[streaming] sealed trait NetworkInputTrackerMessage private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage -private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage +private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[BlockId], metadata: Any) extends NetworkInputTrackerMessage private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage /** @@ -48,7 +49,7 @@ class NetworkInputTracker( val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*) val receiverExecutor = new ReceiverExecutor() val receiverInfo = new HashMap[Int, ActorRef] - val receivedBlockIds = new HashMap[Int, Queue[String]] + val receivedBlockIds = new HashMap[Int, Queue[BlockId]] val timeout = 5000.milliseconds var currentTime: Time = null @@ -67,9 +68,9 @@ class NetworkInputTracker( } /** Return all the blocks received from a receiver. */ - def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized { + def getBlockIds(receiverId: Int, time: Time): Array[BlockId] = synchronized { val queue = receivedBlockIds.synchronized { - receivedBlockIds.getOrElse(receiverId, new Queue[String]()) + receivedBlockIds.getOrElse(receiverId, new Queue[BlockId]()) } val result = queue.synchronized { queue.dequeueAll(x => true) @@ -92,7 +93,7 @@ class NetworkInputTracker( case AddBlocks(streamId, blockIds, metadata) => { val tmp = receivedBlockIds.synchronized { if (!receivedBlockIds.contains(streamId)) { - receivedBlockIds += ((streamId, new Queue[String])) + receivedBlockIds += ((streamId, new Queue[BlockId])) } receivedBlockIds(streamId) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index dc60046805..ee265ab4e9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -100,6 +100,10 @@ class StreamingContext private ( "both SparkContext and checkpoint as null") } + if(cp_ != null && cp_.delaySeconds >= 0 && MetadataCleaner.getDelaySeconds < 0) { + MetadataCleaner.setDelaySeconds(cp_.delaySeconds) + } + if (MetadataCleaner.getDelaySeconds < 0) { throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; " + "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index 31f9891560..8d3ac0fc65 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -31,7 +31,7 @@ import org.apache.spark.streaming.util.{RecurringTimer, SystemClock} import org.apache.spark.streaming._ import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.rdd.{RDD, BlockRDD} -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId} /** * Abstract class for defining any InputDStream that has to start a receiver on worker @@ -69,7 +69,7 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) Some(new BlockRDD[T](ssc.sc, blockIds)) } else { - Some(new BlockRDD[T](ssc.sc, Array[String]())) + Some(new BlockRDD[T](ssc.sc, Array[BlockId]())) } } } @@ -77,7 +77,7 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming private[streaming] sealed trait NetworkReceiverMessage private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage -private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage +private[streaming] case class ReportBlock(blockId: BlockId, metadata: Any) extends NetworkReceiverMessage private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage /** @@ -158,7 +158,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log /** * Pushes a block (as an ArrayBuffer filled with data) into the block manager. */ - def pushBlock(blockId: String, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) { + def pushBlock(blockId: BlockId, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) { env.blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], level) actor ! ReportBlock(blockId, metadata) } @@ -166,7 +166,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log /** * Pushes a block (as bytes) into the block manager. */ - def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { + def pushBlock(blockId: BlockId, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { env.blockManager.putBytes(blockId, bytes, level) actor ! ReportBlock(blockId, metadata) } @@ -209,7 +209,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log class BlockGenerator(storageLevel: StorageLevel) extends Serializable with Logging { - case class Block(id: String, buffer: ArrayBuffer[T], metadata: Any = null) + case class Block(id: BlockId, buffer: ArrayBuffer[T], metadata: Any = null) val clock = new SystemClock() val blockInterval = System.getProperty("spark.streaming.blockInterval", "200").toLong @@ -241,7 +241,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log val newBlockBuffer = currentBuffer currentBuffer = new ArrayBuffer[T] if (newBlockBuffer.size > 0) { - val blockId = "input-" + NetworkReceiver.this.streamId + "-" + (time - blockInterval) + val blockId = StreamBlockId(NetworkReceiver.this.streamId, time - blockInterval) val newBlock = new Block(blockId, newBlockBuffer) blocksForPushing.add(newBlock) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala index c91f12ecd7..10ed4ef78d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala @@ -18,7 +18,7 @@ package org.apache.spark.streaming.dstream import org.apache.spark.Logging -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.StreamingContext import java.net.InetSocketAddress @@ -71,7 +71,7 @@ class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel) var nextBlockNumber = 0 while (true) { val buffer = queue.take() - val blockId = "input-" + streamId + "-" + nextBlockNumber + val blockId = StreamBlockId(streamId, nextBlockNumber) nextBlockNumber += 1 pushBlock(blockId, buffer, null, storageLevel) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala index 4b5d8c467e..ef0f85a717 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala @@ -21,7 +21,7 @@ import akka.actor.{ Actor, PoisonPill, Props, SupervisorStrategy } import akka.actor.{ actorRef2Scala, ActorRef } import akka.actor.{ PossiblyHarmful, OneForOneStrategy } -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.dstream.NetworkReceiver import java.util.concurrent.atomic.AtomicInteger @@ -159,7 +159,7 @@ private[streaming] class ActorReceiver[T: ClassManifest]( protected def pushBlock(iter: Iterator[T]) { val buffer = new ArrayBuffer[T] buffer ++= iter - pushBlock("block-" + streamId + "-" + System.nanoTime(), buffer, null, storageLevel) + pushBlock(StreamBlockId(streamId, System.nanoTime()), buffer, null, storageLevel) } protected def onStart() = { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 6d6ef149cc..25da9aa917 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -22,7 +22,7 @@ import org.apache.spark.util.Utils import org.apache.spark.scheduler.SplitInfo import scala.collection import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId, ContainerId, Priority, Resource, ResourceRequest, ContainerStatus, Container} -import org.apache.spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend} +import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend} import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse} import org.apache.hadoop.yarn.util.{RackResolver, Records} import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap} @@ -211,7 +211,7 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM val workerId = workerIdCounter.incrementAndGet().toString val driverUrl = "akka://spark@%s:%s/user/%s".format( System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), - StandaloneSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ACTOR_NAME) logInfo("launching container on " + containerId + " host " + workerHostname) // just to be safe, simply remove it from pendingReleaseContainers. Should not be there, but .. |