aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbin/compute-classpath.sh22
-rwxr-xr-xbin/slaves.sh19
-rwxr-xr-xbin/spark-daemon.sh21
-rwxr-xr-xbin/spark-daemons.sh2
-rw-r--r--core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java3
-rw-r--r--core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java23
-rwxr-xr-xcore/src/main/java/org/apache/spark/network/netty/PathResolver.java11
-rw-r--r--core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala17
-rw-r--r--core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/CacheManager.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/FutureAction.scala250
-rw-r--r--core/src/main/scala/org/apache/spark/InterruptibleIterator.scala30
-rw-r--r--core/src/main/scala/org/apache/spark/ShuffleFetcher.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala156
-rw-r--r--core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/TaskEndReason.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala1058
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala410
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala54
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala247
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala603
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala (renamed from core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala)22
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala166
-rw-r--r--core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManager.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala27
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/package.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala122
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala116
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala (renamed from core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala)12
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala79
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala95
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala154
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala126
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala62
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Pool.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala53
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala63
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala51
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala96
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala)22
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala)26
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala196
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala22
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockException.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockId.scala103
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala580
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockMessage.scala38
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala135
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockStore.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala184
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskStore.scala280
-rw-r--r--core/src/main/scala/org/apache/spark/storage/FileSegment.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/storage/MemoryStore.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala62
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala84
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StorageUtils.scala47
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala22
-rw-r--r--core/src/test/scala/org/apache/spark/BroadcastSuite.scala52
-rw-r--r--core/src/test/scala/org/apache/spark/CacheManagerSuite.scala21
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala10
-rw-r--r--core/src/test/scala/org/apache/spark/DistributedSuite.scala16
-rw-r--r--core/src/test/scala/org/apache/spark/JavaAPISuite.java2
-rw-r--r--core/src/test/scala/org/apache/spark/JobCancellationSuite.scala209
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala176
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala29
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala49
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala28
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala114
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala102
-rw-r--r--docs/configuration.md10
-rw-r--r--examples/pom.xml8
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala15
-rw-r--r--examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala13
-rw-r--r--pom.xml99
-rw-r--r--project/SparkBuild.scala14
-rw-r--r--python/pyspark/accumulators.py13
-rw-r--r--repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala9
-rw-r--r--streaming/pom.xml16
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala11
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala14
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala4
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala4
129 files changed, 3674 insertions, 4009 deletions
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
index c7819d4932..c16afd6b36 100755
--- a/bin/compute-classpath.sh
+++ b/bin/compute-classpath.sh
@@ -32,12 +32,26 @@ fi
# Build up classpath
CLASSPATH="$SPARK_CLASSPATH:$FWDIR/conf"
-if [ -f "$FWDIR/RELEASE" ]; then
- ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar`
+
+# First check if we have a dependencies jar. If so, include binary classes with the deps jar
+if [ -f "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar ]; then
+ CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes"
+
+ DEPS_ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar`
+ CLASSPATH="$CLASSPATH:$DEPS_ASSEMBLY_JAR"
else
- ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar`
+ # Else use spark-assembly jar from either RELEASE or assembly directory
+ if [ -f "$FWDIR/RELEASE" ]; then
+ ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar`
+ else
+ ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar`
+ fi
+ CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR"
fi
-CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR"
# Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1
if [[ $SPARK_TESTING == 1 ]]; then
diff --git a/bin/slaves.sh b/bin/slaves.sh
index 752565b759..c367c2fd8e 100755
--- a/bin/slaves.sh
+++ b/bin/slaves.sh
@@ -28,7 +28,7 @@
# SPARK_SSH_OPTS Options passed to ssh when running remote commands.
##
-usage="Usage: slaves.sh [--config confdir] command..."
+usage="Usage: slaves.sh [--config <conf-dir>] command..."
# if no args specified, show usage
if [ $# -le 0 ]; then
@@ -46,6 +46,23 @@ bin=`cd "$bin"; pwd`
# spark-env.sh. Save it here.
HOSTLIST=$SPARK_SLAVES
+# Check if --config is passed as an argument. It is an optional parameter.
+# Exit if the argument is not a directory.
+if [ "$1" == "--config" ]
+then
+ shift
+ conf_dir=$1
+ if [ ! -d "$conf_dir" ]
+ then
+ echo "ERROR : $conf_dir is not a directory"
+ echo $usage
+ exit 1
+ else
+ export SPARK_CONF_DIR=$conf_dir
+ fi
+ shift
+fi
+
if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then
. "${SPARK_CONF_DIR}/spark-env.sh"
fi
diff --git a/bin/spark-daemon.sh b/bin/spark-daemon.sh
index 5bfe967fbf..a0c0d44b58 100755
--- a/bin/spark-daemon.sh
+++ b/bin/spark-daemon.sh
@@ -29,7 +29,7 @@
# SPARK_NICENESS The scheduling priority for daemons. Defaults to 0.
##
-usage="Usage: spark-daemon.sh [--config <conf-dir>] [--hosts hostlistfile] (start|stop) <spark-command> <spark-instance-number> <args...>"
+usage="Usage: spark-daemon.sh [--config <conf-dir>] (start|stop) <spark-command> <spark-instance-number> <args...>"
# if no args specified, show usage
if [ $# -le 1 ]; then
@@ -43,6 +43,25 @@ bin=`cd "$bin"; pwd`
. "$bin/spark-config.sh"
# get arguments
+
+# Check if --config is passed as an argument. It is an optional parameter.
+# Exit if the argument is not a directory.
+
+if [ "$1" == "--config" ]
+then
+ shift
+ conf_dir=$1
+ if [ ! -d "$conf_dir" ]
+ then
+ echo "ERROR : $conf_dir is not a directory"
+ echo $usage
+ exit 1
+ else
+ export SPARK_CONF_DIR=$conf_dir
+ fi
+ shift
+fi
+
startStop=$1
shift
command=$1
diff --git a/bin/spark-daemons.sh b/bin/spark-daemons.sh
index 354eb905a1..64286cb2da 100755
--- a/bin/spark-daemons.sh
+++ b/bin/spark-daemons.sh
@@ -19,7 +19,7 @@
# Run a Spark command on all slave hosts.
-usage="Usage: spark-daemons.sh [--config confdir] [--hosts hostlistfile] [start|stop] command instance-number args..."
+usage="Usage: spark-daemons.sh [--config <conf-dir>] [start|stop] command instance-number args..."
# if no args specified, show usage
if [ $# -le 1 ]; then
diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java
index c4aa2669e0..8a09210245 100644
--- a/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java
+++ b/core/src/main/java/org/apache/spark/network/netty/FileClientHandler.java
@@ -21,6 +21,7 @@ import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundByteHandlerAdapter;
+import org.apache.spark.storage.BlockId;
abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
@@ -33,7 +34,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
}
public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
- public abstract void handleError(String blockId);
+ public abstract void handleError(BlockId blockId);
@Override
public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {
diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
index d3d57a0255..172c6e4b1c 100644
--- a/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
+++ b/core/src/main/java/org/apache/spark/network/netty/FileServerHandler.java
@@ -24,6 +24,8 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.channel.DefaultFileRegion;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.FileSegment;
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
@@ -34,41 +36,36 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
}
@Override
- public void messageReceived(ChannelHandlerContext ctx, String blockId) {
- String path = pResolver.getAbsolutePath(blockId);
- // if getFilePath returns null, close the channel
- if (path == null) {
+ public void messageReceived(ChannelHandlerContext ctx, String blockIdString) {
+ BlockId blockId = BlockId.apply(blockIdString);
+ FileSegment fileSegment = pResolver.getBlockLocation(blockId);
+ // if getBlockLocation returns null, close the channel
+ if (fileSegment == null) {
//ctx.close();
return;
}
- File file = new File(path);
+ File file = fileSegment.file();
if (file.exists()) {
if (!file.isFile()) {
- //logger.info("Not a file : " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
- long length = file.length();
+ long length = fileSegment.length();
if (length > Integer.MAX_VALUE || length <= 0) {
- //logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
int len = new Long(length).intValue();
- //logger.info("Sending block "+blockId+" filelen = "+len);
- //logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
ctx.write((new FileHeader(len, blockId)).buffer());
try {
ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
- .getChannel(), 0, file.length()));
+ .getChannel(), fileSegment.offset(), fileSegment.length()));
} catch (Exception e) {
- //logger.warning("Exception when sending file : " + file.getAbsolutePath());
e.printStackTrace();
}
} else {
- //logger.warning("File not found: " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
}
ctx.flush();
diff --git a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java b/core/src/main/java/org/apache/spark/network/netty/PathResolver.java
index 94c034cad0..9f7ced44cf 100755
--- a/core/src/main/java/org/apache/spark/network/netty/PathResolver.java
+++ b/core/src/main/java/org/apache/spark/network/netty/PathResolver.java
@@ -17,13 +17,10 @@
package org.apache.spark.network.netty;
+import org.apache.spark.storage.BlockId;
+import org.apache.spark.storage.FileSegment;
public interface PathResolver {
- /**
- * Get the absolute path of the file
- *
- * @param fileId
- * @return the absolute path of file
- */
- public String getAbsolutePath(String fileId);
+ /** Get the file segment in which the given block resides. */
+ public FileSegment getBlockLocation(BlockId blockId);
}
diff --git a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
index f87460039b..0c47afae54 100644
--- a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
+++ b/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
@@ -17,20 +17,29 @@
package org.apache.hadoop.mapred
+private[apache]
trait SparkHadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = {
- val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", "org.apache.hadoop.mapred.JobContext");
- val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[org.apache.hadoop.mapreduce.JobID])
+ val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl",
+ "org.apache.hadoop.mapred.JobContext")
+ val ctor = klass.getDeclaredConstructor(classOf[JobConf],
+ classOf[org.apache.hadoop.mapreduce.JobID])
ctor.newInstance(conf, jobId).asInstanceOf[JobContext]
}
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = {
- val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", "org.apache.hadoop.mapred.TaskAttemptContext")
+ val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl",
+ "org.apache.hadoop.mapred.TaskAttemptContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID])
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
- def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
+ def newTaskAttemptID(
+ jtIdentifier: String,
+ jobId: Int,
+ isMap: Boolean,
+ taskId: Int,
+ attemptId: Int) = {
new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId)
}
diff --git a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
index 93180307fa..32429f01ac 100644
--- a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
+++ b/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
@@ -17,9 +17,10 @@
package org.apache.hadoop.mapreduce
-import org.apache.hadoop.conf.Configuration
import java.lang.{Integer => JInteger, Boolean => JBoolean}
+import org.apache.hadoop.conf.Configuration
+private[apache]
trait SparkHadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = {
val klass = firstAvailableClass(
@@ -37,23 +38,31 @@ trait SparkHadoopMapReduceUtil {
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
- def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
- val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID");
+ def newTaskAttemptID(
+ jtIdentifier: String,
+ jobId: Int,
+ isMap: Boolean,
+ taskId: Int,
+ attemptId: Int) = {
+ val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID")
try {
- // first, attempt to use the old-style constructor that takes a boolean isMap (not available in YARN)
+ // First, attempt to use the old-style constructor that takes a boolean isMap
+ // (not available in YARN)
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean],
- classOf[Int], classOf[Int])
- ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), new
- JInteger(attemptId)).asInstanceOf[TaskAttemptID]
+ classOf[Int], classOf[Int])
+ ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId),
+ new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
} catch {
case exc: NoSuchMethodException => {
- // failed, look for the new ctor that takes a TaskType (not available in 1.x)
- val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType").asInstanceOf[Class[Enum[_]]]
- val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(taskTypeClass, if(isMap) "MAP" else "REDUCE")
+ // If that failed, look for the new constructor that takes a TaskType (not available in 1.x)
+ val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType")
+ .asInstanceOf[Class[Enum[_]]]
+ val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(
+ taskTypeClass, if(isMap) "MAP" else "REDUCE")
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass,
classOf[Int], classOf[Int])
- ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), new
- JInteger(attemptId)).asInstanceOf[TaskAttemptID]
+ ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId),
+ new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
index 908ff56a6b..d9ed572da6 100644
--- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
@@ -22,13 +22,17 @@ import scala.collection.mutable.HashMap
import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
import org.apache.spark.serializer.Serializer
-import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
- override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer)
+ override def fetch[T](
+ shuffleId: Int,
+ reduceId: Int,
+ context: TaskContext,
+ serializer: Serializer)
: Iterator[T] =
{
@@ -45,12 +49,12 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
}
- val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map {
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) =>
- (address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2)))
+ (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}
- def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[T] = {
+ def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
@@ -58,9 +62,8 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
block.asInstanceOf[Iterator[T]]
}
case None => {
- val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
blockId match {
- case regex(shufId, mapId, _) =>
+ case ShuffleBlockId(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
case _ =>
@@ -74,7 +77,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val itr = blockFetcherItr.flatMap(unpackBlock)
- CompletionIterator[T, Iterator[T]](itr, {
+ val completionIter = CompletionIterator[T, Iterator[T]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
@@ -83,7 +86,9 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
- metrics.shuffleReadMetrics = Some(shuffleMetrics)
+ context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics)
})
+
+ new InterruptibleIterator[T](context, completionIter)
}
}
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index 4cf7eb96da..519ecde50a 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -18,7 +18,7 @@
package org.apache.spark
import scala.collection.mutable.{ArrayBuffer, HashSet}
-import org.apache.spark.storage.{BlockManager, StorageLevel}
+import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, RDDBlockId}
import org.apache.spark.rdd.RDD
@@ -28,17 +28,17 @@ import org.apache.spark.rdd.RDD
private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
/** Keys of RDD splits that are being computed/loaded. */
- private val loading = new HashSet[String]()
+ private val loading = new HashSet[RDDBlockId]()
/** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = {
- val key = "rdd_%d_%d".format(rdd.id, split.index)
+ val key = RDDBlockId(rdd.id, split.index)
logDebug("Looking for partition " + key)
blockManager.get(key) match {
case Some(values) =>
// Partition is already materialized, so just return its values
- return values.asInstanceOf[Iterator[T]]
+ return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
case None =>
// Mark the split as loading (unless someone else marks it first)
@@ -56,7 +56,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
// downside of the current code is that threads wait serially if this does happen.
blockManager.get(key) match {
case Some(values) =>
- return values.asInstanceOf[Iterator[T]]
+ return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
case None =>
logInfo("Whoever was loading %s failed; we'll try it ourselves".format(key))
loading.add(key)
@@ -73,7 +73,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
if (context.runningLocally) { return computedValues }
val elements = new ArrayBuffer[Any]
elements ++= computedValues
- blockManager.put(key, elements, storageLevel, true)
+ blockManager.put(key, elements, storageLevel, tellMaster = true)
return elements.iterator.asInstanceOf[Iterator[T]]
} finally {
loading.synchronized {
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
new file mode 100644
index 0000000000..1ad9240cfa
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -0,0 +1,250 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.concurrent._
+import scala.concurrent.duration.Duration
+import scala.util.Try
+
+import org.apache.spark.scheduler.{JobSucceeded, JobWaiter}
+import org.apache.spark.scheduler.JobFailed
+import org.apache.spark.rdd.RDD
+
+
+/**
+ * A future for the result of an action. This is an extension of the Scala Future interface to
+ * support cancellation.
+ */
+trait FutureAction[T] extends Future[T] {
+ // Note that we redefine methods of the Future trait here explicitly so we can specify a different
+ // documentation (with reference to the word "action").
+
+ /**
+ * Cancels the execution of this action.
+ */
+ def cancel()
+
+ /**
+ * Blocks until this action completes.
+ * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf
+ * for unbounded waiting, or a finite positive duration
+ * @return this FutureAction
+ */
+ override def ready(atMost: Duration)(implicit permit: CanAwait): FutureAction.this.type
+
+ /**
+ * Awaits and returns the result (of type T) of this action.
+ * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf
+ * for unbounded waiting, or a finite positive duration
+ * @throws Exception exception during action execution
+ * @return the result value if the action is completed within the specific maximum wait time
+ */
+ @throws(classOf[Exception])
+ override def result(atMost: Duration)(implicit permit: CanAwait): T
+
+ /**
+ * When this action is completed, either through an exception, or a value, applies the provided
+ * function.
+ */
+ def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext)
+
+ /**
+ * Returns whether the action has already been completed with a value or an exception.
+ */
+ override def isCompleted: Boolean
+
+ /**
+ * The value of this Future.
+ *
+ * If the future is not completed the returned value will be None. If the future is completed
+ * the value will be Some(Success(t)) if it contains a valid result, or Some(Failure(error)) if
+ * it contains an exception.
+ */
+ override def value: Option[Try[T]]
+
+ /**
+ * Blocks and returns the result of this job.
+ */
+ @throws(classOf[Exception])
+ def get(): T = Await.result(this, Duration.Inf)
+}
+
+
+/**
+ * The future holding the result of an action that triggers a single job. Examples include
+ * count, collect, reduce.
+ */
+class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
+ extends FutureAction[T] {
+
+ override def cancel() {
+ jobWaiter.cancel()
+ }
+
+ override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = {
+ if (!atMost.isFinite()) {
+ awaitResult()
+ } else {
+ val finishTime = System.currentTimeMillis() + atMost.toMillis
+ while (!isCompleted) {
+ val time = System.currentTimeMillis()
+ if (time >= finishTime) {
+ throw new TimeoutException
+ } else {
+ jobWaiter.wait(finishTime - time)
+ }
+ }
+ }
+ this
+ }
+
+ @throws(classOf[Exception])
+ override def result(atMost: Duration)(implicit permit: CanAwait): T = {
+ ready(atMost)(permit)
+ awaitResult() match {
+ case scala.util.Success(res) => res
+ case scala.util.Failure(e) => throw e
+ }
+ }
+
+ override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) {
+ executor.execute(new Runnable {
+ override def run() {
+ func(awaitResult())
+ }
+ })
+ }
+
+ override def isCompleted: Boolean = jobWaiter.jobFinished
+
+ override def value: Option[Try[T]] = {
+ if (jobWaiter.jobFinished) {
+ Some(awaitResult())
+ } else {
+ None
+ }
+ }
+
+ private def awaitResult(): Try[T] = {
+ jobWaiter.awaitResult() match {
+ case JobSucceeded => scala.util.Success(resultFunc)
+ case JobFailed(e: Exception, _) => scala.util.Failure(e)
+ }
+ }
+}
+
+
+/**
+ * A FutureAction for actions that could trigger multiple Spark jobs. Examples include take,
+ * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the
+ * action thread if it is being blocked by a job.
+ */
+class ComplexFutureAction[T] extends FutureAction[T] {
+
+ // Pointer to the thread that is executing the action. It is set when the action is run.
+ @volatile private var thread: Thread = _
+
+ // A flag indicating whether the future has been cancelled. This is used in case the future
+ // is cancelled before the action was even run (and thus we have no thread to interrupt).
+ @volatile private var _cancelled: Boolean = false
+
+ // A promise used to signal the future.
+ private val p = promise[T]()
+
+ override def cancel(): Unit = this.synchronized {
+ _cancelled = true
+ if (thread != null) {
+ thread.interrupt()
+ }
+ }
+
+ /**
+ * Executes some action enclosed in the closure. To properly enable cancellation, the closure
+ * should use runJob implementation in this promise. See takeAsync for example.
+ */
+ def run(func: => T)(implicit executor: ExecutionContext): this.type = {
+ scala.concurrent.future {
+ thread = Thread.currentThread
+ try {
+ p.success(func)
+ } catch {
+ case e: Exception => p.failure(e)
+ } finally {
+ thread = null
+ }
+ }
+ this
+ }
+
+ /**
+ * Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext
+ * to enable cancellation.
+ */
+ def runJob[T, U, R](
+ rdd: RDD[T],
+ processPartition: Iterator[T] => U,
+ partitions: Seq[Int],
+ resultHandler: (Int, U) => Unit,
+ resultFunc: => R) {
+ // If the action hasn't been cancelled yet, submit the job. The check and the submitJob
+ // command need to be in an atomic block.
+ val job = this.synchronized {
+ if (!cancelled) {
+ rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc)
+ } else {
+ throw new SparkException("Action has been cancelled")
+ }
+ }
+
+ // Wait for the job to complete. If the action is cancelled (with an interrupt),
+ // cancel the job and stop the execution. This is not in a synchronized block because
+ // Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
+ try {
+ Await.ready(job, Duration.Inf)
+ } catch {
+ case e: InterruptedException =>
+ job.cancel()
+ throw new SparkException("Action has been cancelled")
+ }
+ }
+
+ /**
+ * Returns whether the promise has been cancelled.
+ */
+ def cancelled: Boolean = _cancelled
+
+ @throws(classOf[InterruptedException])
+ @throws(classOf[scala.concurrent.TimeoutException])
+ override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = {
+ p.future.ready(atMost)(permit)
+ this
+ }
+
+ @throws(classOf[Exception])
+ override def result(atMost: Duration)(implicit permit: CanAwait): T = {
+ p.future.result(atMost)(permit)
+ }
+
+ override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit = {
+ p.future.onComplete(func)(executor)
+ }
+
+ override def isCompleted: Boolean = p.isCompleted
+
+ override def value: Option[Try[T]] = p.future.value
+}
diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
new file mode 100644
index 0000000000..56e0b8d2c0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * An iterator that wraps around an existing iterator to provide task killing functionality.
+ * It works by checking the interrupted flag in TaskContext.
+ */
+class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
+ extends Iterator[T] {
+
+ def hasNext: Boolean = !context.interrupted && delegate.hasNext
+
+ def next(): T = delegate.next()
+}
diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
index 307c383a89..a85aa50a9b 100644
--- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
@@ -27,7 +27,10 @@ private[spark] abstract class ShuffleFetcher {
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
- def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
+ def fetch[T](
+ shuffleId: Int,
+ reduceId: Int,
+ context: TaskContext,
serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T]
/** Stop the fetcher */
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index b884ae5879..564466cfd5 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -51,25 +51,20 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor
import org.apache.mesos.MesosNativeLibrary
-import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.LocalSparkCluster
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend,
- ClusterScheduler}
-import org.apache.spark.scheduler.local.LocalScheduler
+import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
+ SparkDeploySchedulerBackend, ClusterScheduler}
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import org.apache.spark.storage.{StorageUtils, BlockManagerSource}
+import org.apache.spark.scheduler.local.LocalScheduler
+import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
import org.apache.spark.ui.SparkUI
-import org.apache.spark.util._
-import org.apache.spark.scheduler.StageInfo
-import org.apache.spark.storage.RDDInfo
-import org.apache.spark.storage.StorageStatus
-import scala.Some
-import org.apache.spark.scheduler.StageInfo
-import org.apache.spark.storage.RDDInfo
-import org.apache.spark.storage.StorageStatus
+import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType,
+ TimeStampedHashMap, Utils}
+
+
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -125,7 +120,7 @@ class SparkContext(
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup)
- // Initalize the Spark UI
+ // Initialize the Spark UI
private[spark] val ui = new SparkUI(this)
ui.bind()
@@ -161,8 +156,8 @@ class SparkContext(
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """spark://(.*)""".r
- //Regular expression for connection to Mesos cluster
- val MESOS_REGEX = """(mesos://.*)""".r
+ // Regular expression for connection to Mesos cluster
+ val MESOS_REGEX = """mesos://(.*)""".r
master match {
case "local" =>
@@ -213,25 +208,24 @@ class SparkContext(
throw new SparkException("YARN mode not available ?", th)
}
}
- val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem)
+ val backend = new CoarseGrainedSchedulerBackend(scheduler, this.env.actorSystem)
scheduler.initialize(backend)
scheduler
- case _ =>
- if (MESOS_REGEX.findFirstIn(master).isEmpty) {
- logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
- }
+ case MESOS_REGEX(mesosUrl) =>
MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this)
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
- val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos://
val backend = if (coarseGrained) {
- new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
+ new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName)
} else {
- new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
+ new MesosSchedulerBackend(scheduler, this, mesosUrl, appName)
}
scheduler.initialize(backend)
scheduler
+
+ case _ =>
+ throw new SparkException("Could not parse Master URL: '" + master + "'")
}
}
taskScheduler.start()
@@ -288,15 +282,46 @@ class SparkContext(
Option(localProperties.get).map(_.getProperty(key)).getOrElse(null)
/** Set a human readable description of the current job. */
+ @deprecated("use setJobGroup", "0.8.1")
def setJobDescription(value: String) {
- setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
+ setJobGroup("", value)
+ }
+
+ /**
+ * Assigns a group id to all the jobs started by this thread until the group id is set to a
+ * different value or cleared.
+ *
+ * Often, a unit of execution in an application consists of multiple Spark actions or jobs.
+ * Application programmers can use this method to group all those jobs together and give a
+ * group description. Once set, the Spark web UI will associate such jobs with this group.
+ *
+ * The application can also use [[org.apache.spark.SparkContext.cancelJobGroup]] to cancel all
+ * running jobs in this group. For example,
+ * {{{
+ * // In the main thread:
+ * sc.setJobGroup("some_job_to_cancel", "some job description")
+ * sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
+ *
+ * // In a separate thread:
+ * sc.cancelJobGroup("some_job_to_cancel")
+ * }}}
+ */
+ def setJobGroup(groupId: String, description: String) {
+ setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description)
+ setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId)
+ }
+
+ /** Clear the job group id and its description. */
+ def clearJobGroup() {
+ setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null)
+ setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null)
}
// Post init
taskScheduler.postStartHook()
- val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
- val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
+ private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
+ private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
def initDriverMetrics() {
SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource)
@@ -335,7 +360,7 @@ class SparkContext(
}
/**
- * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf giving its InputFormat and any
+ * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and any
* other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
* etc).
*/
@@ -361,24 +386,15 @@ class SparkContext(
): RDD[(K, V)] = {
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
- hadoopFile(path, confBroadcast, inputFormatClass, keyClass, valueClass, minSplits)
- }
-
- /**
- * Get an RDD for a Hadoop file with an arbitray InputFormat. Accept a Hadoop Configuration
- * that has already been broadcast, assuming that it's safe to use it to construct a
- * HadoopFileRDD (i.e., except for file 'path', all other configuration properties can be resued).
- */
- def hadoopFile[K, V](
- path: String,
- confBroadcast: Broadcast[SerializableWritable[Configuration]],
- inputFormatClass: Class[_ <: InputFormat[K, V]],
- keyClass: Class[K],
- valueClass: Class[V],
- minSplits: Int
- ): RDD[(K, V)] = {
- new HadoopFileRDD(
- this, path, confBroadcast, inputFormatClass, keyClass, valueClass, minSplits)
+ val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
+ new HadoopRDD(
+ this,
+ confBroadcast,
+ Some(setInputPathsFunc),
+ inputFormatClass,
+ keyClass,
+ valueClass,
+ minSplits)
}
/**
@@ -764,10 +780,11 @@ class SparkContext(
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
val callSite = Utils.formatSparkCallSite
+ val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler,
- localProperties.get)
+ val result = dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
+ resultHandler, localProperties.get)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
@@ -857,6 +874,42 @@ class SparkContext(
}
/**
+ * Submit a job for execution and return a FutureJob holding the result.
+ */
+ def submitJob[T, U, R](
+ rdd: RDD[T],
+ processPartition: Iterator[T] => U,
+ partitions: Seq[Int],
+ resultHandler: (Int, U) => Unit,
+ resultFunc: => R): SimpleFutureAction[R] =
+ {
+ val cleanF = clean(processPartition)
+ val callSite = Utils.formatSparkCallSite
+ val waiter = dagScheduler.submitJob(
+ rdd,
+ (context: TaskContext, iter: Iterator[T]) => cleanF(iter),
+ partitions,
+ callSite,
+ allowLocal = false,
+ resultHandler,
+ localProperties.get)
+ new SimpleFutureAction(waiter, resultFunc)
+ }
+
+ /**
+ * Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]]
+ * for more information.
+ */
+ def cancelJobGroup(groupId: String) {
+ dagScheduler.cancelJobGroup(groupId)
+ }
+
+ /** Cancel all jobs that have been scheduled or are running. */
+ def cancelAllJobs() {
+ dagScheduler.cancelAllJobs()
+ }
+
+ /**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
*/
@@ -912,7 +965,10 @@ class SparkContext(
* various Spark features.
*/
object SparkContext {
- val SPARK_JOB_DESCRIPTION = "spark.job.description"
+
+ private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description"
+
+ private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
@@ -939,6 +995,8 @@ object SparkContext {
implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
+ implicit def rddToAsyncRDDActions[T: ClassManifest](rdd: RDD[T]) = new AsyncRDDActions(rdd)
+
implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](
rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 2bab9d6e3d..103a1c2051 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -17,14 +17,14 @@
package org.apache.hadoop.mapred
-import org.apache.hadoop.fs.FileSystem
-import org.apache.hadoop.fs.Path
-
+import java.io.IOException
import java.text.SimpleDateFormat
import java.text.NumberFormat
-import java.io.IOException
import java.util.Date
+import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.Logging
import org.apache.spark.SerializableWritable
@@ -36,7 +36,11 @@ import org.apache.spark.SerializableWritable
* Saves the RDD using a JobConf, which should contain an output key class, an output value class,
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
-class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable {
+private[apache]
+class SparkHadoopWriter(@transient jobConf: JobConf)
+ extends Logging
+ with SparkHadoopMapRedUtil
+ with Serializable {
private val now = new Date()
private val conf = new SerializableWritable(jobConf)
@@ -83,13 +87,11 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH
}
getOutputCommitter().setupTask(getTaskContext())
- writer = getOutputFormat().getRecordWriter(
- fs, conf.value, outputName, Reporter.NULL)
+ writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL)
}
def write(key: AnyRef, value: AnyRef) {
- if (writer!=null) {
- //println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")")
+ if (writer != null) {
writer.write(key, value)
} else {
throw new IOException("Writer is null, open() has not been called")
@@ -179,6 +181,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH
}
}
+private[apache]
object SparkHadoopWriter {
def createJobID(time: Date, id: Int): JobID = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index c2c358c7ad..cae983ed4c 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -17,21 +17,30 @@
package org.apache.spark
-import executor.TaskMetrics
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.executor.TaskMetrics
+
class TaskContext(
val stageId: Int,
- val splitId: Int,
+ val partitionId: Int,
val attemptId: Long,
val runningLocally: Boolean = false,
- val taskMetrics: TaskMetrics = TaskMetrics.empty()
+ @volatile var interrupted: Boolean = false,
+ private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty()
) extends Serializable {
- @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
+ @deprecated("use partitionId", "0.8.1")
+ def splitId = partitionId
+
+ // List of callback functions to execute when the task completes.
+ @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]
- // Add a callback function to be executed on task completion. An example use
- // is for HadoopRDD to register a callback to close the input stream.
+ /**
+ * Add a callback function to be executed on task completion. An example use
+ * is for HadoopRDD to register a callback to close the input stream.
+ * @param f Callback function.
+ */
def addOnCompleteCallback(f: () => Unit) {
onCompleteCallbacks += f
}
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 8466c2a004..c1e5e04b31 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -52,4 +52,6 @@ private[spark] case class ExceptionFailure(
*/
private[spark] case object TaskResultLost extends TaskEndReason
+private[spark] case object TaskKilled extends TaskEndReason
+
private[spark] case class OtherFailure(message: String) extends TaskEndReason
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 1f8ad688a6..12b4d94a56 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -308,7 +308,7 @@ private class BytesToString extends org.apache.spark.api.java.function.Function[
* Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
* collects a list of pickled strings that we pass to Python through a socket.
*/
-class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
+private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
Utils.checkHost(serverHost, "Expected hostname")
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala
deleted file mode 100644
index f82dea9f3a..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/BitTorrentBroadcast.scala
+++ /dev/null
@@ -1,1058 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.io._
-import java.net._
-import java.util.{BitSet, Comparator, Timer, TimerTask, UUID}
-import java.util.concurrent.atomic.AtomicInteger
-
-import scala.collection.mutable.{ListBuffer, Map, Set}
-import scala.math
-
-import org.apache.spark._
-import org.apache.spark.storage.{BlockManager, StorageLevel}
-import org.apache.spark.util.Utils
-
-private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
- extends Broadcast[T](id)
- with Logging
- with Serializable {
-
- def value = value_
-
- def blockId: String = BlockManager.toBroadcastId(id)
-
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- }
-
- @transient var arrayOfBlocks: Array[BroadcastBlock] = null
- @transient var hasBlocksBitVector: BitSet = null
- @transient var numCopiesSent: Array[Int] = null
- @transient var totalBytes = -1
- @transient var totalBlocks = -1
- @transient var hasBlocks = new AtomicInteger(0)
-
- // Used ONLY by driver to track how many unique blocks have been sent out
- @transient var sentBlocks = new AtomicInteger(0)
-
- @transient var listenPortLock = new Object
- @transient var guidePortLock = new Object
- @transient var totalBlocksLock = new Object
-
- @transient var listOfSources = ListBuffer[SourceInfo]()
-
- @transient var serveMR: ServeMultipleRequests = null
-
- // Used only in driver
- @transient var guideMR: GuideMultipleRequests = null
-
- // Used only in Workers
- @transient var ttGuide: TalkToGuide = null
-
- @transient var hostAddress = Utils.localIpAddress
- @transient var listenPort = -1
- @transient var guidePort = -1
-
- @transient var stopBroadcast = false
-
- // Must call this after all the variables have been created/initialized
- if (!isLocal) {
- sendBroadcast()
- }
-
- def sendBroadcast() {
- logInfo("Local host address: " + hostAddress)
-
- // Create a variableInfo object and store it in valueInfos
- var variableInfo = MultiTracker.blockifyObject(value_)
-
- // Prepare the value being broadcasted
- arrayOfBlocks = variableInfo.arrayOfBlocks
- totalBytes = variableInfo.totalBytes
- totalBlocks = variableInfo.totalBlocks
- hasBlocks.set(variableInfo.totalBlocks)
-
- // Guide has all the blocks
- hasBlocksBitVector = new BitSet(totalBlocks)
- hasBlocksBitVector.set(0, totalBlocks)
-
- // Guide still hasn't sent any block
- numCopiesSent = new Array[Int](totalBlocks)
-
- guideMR = new GuideMultipleRequests
- guideMR.setDaemon(true)
- guideMR.start()
- logInfo("GuideMultipleRequests started...")
-
- // Must always come AFTER guideMR is created
- while (guidePort == -1) {
- guidePortLock.synchronized { guidePortLock.wait() }
- }
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- // Must always come AFTER serveMR is created
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Must always come AFTER listenPort is created
- val driverSource =
- SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
- hasBlocksBitVector.synchronized {
- driverSource.hasBlocksBitVector = hasBlocksBitVector
- }
-
- // In the beginning, this is the only known source to Guide
- listOfSources += driverSource
-
- // Register with the Tracker
- MultiTracker.registerBroadcast(id,
- SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
- }
-
- private def readObject(in: ObjectInputStream) {
- in.defaultReadObject()
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.getSingle(blockId) match {
- case Some(x) =>
- value_ = x.asInstanceOf[T]
-
- case None =>
- logInfo("Started reading broadcast variable " + id)
- // Initializing everything because driver will only send null/0 values
- // Only the 1st worker in a node can be here. Others will get from cache
- initializeWorkerVariables()
-
- logInfo("Local host address: " + hostAddress)
-
- // Start local ServeMultipleRequests thread first
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- val start = System.nanoTime
-
- val receptionSucceeded = receiveBroadcast(id)
- if (receptionSucceeded) {
- value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- SparkEnv.get.blockManager.putSingle(
- blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- } else {
- logError("Reading broadcast variable " + id + " failed")
- }
-
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading broadcast variable " + id + " took " + time + " s")
- }
- }
- }
-
- // Initialize variables in the worker node. Driver sends everything as 0/null
- private def initializeWorkerVariables() {
- arrayOfBlocks = null
- hasBlocksBitVector = null
- numCopiesSent = null
- totalBytes = -1
- totalBlocks = -1
- hasBlocks = new AtomicInteger(0)
-
- listenPortLock = new Object
- totalBlocksLock = new Object
-
- serveMR = null
- ttGuide = null
-
- hostAddress = Utils.localIpAddress
- listenPort = -1
-
- listOfSources = ListBuffer[SourceInfo]()
-
- stopBroadcast = false
- }
-
- private def getLocalSourceInfo: SourceInfo = {
- // Wait till hostName and listenPort are OK
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Wait till totalBlocks and totalBytes are OK
- while (totalBlocks == -1) {
- totalBlocksLock.synchronized { totalBlocksLock.wait() }
- }
-
- var localSourceInfo = SourceInfo(
- hostAddress, listenPort, totalBlocks, totalBytes)
-
- localSourceInfo.hasBlocks = hasBlocks.get
-
- hasBlocksBitVector.synchronized {
- localSourceInfo.hasBlocksBitVector = hasBlocksBitVector
- }
-
- return localSourceInfo
- }
-
- // Add new SourceInfo to the listOfSources. Update if it exists already.
- // Optimizing just by OR-ing the BitVectors was BAD for performance
- private def addToListOfSources(newSourceInfo: SourceInfo) {
- listOfSources.synchronized {
- if (listOfSources.contains(newSourceInfo)) {
- listOfSources = listOfSources - newSourceInfo
- }
- listOfSources += newSourceInfo
- }
- }
-
- private def addToListOfSources(newSourceInfos: ListBuffer[SourceInfo]) {
- newSourceInfos.foreach { newSourceInfo =>
- addToListOfSources(newSourceInfo)
- }
- }
-
- class TalkToGuide(gInfo: SourceInfo)
- extends Thread with Logging {
- override def run() {
-
- // Keep exchaning information until all blocks have been received
- while (hasBlocks.get < totalBlocks) {
- talkOnce
- Thread.sleep(MultiTracker.ranGen.nextInt(
- MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
- MultiTracker.MinKnockInterval)
- }
-
- // Talk one more time to let the Guide know of reception completion
- talkOnce
- }
-
- // Connect to Guide and send this worker's information
- private def talkOnce {
- var clientSocketToGuide: Socket = null
- var oosGuide: ObjectOutputStream = null
- var oisGuide: ObjectInputStream = null
-
- clientSocketToGuide = new Socket(gInfo.hostAddress, gInfo.listenPort)
- oosGuide = new ObjectOutputStream(clientSocketToGuide.getOutputStream)
- oosGuide.flush()
- oisGuide = new ObjectInputStream(clientSocketToGuide.getInputStream)
-
- // Send local information
- oosGuide.writeObject(getLocalSourceInfo)
- oosGuide.flush()
-
- // Receive source information from Guide
- var suitableSources =
- oisGuide.readObject.asInstanceOf[ListBuffer[SourceInfo]]
- logDebug("Received suitableSources from Driver " + suitableSources)
-
- addToListOfSources(suitableSources)
-
- oisGuide.close()
- oosGuide.close()
- clientSocketToGuide.close()
- }
- }
-
- def receiveBroadcast(variableID: Long): Boolean = {
- val gInfo = MultiTracker.getGuideInfo(variableID)
-
- if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
- return false
- }
-
- // Wait until hostAddress and listenPort are created by the
- // ServeMultipleRequests thread
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Setup initial states of variables
- totalBlocks = gInfo.totalBlocks
- arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
- hasBlocksBitVector = new BitSet(totalBlocks)
- numCopiesSent = new Array[Int](totalBlocks)
- totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
- totalBytes = gInfo.totalBytes
-
- // Start ttGuide to periodically talk to the Guide
- var ttGuide = new TalkToGuide(gInfo)
- ttGuide.setDaemon(true)
- ttGuide.start()
- logInfo("TalkToGuide started...")
-
- // Start pController to run TalkToPeer threads
- var pcController = new PeerChatterController
- pcController.setDaemon(true)
- pcController.start()
- logInfo("PeerChatterController started...")
-
- // FIXME: Must fix this. This might never break if broadcast fails.
- // We should be able to break and send false. Also need to kill threads
- while (hasBlocks.get < totalBlocks) {
- Thread.sleep(MultiTracker.MaxKnockInterval)
- }
-
- return true
- }
-
- class PeerChatterController
- extends Thread with Logging {
- private var peersNowTalking = ListBuffer[SourceInfo]()
- // TODO: There is a possible bug with blocksInRequestBitVector when a
- // certain bit is NOT unset upon failure resulting in an infinite loop.
- private var blocksInRequestBitVector = new BitSet(totalBlocks)
-
- override def run() {
- var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
-
- while (hasBlocks.get < totalBlocks) {
- var numThreadsToCreate = 0
- listOfSources.synchronized {
- numThreadsToCreate = math.min(listOfSources.size, MultiTracker.MaxChatSlots) -
- threadPool.getActiveCount
- }
-
- while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) {
- var peerToTalkTo = pickPeerToTalkToRandom
-
- if (peerToTalkTo != null)
- logDebug("Peer chosen: " + peerToTalkTo + " with " + peerToTalkTo.hasBlocksBitVector)
- else
- logDebug("No peer chosen...")
-
- if (peerToTalkTo != null) {
- threadPool.execute(new TalkToPeer(peerToTalkTo))
-
- // Add to peersNowTalking. Remove in the thread. We have to do this
- // ASAP, otherwise pickPeerToTalkTo picks the same peer more than once
- peersNowTalking.synchronized { peersNowTalking += peerToTalkTo }
- }
-
- numThreadsToCreate = numThreadsToCreate - 1
- }
-
- // Sleep for a while before starting some more threads
- Thread.sleep(MultiTracker.MinKnockInterval)
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- // Right now picking the one that has the most blocks this peer wants
- // Also picking peer randomly if no one has anything interesting
- private def pickPeerToTalkToRandom: SourceInfo = {
- var curPeer: SourceInfo = null
- var curMax = 0
-
- logDebug("Picking peers to talk to...")
-
- // Find peers that are not connected right now
- var peersNotInUse = ListBuffer[SourceInfo]()
- listOfSources.synchronized {
- peersNowTalking.synchronized {
- peersNotInUse = listOfSources -- peersNowTalking
- }
- }
-
- // Select the peer that has the most blocks that this receiver does not
- peersNotInUse.foreach { eachSource =>
- var tempHasBlocksBitVector: BitSet = null
- hasBlocksBitVector.synchronized {
- tempHasBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
- }
- tempHasBlocksBitVector.flip(0, tempHasBlocksBitVector.size)
- tempHasBlocksBitVector.and(eachSource.hasBlocksBitVector)
-
- if (tempHasBlocksBitVector.cardinality > curMax) {
- curPeer = eachSource
- curMax = tempHasBlocksBitVector.cardinality
- }
- }
-
- // Always picking randomly
- if (curPeer == null && peersNotInUse.size > 0) {
- // Pick uniformly the i'th required peer
- var i = MultiTracker.ranGen.nextInt(peersNotInUse.size)
-
- var peerIter = peersNotInUse.iterator
- curPeer = peerIter.next
-
- while (i > 0) {
- curPeer = peerIter.next
- i = i - 1
- }
- }
-
- return curPeer
- }
-
- // Picking peer with the weight of rare blocks it has
- private def pickPeerToTalkToRarestFirst: SourceInfo = {
- // Find peers that are not connected right now
- var peersNotInUse = ListBuffer[SourceInfo]()
- listOfSources.synchronized {
- peersNowTalking.synchronized {
- peersNotInUse = listOfSources -- peersNowTalking
- }
- }
-
- // Count the number of copies of each block in the neighborhood
- var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0)
-
- listOfSources.synchronized {
- listOfSources.foreach { eachSource =>
- for (i <- 0 until totalBlocks) {
- numCopiesPerBlock(i) +=
- ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 )
- }
- }
- }
-
- // A block is considered rare if there are at most 2 copies of that block
- // This CONSTANT could be a function of the neighborhood size
- var rareBlocksIndices = ListBuffer[Int]()
- for (i <- 0 until totalBlocks) {
- if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) <= 2) {
- rareBlocksIndices += i
- }
- }
-
- // Find peers with rare blocks
- var peersWithRareBlocks = ListBuffer[(SourceInfo, Int)]()
- var totalRareBlocks = 0
-
- peersNotInUse.foreach { eachPeer =>
- var hasRareBlocks = 0
- rareBlocksIndices.foreach { rareBlock =>
- if (eachPeer.hasBlocksBitVector.get(rareBlock)) {
- hasRareBlocks += 1
- }
- }
-
- if (hasRareBlocks > 0) {
- peersWithRareBlocks += ((eachPeer, hasRareBlocks))
- }
- totalRareBlocks += hasRareBlocks
- }
-
- // Select a peer from peersWithRareBlocks based on weight calculated from
- // unique rare blocks
- var selectedPeerToTalkTo: SourceInfo = null
-
- if (peersWithRareBlocks.size > 0) {
- // Sort the peers based on how many rare blocks they have
- peersWithRareBlocks.sortBy(_._2)
-
- var randomNumber = MultiTracker.ranGen.nextDouble
- var tempSum = 0.0
-
- var i = 0
- do {
- tempSum += (1.0 * peersWithRareBlocks(i)._2 / totalRareBlocks)
- if (tempSum >= randomNumber) {
- selectedPeerToTalkTo = peersWithRareBlocks(i)._1
- }
- i += 1
- } while (i < peersWithRareBlocks.size && selectedPeerToTalkTo == null)
- }
-
- if (selectedPeerToTalkTo == null) {
- selectedPeerToTalkTo = pickPeerToTalkToRandom
- }
-
- return selectedPeerToTalkTo
- }
-
- class TalkToPeer(peerToTalkTo: SourceInfo)
- extends Thread with Logging {
- private var peerSocketToSource: Socket = null
- private var oosSource: ObjectOutputStream = null
- private var oisSource: ObjectInputStream = null
-
- override def run() {
- // TODO: There is a possible bug here regarding blocksInRequestBitVector
- var blockToAskFor = -1
-
- // Setup the timeout mechanism
- var timeOutTask = new TimerTask {
- override def run() {
- cleanUpConnections()
- }
- }
-
- var timeOutTimer = new Timer
- timeOutTimer.schedule(timeOutTask, MultiTracker.MaxKnockInterval)
-
- logInfo("TalkToPeer started... => " + peerToTalkTo)
-
- try {
- // Connect to the source
- peerSocketToSource =
- new Socket(peerToTalkTo.hostAddress, peerToTalkTo.listenPort)
- oosSource =
- new ObjectOutputStream(peerSocketToSource.getOutputStream)
- oosSource.flush()
- oisSource =
- new ObjectInputStream(peerSocketToSource.getInputStream)
-
- // Receive latest SourceInfo from peerToTalkTo
- var newPeerToTalkTo = oisSource.readObject.asInstanceOf[SourceInfo]
- // Update listOfSources
- addToListOfSources(newPeerToTalkTo)
-
- // Turn the timer OFF, if the sender responds before timeout
- timeOutTimer.cancel()
-
- // Send the latest SourceInfo
- oosSource.writeObject(getLocalSourceInfo)
- oosSource.flush()
-
- var keepReceiving = true
-
- while (hasBlocks.get < totalBlocks && keepReceiving) {
- blockToAskFor =
- pickBlockRandom(newPeerToTalkTo.hasBlocksBitVector)
-
- // No block to request
- if (blockToAskFor < 0) {
- // Nothing to receive from newPeerToTalkTo
- keepReceiving = false
- } else {
- // Let other threads know that blockToAskFor is being requested
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set(blockToAskFor)
- }
-
- // Start with sending the blockID
- oosSource.writeObject(blockToAskFor)
- oosSource.flush()
-
- // CHANGED: Driver might send some other block than the one
- // requested to ensure fast spreading of all blocks.
- val recvStartTime = System.currentTimeMillis
- val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
- val receptionTime = (System.currentTimeMillis - recvStartTime)
-
- logDebug("Received block: " + bcBlock.blockID + " from " + peerToTalkTo + " in " + receptionTime + " millis.")
-
- if (!hasBlocksBitVector.get(bcBlock.blockID)) {
- arrayOfBlocks(bcBlock.blockID) = bcBlock
-
- // Update the hasBlocksBitVector first
- hasBlocksBitVector.synchronized {
- hasBlocksBitVector.set(bcBlock.blockID)
- hasBlocks.getAndIncrement
- }
-
- // Some block(may NOT be blockToAskFor) has arrived.
- // In any case, blockToAskFor is not in request any more
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set(blockToAskFor, false)
- }
-
- // Reset blockToAskFor to -1. Else it will be considered missing
- blockToAskFor = -1
- }
-
- // Send the latest SourceInfo
- oosSource.writeObject(getLocalSourceInfo)
- oosSource.flush()
- }
- }
- } catch {
- // EOFException is expected to happen because sender can break
- // connection due to timeout
- case eofe: java.io.EOFException => { }
- case e: Exception => {
- logError("TalktoPeer had a " + e)
- // FIXME: Remove 'newPeerToTalkTo' from listOfSources
- // We probably should have the following in some form, but not
- // really here. This exception can happen if the sender just breaks connection
- // listOfSources.synchronized {
- // logInfo("Exception in TalkToPeer. Removing source: " + peerToTalkTo)
- // listOfSources = listOfSources - peerToTalkTo
- // }
- }
- } finally {
- // blockToAskFor != -1 => there was an exception
- if (blockToAskFor != -1) {
- blocksInRequestBitVector.synchronized {
- blocksInRequestBitVector.set(blockToAskFor, false)
- }
- }
-
- cleanUpConnections()
- }
- }
-
- // Right now it picks a block uniformly that this peer does not have
- private def pickBlockRandom(txHasBlocksBitVector: BitSet): Int = {
- var needBlocksBitVector: BitSet = null
-
- // Blocks already present
- hasBlocksBitVector.synchronized {
- needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
- }
-
- // Include blocks already in transmission ONLY IF
- // MultiTracker.EndGameFraction has NOT been achieved
- if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) {
- blocksInRequestBitVector.synchronized {
- needBlocksBitVector.or(blocksInRequestBitVector)
- }
- }
-
- // Find blocks that are neither here nor in transit
- needBlocksBitVector.flip(0, needBlocksBitVector.size)
-
- // Blocks that should/can be requested
- needBlocksBitVector.and(txHasBlocksBitVector)
-
- if (needBlocksBitVector.cardinality == 0) {
- return -1
- } else {
- // Pick uniformly the i'th required block
- var i = MultiTracker.ranGen.nextInt(needBlocksBitVector.cardinality)
- var pickedBlockIndex = needBlocksBitVector.nextSetBit(0)
-
- while (i > 0) {
- pickedBlockIndex =
- needBlocksBitVector.nextSetBit(pickedBlockIndex + 1)
- i -= 1
- }
-
- return pickedBlockIndex
- }
- }
-
- // Pick the block that seems to be the rarest across sources
- private def pickBlockRarestFirst(txHasBlocksBitVector: BitSet): Int = {
- var needBlocksBitVector: BitSet = null
-
- // Blocks already present
- hasBlocksBitVector.synchronized {
- needBlocksBitVector = hasBlocksBitVector.clone.asInstanceOf[BitSet]
- }
-
- // Include blocks already in transmission ONLY IF
- // MultiTracker.EndGameFraction has NOT been achieved
- if ((1.0 * hasBlocks.get / totalBlocks) < MultiTracker.EndGameFraction) {
- blocksInRequestBitVector.synchronized {
- needBlocksBitVector.or(blocksInRequestBitVector)
- }
- }
-
- // Find blocks that are neither here nor in transit
- needBlocksBitVector.flip(0, needBlocksBitVector.size)
-
- // Blocks that should/can be requested
- needBlocksBitVector.and(txHasBlocksBitVector)
-
- if (needBlocksBitVector.cardinality == 0) {
- return -1
- } else {
- // Count the number of copies for each block across all sources
- var numCopiesPerBlock = Array.tabulate [Int](totalBlocks)(_ => 0)
-
- listOfSources.synchronized {
- listOfSources.foreach { eachSource =>
- for (i <- 0 until totalBlocks) {
- numCopiesPerBlock(i) +=
- ( if (eachSource.hasBlocksBitVector.get(i)) 1 else 0 )
- }
- }
- }
-
- // Find the minimum
- var minVal = Integer.MAX_VALUE
- for (i <- 0 until totalBlocks) {
- if (numCopiesPerBlock(i) > 0 && numCopiesPerBlock(i) < minVal) {
- minVal = numCopiesPerBlock(i)
- }
- }
-
- // Find the blocks with the least copies that this peer does not have
- var minBlocksIndices = ListBuffer[Int]()
- for (i <- 0 until totalBlocks) {
- if (needBlocksBitVector.get(i) && numCopiesPerBlock(i) == minVal) {
- minBlocksIndices += i
- }
- }
-
- // Now select a random index from minBlocksIndices
- if (minBlocksIndices.size == 0) {
- return -1
- } else {
- // Pick uniformly the i'th index
- var i = MultiTracker.ranGen.nextInt(minBlocksIndices.size)
- return minBlocksIndices(i)
- }
- }
- }
-
- private def cleanUpConnections() {
- if (oisSource != null) {
- oisSource.close()
- }
- if (oosSource != null) {
- oosSource.close()
- }
- if (peerSocketToSource != null) {
- peerSocketToSource.close()
- }
-
- // Delete from peersNowTalking
- peersNowTalking.synchronized { peersNowTalking -= peerToTalkTo }
- }
- }
- }
-
- class GuideMultipleRequests
- extends Thread with Logging {
- // Keep track of sources that have completed reception
- private var setOfCompletedSources = Set[SourceInfo]()
-
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(0)
- guidePort = serverSocket.getLocalPort
- logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
-
- guidePortLock.synchronized { guidePortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => {
- // Stop broadcast if at least one worker has connected and
- // everyone connected so far are done. Comparing with
- // listOfSources.size - 1, because it includes the Guide itself
- listOfSources.synchronized {
- setOfCompletedSources.synchronized {
- if (listOfSources.size > 1 &&
- setOfCompletedSources.size == listOfSources.size - 1) {
- stopBroadcast = true
- logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
- }
- }
- }
- }
- }
- if (clientSocket != null) {
- logDebug("Guide: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute(new GuideSingleRequest(clientSocket))
- } catch {
- // In failure, close the socket here; else, thread will close it
- case ioe: IOException => {
- clientSocket.close()
- }
- }
- }
- }
-
- // Shutdown the thread pool
- threadPool.shutdown()
-
- logInfo("Sending stopBroadcast notifications...")
- sendStopBroadcastNotifications
-
- MultiTracker.unregisterBroadcast(id)
- } finally {
- if (serverSocket != null) {
- logInfo("GuideMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- }
-
- private def sendStopBroadcastNotifications() {
- listOfSources.synchronized {
- listOfSources.foreach { sourceInfo =>
-
- var guideSocketToSource: Socket = null
- var gosSource: ObjectOutputStream = null
- var gisSource: ObjectInputStream = null
-
- try {
- // Connect to the source
- guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
- gosSource.flush()
- gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
-
- // Throw away whatever comes in
- gisSource.readObject.asInstanceOf[SourceInfo]
-
- // Send stopBroadcast signal. listenPort = SourceInfo.StopBroadcast
- gosSource.writeObject(SourceInfo("", SourceInfo.StopBroadcast))
- gosSource.flush()
- } catch {
- case e: Exception => {
- logError("sendStopBroadcastNotifications had a " + e)
- }
- } finally {
- if (gisSource != null) {
- gisSource.close()
- }
- if (gosSource != null) {
- gosSource.close()
- }
- if (guideSocketToSource != null) {
- guideSocketToSource.close()
- }
- }
- }
- }
- }
-
- class GuideSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- private var sourceInfo: SourceInfo = null
- private var selectedSources: ListBuffer[SourceInfo] = null
-
- override def run() {
- try {
- logInfo("new GuideSingleRequest is running")
- // Connecting worker is sending in its information
- sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- // Select a suitable source and send it back to the worker
- selectedSources = selectSuitableSources(sourceInfo)
- logDebug("Sending selectedSources:" + selectedSources)
- oos.writeObject(selectedSources)
- oos.flush()
-
- // Add this source to the listOfSources
- addToListOfSources(sourceInfo)
- } catch {
- case e: Exception => {
- // Assuming exception caused by receiver failure: remove
- if (listOfSources != null) {
- listOfSources.synchronized { listOfSources -= sourceInfo }
- }
- }
- } finally {
- logInfo("GuideSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- // Randomly select some sources to send back
- private def selectSuitableSources(skipSourceInfo: SourceInfo): ListBuffer[SourceInfo] = {
- var selectedSources = ListBuffer[SourceInfo]()
-
- // If skipSourceInfo.hasBlocksBitVector has all bits set to 'true'
- // then add skipSourceInfo to setOfCompletedSources. Return blank.
- if (skipSourceInfo.hasBlocks == totalBlocks) {
- setOfCompletedSources.synchronized { setOfCompletedSources += skipSourceInfo }
- return selectedSources
- }
-
- listOfSources.synchronized {
- if (listOfSources.size <= MultiTracker.MaxPeersInGuideResponse) {
- selectedSources = listOfSources.clone
- } else {
- var picksLeft = MultiTracker.MaxPeersInGuideResponse
- var alreadyPicked = new BitSet(listOfSources.size)
-
- while (picksLeft > 0) {
- var i = -1
-
- do {
- i = MultiTracker.ranGen.nextInt(listOfSources.size)
- } while (alreadyPicked.get(i))
-
- var peerIter = listOfSources.iterator
- var curPeer = peerIter.next
-
- // Set the BitSet before i is decremented
- alreadyPicked.set(i)
-
- while (i > 0) {
- curPeer = peerIter.next
- i = i - 1
- }
-
- selectedSources += curPeer
-
- picksLeft = picksLeft - 1
- }
- }
- }
-
- // Remove the receiving source (if present)
- selectedSources = selectedSources - skipSourceInfo
-
- return selectedSources
- }
- }
- }
-
- class ServeMultipleRequests
- extends Thread with Logging {
- // Server at most MultiTracker.MaxChatSlots peers
- var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
-
- override def run() {
- var serverSocket = new ServerSocket(0)
- listenPort = serverSocket.getLocalPort
-
- logInfo("ServeMultipleRequests started with " + serverSocket)
-
- listenPortLock.synchronized { listenPortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => { }
- }
- if (clientSocket != null) {
- logDebug("Serve: Accepted new client connection:" + clientSocket)
- try {
- threadPool.execute(new ServeSingleRequest(clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo("ServeMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- class ServeSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- logInfo("new ServeSingleRequest is running")
-
- override def run() {
- try {
- // Send latest local SourceInfo to the receiver
- // In the case of receiver timeout and connection close, this will
- // throw a java.net.SocketException: Broken pipe
- oos.writeObject(getLocalSourceInfo)
- oos.flush()
-
- // Receive latest SourceInfo from the receiver
- var rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- if (rxSourceInfo.listenPort == SourceInfo.StopBroadcast) {
- stopBroadcast = true
- } else {
- addToListOfSources(rxSourceInfo)
- }
-
- val startTime = System.currentTimeMillis
- var curTime = startTime
- var keepSending = true
- var numBlocksToSend = MultiTracker.MaxChatBlocks
-
- while (!stopBroadcast && keepSending && numBlocksToSend > 0) {
- // Receive which block to send
- var blockToSend = ois.readObject.asInstanceOf[Int]
-
- // If it is driver AND at least one copy of each block has not been
- // sent out already, MODIFY blockToSend
- if (MultiTracker.isDriver && sentBlocks.get < totalBlocks) {
- blockToSend = sentBlocks.getAndIncrement
- }
-
- // Send the block
- sendBlock(blockToSend)
- rxSourceInfo.hasBlocksBitVector.set(blockToSend)
-
- numBlocksToSend -= 1
-
- // Receive latest SourceInfo from the receiver
- rxSourceInfo = ois.readObject.asInstanceOf[SourceInfo]
- logDebug("rxSourceInfo: " + rxSourceInfo + " with " + rxSourceInfo.hasBlocksBitVector)
- addToListOfSources(rxSourceInfo)
-
- curTime = System.currentTimeMillis
- // Revoke sending only if there is anyone waiting in the queue
- if (curTime - startTime >= MultiTracker.MaxChatTime &&
- threadPool.getQueue.size > 0) {
- keepSending = false
- }
- }
- } catch {
- case e: Exception => logError("ServeSingleRequest had a " + e)
- } finally {
- logInfo("ServeSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- private def sendBlock(blockToSend: Int) {
- try {
- oos.writeObject(arrayOfBlocks(blockToSend))
- oos.flush()
- } catch {
- case e: Exception => logError("sendBlock had a " + e)
- }
- logDebug("Sent block: " + blockToSend + " to " + clientSocket)
- }
- }
- }
-}
-
-private[spark] class BitTorrentBroadcastFactory
-extends BroadcastFactory {
- def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
-
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
- new BitTorrentBroadcast[T](value_, isLocal, id)
-
- def stop() { MultiTracker.stop() }
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index a4ceb0d6af..609464e38d 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -25,16 +25,15 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import org.apache.spark.{HttpServer, Logging, SparkEnv}
import org.apache.spark.io.CompressionCodec
-import org.apache.spark.storage.{BlockManager, StorageLevel}
-import org.apache.spark.util.{MetadataCleanerType, Utils, MetadataCleaner, TimeStampedHashSet}
-
+import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
- def blockId: String = BlockManager.toBroadcastId(id)
+ def blockId = BroadcastBlockId(id)
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
@@ -121,7 +120,7 @@ private object HttpBroadcast extends Logging {
}
def write(id: Long, value: Any) {
- val file = new File(broadcastDir, "broadcast-" + id)
+ val file = new File(broadcastDir, BroadcastBlockId(id).name)
val out: OutputStream = {
if (compress) {
compressionCodec.compressedOutputStream(new FileOutputStream(file))
@@ -137,7 +136,7 @@ private object HttpBroadcast extends Logging {
}
def read[T](id: Long): T = {
- val url = serverUri + "/broadcast-" + id
+ val url = serverUri + "/" + BroadcastBlockId(id).name
val in = {
if (compress) {
compressionCodec.compressedInputStream(new URL(url).openStream())
diff --git a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala b/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala
deleted file mode 100644
index 21ec94659e..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/MultiTracker.scala
+++ /dev/null
@@ -1,410 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.io._
-import java.net._
-import java.util.Random
-
-import scala.collection.mutable.Map
-
-import org.apache.spark._
-import org.apache.spark.util.Utils
-
-private object MultiTracker
-extends Logging {
-
- // Tracker Messages
- val REGISTER_BROADCAST_TRACKER = 0
- val UNREGISTER_BROADCAST_TRACKER = 1
- val FIND_BROADCAST_TRACKER = 2
-
- // Map to keep track of guides of ongoing broadcasts
- var valueToGuideMap = Map[Long, SourceInfo]()
-
- // Random number generator
- var ranGen = new Random
-
- private var initialized = false
- private var _isDriver = false
-
- private var stopBroadcast = false
-
- private var trackMV: TrackMultipleValues = null
-
- def initialize(__isDriver: Boolean) {
- synchronized {
- if (!initialized) {
- _isDriver = __isDriver
-
- if (isDriver) {
- trackMV = new TrackMultipleValues
- trackMV.setDaemon(true)
- trackMV.start()
-
- // Set DriverHostAddress to the driver's IP address for the slaves to read
- System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress)
- }
-
- initialized = true
- }
- }
- }
-
- def stop() {
- stopBroadcast = true
- }
-
- // Load common parameters
- private var DriverHostAddress_ = System.getProperty(
- "spark.MultiTracker.DriverHostAddress", "")
- private var DriverTrackerPort_ = System.getProperty(
- "spark.broadcast.driverTrackerPort", "11111").toInt
- private var BlockSize_ = System.getProperty(
- "spark.broadcast.blockSize", "4096").toInt * 1024
- private var MaxRetryCount_ = System.getProperty(
- "spark.broadcast.maxRetryCount", "2").toInt
-
- private var TrackerSocketTimeout_ = System.getProperty(
- "spark.broadcast.trackerSocketTimeout", "50000").toInt
- private var ServerSocketTimeout_ = System.getProperty(
- "spark.broadcast.serverSocketTimeout", "10000").toInt
-
- private var MinKnockInterval_ = System.getProperty(
- "spark.broadcast.minKnockInterval", "500").toInt
- private var MaxKnockInterval_ = System.getProperty(
- "spark.broadcast.maxKnockInterval", "999").toInt
-
- // Load TreeBroadcast config params
- private var MaxDegree_ = System.getProperty(
- "spark.broadcast.maxDegree", "2").toInt
-
- // Load BitTorrentBroadcast config params
- private var MaxPeersInGuideResponse_ = System.getProperty(
- "spark.broadcast.maxPeersInGuideResponse", "4").toInt
-
- private var MaxChatSlots_ = System.getProperty(
- "spark.broadcast.maxChatSlots", "4").toInt
- private var MaxChatTime_ = System.getProperty(
- "spark.broadcast.maxChatTime", "500").toInt
- private var MaxChatBlocks_ = System.getProperty(
- "spark.broadcast.maxChatBlocks", "1024").toInt
-
- private var EndGameFraction_ = System.getProperty(
- "spark.broadcast.endGameFraction", "0.95").toDouble
-
- def isDriver = _isDriver
-
- // Common config params
- def DriverHostAddress = DriverHostAddress_
- def DriverTrackerPort = DriverTrackerPort_
- def BlockSize = BlockSize_
- def MaxRetryCount = MaxRetryCount_
-
- def TrackerSocketTimeout = TrackerSocketTimeout_
- def ServerSocketTimeout = ServerSocketTimeout_
-
- def MinKnockInterval = MinKnockInterval_
- def MaxKnockInterval = MaxKnockInterval_
-
- // TreeBroadcast configs
- def MaxDegree = MaxDegree_
-
- // BitTorrentBroadcast configs
- def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
-
- def MaxChatSlots = MaxChatSlots_
- def MaxChatTime = MaxChatTime_
- def MaxChatBlocks = MaxChatBlocks_
-
- def EndGameFraction = EndGameFraction_
-
- class TrackMultipleValues
- extends Thread with Logging {
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(DriverTrackerPort)
- logInfo("TrackMultipleValues started at " + serverSocket)
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(TrackerSocketTimeout)
- clientSocket = serverSocket.accept()
- } catch {
- case e: Exception => {
- if (stopBroadcast) {
- logInfo("Stopping TrackMultipleValues...")
- }
- }
- }
-
- if (clientSocket != null) {
- try {
- threadPool.execute(new Thread {
- override def run() {
- val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- try {
- // First, read message type
- val messageType = ois.readObject.asInstanceOf[Int]
-
- if (messageType == REGISTER_BROADCAST_TRACKER) {
- // Receive Long
- val id = ois.readObject.asInstanceOf[Long]
- // Receive hostAddress and listenPort
- val gInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- // Add to the map
- valueToGuideMap.synchronized {
- valueToGuideMap += (id -> gInfo)
- }
-
- logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
-
- // Send dummy ACK
- oos.writeObject(-1)
- oos.flush()
- } else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
- // Receive Long
- val id = ois.readObject.asInstanceOf[Long]
-
- // Remove from the map
- valueToGuideMap.synchronized {
- valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault)
- }
-
- logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
-
- // Send dummy ACK
- oos.writeObject(-1)
- oos.flush()
- } else if (messageType == FIND_BROADCAST_TRACKER) {
- // Receive Long
- val id = ois.readObject.asInstanceOf[Long]
-
- var gInfo =
- if (valueToGuideMap.contains(id)) valueToGuideMap(id)
- else SourceInfo("", SourceInfo.TxNotStartedRetry)
-
- logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort)
-
- // Send reply back
- oos.writeObject(gInfo)
- oos.flush()
- } else {
- throw new SparkException("Undefined messageType at TrackMultipleValues")
- }
- } catch {
- case e: Exception => {
- logError("TrackMultipleValues had a " + e)
- }
- } finally {
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
- })
- } catch {
- // In failure, close socket here; else, client thread will close
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- serverSocket.close()
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
- }
-
- def getGuideInfo(variableLong: Long): SourceInfo = {
- var clientSocketToTracker: Socket = null
- var oosTracker: ObjectOutputStream = null
- var oisTracker: ObjectInputStream = null
-
- var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry)
-
- var retriesLeft = MultiTracker.MaxRetryCount
- do {
- try {
- // Connect to the tracker to find out GuideInfo
- clientSocketToTracker =
- new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort)
- oosTracker =
- new ObjectOutputStream(clientSocketToTracker.getOutputStream)
- oosTracker.flush()
- oisTracker =
- new ObjectInputStream(clientSocketToTracker.getInputStream)
-
- // Send messageType/intention
- oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER)
- oosTracker.flush()
-
- // Send Long and receive GuideInfo
- oosTracker.writeObject(variableLong)
- oosTracker.flush()
- gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
- } catch {
- case e: Exception => logError("getGuideInfo had a " + e)
- } finally {
- if (oisTracker != null) {
- oisTracker.close()
- }
- if (oosTracker != null) {
- oosTracker.close()
- }
- if (clientSocketToTracker != null) {
- clientSocketToTracker.close()
- }
- }
-
- Thread.sleep(MultiTracker.ranGen.nextInt(
- MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
- MultiTracker.MinKnockInterval)
-
- retriesLeft -= 1
- } while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry)
-
- logDebug("Got this guidePort from Tracker: " + gInfo.listenPort)
- return gInfo
- }
-
- def registerBroadcast(id: Long, gInfo: SourceInfo) {
- val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
- val oosST = new ObjectOutputStream(socket.getOutputStream)
- oosST.flush()
- val oisST = new ObjectInputStream(socket.getInputStream)
-
- // Send messageType/intention
- oosST.writeObject(REGISTER_BROADCAST_TRACKER)
- oosST.flush()
-
- // Send Long of this broadcast
- oosST.writeObject(id)
- oosST.flush()
-
- // Send this tracker's information
- oosST.writeObject(gInfo)
- oosST.flush()
-
- // Receive ACK and throw it away
- oisST.readObject.asInstanceOf[Int]
-
- // Shut stuff down
- oisST.close()
- oosST.close()
- socket.close()
- }
-
- def unregisterBroadcast(id: Long) {
- val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
- val oosST = new ObjectOutputStream(socket.getOutputStream)
- oosST.flush()
- val oisST = new ObjectInputStream(socket.getInputStream)
-
- // Send messageType/intention
- oosST.writeObject(UNREGISTER_BROADCAST_TRACKER)
- oosST.flush()
-
- // Send Long of this broadcast
- oosST.writeObject(id)
- oosST.flush()
-
- // Receive ACK and throw it away
- oisST.readObject.asInstanceOf[Int]
-
- // Shut stuff down
- oisST.close()
- oosST.close()
- socket.close()
- }
-
- // Helper method to convert an object to Array[BroadcastBlock]
- def blockifyObject[IN](obj: IN): VariableInfo = {
- val baos = new ByteArrayOutputStream
- val oos = new ObjectOutputStream(baos)
- oos.writeObject(obj)
- oos.close()
- baos.close()
- val byteArray = baos.toByteArray
- val bais = new ByteArrayInputStream(byteArray)
-
- var blockNum = (byteArray.length / BlockSize)
- if (byteArray.length % BlockSize != 0)
- blockNum += 1
-
- var retVal = new Array[BroadcastBlock](blockNum)
- var blockID = 0
-
- for (i <- 0 until (byteArray.length, BlockSize)) {
- val thisBlockSize = math.min(BlockSize, byteArray.length - i)
- var tempByteArray = new Array[Byte](thisBlockSize)
- val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
-
- retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
- blockID += 1
- }
- bais.close()
-
- var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
- variableInfo.hasBlocks = blockNum
-
- return variableInfo
- }
-
- // Helper method to convert Array[BroadcastBlock] to object
- def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
- totalBytes: Int,
- totalBlocks: Int): OUT = {
-
- var retByteArray = new Array[Byte](totalBytes)
- for (i <- 0 until totalBlocks) {
- System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
- i * BlockSize, arrayOfBlocks(i).byteArray.length)
- }
- byteArrayToObject(retByteArray)
- }
-
- private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
- val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
- override def resolveClass(desc: ObjectStreamClass) =
- Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
- }
- val retVal = in.readObject.asInstanceOf[OUT]
- in.close()
- return retVal
- }
-}
-
-private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
-extends Serializable
-
-private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
- totalBlocks: Int,
- totalBytes: Int)
-extends Serializable {
- @transient var hasBlocks = 0
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala b/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala
deleted file mode 100644
index baa1fd6da4..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/SourceInfo.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.util.BitSet
-
-import org.apache.spark._
-
-/**
- * Used to keep and pass around information of peers involved in a broadcast
- */
-private[spark] case class SourceInfo (hostAddress: String,
- listenPort: Int,
- totalBlocks: Int = SourceInfo.UnusedParam,
- totalBytes: Int = SourceInfo.UnusedParam)
-extends Comparable[SourceInfo] with Logging {
-
- var currentLeechers = 0
- var receptionFailed = false
-
- var hasBlocks = 0
- var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
-
- // Ascending sort based on leecher count
- def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
-}
-
-/**
- * Helper Object of SourceInfo for its constants
- */
-private[spark] object SourceInfo {
- // Broadcast has not started yet! Should never happen.
- val TxNotStartedRetry = -1
- // Broadcast has already finished. Try default mechanism.
- val TxOverGoToDefault = -3
- // Other constants
- val StopBroadcast = -2
- val UnusedParam = 0
-}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
new file mode 100644
index 0000000000..073a0a5029
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.broadcast
+
+import java.io._
+
+import scala.math
+import scala.util.Random
+
+import org.apache.spark._
+import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
+import org.apache.spark.util.Utils
+
+
+private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
+extends Broadcast[T](id) with Logging with Serializable {
+
+ def value = value_
+
+ def broadcastId = BroadcastBlockId(id)
+
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
+ }
+
+ @transient var arrayOfBlocks: Array[TorrentBlock] = null
+ @transient var totalBlocks = -1
+ @transient var totalBytes = -1
+ @transient var hasBlocks = 0
+
+ if (!isLocal) {
+ sendBroadcast()
+ }
+
+ def sendBroadcast() {
+ var tInfo = TorrentBroadcast.blockifyObject(value_)
+
+ totalBlocks = tInfo.totalBlocks
+ totalBytes = tInfo.totalBytes
+ hasBlocks = tInfo.totalBlocks
+
+ // Store meta-info
+ val metaId = BroadcastHelperBlockId(broadcastId, "meta")
+ val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(
+ metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true)
+ }
+
+ // Store individual pieces
+ for (i <- 0 until totalBlocks) {
+ val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.putSingle(
+ pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
+ }
+ }
+ }
+
+ // Called by JVM when deserializing an object
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject()
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(broadcastId) match {
+ case Some(x) =>
+ value_ = x.asInstanceOf[T]
+
+ case None =>
+ val start = System.nanoTime
+ logInfo("Started reading broadcast variable " + id)
+
+ // Initialize @transient variables that will receive garbage values from the master.
+ resetWorkerVariables()
+
+ if (receiveBroadcast(id)) {
+ value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
+
+ // Store the merged copy in cache so that the next worker doesn't need to rebuild it.
+ // This creates a tradeoff between memory usage and latency.
+ // Storing copy doubles the memory footprint; not storing doubles deserialization cost.
+ SparkEnv.get.blockManager.putSingle(
+ broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
+
+ // Remove arrayOfBlocks from memory once value_ is on local cache
+ resetWorkerVariables()
+ } else {
+ logError("Reading broadcast variable " + id + " failed")
+ }
+
+ val time = (System.nanoTime - start) / 1e9
+ logInfo("Reading broadcast variable " + id + " took " + time + " s")
+ }
+ }
+ }
+
+ private def resetWorkerVariables() {
+ arrayOfBlocks = null
+ totalBytes = -1
+ totalBlocks = -1
+ hasBlocks = 0
+ }
+
+ def receiveBroadcast(variableID: Long): Boolean = {
+ // Receive meta-info
+ val metaId = BroadcastHelperBlockId(broadcastId, "meta")
+ var attemptId = 10
+ while (attemptId > 0 && totalBlocks == -1) {
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(metaId) match {
+ case Some(x) =>
+ val tInfo = x.asInstanceOf[TorrentInfo]
+ totalBlocks = tInfo.totalBlocks
+ totalBytes = tInfo.totalBytes
+ arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
+ hasBlocks = 0
+
+ case None =>
+ Thread.sleep(500)
+ }
+ }
+ attemptId -= 1
+ }
+ if (totalBlocks == -1) {
+ return false
+ }
+
+ // Receive actual blocks
+ val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
+ for (pid <- recvOrder) {
+ val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
+ TorrentBroadcast.synchronized {
+ SparkEnv.get.blockManager.getSingle(pieceId) match {
+ case Some(x) =>
+ arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
+ hasBlocks += 1
+ SparkEnv.get.blockManager.putSingle(
+ pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)
+
+ case None =>
+ throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
+ }
+ }
+ }
+
+ (hasBlocks == totalBlocks)
+ }
+
+}
+
+private object TorrentBroadcast
+extends Logging {
+
+ private var initialized = false
+
+ def initialize(_isDriver: Boolean) {
+ synchronized {
+ if (!initialized) {
+ initialized = true
+ }
+ }
+ }
+
+ def stop() {
+ initialized = false
+ }
+
+ val BLOCK_SIZE = System.getProperty("spark.broadcast.blockSize", "4096").toInt * 1024
+
+ def blockifyObject[T](obj: T): TorrentInfo = {
+ val byteArray = Utils.serialize[T](obj)
+ val bais = new ByteArrayInputStream(byteArray)
+
+ var blockNum = (byteArray.length / BLOCK_SIZE)
+ if (byteArray.length % BLOCK_SIZE != 0)
+ blockNum += 1
+
+ var retVal = new Array[TorrentBlock](blockNum)
+ var blockID = 0
+
+ for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
+ val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
+ var tempByteArray = new Array[Byte](thisBlockSize)
+ val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
+
+ retVal(blockID) = new TorrentBlock(blockID, tempByteArray)
+ blockID += 1
+ }
+ bais.close()
+
+ var tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
+ tInfo.hasBlocks = blockNum
+
+ return tInfo
+ }
+
+ def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock],
+ totalBytes: Int,
+ totalBlocks: Int): T = {
+ var retByteArray = new Array[Byte](totalBytes)
+ for (i <- 0 until totalBlocks) {
+ System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
+ i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
+ }
+ Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)
+ }
+
+}
+
+private[spark] case class TorrentBlock(
+ blockID: Int,
+ byteArray: Array[Byte])
+ extends Serializable
+
+private[spark] case class TorrentInfo(
+ @transient arrayOfBlocks : Array[TorrentBlock],
+ totalBlocks: Int,
+ totalBytes: Int)
+ extends Serializable {
+
+ @transient var hasBlocks = 0
+}
+
+private[spark] class TorrentBroadcastFactory
+ extends BroadcastFactory {
+
+ def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) }
+
+ def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ new TorrentBroadcast[T](value_, isLocal, id)
+
+ def stop() { TorrentBroadcast.stop() }
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala
deleted file mode 100644
index b664f28e42..0000000000
--- a/core/src/main/scala/org/apache/spark/broadcast/TreeBroadcast.scala
+++ /dev/null
@@ -1,603 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.broadcast
-
-import java.io._
-import java.net._
-import java.util.{Comparator, Random, UUID}
-
-import scala.collection.mutable.{ListBuffer, Map, Set}
-import scala.math
-
-import org.apache.spark._
-import org.apache.spark.storage.{BlockManager, StorageLevel}
-import org.apache.spark.util.Utils
-
-private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
-extends Broadcast[T](id) with Logging with Serializable {
-
- def value = value_
-
- def blockId = BlockManager.toBroadcastId(id)
-
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- }
-
- @transient var arrayOfBlocks: Array[BroadcastBlock] = null
- @transient var totalBytes = -1
- @transient var totalBlocks = -1
- @transient var hasBlocks = 0
-
- @transient var listenPortLock = new Object
- @transient var guidePortLock = new Object
- @transient var totalBlocksLock = new Object
- @transient var hasBlocksLock = new Object
-
- @transient var listOfSources = ListBuffer[SourceInfo]()
-
- @transient var serveMR: ServeMultipleRequests = null
- @transient var guideMR: GuideMultipleRequests = null
-
- @transient var hostAddress = Utils.localIpAddress
- @transient var listenPort = -1
- @transient var guidePort = -1
-
- @transient var stopBroadcast = false
-
- // Must call this after all the variables have been created/initialized
- if (!isLocal) {
- sendBroadcast()
- }
-
- def sendBroadcast() {
- logInfo("Local host address: " + hostAddress)
-
- // Create a variableInfo object and store it in valueInfos
- var variableInfo = MultiTracker.blockifyObject(value_)
-
- // Prepare the value being broadcasted
- arrayOfBlocks = variableInfo.arrayOfBlocks
- totalBytes = variableInfo.totalBytes
- totalBlocks = variableInfo.totalBlocks
- hasBlocks = variableInfo.totalBlocks
-
- guideMR = new GuideMultipleRequests
- guideMR.setDaemon(true)
- guideMR.start()
- logInfo("GuideMultipleRequests started...")
-
- // Must always come AFTER guideMR is created
- while (guidePort == -1) {
- guidePortLock.synchronized { guidePortLock.wait() }
- }
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- // Must always come AFTER serveMR is created
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- // Must always come AFTER listenPort is created
- val masterSource =
- SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
- listOfSources += masterSource
-
- // Register with the Tracker
- MultiTracker.registerBroadcast(id,
- SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
- }
-
- private def readObject(in: ObjectInputStream) {
- in.defaultReadObject()
- MultiTracker.synchronized {
- SparkEnv.get.blockManager.getSingle(blockId) match {
- case Some(x) =>
- value_ = x.asInstanceOf[T]
-
- case None =>
- logInfo("Started reading broadcast variable " + id)
- // Initializing everything because Driver will only send null/0 values
- // Only the 1st worker in a node can be here. Others will get from cache
- initializeWorkerVariables()
-
- logInfo("Local host address: " + hostAddress)
-
- serveMR = new ServeMultipleRequests
- serveMR.setDaemon(true)
- serveMR.start()
- logInfo("ServeMultipleRequests started...")
-
- val start = System.nanoTime
-
- val receptionSucceeded = receiveBroadcast(id)
- if (receptionSucceeded) {
- value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
- SparkEnv.get.blockManager.putSingle(
- blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
- } else {
- logError("Reading broadcast variable " + id + " failed")
- }
-
- val time = (System.nanoTime - start) / 1e9
- logInfo("Reading broadcast variable " + id + " took " + time + " s")
- }
- }
- }
-
- private def initializeWorkerVariables() {
- arrayOfBlocks = null
- totalBytes = -1
- totalBlocks = -1
- hasBlocks = 0
-
- listenPortLock = new Object
- totalBlocksLock = new Object
- hasBlocksLock = new Object
-
- serveMR = null
-
- hostAddress = Utils.localIpAddress
- listenPort = -1
-
- stopBroadcast = false
- }
-
- def receiveBroadcast(variableID: Long): Boolean = {
- val gInfo = MultiTracker.getGuideInfo(variableID)
-
- if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
- return false
- }
-
- // Wait until hostAddress and listenPort are created by the
- // ServeMultipleRequests thread
- while (listenPort == -1) {
- listenPortLock.synchronized { listenPortLock.wait() }
- }
-
- var clientSocketToDriver: Socket = null
- var oosDriver: ObjectOutputStream = null
- var oisDriver: ObjectInputStream = null
-
- // Connect and receive broadcast from the specified source, retrying the
- // specified number of times in case of failures
- var retriesLeft = MultiTracker.MaxRetryCount
- do {
- // Connect to Driver and send this worker's Information
- clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort)
- oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream)
- oosDriver.flush()
- oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream)
-
- logDebug("Connected to Driver's guiding object")
-
- // Send local source information
- oosDriver.writeObject(SourceInfo(hostAddress, listenPort))
- oosDriver.flush()
-
- // Receive source information from Driver
- var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo]
- totalBlocks = sourceInfo.totalBlocks
- arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
- totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
- totalBytes = sourceInfo.totalBytes
-
- logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort)
-
- val start = System.nanoTime
- val receptionSucceeded = receiveSingleTransmission(sourceInfo)
- val time = (System.nanoTime - start) / 1e9
-
- // Updating some statistics in sourceInfo. Driver will be using them later
- if (!receptionSucceeded) {
- sourceInfo.receptionFailed = true
- }
-
- // Send back statistics to the Driver
- oosDriver.writeObject(sourceInfo)
-
- if (oisDriver != null) {
- oisDriver.close()
- }
- if (oosDriver != null) {
- oosDriver.close()
- }
- if (clientSocketToDriver != null) {
- clientSocketToDriver.close()
- }
-
- retriesLeft -= 1
- } while (retriesLeft > 0 && hasBlocks < totalBlocks)
-
- return (hasBlocks == totalBlocks)
- }
-
- /**
- * Tries to receive broadcast from the source and returns Boolean status.
- * This might be called multiple times to retry a defined number of times.
- */
- private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
- var clientSocketToSource: Socket = null
- var oosSource: ObjectOutputStream = null
- var oisSource: ObjectInputStream = null
-
- var receptionSucceeded = false
- try {
- // Connect to the source to get the object itself
- clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream)
- oosSource.flush()
- oisSource = new ObjectInputStream(clientSocketToSource.getInputStream)
-
- logDebug("Inside receiveSingleTransmission")
- logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
-
- // Send the range
- oosSource.writeObject((hasBlocks, totalBlocks))
- oosSource.flush()
-
- for (i <- hasBlocks until totalBlocks) {
- val recvStartTime = System.currentTimeMillis
- val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
- val receptionTime = (System.currentTimeMillis - recvStartTime)
-
- logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
-
- arrayOfBlocks(hasBlocks) = bcBlock
- hasBlocks += 1
-
- // Set to true if at least one block is received
- receptionSucceeded = true
- hasBlocksLock.synchronized { hasBlocksLock.notifyAll() }
- }
- } catch {
- case e: Exception => logError("receiveSingleTransmission had a " + e)
- } finally {
- if (oisSource != null) {
- oisSource.close()
- }
- if (oosSource != null) {
- oosSource.close()
- }
- if (clientSocketToSource != null) {
- clientSocketToSource.close()
- }
- }
-
- return receptionSucceeded
- }
-
- class GuideMultipleRequests
- extends Thread with Logging {
- // Keep track of sources that have completed reception
- private var setOfCompletedSources = Set[SourceInfo]()
-
- override def run() {
- var threadPool = Utils.newDaemonCachedThreadPool()
- var serverSocket: ServerSocket = null
-
- serverSocket = new ServerSocket(0)
- guidePort = serverSocket.getLocalPort
- logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
-
- guidePortLock.synchronized { guidePortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => {
- // Stop broadcast if at least one worker has connected and
- // everyone connected so far are done. Comparing with
- // listOfSources.size - 1, because it includes the Guide itself
- listOfSources.synchronized {
- setOfCompletedSources.synchronized {
- if (listOfSources.size > 1 &&
- setOfCompletedSources.size == listOfSources.size - 1) {
- stopBroadcast = true
- logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
- }
- }
- }
- }
- }
- if (clientSocket != null) {
- logDebug("Guide: Accepted new client connection: " + clientSocket)
- try {
- threadPool.execute(new GuideSingleRequest(clientSocket))
- } catch {
- // In failure, close() the socket here; else, the thread will close() it
- case ioe: IOException => clientSocket.close()
- }
- }
- }
-
- logInfo("Sending stopBroadcast notifications...")
- sendStopBroadcastNotifications
-
- MultiTracker.unregisterBroadcast(id)
- } finally {
- if (serverSocket != null) {
- logInfo("GuideMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- private def sendStopBroadcastNotifications() {
- listOfSources.synchronized {
- var listIter = listOfSources.iterator
- while (listIter.hasNext) {
- var sourceInfo = listIter.next
-
- var guideSocketToSource: Socket = null
- var gosSource: ObjectOutputStream = null
- var gisSource: ObjectInputStream = null
-
- try {
- // Connect to the source
- guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
- gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
- gosSource.flush()
- gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
-
- // Send stopBroadcast signal
- gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast))
- gosSource.flush()
- } catch {
- case e: Exception => {
- logError("sendStopBroadcastNotifications had a " + e)
- }
- } finally {
- if (gisSource != null) {
- gisSource.close()
- }
- if (gosSource != null) {
- gosSource.close()
- }
- if (guideSocketToSource != null) {
- guideSocketToSource.close()
- }
- }
- }
- }
- }
-
- class GuideSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- private var selectedSourceInfo: SourceInfo = null
- private var thisWorkerInfo:SourceInfo = null
-
- override def run() {
- try {
- logInfo("new GuideSingleRequest is running")
- // Connecting worker is sending in its hostAddress and listenPort it will
- // be listening to. Other fields are invalid (SourceInfo.UnusedParam)
- var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- listOfSources.synchronized {
- // Select a suitable source and send it back to the worker
- selectedSourceInfo = selectSuitableSource(sourceInfo)
- logDebug("Sending selectedSourceInfo: " + selectedSourceInfo)
- oos.writeObject(selectedSourceInfo)
- oos.flush()
-
- // Add this new (if it can finish) source to the list of sources
- thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
- sourceInfo.listenPort, totalBlocks, totalBytes)
- logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo)
- listOfSources += thisWorkerInfo
- }
-
- // Wait till the whole transfer is done. Then receive and update source
- // statistics in listOfSources
- sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
-
- listOfSources.synchronized {
- // This should work since SourceInfo is a case class
- assert(listOfSources.contains(selectedSourceInfo))
-
- // Remove first
- // (Currently removing a source based on just one failure notification!)
- listOfSources = listOfSources - selectedSourceInfo
-
- // Update sourceInfo and put it back in, IF reception succeeded
- if (!sourceInfo.receptionFailed) {
- // Add thisWorkerInfo to sources that have completed reception
- setOfCompletedSources.synchronized {
- setOfCompletedSources += thisWorkerInfo
- }
-
- // Update leecher count and put it back in
- selectedSourceInfo.currentLeechers -= 1
- listOfSources += selectedSourceInfo
- }
- }
- } catch {
- case e: Exception => {
- // Remove failed worker from listOfSources and update leecherCount of
- // corresponding source worker
- listOfSources.synchronized {
- if (selectedSourceInfo != null) {
- // Remove first
- listOfSources = listOfSources - selectedSourceInfo
- // Update leecher count and put it back in
- selectedSourceInfo.currentLeechers -= 1
- listOfSources += selectedSourceInfo
- }
-
- // Remove thisWorkerInfo
- if (listOfSources != null) {
- listOfSources = listOfSources - thisWorkerInfo
- }
- }
- }
- } finally {
- logInfo("GuideSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- // Assuming the caller to have a synchronized block on listOfSources
- // Select one with the most leechers. This will level-wise fill the tree
- private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
- var maxLeechers = -1
- var selectedSource: SourceInfo = null
-
- listOfSources.foreach { source =>
- if ((source.hostAddress != skipSourceInfo.hostAddress ||
- source.listenPort != skipSourceInfo.listenPort) &&
- source.currentLeechers < MultiTracker.MaxDegree &&
- source.currentLeechers > maxLeechers) {
- selectedSource = source
- maxLeechers = source.currentLeechers
- }
- }
-
- // Update leecher count
- selectedSource.currentLeechers += 1
- return selectedSource
- }
- }
- }
-
- class ServeMultipleRequests
- extends Thread with Logging {
-
- var threadPool = Utils.newDaemonCachedThreadPool()
-
- override def run() {
- var serverSocket = new ServerSocket(0)
- listenPort = serverSocket.getLocalPort
-
- logInfo("ServeMultipleRequests started with " + serverSocket)
-
- listenPortLock.synchronized { listenPortLock.notifyAll() }
-
- try {
- while (!stopBroadcast) {
- var clientSocket: Socket = null
- try {
- serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
- clientSocket = serverSocket.accept
- } catch {
- case e: Exception => { }
- }
-
- if (clientSocket != null) {
- logDebug("Serve: Accepted new client connection: " + clientSocket)
- try {
- threadPool.execute(new ServeSingleRequest(clientSocket))
- } catch {
- // In failure, close socket here; else, the thread will close it
- case ioe: IOException => clientSocket.close()
- }
- }
- }
- } finally {
- if (serverSocket != null) {
- logInfo("ServeMultipleRequests now stopping...")
- serverSocket.close()
- }
- }
- // Shutdown the thread pool
- threadPool.shutdown()
- }
-
- class ServeSingleRequest(val clientSocket: Socket)
- extends Thread with Logging {
- private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
- oos.flush()
- private val ois = new ObjectInputStream(clientSocket.getInputStream)
-
- private var sendFrom = 0
- private var sendUntil = totalBlocks
-
- override def run() {
- try {
- logInfo("new ServeSingleRequest is running")
-
- // Receive range to send
- var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
- sendFrom = rangeToSend._1
- sendUntil = rangeToSend._2
-
- // If not a valid range, stop broadcast
- if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) {
- stopBroadcast = true
- } else {
- sendObject
- }
- } catch {
- case e: Exception => logError("ServeSingleRequest had a " + e)
- } finally {
- logInfo("ServeSingleRequest is closing streams and sockets")
- ois.close()
- oos.close()
- clientSocket.close()
- }
- }
-
- private def sendObject() {
- // Wait till receiving the SourceInfo from Driver
- while (totalBlocks == -1) {
- totalBlocksLock.synchronized { totalBlocksLock.wait() }
- }
-
- for (i <- sendFrom until sendUntil) {
- while (i == hasBlocks) {
- hasBlocksLock.synchronized { hasBlocksLock.wait() }
- }
- try {
- oos.writeObject(arrayOfBlocks(i))
- oos.flush()
- } catch {
- case e: Exception => logError("sendObject had a " + e)
- }
- logDebug("Sent block: " + i + " to " + clientSocket)
- }
- }
- }
- }
-}
-
-private[spark] class TreeBroadcastFactory
-extends BroadcastFactory {
- def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
-
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
- new TreeBroadcast[T](value_, isLocal, id)
-
- def stop() { MultiTracker.stop() }
-}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 993ba6bd3d..83cd3df5fa 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -17,26 +17,31 @@
package org.apache.spark.deploy
-import com.google.common.collect.MapMaker
-
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
+import com.google.common.collect.MapMaker
+
/**
- * Contains util methods to interact with Hadoop from spark.
+ * Contains util methods to interact with Hadoop from Spark.
*/
+private[spark]
class SparkHadoopUtil {
// A general, soft-reference map for metadata needed during HadoopRDD split computation
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
- // Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop
- // subsystems
+ /**
+ * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
+ * subsystems.
+ */
def newConfiguration(): Configuration = new Configuration()
- // Add any user credentials to the job conf which are necessary for running on a secure Hadoop
- // cluster
+ /**
+ * Add any user credentials to the job conf which are necessary for running on a secure Hadoop
+ * cluster.
+ */
def addCredentials(conf: JobConf) {}
def isYarnMode(): Boolean = { false }
diff --git a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 7839023868..52b1c492b2 100644
--- a/core/src/main/scala/org/apache/spark/executor/StandaloneExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -24,11 +24,11 @@ import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClie
import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{Utils, AkkaUtils}
-private[spark] class StandaloneExecutorBackend(
+private[spark] class CoarseGrainedExecutorBackend(
driverUrl: String,
executorId: String,
hostPort: String,
@@ -63,12 +63,20 @@ private[spark] class StandaloneExecutorBackend(
case LaunchTask(taskDesc) =>
logInfo("Got assigned task " + taskDesc.taskId)
if (executor == null) {
- logError("Received launchTask but executor was null")
+ logError("Received LaunchTask command but executor was null")
System.exit(1)
} else {
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
}
+ case KillTask(taskId, _) =>
+ if (executor == null) {
+ logError("Received KillTask command but executor was null")
+ System.exit(1)
+ } else {
+ executor.killTask(taskId)
+ }
+
case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
logError("Driver terminated or disconnected! Shutting down.")
System.exit(1)
@@ -79,7 +87,7 @@ private[spark] class StandaloneExecutorBackend(
}
}
-private[spark] object StandaloneExecutorBackend {
+private[spark] object CoarseGrainedExecutorBackend {
def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
// Debug code
Utils.checkHost(hostname)
@@ -91,7 +99,7 @@ private[spark] object StandaloneExecutorBackend {
val sparkHostPort = hostname + ":" + boundPort
System.setProperty("spark.hostPort", sparkHostPort)
val actor = actorSystem.actorOf(
- Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)),
+ Props(new CoarseGrainedExecutorBackend(driverUrl, executorId, sparkHostPort, cores)),
name = "Executor")
actorSystem.awaitTermination()
}
@@ -99,7 +107,9 @@ private[spark] object StandaloneExecutorBackend {
def main(args: Array[String]) {
if (args.length < 4) {
//the reason we allow the last frameworkId argument is to make it easy to kill rogue executors
- System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores> [<appid>]")
+ System.err.println(
+ "Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> <cores> " +
+ "[<appid>]")
System.exit(1)
}
run(args(0), args(1), args(2), args(3).toInt)
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index acdb8d0343..b773346df3 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -27,7 +27,7 @@ import scala.collection.mutable.HashMap
import org.apache.spark.scheduler._
import org.apache.spark._
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util.Utils
/**
@@ -36,7 +36,8 @@ import org.apache.spark.util.Utils
private[spark] class Executor(
executorId: String,
slaveHostname: String,
- properties: Seq[(String, String)])
+ properties: Seq[(String, String)],
+ isLocal: Boolean = false)
extends Logging
{
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
@@ -73,46 +74,77 @@ private[spark] class Executor(
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
Thread.currentThread.setContextClassLoader(replClassLoader)
- // Make any thread terminations due to uncaught exceptions kill the entire
- // executor process to avoid surprising stalls.
- Thread.setDefaultUncaughtExceptionHandler(
- new Thread.UncaughtExceptionHandler {
- override def uncaughtException(thread: Thread, exception: Throwable) {
- try {
- logError("Uncaught exception in thread " + thread, exception)
-
- // We may have been called from a shutdown hook. If so, we must not call System.exit().
- // (If we do, we will deadlock.)
- if (!Utils.inShutdown()) {
- if (exception.isInstanceOf[OutOfMemoryError]) {
- System.exit(ExecutorExitCode.OOM)
- } else {
- System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
+ if (!isLocal) {
+ // Setup an uncaught exception handler for non-local mode.
+ // Make any thread terminations due to uncaught exceptions kill the entire
+ // executor process to avoid surprising stalls.
+ Thread.setDefaultUncaughtExceptionHandler(
+ new Thread.UncaughtExceptionHandler {
+ override def uncaughtException(thread: Thread, exception: Throwable) {
+ try {
+ logError("Uncaught exception in thread " + thread, exception)
+
+ // We may have been called from a shutdown hook. If so, we must not call System.exit().
+ // (If we do, we will deadlock.)
+ if (!Utils.inShutdown()) {
+ if (exception.isInstanceOf[OutOfMemoryError]) {
+ System.exit(ExecutorExitCode.OOM)
+ } else {
+ System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
+ }
}
+ } catch {
+ case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
+ case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
}
- } catch {
- case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
- case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
}
}
- }
- )
+ )
+ }
val executorSource = new ExecutorSource(this, executorId)
// Initialize Spark environment (using system properties read above)
- val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
- SparkEnv.set(env)
- env.metricsSystem.registerSource(executorSource)
+ private val env = {
+ if (!isLocal) {
+ val _env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0,
+ isDriver = false, isLocal = false)
+ SparkEnv.set(_env)
+ _env.metricsSystem.registerSource(executorSource)
+ _env
+ } else {
+ SparkEnv.get
+ }
+ }
- private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size")
+ // Akka's message frame size. If task result is bigger than this, we use the block manager
+ // to send the result back.
+ private val akkaFrameSize = {
+ env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size")
+ }
// Start worker thread pool
- val threadPool = new ThreadPoolExecutor(
- 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
+ val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
+
+ // Maintains the list of running tasks.
+ private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
- threadPool.execute(new TaskRunner(context, taskId, serializedTask))
+ val tr = new TaskRunner(context, taskId, serializedTask)
+ runningTasks.put(taskId, tr)
+ threadPool.execute(tr)
+ }
+
+ def killTask(taskId: Long) {
+ val tr = runningTasks.get(taskId)
+ if (tr != null) {
+ tr.kill()
+ // We remove the task also in the finally block in TaskRunner.run.
+ // The reason we need to remove it here is because killTask might be called before the task
+ // is even launched, and never reaching that finally block. ConcurrentHashMap's remove is
+ // idempotent.
+ runningTasks.remove(taskId)
+ }
}
/** Get the Yarn approved local directories. */
@@ -124,56 +156,87 @@ private[spark] class Executor(
.getOrElse(Option(System.getenv("LOCAL_DIRS"))
.getOrElse(""))
- if (localDirs.isEmpty()) {
+ if (localDirs.isEmpty) {
throw new Exception("Yarn Local dirs can't be empty")
}
- return localDirs
+ localDirs
}
- class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
+ class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
extends Runnable {
+ @volatile private var killed = false
+ @volatile private var task: Task[Any] = _
+
+ def kill() {
+ logInfo("Executor is trying to kill task " + taskId)
+ killed = true
+ if (task != null) {
+ task.kill()
+ }
+ }
+
override def run() {
val startTime = System.currentTimeMillis()
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + taskId)
- context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
+ execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var attemptedTask: Option[Task[Any]] = None
var taskStart: Long = 0
- def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum
- val startGCTime = getTotalGCTime
+ def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
+ val startGCTime = gcTime
try {
SparkEnv.set(env)
Accumulators.clear()
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
- val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
+ task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
+
+ // If this task has been killed before we deserialized it, let's quit now. Otherwise,
+ // continue executing the task.
+ if (killed) {
+ logInfo("Executor killed task " + taskId)
+ execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
+ return
+ }
+
attemptedTask = Some(task)
- logInfo("Its epoch is " + task.epoch)
+ logDebug("Task " + taskId +"'s epoch is " + task.epoch)
env.mapOutputTracker.updateEpoch(task.epoch)
+
+ // Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
val value = task.run(taskId.toInt)
val taskFinish = System.currentTimeMillis()
+
+ // If the task has been killed, let's fail it.
+ if (task.killed) {
+ logInfo("Executor killed task " + taskId)
+ execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
+ return
+ }
+
for (m <- task.metrics) {
- m.hostname = Utils.localHostName
+ m.hostname = Utils.localHostName()
m.executorDeserializeTime = (taskStart - startTime).toInt
m.executorRunTime = (taskFinish - taskStart).toInt
- m.jvmGCTime = getTotalGCTime - startGCTime
+ m.jvmGCTime = gcTime - startGCTime
}
- //TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c
- // we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could
- // just change the relevants bytes in the byte buffer
+ // TODO I'd also like to track the time it takes to serialize the task results, but that is
+ // huge headache, b/c we need to serialize the task metrics first. If TaskMetrics had a
+ // custom serialized format, we could just change the relevants bytes in the byte buffer
val accumUpdates = Accumulators.values
+
val directResult = new DirectTaskResult(value, accumUpdates, task.metrics.getOrElse(null))
val serializedDirectResult = ser.serialize(directResult)
logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit)
val serializedResult = {
if (serializedDirectResult.limit >= akkaFrameSize - 1024) {
logInfo("Storing result for " + taskId + " in local BlockManager")
- val blockId = "taskresult_" + taskId
+ val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
ser.serialize(new IndirectTaskResult[Any](blockId))
@@ -182,12 +245,13 @@ private[spark] class Executor(
serializedDirectResult
}
}
- context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
+
+ execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
logInfo("Finished task ID " + taskId)
} catch {
case ffe: FetchFailedException => {
val reason = ffe.toTaskEndReason
- context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
+ execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}
case t: Throwable => {
@@ -195,10 +259,10 @@ private[spark] class Executor(
val metrics = attemptedTask.flatMap(t => t.metrics)
for (m <- metrics) {
m.executorRunTime = serviceTime
- m.jvmGCTime = getTotalGCTime - startGCTime
+ m.jvmGCTime = gcTime - startGCTime
}
val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics)
- context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
+ execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
// TODO: Should we exit the whole executor here? On the one hand, the failed task may
// have left some weird state around depending on when the exception was thrown, but on
@@ -206,6 +270,8 @@ private[spark] class Executor(
logError("Exception in task ID " + taskId, t)
//System.exit(1)
}
+ } finally {
+ runningTasks.remove(taskId)
}
}
}
@@ -215,7 +281,7 @@ private[spark] class Executor(
* created by the interpreter to the search path
*/
private def createClassLoader(): ExecutorURLClassLoader = {
- var loader = this.getClass.getClassLoader
+ val loader = this.getClass.getClassLoader
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
@@ -237,7 +303,7 @@ private[spark] class Executor(
val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
- return constructor.newInstance(classUri, parent)
+ constructor.newInstance(classUri, parent)
} catch {
case _: ClassNotFoundException =>
logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!")
@@ -245,7 +311,7 @@ private[spark] class Executor(
null
}
} else {
- return parent
+ parent
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index da62091980..b56d8c9912 100644
--- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -18,14 +18,18 @@
package org.apache.spark.executor
import java.nio.ByteBuffer
-import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver}
-import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => MesosTaskStatus, _}
-import org.apache.spark.TaskState.TaskState
+
import com.google.protobuf.ByteString
-import org.apache.spark.{Logging}
+
+import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver}
+import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _}
+
+import org.apache.spark.Logging
import org.apache.spark.TaskState
+import org.apache.spark.TaskState.TaskState
import org.apache.spark.util.Utils
+
private[spark] class MesosExecutorBackend
extends MesosExecutor
with ExecutorBackend
@@ -71,7 +75,11 @@ private[spark] class MesosExecutorBackend
}
override def killTask(d: ExecutorDriver, t: TaskID) {
- logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)")
+ if (executor == null) {
+ logError("Received KillTask but executor was null")
+ } else {
+ executor.killTask(t.getValue.toLong)
+ }
}
override def reregistered(d: ExecutorDriver, p2: SlaveInfo) {}
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index f311141148..0b4892f98f 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -102,4 +102,9 @@ class ShuffleWriteMetrics extends Serializable {
* Number of bytes written for a shuffle
*/
var shuffleBytesWritten: Long = _
+
+ /**
+ * Time spent blocking on writes to disk or buffer cache, in nanoseconds.
+ */
+ var shuffleWriteTime: Long = _
}
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index e15a839c4e..9c2fee4023 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -79,7 +79,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
private val registerRequests = new SynchronizedQueue[SendingConnection]
- implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
+ implicit val futureExecContext = ExecutionContext.fromExecutor(
+ Utils.newDaemonCachedThreadPool("Connection manager future execution context"))
private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
index 3c29700920..1b9fa1e53a 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala
@@ -20,17 +20,18 @@ package org.apache.spark.network.netty
import io.netty.buffer._
import org.apache.spark.Logging
+import org.apache.spark.storage.{TestBlockId, BlockId}
private[spark] class FileHeader (
val fileLen: Int,
- val blockId: String) extends Logging {
+ val blockId: BlockId) extends Logging {
lazy val buffer = {
val buf = Unpooled.buffer()
buf.capacity(FileHeader.HEADER_SIZE)
buf.writeInt(fileLen)
- buf.writeInt(blockId.length)
- blockId.foreach((x: Char) => buf.writeByte(x))
+ buf.writeInt(blockId.name.length)
+ blockId.name.foreach((x: Char) => buf.writeByte(x))
//padding the rest of header
if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
@@ -57,18 +58,15 @@ private[spark] object FileHeader {
for (i <- 1 to idLength) {
idBuilder += buf.readByte().asInstanceOf[Char]
}
- val blockId = idBuilder.toString()
+ val blockId = BlockId(idBuilder.toString())
new FileHeader(length, blockId)
}
-
- def main (args:Array[String]){
-
- val header = new FileHeader(25,"block_0");
- val buf = header.buffer;
- val newheader = FileHeader.create(buf);
- System.out.println("id="+newheader.blockId+",size="+newheader.fileLen)
-
+ def main (args:Array[String]) {
+ val header = new FileHeader(25, TestBlockId("my_block"))
+ val buf = header.buffer
+ val newHeader = FileHeader.create(buf)
+ System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen)
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
index 9493ccffd9..481ff8c3e0 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
@@ -27,12 +27,13 @@ import org.apache.spark.Logging
import org.apache.spark.network.ConnectionManagerId
import scala.collection.JavaConverters._
+import org.apache.spark.storage.BlockId
private[spark] class ShuffleCopier extends Logging {
- def getBlock(host: String, port: Int, blockId: String,
- resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ def getBlock(host: String, port: Int, blockId: BlockId,
+ resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
@@ -41,7 +42,7 @@ private[spark] class ShuffleCopier extends Logging {
try {
fc.init()
fc.connect(host, port)
- fc.sendRequest(blockId)
+ fc.sendRequest(blockId.name)
fc.waitForClose()
fc.close()
} catch {
@@ -53,14 +54,14 @@ private[spark] class ShuffleCopier extends Logging {
}
}
- def getBlock(cmId: ConnectionManagerId, blockId: String,
- resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ def getBlock(cmId: ConnectionManagerId, blockId: BlockId,
+ resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
}
def getBlocks(cmId: ConnectionManagerId,
- blocks: Seq[(String, Long)],
- resultCollectCallback: (String, Long, ByteBuf) => Unit) {
+ blocks: Seq[(BlockId, Long)],
+ resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
for ((blockId, size) <- blocks) {
getBlock(cmId, blockId, resultCollectCallback)
@@ -71,7 +72,7 @@ private[spark] class ShuffleCopier extends Logging {
private[spark] object ShuffleCopier extends Logging {
- private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit)
+ private class ShuffleClientHandler(resultCollectCallBack: (BlockId, Long, ByteBuf) => Unit)
extends FileClientHandler with Logging {
override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
@@ -79,14 +80,14 @@ private[spark] object ShuffleCopier extends Logging {
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
}
- override def handleError(blockId: String) {
+ override def handleError(blockId: BlockId) {
if (!isComplete) {
resultCollectCallBack(blockId, -1, null)
}
}
}
- def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
+ def echoResultCollectCallBack(blockId: BlockId, size: Long, content: ByteBuf) {
if (size != -1) {
logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
}
@@ -99,7 +100,7 @@ private[spark] object ShuffleCopier extends Logging {
}
val host = args(0)
val port = args(1).toInt
- val file = args(2)
+ val blockId = BlockId(args(2))
val threads = if (args.length > 3) args(3).toInt else 10
val copiers = Executors.newFixedThreadPool(80)
@@ -107,12 +108,12 @@ private[spark] object ShuffleCopier extends Logging {
Executors.callable(new Runnable() {
def run() {
val copier = new ShuffleCopier()
- copier.getBlock(host, port, file, echoResultCollectCallBack)
+ copier.getBlock(host, port, blockId, echoResultCollectCallBack)
}
})
}).asJava
copiers.invokeAll(tasks)
- copiers.shutdown
+ copiers.shutdown()
System.exit(0)
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
index 0c5ded3145..546d921067 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleSender.scala
@@ -21,7 +21,7 @@ import java.io.File
import org.apache.spark.Logging
import org.apache.spark.util.Utils
-import org.apache.spark.storage.ShuffleBlockManager
+import org.apache.spark.storage.{BlockId, FileSegment}
private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
@@ -54,8 +54,8 @@ private[spark] object ShuffleSender {
val localDirs = args.drop(2).map(new File(_))
val pResovler = new PathResolver {
- override def getAbsolutePath(blockId: String): String = {
- if (!ShuffleBlockManager.isShuffle(blockId)) {
+ override def getBlockLocation(blockId: BlockId): FileSegment = {
+ if (!blockId.isShuffle) {
throw new Exception("Block " + blockId + " is not a shuffle block")
}
// Figure out which local directory it hashes to, and which subdirectory in that
@@ -63,8 +63,8 @@ private[spark] object ShuffleSender {
val dirId = hash % localDirs.length
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
- val file = new File(subDir, blockId)
- return file.getAbsolutePath
+ val file = new File(subDir, blockId.name)
+ return new FileSegment(file, 0, file.length())
}
}
val sender = new ShuffleSender(port, pResovler)
diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala
index f132e2b735..70a5a8caff 100644
--- a/core/src/main/scala/org/apache/spark/package.scala
+++ b/core/src/main/scala/org/apache/spark/package.scala
@@ -15,6 +15,8 @@
* limitations under the License.
*/
+package org.apache
+
/**
* Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to
* Spark, while [[org.apache.spark.rdd.RDD]] is the data type representing a distributed collection,
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
new file mode 100644
index 0000000000..faaf837be0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
+
+/**
+ * A set of asynchronous RDD actions available through an implicit conversion.
+ * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
+ */
+class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with Logging {
+
+ /**
+ * Returns a future for counting the number of elements in the RDD.
+ */
+ def countAsync(): FutureAction[Long] = {
+ val totalCount = new AtomicLong
+ self.context.submitJob(
+ self,
+ (iter: Iterator[T]) => {
+ var result = 0L
+ while (iter.hasNext) {
+ result += 1L
+ iter.next()
+ }
+ result
+ },
+ Range(0, self.partitions.size),
+ (index: Int, data: Long) => totalCount.addAndGet(data),
+ totalCount.get())
+ }
+
+ /**
+ * Returns a future for retrieving all elements of this RDD.
+ */
+ def collectAsync(): FutureAction[Seq[T]] = {
+ val results = new Array[Array[T]](self.partitions.size)
+ self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size),
+ (index, data) => results(index) = data, results.flatten.toSeq)
+ }
+
+ /**
+ * Returns a future for retrieving the first num elements of the RDD.
+ */
+ def takeAsync(num: Int): FutureAction[Seq[T]] = {
+ val f = new ComplexFutureAction[Seq[T]]
+
+ f.run {
+ val results = new ArrayBuffer[T](num)
+ val totalParts = self.partitions.length
+ var partsScanned = 0
+ while (results.size < num && partsScanned < totalParts) {
+ // The number of partitions to try in this iteration. It is ok for this number to be
+ // greater than totalParts because we actually cap it at totalParts in runJob.
+ var numPartsToTry = 1
+ if (partsScanned > 0) {
+ // If we didn't find any rows after the first iteration, just try all partitions next.
+ // Otherwise, interpolate the number of partitions we need to try, but overestimate it
+ // by 50%.
+ if (results.size == 0) {
+ numPartsToTry = totalParts - 1
+ } else {
+ numPartsToTry = (1.5 * num * partsScanned / results.size).toInt
+ }
+ }
+ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
+
+ val left = num - results.size
+ val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+
+ val buf = new Array[Array[T]](p.size)
+ f.runJob(self,
+ (it: Iterator[T]) => it.take(left).toArray,
+ p,
+ (index: Int, data: Array[T]) => buf(index) = data,
+ Unit)
+
+ buf.foreach(results ++= _.take(num - results.size))
+ partsScanned += numPartsToTry
+ }
+ results.toSeq
+ }
+
+ f
+ }
+
+ /**
+ * Applies a function f to all elements of this RDD.
+ */
+ def foreachAsync(f: T => Unit): FutureAction[Unit] = {
+ self.context.submitJob[T, Unit, Unit](self, _.foreach(f), Range(0, self.partitions.size),
+ (index, data) => Unit, Unit)
+ }
+
+ /**
+ * Applies a function f to each partition of this RDD.
+ */
+ def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = {
+ self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size),
+ (index, data) => Unit, Unit)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index bca6956a18..44ea573a7c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -18,14 +18,14 @@
package org.apache.spark.rdd
import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext}
-import org.apache.spark.storage.BlockManager
+import org.apache.spark.storage.{BlockId, BlockManager}
-private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition {
+private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends Partition {
val index = idx
}
private[spark]
-class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
+class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[BlockId])
extends RDD[T](sc, Nil) {
@transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 3311757189..ccaaecb85b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -85,7 +85,7 @@ private[spark] object CheckpointRDD extends Logging {
val outputDir = new Path(path)
val fs = outputDir.getFileSystem(env.hadoop.newConfiguration())
- val finalOutputName = splitIdToFile(ctx.splitId)
+ val finalOutputName = splitIdToFile(ctx.partitionId)
val finalOutputPath = new Path(outputDir, finalOutputName)
val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index d237797aa6..911a002884 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -21,7 +21,7 @@ import java.io.{ObjectOutputStream, IOException}
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext}
+import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
import org.apache.spark.util.AppendOnlyMap
@@ -125,12 +125,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
- fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
+ fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach {
kv => getSeq(kv._1)(depNum) += kv._2
}
}
}
- map.iterator
+ new InterruptibleIterator(context, map.iterator)
}
override def clearDependencies() {
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index d3b3fffd40..fad042c7ae 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -27,54 +27,18 @@ import org.apache.hadoop.mapred.RecordReader
import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.util.ReflectionUtils
-import org.apache.spark.{Logging, Partition, SerializableWritable, SparkContext, SparkEnv,
- TaskContext}
+import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.util.NextIterator
import org.apache.hadoop.conf.{Configuration, Configurable}
-/**
- * An RDD that reads a file (or multiple files) from Hadoop (e.g. files in HDFS, the local file
- * system, or S3).
- * This accepts a general, broadcasted Hadoop Configuration because those tend to remain the same
- * across multiple reads; the 'path' is the only variable that is different across new JobConfs
- * created from the Configuration.
- */
-class HadoopFileRDD[K, V](
- sc: SparkContext,
- path: String,
- broadcastedConf: Broadcast[SerializableWritable[Configuration]],
- inputFormatClass: Class[_ <: InputFormat[K, V]],
- keyClass: Class[K],
- valueClass: Class[V],
- minSplits: Int)
- extends HadoopRDD[K, V](sc, broadcastedConf, inputFormatClass, keyClass, valueClass, minSplits) {
-
- override def getJobConf(): JobConf = {
- if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) {
- // getJobConf() has been called previously, so there is already a local cache of the JobConf
- // needed by this RDD.
- return HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf]
- } else {
- // Create a new JobConf, set the input file/directory paths to read from, and cache the
- // JobConf (i.e., in a shared hash map in the slave's JVM process that's accessible through
- // HadoopRDD.putCachedMetadata()), so that we only create one copy across multiple
- // getJobConf() calls for this RDD in the local process.
- // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects.
- val newJobConf = new JobConf(broadcastedConf.value.value)
- FileInputFormat.setInputPaths(newJobConf, path)
- HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
- return newJobConf
- }
- }
-}
/**
* A Spark split class that wraps around a Hadoop InputSplit.
*/
private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit)
extends Partition {
-
+
val inputSplit = new SerializableWritable[InputSplit](s)
override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt
@@ -83,11 +47,24 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp
}
/**
- * A base class that provides core functionality for reading data partitions stored in Hadoop.
+ * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS,
+ * sources in HBase, or S3).
+ *
+ * @param sc The SparkContext to associate the RDD with.
+ * @param broadCastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed
+ * variabe references an instance of JobConf, then that JobConf will be used for the Hadoop job.
+ * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration.
+ * @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that HadoopRDD
+ * creates.
+ * @param inputFormatClass Storage format of the data to be read.
+ * @param keyClass Class of the key associated with the inputFormatClass.
+ * @param valueClass Class of the value associated with the inputFormatClass.
+ * @param minSplits Minimum number of Hadoop Splits (HadoopRDD partitions) to generate.
*/
class HadoopRDD[K, V](
sc: SparkContext,
broadcastedConf: Broadcast[SerializableWritable[Configuration]],
+ initLocalJobConfFuncOpt: Option[JobConf => Unit],
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
@@ -105,6 +82,7 @@ class HadoopRDD[K, V](
sc,
sc.broadcast(new SerializableWritable(conf))
.asInstanceOf[Broadcast[SerializableWritable[Configuration]]],
+ None /* initLocalJobConfFuncOpt */,
inputFormatClass,
keyClass,
valueClass,
@@ -130,6 +108,7 @@ class HadoopRDD[K, V](
// local process. The local cache is accessed through HadoopRDD.putCachedMetadata().
// The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects.
val newJobConf = new JobConf(broadcastedConf.value.value)
+ initLocalJobConfFuncOpt.map(f => f(newJobConf))
HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
return newJobConf
}
@@ -164,38 +143,41 @@ class HadoopRDD[K, V](
array
}
- override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] {
- val split = theSplit.asInstanceOf[HadoopPartition]
- logInfo("Input split: " + split.inputSplit)
- var reader: RecordReader[K, V] = null
-
- val jobConf = getJobConf()
- val inputFormat = getInputFormat(jobConf)
- reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
-
- // Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback{ () => closeIfNeeded() }
-
- val key: K = reader.createKey()
- val value: V = reader.createValue()
-
- override def getNext() = {
- try {
- finished = !reader.next(key, value)
- } catch {
- case eof: EOFException =>
- finished = true
+ override def compute(theSplit: Partition, context: TaskContext) = {
+ val iter = new NextIterator[(K, V)] {
+ val split = theSplit.asInstanceOf[HadoopPartition]
+ logInfo("Input split: " + split.inputSplit)
+ var reader: RecordReader[K, V] = null
+
+ val jobConf = getJobConf()
+ val inputFormat = getInputFormat(jobConf)
+ reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback{ () => closeIfNeeded() }
+
+ val key: K = reader.createKey()
+ val value: V = reader.createValue()
+
+ override def getNext() = {
+ try {
+ finished = !reader.next(key, value)
+ } catch {
+ case eof: EOFException =>
+ finished = true
+ }
+ (key, value)
}
- (key, value)
- }
- override def close() {
- try {
- reader.close()
- } catch {
- case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ override def close() {
+ try {
+ reader.close()
+ } catch {
+ case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ }
}
}
+ new InterruptibleIterator[(K, V)](context, iter)
}
override def getPreferredLocations(split: Partition): Seq[String] = {
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
index 3ed8339010..aea08ff81b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithIndexRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
@@ -21,14 +21,14 @@ import org.apache.spark.{Partition, TaskContext}
/**
- * A variant of the MapPartitionsRDD that passes the partition index into the
- * closure. This can be used to generate or collect partition specific
- * information such as the number of tuples in a partition.
+ * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the
+ * TaskContext, the closure can either get access to the interruptible flag or get the index
+ * of the partition in the RDD.
*/
private[spark]
-class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest](
+class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
- f: (Int, Iterator[T]) => Iterator[U],
+ f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean
) extends RDD[U](prev) {
@@ -37,5 +37,5 @@ class MapPartitionsWithIndexRDD[U: ClassManifest, T: ClassManifest](
override val partitioner = if (preservesPartitioning) prev.partitioner else None
override def compute(split: Partition, context: TaskContext) =
- f(split.index, firstParent[T].iterator(split, context))
+ f(context, firstParent[T].iterator(split, context))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 7b3a89f7e0..2662d48c84 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -24,7 +24,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
-import org.apache.spark.{Dependency, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
+import org.apache.spark.{InterruptibleIterator, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
private[spark]
@@ -71,49 +71,52 @@ class NewHadoopRDD[K, V](
result
}
- override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
- val split = theSplit.asInstanceOf[NewHadoopPartition]
- logInfo("Input split: " + split.serializableHadoopSplit)
- val conf = confBroadcast.value.value
- val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
- val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
- val format = inputFormatClass.newInstance
- if (format.isInstanceOf[Configurable]) {
- format.asInstanceOf[Configurable].setConf(conf)
- }
- val reader = format.createRecordReader(
- split.serializableHadoopSplit.value, hadoopAttemptContext)
- reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
-
- // Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback(() => close())
-
- var havePair = false
- var finished = false
-
- override def hasNext: Boolean = {
- if (!finished && !havePair) {
- finished = !reader.nextKeyValue
- havePair = !finished
+ override def compute(theSplit: Partition, context: TaskContext) = {
+ val iter = new Iterator[(K, V)] {
+ val split = theSplit.asInstanceOf[NewHadoopPartition]
+ logInfo("Input split: " + split.serializableHadoopSplit)
+ val conf = confBroadcast.value.value
+ val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
+ val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
+ val format = inputFormatClass.newInstance
+ if (format.isInstanceOf[Configurable]) {
+ format.asInstanceOf[Configurable].setConf(conf)
+ }
+ val reader = format.createRecordReader(
+ split.serializableHadoopSplit.value, hadoopAttemptContext)
+ reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addOnCompleteCallback(() => close())
+
+ var havePair = false
+ var finished = false
+
+ override def hasNext: Boolean = {
+ if (!finished && !havePair) {
+ finished = !reader.nextKeyValue
+ havePair = !finished
+ }
+ !finished
}
- !finished
- }
- override def next: (K, V) = {
- if (!hasNext) {
- throw new java.util.NoSuchElementException("End of stream")
+ override def next(): (K, V) = {
+ if (!hasNext) {
+ throw new java.util.NoSuchElementException("End of stream")
+ }
+ havePair = false
+ (reader.getCurrentKey, reader.getCurrentValue)
}
- havePair = false
- return (reader.getCurrentKey, reader.getCurrentValue)
- }
- private def close() {
- try {
- reader.close()
- } catch {
- case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ private def close() {
+ try {
+ reader.close()
+ } catch {
+ case e: Exception => logWarning("Exception in RecordReader.close()", e)
+ }
}
}
+ new InterruptibleIterator(context, iter)
}
override def getPreferredLocations(split: Partition): Seq[String] = {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index a47c512275..93b78e1232 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -84,18 +84,24 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
}
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) {
- self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
+ self.mapPartitionsWithContext((context, iter) => {
+ new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
+ }, preservesPartitioning = true)
} else if (mapSideCombine) {
val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
.setSerializer(serializerClass)
- partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true)
+ partitioned.mapPartitionsWithContext((context, iter) => {
+ new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter))
+ }, preservesPartitioning = true)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
- values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
+ values.mapPartitionsWithContext((context, iter) => {
+ new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
+ }, preservesPartitioning = true)
}
}
@@ -564,7 +570,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
- val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber)
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.partitionId, attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = outputFormatClass.newInstance
val committer = format.getOutputCommitter(hadoopContext)
@@ -663,7 +669,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
- writer.setup(context.stageId, context.splitId, attemptNumber)
+ writer.setup(context.stageId, context.partitionId, attemptNumber)
writer.open()
var count = 0
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
index 6dbd4309aa..cd96250389 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -94,8 +94,9 @@ private[spark] class ParallelCollectionRDD[T: ClassManifest](
slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
}
- override def compute(s: Partition, context: TaskContext) =
- s.asInstanceOf[ParallelCollectionPartition[T]].iterator
+ override def compute(s: Partition, context: TaskContext) = {
+ new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
+ }
override def getPreferredLocations(s: Partition): Seq[String] = {
locationPrefs.getOrElse(s.index, Nil)
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 1893627ee2..0355618e43 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -418,26 +418,39 @@ abstract class RDD[T: ClassManifest](
command: Seq[String],
env: Map[String, String] = Map(),
printPipeContext: (String => Unit) => Unit = null,
- printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
+ printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = {
new PipedRDD(this, command, env,
if (printPipeContext ne null) sc.clean(printPipeContext) else null,
if (printRDDElement ne null) sc.clean(printRDDElement) else null)
+ }
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
- def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] =
+ def mapPartitions[U: ClassManifest](
+ f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
+ }
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
def mapPartitionsWithIndex[U: ClassManifest](
- f: (Int, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] =
- new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
+ f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
+ val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter)
+ new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning)
+ }
+
+ /**
+ * Return a new RDD by applying a function to each partition of this RDD. This is a variant of
+ * mapPartitions that also passes the TaskContext into the closure.
+ */
+ def mapPartitionsWithContext[U: ClassManifest](
+ f: (TaskContext, Iterator[T]) => Iterator[U],
+ preservesPartitioning: Boolean = false): RDD[U] = {
+ new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
+ }
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
@@ -445,22 +458,23 @@ abstract class RDD[T: ClassManifest](
*/
@deprecated("use mapPartitionsWithIndex", "0.7.0")
def mapPartitionsWithSplit[U: ClassManifest](
- f: (Int, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] =
- new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
+ f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
+ mapPartitionsWithIndex(f, preservesPartitioning)
+ }
/**
* Maps f over this RDD, where f takes an additional parameter of type A. This
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
- (f:(T, A) => U): RDD[U] = {
- def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(index)
- iter.map(t => f(t, a))
- }
- new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
+ def mapWith[A: ClassManifest, U: ClassManifest]
+ (constructA: Int => A, preservesPartitioning: Boolean = false)
+ (f: (T, A) => U): RDD[U] = {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
+ val a = constructA(context.partitionId)
+ iter.map(t => f(t, a))
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
@@ -468,13 +482,14 @@ abstract class RDD[T: ClassManifest](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
- (f:(T, A) => Seq[U]): RDD[U] = {
- def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(index)
- iter.flatMap(t => f(t, a))
- }
- new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
+ def flatMapWith[A: ClassManifest, U: ClassManifest]
+ (constructA: Int => A, preservesPartitioning: Boolean = false)
+ (f: (T, A) => Seq[U]): RDD[U] = {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
+ val a = constructA(context.partitionId)
+ iter.flatMap(t => f(t, a))
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
@@ -482,13 +497,12 @@ abstract class RDD[T: ClassManifest](
* This additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def foreachWith[A: ClassManifest](constructA: Int => A)
- (f:(T, A) => Unit) {
- def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(index)
- iter.map(t => {f(t, a); t})
- }
- (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {})
+ def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
+ val a = constructA(context.partitionId)
+ iter.map(t => {f(t, a); t})
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {})
}
/**
@@ -496,13 +510,12 @@ abstract class RDD[T: ClassManifest](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def filterWith[A: ClassManifest](constructA: Int => A)
- (p:(T, A) => Boolean): RDD[T] = {
- def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(index)
- iter.filter(t => p(t, a))
- }
- new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)
+ def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
+ def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
+ val a = constructA(context.partitionId)
+ iter.filter(t => p(t, a))
+ }
+ new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
}
/**
@@ -541,16 +554,14 @@ abstract class RDD[T: ClassManifest](
* Applies a function f to all elements of this RDD.
*/
def foreach(f: T => Unit) {
- val cleanF = sc.clean(f)
- sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
+ sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f))
}
/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
- val cleanF = sc.clean(f)
- sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
+ sc.runJob(this, (iter: Iterator[T]) => f(iter))
}
/**
@@ -675,6 +686,8 @@ abstract class RDD[T: ClassManifest](
*/
def count(): Long = {
sc.runJob(this, (iter: Iterator[T]) => {
+ // Use a while loop to count the number of elements rather than iter.size because
+ // iter.size uses a for loop, which is slightly slower in current version of Scala.
var result = 0L
while (iter.hasNext) {
result += 1L
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index 9537152335..a5d751a7bd 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -56,7 +56,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest](
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context.taskMetrics,
+ SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context,
SparkEnv.get.serializerManager.get(serializerClass))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index 8c1a29dfff..7af4d803e7 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -108,7 +108,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
}
case ShuffleCoGroupSplitDep(shuffleId) => {
val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
- context.taskMetrics, serializer)
+ context, serializer)
iter.foreach(op)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 4226617cfb..d84f5968df 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -28,8 +28,8 @@ import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
-import org.apache.spark.storage.{BlockManager, BlockManagerMaster}
-import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
/**
* The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of
@@ -41,11 +41,11 @@ import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedH
* locations to run each task on, based on the current cache status, and passes these to the
* low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being
* lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are
- * not caused by shuffie file loss are handled by the TaskScheduler, which will retry each task
+ * not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task
* a small number of times before cancelling the whole stage.
*
* THREADING: This class runs all its logic in a single thread executing the run() method, to which
- * events are submitted using a synchonized queue (eventQueue). The public API methods, such as
+ * events are submitted using a synchronized queue (eventQueue). The public API methods, such as
* runJob, taskEnded and executorLost, post events asynchronously to this queue. All other methods
* should be private.
*/
@@ -55,20 +55,20 @@ class DAGScheduler(
mapOutputTracker: MapOutputTracker,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv)
- extends TaskSchedulerListener with Logging {
+ extends Logging {
def this(taskSched: TaskScheduler) {
this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
}
- taskSched.setListener(this)
+ taskSched.setDAGScheduler(this)
// Called by TaskScheduler to report task's starting.
- override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
+ def taskStarted(task: Task[_], taskInfo: TaskInfo) {
eventQueue.put(BeginEvent(task, taskInfo))
}
// Called by TaskScheduler to report task completions or failures.
- override def taskEnded(
+ def taskEnded(
task: Task[_],
reason: TaskEndReason,
result: Any,
@@ -79,17 +79,18 @@ class DAGScheduler(
}
// Called by TaskScheduler when an executor fails.
- override def executorLost(execId: String) {
+ def executorLost(execId: String) {
eventQueue.put(ExecutorLost(execId))
}
// Called by TaskScheduler when a host is added
- override def executorGained(execId: String, host: String) {
+ def executorGained(execId: String, host: String) {
eventQueue.put(ExecutorGained(execId, host))
}
- // Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
- override def taskSetFailed(taskSet: TaskSet, reason: String) {
+ // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
+ // cancellation of the job itself.
+ def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
}
@@ -104,13 +105,15 @@ class DAGScheduler(
private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
- val nextJobId = new AtomicInteger(0)
+ private[scheduler] val nextJobId = new AtomicInteger(0)
- val nextStageId = new AtomicInteger(0)
+ def numTotalJobs: Int = nextJobId.get()
- val stageIdToStage = new TimeStampedHashMap[Int, Stage]
+ private val nextStageId = new AtomicInteger(0)
- val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
+ private val stageIdToStage = new TimeStampedHashMap[Int, Stage]
+
+ private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
@@ -127,6 +130,7 @@ class DAGScheduler(
// stray messages to detect.
val failedEpoch = new HashMap[String, Long]
+ // stage id to the active job
val idToActiveJob = new HashMap[Int, ActiveJob]
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
@@ -156,7 +160,7 @@ class DAGScheduler(
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
if (!cacheLocs.contains(rdd.id)) {
- val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
+ val blockIds = rdd.partitions.indices.map(index=> RDDBlockId(rdd.id, index)).toArray[BlockId]
val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
cacheLocs(rdd.id) = blockIds.map { id =>
locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId))
@@ -261,32 +265,41 @@ class DAGScheduler(
}
/**
- * Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
- * JobWaiter whose getResult() method will return the result of the job when it is complete.
- *
- * The job is assumed to have at least one partition; zero partition jobs should be handled
- * without a JobSubmitted event.
+ * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object
+ * can be used to block until the the job finishes executing or can be used to cancel the job.
*/
- private[scheduler] def prepareJob[T, U: ClassManifest](
- finalRdd: RDD[T],
+ def submitJob[T, U](
+ rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
resultHandler: (Int, U) => Unit,
- properties: Properties = null)
- : (JobSubmitted, JobWaiter[U]) =
+ properties: Properties = null): JobWaiter[U] =
{
+ // Check to make sure we are not launching a task on a partition that does not exist.
+ val maxPartitions = rdd.partitions.length
+ partitions.find(p => p >= maxPartitions).foreach { p =>
+ throw new IllegalArgumentException(
+ "Attempting to access a non-existent partition: " + p + ". " +
+ "Total number of partitions: " + maxPartitions)
+ }
+
+ val jobId = nextJobId.getAndIncrement()
+ if (partitions.size == 0) {
+ return new JobWaiter[U](this, jobId, 0, resultHandler)
+ }
+
assert(partitions.size > 0)
- val waiter = new JobWaiter(partitions.size, resultHandler)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter,
- properties)
- (toSubmit, waiter)
+ val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
+ eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite,
+ waiter, properties))
+ waiter
}
def runJob[T, U: ClassManifest](
- finalRdd: RDD[T],
+ rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
@@ -294,21 +307,7 @@ class DAGScheduler(
resultHandler: (Int, U) => Unit,
properties: Properties = null)
{
- if (partitions.size == 0) {
- return
- }
-
- // Check to make sure we are not launching a task on a partition that does not exist.
- val maxPartitions = finalRdd.partitions.length
- partitions.find(p => p >= maxPartitions).foreach { p =>
- throw new IllegalArgumentException(
- "Attempting to access a non-existent partition: " + p + ". " +
- "Total number of partitions: " + maxPartitions)
- }
-
- val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob(
- finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)
- eventQueue.put(toSubmit)
+ val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
waiter.awaitResult() match {
case JobSucceeded => {}
case JobFailed(exception: Exception, _) =>
@@ -329,19 +328,40 @@ class DAGScheduler(
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
- eventQueue.put(JobSubmitted(rdd, func2, partitions, allowLocal = false, callSite, listener, properties))
+ val jobId = nextJobId.getAndIncrement()
+ eventQueue.put(JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite,
+ listener, properties))
listener.awaitResult() // Will throw an exception if the job fails
}
/**
+ * Cancel a job that is running or waiting in the queue.
+ */
+ def cancelJob(jobId: Int) {
+ logInfo("Asked to cancel job " + jobId)
+ eventQueue.put(JobCancelled(jobId))
+ }
+
+ def cancelJobGroup(groupId: String) {
+ logInfo("Asked to cancel job group " + groupId)
+ eventQueue.put(JobGroupCancelled(groupId))
+ }
+
+ /**
+ * Cancel all jobs that are running or waiting in the queue.
+ */
+ def cancelAllJobs() {
+ eventQueue.put(AllJobsCancelled)
+ }
+
+ /**
* Process one event retrieved from the event queue.
* Returns true if we should stop the event loop.
*/
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
- case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) =>
- val jobId = nextJobId.getAndIncrement()
- val finalStage = newStage(finalRDD, None, jobId, Some(callSite))
+ case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
+ val finalStage = newStage(rdd, None, jobId, Some(callSite))
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
@@ -360,6 +380,29 @@ class DAGScheduler(
submitStage(finalStage)
}
+ case JobCancelled(jobId) =>
+ // Cancel a job: find all the running stages that are linked to this job, and cancel them.
+ running.filter(_.jobId == jobId).foreach { stage =>
+ taskSched.cancelTasks(stage.id)
+ }
+
+ case JobGroupCancelled(groupId) =>
+ // Cancel all jobs belonging to this job group.
+ // First finds all active jobs with this group id, and then kill stages for them.
+ val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
+ .map(_.jobId)
+ if (!jobIds.isEmpty) {
+ running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage =>
+ taskSched.cancelTasks(stage.id)
+ }
+ }
+
+ case AllJobsCancelled =>
+ // Cancel all running jobs.
+ running.foreach { stage =>
+ taskSched.cancelTasks(stage.id)
+ }
+
case ExecutorGained(execId, host) =>
handleExecutorGained(execId, host)
@@ -578,6 +621,11 @@ class DAGScheduler(
*/
private def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
+
+ if (!stageIdToStage.contains(task.stageId)) {
+ // Skip all the actions if the stage has been cancelled.
+ return
+ }
val stage = stageIdToStage(task.stageId)
def markStageAsFinished(stage: Stage) = {
@@ -626,7 +674,7 @@ class DAGScheduler(
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
} else {
- stage.addOutputLoc(smt.partition, status)
+ stage.addOutputLoc(smt.partitionId, status)
}
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
markStageAsFinished(stage)
@@ -752,14 +800,14 @@ class DAGScheduler(
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
- * being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
+ * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
private def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
failedStage.completionTime = Some(System.currentTimeMillis())
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
- val error = new SparkException("Job failed: " + reason)
+ val error = new SparkException("Job aborted: " + reason)
job.listener.jobFailed(error)
listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage))))
idToActiveJob -= resultStage.jobId
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 10ff1b4376..a5769c6041 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -31,9 +31,10 @@ import org.apache.spark.executor.TaskMetrics
* submitted) but there is a single "logic" thread that reads these events and takes decisions.
* This greatly simplifies synchronization.
*/
-private[spark] sealed trait DAGSchedulerEvent
+private[scheduler] sealed trait DAGSchedulerEvent
-private[spark] case class JobSubmitted(
+private[scheduler] case class JobSubmitted(
+ jobId: Int,
finalRDD: RDD[_],
func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int],
@@ -43,9 +44,16 @@ private[spark] case class JobSubmitted(
properties: Properties = null)
extends DAGSchedulerEvent
-private[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent
-private[spark] case class CompletionEvent(
+private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent
+
+private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
+
+private[scheduler]
+case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
+
+private[scheduler] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,
result: Any,
@@ -54,10 +62,12 @@ private[spark] case class CompletionEvent(
taskMetrics: TaskMetrics)
extends DAGSchedulerEvent
-private[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
+private[scheduler]
+case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
-private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
+private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
-private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
+private[scheduler]
+case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
-private[spark] case object StopDAGScheduler extends DAGSchedulerEvent
+private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
index 151514896f..7b5c0e29ad 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala
@@ -40,7 +40,7 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: Spar
})
metricRegistry.register(MetricRegistry.name("job", "allJobs"), new Gauge[Int] {
- override def getValue: Int = dagScheduler.nextJobId.get()
+ override def getValue: Int = dagScheduler.numTotalJobs
})
metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 3628b1b078..19c0251690 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -24,56 +24,54 @@ import java.text.SimpleDateFormat
import java.util.{Date, Properties}
import java.util.concurrent.LinkedBlockingQueue
-import scala.collection.mutable.{Map, HashMap, ListBuffer}
-import scala.io.Source
+import scala.collection.mutable.{HashMap, ListBuffer}
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.executor.TaskMetrics
-// Used to record runtime information for each job, including RDD graph
-// tasks' start/stop shuffle information and information from outside
-
+/**
+ * A logger class to record runtime information for jobs in Spark. This class outputs one log file
+ * per Spark job with information such as RDD graph, tasks start/stop, shuffle information.
+ *
+ * @param logDirName The base directory for the log files.
+ */
class JobLogger(val logDirName: String) extends SparkListener with Logging {
- private val logDir =
- if (System.getenv("SPARK_LOG_DIR") != null)
- System.getenv("SPARK_LOG_DIR")
- else
- "/tmp/spark"
+
+ private val logDir = Option(System.getenv("SPARK_LOG_DIR")).getOrElse("/tmp/spark")
+
private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
private val stageIDToJobID = new HashMap[Int, Int]
private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
-
+
createLogDir()
def this() = this(String.valueOf(System.currentTimeMillis()))
-
- def getLogDir = logDir
- def getJobIDtoPrintWriter = jobIDToPrintWriter
- def getStageIDToJobID = stageIDToJobID
- def getJobIDToStages = jobIDToStages
- def getEventQueue = eventQueue
-
+
+ // The following 5 functions are used only in testing.
+ private[scheduler] def getLogDir = logDir
+ private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
+ private[scheduler] def getStageIDToJobID = stageIDToJobID
+ private[scheduler] def getJobIDToStages = jobIDToStages
+ private[scheduler] def getEventQueue = eventQueue
+
// Create a folder for log files, the folder's name is the creation time of the jobLogger
protected def createLogDir() {
val dir = new File(logDir + "/" + logDirName + "/")
- if (dir.exists()) {
- return
- }
- if (dir.mkdirs() == false) {
- logError("create log directory error:" + logDir + "/" + logDirName + "/")
+ if (!dir.exists() && !dir.mkdirs()) {
+ logError("Error creating log directory: " + logDir + "/" + logDirName + "/")
}
}
// Create a log file for one job, the file name is the jobID
protected def createLogWriter(jobID: Int) {
- try{
+ try {
val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
jobIDToPrintWriter += (jobID -> fileWriter)
- } catch {
- case e: FileNotFoundException => e.printStackTrace()
- }
+ } catch {
+ case e: FileNotFoundException => e.printStackTrace()
+ }
}
// Close log file, and clean the stage relationship in stageIDToJobID
@@ -118,10 +116,9 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
var rddList = new ListBuffer[RDD[_]]
rddList += rdd
- rdd.dependencies.foreach{ dep => dep match {
- case shufDep: ShuffleDependency[_,_] =>
- case _ => rddList ++= getRddsInStage(dep.rdd)
- }
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
}
rddList
}
@@ -161,29 +158,27 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
jobLogInfo(jobID, indentString(indent) + rddInfo, false)
- rdd.dependencies.foreach{ dep => dep match {
- case shufDep: ShuffleDependency[_,_] =>
- val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
- jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
- case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
- }
+ rdd.dependencies.foreach {
+ case shufDep: ShuffleDependency[_, _] =>
+ val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
+ jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
+ case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
}
}
protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
- var stageInfo: String = ""
- if (stage.isShuffleMap) {
- stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
- stage.shuffleDep.get.shuffleId
- }else{
- stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
+ val stageInfo = if (stage.isShuffleMap) {
+ "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
+ } else {
+ "STAGE_ID=" + stage.id + " RESULT_STAGE"
}
if (stage.jobId == jobID) {
jobLogInfo(jobID, indentString(indent) + stageInfo, false)
recordRddInStageGraph(jobID, stage.rdd, indent)
stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
- } else
+ } else {
jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
+ }
}
// Record task metrics into job log files
@@ -193,39 +188,32 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
" START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
" EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
- val readMetrics =
- taskMetrics.shuffleReadMetrics match {
- case Some(metrics) =>
- " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
- " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
- " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
- " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
- " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
- " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
- " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
- case None => ""
- }
- val writeMetrics =
- taskMetrics.shuffleWriteMetrics match {
- case Some(metrics) =>
- " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
- case None => ""
- }
+ val readMetrics = taskMetrics.shuffleReadMetrics match {
+ case Some(metrics) =>
+ " SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
+ " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
+ " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
+ " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
+ " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
+ " REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ case None => ""
+ }
+ val writeMetrics = taskMetrics.shuffleWriteMetrics match {
+ case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
+ case None => ""
+ }
stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
}
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
- stageLogInfo(
- stageSubmitted.stage.id,
- "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
- stageSubmitted.stage.id, stageSubmitted.taskSize))
+ stageLogInfo(stageSubmitted.stage.id, "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
+ stageSubmitted.stage.id, stageSubmitted.taskSize))
}
override def onStageCompleted(stageCompleted: StageCompleted) {
- stageLogInfo(
- stageCompleted.stageInfo.stage.id,
+ stageLogInfo(stageCompleted.stageInfo.stage.id,
"STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id))
-
}
override def onTaskStart(taskStart: SparkListenerTaskStart) { }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
index 200d881799..58f238d8cf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
@@ -17,48 +17,58 @@
package org.apache.spark.scheduler
-import scala.collection.mutable.ArrayBuffer
-
/**
* An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
* results to the given handler function.
*/
-private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit)
+private[spark] class JobWaiter[T](
+ dagScheduler: DAGScheduler,
+ jobId: Int,
+ totalTasks: Int,
+ resultHandler: (Int, T) => Unit)
extends JobListener {
private var finishedTasks = 0
- private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
- private var jobResult: JobResult = null // If the job is finished, this will be its result
+ // Is the job as a whole finished (succeeded or failed)?
+ private var _jobFinished = totalTasks == 0
- override def taskSucceeded(index: Int, result: Any) {
- synchronized {
- if (jobFinished) {
- throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
- }
- resultHandler(index, result.asInstanceOf[T])
- finishedTasks += 1
- if (finishedTasks == totalTasks) {
- jobFinished = true
- jobResult = JobSucceeded
- this.notifyAll()
- }
- }
+ def jobFinished = _jobFinished
+
+ // If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero
+ // partition RDDs), we set the jobResult directly to JobSucceeded.
+ private var jobResult: JobResult = if (jobFinished) JobSucceeded else null
+
+ /**
+ * Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled
+ * asynchronously. After the low level scheduler cancels all the tasks belonging to this job, it
+ * will fail this job with a SparkException.
+ */
+ def cancel() {
+ dagScheduler.cancelJob(jobId)
}
- override def jobFailed(exception: Exception) {
- synchronized {
- if (jobFinished) {
- throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter")
- }
- jobFinished = true
- jobResult = JobFailed(exception, None)
+ override def taskSucceeded(index: Int, result: Any): Unit = synchronized {
+ if (_jobFinished) {
+ throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
+ }
+ resultHandler(index, result.asInstanceOf[T])
+ finishedTasks += 1
+ if (finishedTasks == totalTasks) {
+ _jobFinished = true
+ jobResult = JobSucceeded
this.notifyAll()
}
}
+ override def jobFailed(exception: Exception): Unit = synchronized {
+ _jobFinished = true
+ jobResult = JobFailed(exception, None)
+ this.notifyAll()
+ }
+
def awaitResult(): JobResult = synchronized {
- while (!jobFinished) {
+ while (!_jobFinished) {
this.wait()
}
return jobResult
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
index 9eb8d48501..596f9adde9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala
@@ -43,7 +43,10 @@ private[spark] class Pool(
var runningTasks = 0
var priority = 0
- var stageId = 0
+
+ // A pool's stage id is used to break the tie in scheduling.
+ var stageId = -1
+
var name = poolName
var parent: Pool = null
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index 6dd422bbf6..310ec62ca8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -38,17 +38,17 @@ private[spark] object ResultTask {
synchronized {
val old = serializedInfoCache.get(stageId).orNull
if (old != null) {
- return old
+ old
} else {
val out = new ByteArrayOutputStream
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objOut = ser.serializeStream(new GZIPOutputStream(out))
objOut.writeObject(rdd)
objOut.writeObject(func)
objOut.close()
val bytes = out.toByteArray
serializedInfoCache.put(stageId, bytes)
- return bytes
+ bytes
}
}
}
@@ -56,11 +56,11 @@ private[spark] object ResultTask {
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
- val ser = SparkEnv.get.closureSerializer.newInstance
+ val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
- return (rdd, func)
+ (rdd, func)
}
def clearCache() {
@@ -71,29 +71,37 @@ private[spark] object ResultTask {
}
+/**
+ * A task that sends back the output to the driver application.
+ *
+ * See [[org.apache.spark.scheduler.Task]] for more information.
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param rdd input to func
+ * @param func a function to apply on a partition of the RDD
+ * @param _partitionId index of the number in the RDD
+ * @param locs preferred task execution locations for locality scheduling
+ * @param outputId index of the task in this job (a job can launch tasks on only a subset of the
+ * input RDD's partitions).
+ */
private[spark] class ResultTask[T, U](
stageId: Int,
var rdd: RDD[T],
var func: (TaskContext, Iterator[T]) => U,
- var partition: Int,
+ _partitionId: Int,
@transient locs: Seq[TaskLocation],
var outputId: Int)
- extends Task[U](stageId) with Externalizable {
+ extends Task[U](stageId, _partitionId) with Externalizable {
def this() = this(0, null, null, 0, null, 0)
- var split = if (rdd == null) {
- null
- } else {
- rdd.partitions(partition)
- }
+ var split = if (rdd == null) null else rdd.partitions(partitionId)
@transient private val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
}
- override def run(attemptId: Long): U = {
- val context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
+ override def runTask(context: TaskContext): U = {
metrics = Some(context.taskMetrics)
try {
func(context, rdd.iterator(split, context))
@@ -104,17 +112,17 @@ private[spark] class ResultTask[T, U](
override def preferredLocations: Seq[TaskLocation] = preferredLocs
- override def toString = "ResultTask(" + stageId + ", " + partition + ")"
+ override def toString = "ResultTask(" + stageId + ", " + partitionId + ")"
override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
- split = rdd.partitions(partition)
+ split = rdd.partitions(partitionId)
out.writeInt(stageId)
val bytes = ResultTask.serializeInfo(
stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
out.writeInt(bytes.length)
out.write(bytes)
- out.writeInt(partition)
+ out.writeInt(partitionId)
out.writeInt(outputId)
out.writeLong(epoch)
out.writeObject(split)
@@ -129,7 +137,7 @@ private[spark] class ResultTask[T, U](
val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
rdd = rdd_.asInstanceOf[RDD[T]]
func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
- partition = in.readInt()
+ partitionId = in.readInt()
outputId = in.readInt()
epoch = in.readLong()
split = in.readObject().asInstanceOf[Partition]
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
index 4e25086ec9..356fe56bf3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala
@@ -30,7 +30,10 @@ import scala.xml.XML
* addTaskSetManager: build the leaf nodes(TaskSetManagers)
*/
private[spark] trait SchedulableBuilder {
+ def rootPool: Pool
+
def buildPools()
+
def addTaskSetManager(manager: Schedulable, properties: Properties)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 3b9d5679fb..24d97da6eb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -53,7 +53,7 @@ private[spark] object ShuffleMapTask {
objOut.close()
val bytes = out.toByteArray
serializedInfoCache.put(stageId, bytes)
- return bytes
+ bytes
}
}
}
@@ -66,7 +66,7 @@ private[spark] object ShuffleMapTask {
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
- return (rdd, dep)
+ (rdd, dep)
}
}
@@ -75,7 +75,7 @@ private[spark] object ShuffleMapTask {
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val objIn = new ObjectInputStream(in)
val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
- return (HashMap(set.toSeq: _*))
+ HashMap(set.toSeq: _*)
}
def clearCache() {
@@ -85,13 +85,25 @@ private[spark] object ShuffleMapTask {
}
}
+/**
+ * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
+ * specified in the ShuffleDependency).
+ *
+ * See [[org.apache.spark.scheduler.Task]] for more information.
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param rdd the final RDD in this stage
+ * @param dep the ShuffleDependency
+ * @param _partitionId index of the number in the RDD
+ * @param locs preferred task execution locations for locality scheduling
+ */
private[spark] class ShuffleMapTask(
stageId: Int,
var rdd: RDD[_],
var dep: ShuffleDependency[_,_],
- var partition: Int,
+ _partitionId: Int,
@transient private var locs: Seq[TaskLocation])
- extends Task[MapStatus](stageId)
+ extends Task[MapStatus](stageId, _partitionId)
with Externalizable
with Logging {
@@ -101,16 +113,16 @@ private[spark] class ShuffleMapTask(
if (locs == null) Nil else locs.toSet.toSeq
}
- var split = if (rdd == null) null else rdd.partitions(partition)
+ var split = if (rdd == null) null else rdd.partitions(partitionId)
override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
- split = rdd.partitions(partition)
+ split = rdd.partitions(partitionId)
out.writeInt(stageId)
val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
out.writeInt(bytes.length)
out.write(bytes)
- out.writeInt(partition)
+ out.writeInt(partitionId)
out.writeLong(epoch)
out.writeObject(split)
}
@@ -124,16 +136,14 @@ private[spark] class ShuffleMapTask(
val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
rdd = rdd_
dep = dep_
- partition = in.readInt()
+ partitionId = in.readInt()
epoch = in.readLong()
split = in.readObject().asInstanceOf[Partition]
}
- override def run(attemptId: Long): MapStatus = {
+ override def runTask(context: TaskContext): MapStatus = {
val numOutputSplits = dep.partitioner.numPartitions
-
- val taskContext = new TaskContext(stageId, partition, attemptId, runningLocally = false)
- metrics = Some(taskContext.taskMetrics)
+ metrics = Some(context.taskMetrics)
val blockManager = SparkEnv.get.blockManager
var shuffle: ShuffleBlocks = null
@@ -143,10 +153,10 @@ private[spark] class ShuffleMapTask(
// Obtain all the block writers for shuffle blocks.
val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
- buckets = shuffle.acquireWriters(partition)
+ buckets = shuffle.acquireWriters(partitionId)
// Write the map output to its associated buckets.
- for (elem <- rdd.iterator(split, taskContext)) {
+ for (elem <- rdd.iterator(split, context)) {
val pair = elem.asInstanceOf[Product2[Any, Any]]
val bucketId = dep.partitioner.getPartition(pair._1)
buckets.writers(bucketId).write(pair)
@@ -154,20 +164,22 @@ private[spark] class ShuffleMapTask(
// Commit the writes. Get the size of each bucket block (total block size).
var totalBytes = 0L
+ var totalTime = 0L
val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
writer.commit()
- writer.close()
- val size = writer.size()
+ val size = writer.fileSegment().length
totalBytes += size
+ totalTime += writer.timeWriting()
MapOutputTracker.compressSize(size)
}
// Update shuffle metrics.
val shuffleMetrics = new ShuffleWriteMetrics
shuffleMetrics.shuffleBytesWritten = totalBytes
+ shuffleMetrics.shuffleWriteTime = totalTime
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
- return new MapStatus(blockManager.blockManagerId, compressedSizes)
+ new MapStatus(blockManager.blockManagerId, compressedSizes)
} catch { case e: Exception =>
// If there is an exception from running the task, revert the partial writes
// and throw the exception upstream to Spark.
@@ -178,14 +190,15 @@ private[spark] class ShuffleMapTask(
} finally {
// Release the writers back to the shuffle block manager.
if (shuffle != null && buckets != null) {
+ buckets.writers.foreach(_.close())
shuffle.releaseWriters(buckets)
}
// Execute the callbacks on task completion.
- taskContext.executeOnCompleteCallbacks()
+ context.executeOnCompleteCallbacks()
}
}
override def preferredLocations: Seq[TaskLocation] = preferredLocs
- override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
+ override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partitionId)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 62b521ad45..466baf9913 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -54,7 +54,7 @@ trait SparkListener {
/**
* Called when a task starts
*/
- def onTaskStart(taskEnd: SparkListenerTaskStart) { }
+ def onTaskStart(taskStart: SparkListenerTaskStart) { }
/**
* Called when a task ends
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 598d91752a..69b42e86ea 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -17,25 +17,74 @@
package org.apache.spark.scheduler
-import org.apache.spark.serializer.SerializerInstance
import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
-import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
-import org.apache.spark.util.ByteBufferInputStream
+
import scala.collection.mutable.HashMap
+
+import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+
+import org.apache.spark.TaskContext
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.serializer.SerializerInstance
+import org.apache.spark.util.ByteBufferInputStream
+
/**
- * A task to execute on a worker node.
+ * A unit of execution. We have two kinds of Task's in Spark:
+ * - [[org.apache.spark.scheduler.ShuffleMapTask]]
+ * - [[org.apache.spark.scheduler.ResultTask]]
+ *
+ * A Spark job consists of one or more stages. The very last stage in a job consists of multiple
+ * ResultTask's, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task
+ * and sends the task output back to the driver application. A ShuffleMapTask executes the task
+ * and divides the task output to multiple buckets (based on the task's partitioner).
+ *
+ * @param stageId id of the stage this task belongs to
+ * @param partitionId index of the number in the RDD
*/
-private[spark] abstract class Task[T](val stageId: Int) extends Serializable {
- def run(attemptId: Long): T
+private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
+
+ final def run(attemptId: Long): T = {
+ context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
+ if (_killed) {
+ kill()
+ }
+ runTask(context)
+ }
+
+ def runTask(context: TaskContext): T
+
def preferredLocations: Seq[TaskLocation] = Nil
- var epoch: Long = -1 // Map output tracker epoch. Will be set by TaskScheduler.
+ // Map output tracker epoch. Will be set by TaskScheduler.
+ var epoch: Long = -1
var metrics: Option[TaskMetrics] = None
+ // Task context, to be initialized in run().
+ @transient protected var context: TaskContext = _
+
+ // A flag to indicate whether the task is killed. This is used in case context is not yet
+ // initialized when kill() is invoked.
+ @volatile @transient private var _killed = false
+
+ /**
+ * Whether the task has been killed.
+ */
+ def killed: Boolean = _killed
+
+ /**
+ * Kills a task by setting the interrupted flag to true. This relies on the upper level Spark
+ * code and user code to properly handle the flag. This function should be idempotent so it can
+ * be called multiple times.
+ */
+ def kill() {
+ _killed = true
+ if (context != null) {
+ context.interrupted = true
+ }
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index db3954a9d3..7e468d0d67 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -24,13 +24,14 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.{SparkEnv}
import java.nio.ByteBuffer
import org.apache.spark.util.Utils
+import org.apache.spark.storage.BlockId
// Task result. Also contains updates to accumulator variables.
private[spark] sealed trait TaskResult[T]
/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
private[spark]
-case class IndirectTaskResult[T](val blockId: String) extends TaskResult[T] with Serializable
+case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable
/** A TaskResult that contains the task's return value and accumulator updates. */
private[spark]
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 7c2a9f03d7..10e0478108 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -24,8 +24,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
* Each TaskScheduler schedulers task for a single SparkContext.
* These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
* and are responsible for sending the tasks to the cluster, running them, retrying if there
- * are failures, and mitigating stragglers. They return events to the DAGScheduler through
- * the TaskSchedulerListener interface.
+ * are failures, and mitigating stragglers. They return events to the DAGScheduler.
*/
private[spark] trait TaskScheduler {
@@ -45,8 +44,11 @@ private[spark] trait TaskScheduler {
// Submit a sequence of tasks to run.
def submitTasks(taskSet: TaskSet): Unit
- // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
- def setListener(listener: TaskSchedulerListener): Unit
+ // Cancel a stage.
+ def cancelTasks(stageId: Int)
+
+ // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
+ def setDAGScheduler(dagScheduler: DAGScheduler): Unit
// Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
def defaultParallelism(): Int
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala
deleted file mode 100644
index 593fa9fb93..0000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler
-
-import scala.collection.mutable.Map
-
-import org.apache.spark.TaskEndReason
-import org.apache.spark.executor.TaskMetrics
-
-/**
- * Interface for getting events back from the TaskScheduler.
- */
-private[spark] trait TaskSchedulerListener {
- // A task has started.
- def taskStarted(task: Task[_], taskInfo: TaskInfo)
-
- // A task has finished or failed.
- def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any],
- taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit
-
- // A node was added to the cluster.
- def executorGained(execId: String, host: String): Unit
-
- // A node was lost from the cluster.
- def executorLost(execId: String): Unit
-
- // The TaskScheduler wants to abort an entire task set.
- def taskSetFailed(taskSet: TaskSet, reason: String): Unit
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
index c3ad325156..03bf760837 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
@@ -31,5 +31,9 @@ private[spark] class TaskSet(
val properties: Properties) {
val id: String = stageId + "." + attempt
+ def kill() {
+ tasks.foreach(_.kill())
+ }
+
override def toString: String = "TaskSet " + id
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
index 1a844b7e7e..4ea8bf8853 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
@@ -17,7 +17,6 @@
package org.apache.spark.scheduler.cluster
-import java.lang.{Boolean => JBoolean}
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicLong
import java.util.{TimerTask, Timer}
@@ -79,14 +78,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
private val executorIdToHost = new HashMap[String, String]
- // JAR server, if any JARs were added by the user to the SparkContext
- var jarServer: HttpServer = null
-
- // URIs of JARs to pass to executor
- var jarUris: String = ""
-
// Listener object to pass upcalls into
- var listener: TaskSchedulerListener = null
+ var dagScheduler: DAGScheduler = null
var backend: SchedulerBackend = null
@@ -101,8 +94,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// This is a var so that we can reset it for testing purposes.
private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this)
- override def setListener(listener: TaskSchedulerListener) {
- this.listener = listener
+ override def setDAGScheduler(dagScheduler: DAGScheduler) {
+ this.dagScheduler = dagScheduler
}
def initialize(context: SchedulerBackend) {
@@ -171,8 +164,31 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
backend.reviveOffers()
}
- def taskSetFinished(manager: TaskSetManager) {
- this.synchronized {
+ override def cancelTasks(stageId: Int): Unit = synchronized {
+ logInfo("Cancelling stage " + stageId)
+ activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
+ // There are two possible cases here:
+ // 1. The task set manager has been created and some tasks have been scheduled.
+ // In this case, send a kill signal to the executors to kill the task and then abort
+ // the stage.
+ // 2. The task set manager has been created but no tasks has been scheduled. In this case,
+ // simply abort the stage.
+ val taskIds = taskSetTaskIds(tsm.taskSet.id)
+ if (taskIds.size > 0) {
+ taskIds.foreach { tid =>
+ val execId = taskIdToExecutorId(tid)
+ backend.killTask(tid, execId)
+ }
+ }
+ tsm.error("Stage %d was cancelled".format(stageId))
+ }
+ }
+
+ def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
+ // Check to see if the given task set has been removed. This is possible in the case of
+ // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has
+ // more than one running tasks).
+ if (activeTaskSets.contains(manager.taskSet.id)) {
activeTaskSets -= manager.taskSet.id
manager.parent.removeSchedulable(manager)
logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
@@ -281,7 +297,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
// Update the DAGScheduler without holding a lock on this, since that can deadlock
if (failedExecutor != None) {
- listener.executorLost(failedExecutor.get)
+ dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
if (taskFailed) {
@@ -334,9 +350,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (backend != null) {
backend.stop()
}
- if (jarServer != null) {
- jarServer.stop()
- }
if (taskResultGetter != null) {
taskResultGetter.stop()
}
@@ -384,9 +397,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
logError("Lost an executor " + executorId + " (already removed): " + reason)
}
}
- // Call listener.executorLost without holding the lock on this to prevent deadlock
+ // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock
if (failedExecutor != None) {
- listener.executorLost(failedExecutor.get)
+ dagScheduler.executorLost(failedExecutor.get)
backend.reviveOffers()
}
}
@@ -405,7 +418,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
def executorGained(execId: String, host: String) {
- listener.executorGained(execId, host)
+ dagScheduler.executorGained(execId, host)
}
def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
index 936167c13f..29093e3b4f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -17,18 +17,16 @@
package org.apache.spark.scheduler.cluster
-import java.nio.ByteBuffer
-import java.util.{Arrays, NoSuchElementException}
+import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.math.max
import scala.math.min
-import scala.Some
import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
- SparkException, Success, TaskEndReason, TaskResultLost, TaskState}
+ Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler._
import org.apache.spark.util.{SystemClock, Clock}
@@ -417,11 +415,11 @@ private[spark] class ClusterTaskSetManager(
}
private def taskStarted(task: Task[_], info: TaskInfo) {
- sched.listener.taskStarted(task, info)
+ sched.dagScheduler.taskStarted(task, info)
}
/**
- * Marks the task as successful and notifies the listener that a task has ended.
+ * Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/
def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
val info = taskInfos(tid)
@@ -431,7 +429,7 @@ private[spark] class ClusterTaskSetManager(
if (!successful(index)) {
logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
tid, info.duration, info.host, tasksSuccessful, numTasks))
- sched.listener.taskEnded(
+ sched.dagScheduler.taskEnded(
tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
// Mark successful and stop if all the tasks have succeeded.
@@ -447,7 +445,8 @@ private[spark] class ClusterTaskSetManager(
}
/**
- * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the listener.
+ * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
+ * DAG Scheduler.
*/
def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
val info = taskInfos(tid)
@@ -458,54 +457,57 @@ private[spark] class ClusterTaskSetManager(
val index = info.index
info.markFailed()
if (!successful(index)) {
- logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
+ logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
copiesRunning(index) -= 1
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
reason.foreach {
- _ match {
- case fetchFailed: FetchFailed =>
- logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
- sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
- successful(index) = true
- tasksSuccessful += 1
- sched.taskSetFinished(this)
- removeAllRunningTasks()
- return
-
- case ef: ExceptionFailure =>
- sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
- val key = ef.description
- val now = clock.getTime()
- val (printFull, dupCount) = {
- if (recentExceptions.contains(key)) {
- val (dupCount, printTime) = recentExceptions(key)
- if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
- recentExceptions(key) = (0, now)
- (true, 0)
- } else {
- recentExceptions(key) = (dupCount + 1, printTime)
- (false, dupCount + 1)
- }
- } else {
+ case fetchFailed: FetchFailed =>
+ logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+ sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+ successful(index) = true
+ tasksSuccessful += 1
+ sched.taskSetFinished(this)
+ removeAllRunningTasks()
+ return
+
+ case TaskKilled =>
+ logWarning("Task %d was killed.".format(tid))
+ sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
+ return
+
+ case ef: ExceptionFailure =>
+ sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
+ val key = ef.description
+ val now = clock.getTime()
+ val (printFull, dupCount) = {
+ if (recentExceptions.contains(key)) {
+ val (dupCount, printTime) = recentExceptions(key)
+ if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
recentExceptions(key) = (0, now)
(true, 0)
+ } else {
+ recentExceptions(key) = (dupCount + 1, printTime)
+ (false, dupCount + 1)
}
- }
- if (printFull) {
- val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
- logInfo("Loss was due to %s\n%s\n%s".format(
- ef.className, ef.description, locs.mkString("\n")))
} else {
- logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+ recentExceptions(key) = (0, now)
+ (true, 0)
}
+ }
+ if (printFull) {
+ val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
+ logWarning("Loss was due to %s\n%s\n%s".format(
+ ef.className, ef.description, locs.mkString("\n")))
+ } else {
+ logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+ }
- case TaskResultLost =>
- logInfo("Lost result for TID %s on host %s".format(tid, info.host))
- sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
+ case TaskResultLost =>
+ logWarning("Lost result for TID %s on host %s".format(tid, info.host))
+ sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
- case _ => {}
- }
+ case _ => {}
}
// On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index)
@@ -532,7 +534,7 @@ private[spark] class ClusterTaskSetManager(
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
- sched.listener.taskSetFailed(taskSet, message)
+ sched.dagScheduler.taskSetFailed(taskSet, message)
removeAllRunningTasks()
sched.taskSetFinished(this)
}
@@ -605,7 +607,7 @@ private[spark] class ClusterTaskSetManager(
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
- sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
+ sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index c0b836bf1a..a8230ec6bc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -24,26 +24,28 @@ import org.apache.spark.scheduler.TaskDescription
import org.apache.spark.util.{Utils, SerializableBuffer}
-private[spark] sealed trait StandaloneClusterMessage extends Serializable
+private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable
-private[spark] object StandaloneClusterMessages {
+private[spark] object CoarseGrainedClusterMessages {
// Driver to executors
- case class LaunchTask(task: TaskDescription) extends StandaloneClusterMessage
+ case class LaunchTask(task: TaskDescription) extends CoarseGrainedClusterMessage
+
+ case class KillTask(taskId: Long, executor: String) extends CoarseGrainedClusterMessage
case class RegisteredExecutor(sparkProperties: Seq[(String, String)])
- extends StandaloneClusterMessage
+ extends CoarseGrainedClusterMessage
- case class RegisterExecutorFailed(message: String) extends StandaloneClusterMessage
+ case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage
// Executors to driver
case class RegisterExecutor(executorId: String, hostPort: String, cores: Int)
- extends StandaloneClusterMessage {
+ extends CoarseGrainedClusterMessage {
Utils.checkHostPort(hostPort, "Expected host port")
}
case class StatusUpdate(executorId: String, taskId: Long, state: TaskState,
- data: SerializableBuffer) extends StandaloneClusterMessage
+ data: SerializableBuffer) extends CoarseGrainedClusterMessage
object StatusUpdate {
/** Alternate factory method that takes a ByteBuffer directly for the data field */
@@ -54,10 +56,10 @@ private[spark] object StandaloneClusterMessages {
}
// Internal messages in driver
- case object ReviveOffers extends StandaloneClusterMessage
+ case object ReviveOffers extends CoarseGrainedClusterMessage
- case object StopDriver extends StandaloneClusterMessage
+ case object StopDriver extends CoarseGrainedClusterMessage
- case class RemoveExecutor(executorId: String, reason: String) extends StandaloneClusterMessage
+ case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index f3aeea43d5..c0f1c6dbad 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -30,16 +30,19 @@ import akka.util.duration._
import org.apache.spark.{SparkException, Logging, TaskState}
import org.apache.spark.scheduler.TaskDescription
-import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.Utils
/**
- * A standalone scheduler backend, which waits for standalone executors to connect to it through
- * Akka. These may be executed in a variety of ways, such as Mesos tasks for the coarse-grained
- * Mesos mode or standalone processes for Spark's standalone deploy mode (spark.deploy.*).
+ * A scheduler backend that waits for coarse grained executors to connect to it through Akka.
+ * This backend holds onto each executor for the duration of the Spark job rather than relinquishing
+ * executors whenever a task is done and asking the scheduler to launch a new executor for
+ * each new task. Executors may be launched in a variety of ways, such as Mesos tasks for the
+ * coarse-grained Mesos mode or standalone processes for Spark's standalone deploy mode
+ * (spark.deploy.*).
*/
private[spark]
-class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem)
+class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem)
extends SchedulerBackend with Logging
{
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
@@ -91,6 +94,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
case ReviveOffers =>
makeOffers()
+ case KillTask(taskId, executorId) =>
+ executorActor(executorId) ! KillTask(taskId, executorId)
+
case StopDriver =>
sender ! true
context.stop(self)
@@ -159,7 +165,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
}
driverActor = actorSystem.actorOf(
- Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
+ Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME)
}
private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
@@ -180,6 +186,10 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
driverActor ! ReviveOffers
}
+ override def killTask(taskId: Long, executorId: String) {
+ driverActor ! KillTask(taskId, executorId)
+ }
+
override def defaultParallelism() = Option(System.getProperty("spark.default.parallelism"))
.map(_.toInt).getOrElse(math.max(totalCoreCount.get(), 2))
@@ -195,6 +205,6 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
}
-private[spark] object StandaloneSchedulerBackend {
- val ACTOR_NAME = "StandaloneScheduler"
+private[spark] object CoarseGrainedSchedulerBackend {
+ val ACTOR_NAME = "CoarseGrainedScheduler"
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
index d57eb3276f..5367218faa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
@@ -17,7 +17,7 @@
package org.apache.spark.scheduler.cluster
-import org.apache.spark.{SparkContext}
+import org.apache.spark.SparkContext
/**
* A backend interface for cluster scheduling systems that allows plugging in different ones under
@@ -30,8 +30,8 @@ private[spark] trait SchedulerBackend {
def reviveOffers(): Unit
def defaultParallelism(): Int
+ def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException
+
// Memory used by each executor (in megabytes)
protected val executorMemory: Int = SparkContext.executorMemoryRequested
-
- // TODO: Probably want to add a killTask too
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index cb88159b8d..cefa970bb9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -28,7 +28,7 @@ private[spark] class SparkDeploySchedulerBackend(
sc: SparkContext,
masters: Array[String],
appName: String)
- extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
with ClientListener
with Logging {
@@ -44,10 +44,10 @@ private[spark] class SparkDeploySchedulerBackend(
// The endpoint for executors to talk to us
val driverUrl = "akka://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
- StandaloneSchedulerBackend.ACTOR_NAME)
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}")
val command = Command(
- "org.apache.spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
+ "org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs)
val sparkHome = sc.getSparkHome().getOrElse(null)
val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome,
"http://" + sc.ui.appUIAddress)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
index feec8ecfe4..4312c46cc1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
@@ -24,33 +24,16 @@ import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
import org.apache.spark.serializer.SerializerInstance
+import org.apache.spark.util.Utils
/**
* Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
*/
private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
extends Logging {
- private val MIN_THREADS = System.getProperty("spark.resultGetter.minThreads", "4").toInt
- private val MAX_THREADS = System.getProperty("spark.resultGetter.maxThreads", "4").toInt
- private val getTaskResultExecutor = new ThreadPoolExecutor(
- MIN_THREADS,
- MAX_THREADS,
- 0L,
- TimeUnit.SECONDS,
- new LinkedBlockingDeque[Runnable],
- new ResultResolverThreadFactory)
-
- class ResultResolverThreadFactory extends ThreadFactory {
- private var counter = 0
- private var PREFIX = "Result resolver thread"
-
- override def newThread(r: Runnable): Thread = {
- val thread = new Thread(r, "%s-%s".format(PREFIX, counter))
- counter += 1
- thread.setDaemon(true)
- return thread
- }
- }
+ private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt
+ private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool(
+ THREADS, "Result resolver thread")
protected val serializer = new ThreadLocal[SerializerInstance] {
override def initialValue(): SerializerInstance = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 8f2eef9a53..300fe693f1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -30,13 +30,14 @@ import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
import org.apache.spark.{SparkException, Logging, SparkContext, TaskState}
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend}
+import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend}
/**
* A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
* onto each Mesos node for the duration of the Spark job instead of relinquishing cores whenever
* a task is done. It launches Spark tasks within the coarse-grained Mesos tasks using the
- * StandaloneBackend mechanism. This class is useful for lower and more predictable latency.
+ * CoarseGrainedSchedulerBackend mechanism. This class is useful for lower and more predictable
+ * latency.
*
* Unfortunately this has a bit of duplication from MesosSchedulerBackend, but it seems hard to
* remove this.
@@ -46,7 +47,7 @@ private[spark] class CoarseMesosSchedulerBackend(
sc: SparkContext,
master: String,
appName: String)
- extends StandaloneSchedulerBackend(scheduler, sc.env.actorSystem)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
with MScheduler
with Logging {
@@ -122,20 +123,20 @@ private[spark] class CoarseMesosSchedulerBackend(
val driverUrl = "akka://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"),
System.getProperty("spark.driver.port"),
- StandaloneSchedulerBackend.ACTOR_NAME)
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
val uri = System.getProperty("spark.executor.uri")
if (uri == null) {
val runScript = new File(sparkHome, "spark-class").getCanonicalPath
command.setValue(
- "\"%s\" org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format(
+ "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format(
runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
command.setValue(
- "cd %s*; ./spark-class org.apache.spark.executor.StandaloneExecutorBackend %s %s %s %d".format(
- basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
+ "cd %s*; ./spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d"
+ .format(basename, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
return command.build()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
index 4d1bb1c639..2699f0b33e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
@@ -17,23 +17,19 @@
package org.apache.spark.scheduler.local
-import java.io.File
-import java.lang.management.ManagementFactory
-import java.util.concurrent.atomic.AtomicInteger
import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicInteger
-import scala.collection.JavaConversions._
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+
+import akka.actor._
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.executor.ExecutorURLClassLoader
+import org.apache.spark.executor.{Executor, ExecutorBackend}
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
-import akka.actor._
-import org.apache.spark.util.Utils
+
/**
* A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
@@ -41,52 +37,57 @@ import org.apache.spark.util.Utils
* testing fault recovery.
*/
-private[spark]
+private[local]
case class LocalReviveOffers()
-private[spark]
+private[local]
case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
+private[local]
+case class KillTask(taskId: Long)
+
private[spark]
-class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging {
+class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int)
+ extends Actor with Logging {
+
+ val executor = new Executor("localhost", "localhost", Seq.empty, isLocal = true)
def receive = {
case LocalReviveOffers =>
launchTask(localScheduler.resourceOffer(freeCores))
+
case LocalStatusUpdate(taskId, state, serializeData) =>
- freeCores += 1
- localScheduler.statusUpdate(taskId, state, serializeData)
- launchTask(localScheduler.resourceOffer(freeCores))
+ if (TaskState.isFinished(state)) {
+ freeCores += 1
+ launchTask(localScheduler.resourceOffer(freeCores))
+ }
+
+ case KillTask(taskId) =>
+ executor.killTask(taskId)
}
- def launchTask(tasks : Seq[TaskDescription]) {
+ private def launchTask(tasks: Seq[TaskDescription]) {
for (task <- tasks) {
freeCores -= 1
- localScheduler.threadPool.submit(new Runnable {
- def run() {
- localScheduler.runTask(task.taskId, task.serializedTask)
- }
- })
+ executor.launchTask(localScheduler, task.taskId, task.serializedTask)
}
}
}
private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
extends TaskScheduler
+ with ExecutorBackend
with Logging {
- var attemptId = new AtomicInteger(0)
- var threadPool = Utils.newDaemonFixedThreadPool(threads)
val env = SparkEnv.get
- var listener: TaskSchedulerListener = null
+ val attemptId = new AtomicInteger
+ var dagScheduler: DAGScheduler = null
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
- val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader)
-
var schedulableBuilder: SchedulableBuilder = null
var rootPool: Pool = null
val schedulingMode: SchedulingMode = SchedulingMode.withName(
@@ -113,8 +114,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
}
- override def setListener(listener: TaskSchedulerListener) {
- this.listener = listener
+ override def setDAGScheduler(dagScheduler: DAGScheduler) {
+ this.dagScheduler = dagScheduler
}
override def submitTasks(taskSet: TaskSet) {
@@ -127,6 +128,26 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
}
}
+ override def cancelTasks(stageId: Int): Unit = synchronized {
+ logInfo("Cancelling stage " + stageId)
+ logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId))
+ activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
+ // There are two possible cases here:
+ // 1. The task set manager has been created and some tasks have been scheduled.
+ // In this case, send a kill signal to the executors to kill the task and then abort
+ // the stage.
+ // 2. The task set manager has been created but no tasks has been scheduled. In this case,
+ // simply abort the stage.
+ val taskIds = taskSetTaskIds(tsm.taskSet.id)
+ if (taskIds.size > 0) {
+ taskIds.foreach { tid =>
+ localActor ! KillTask(tid)
+ }
+ }
+ tsm.error("Stage %d was cancelled".format(stageId))
+ }
+ }
+
def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
synchronized {
var freeCpuCores = freeCores
@@ -166,107 +187,32 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
}
}
- def runTask(taskId: Long, bytes: ByteBuffer) {
- logInfo("Running " + taskId)
- val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
- // Set the Spark execution environment for the worker thread
- SparkEnv.set(env)
- val ser = SparkEnv.get.closureSerializer.newInstance()
- val objectSer = SparkEnv.get.serializer.newInstance()
- var attemptedTask: Option[Task[_]] = None
- val start = System.currentTimeMillis()
- var taskStart: Long = 0
- def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum
- val startGCTime = getTotalGCTime
-
- try {
- Accumulators.clear()
- Thread.currentThread().setContextClassLoader(classLoader)
-
- // Serialize and deserialize the task so that accumulators are changed to thread-local ones;
- // this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
- val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
- updateDependencies(taskFiles, taskJars) // Download any files added with addFile
- val deserializedTask = ser.deserialize[Task[_]](
- taskBytes, Thread.currentThread.getContextClassLoader)
- attemptedTask = Some(deserializedTask)
- val deserTime = System.currentTimeMillis() - start
- taskStart = System.currentTimeMillis()
-
- // Run it
- val result: Any = deserializedTask.run(taskId)
-
- // Serialize and deserialize the result to emulate what the Mesos
- // executor does. This is useful to catch serialization errors early
- // on in development (so when users move their local Spark programs
- // to the cluster, they don't get surprised by serialization errors).
- val serResult = objectSer.serialize(result)
- deserializedTask.metrics.get.resultSize = serResult.limit()
- val resultToReturn = objectSer.deserialize[Any](serResult)
- val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
- ser.serialize(Accumulators.values))
- val serviceTime = System.currentTimeMillis() - taskStart
- logInfo("Finished " + taskId)
- deserializedTask.metrics.get.executorRunTime = serviceTime.toInt
- deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime
- deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
- val taskResult = new DirectTaskResult(
- result, accumUpdates, deserializedTask.metrics.getOrElse(null))
- val serializedResult = ser.serialize(taskResult)
- localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
- } catch {
- case t: Throwable => {
- val serviceTime = System.currentTimeMillis() - taskStart
- val metrics = attemptedTask.flatMap(t => t.metrics)
- for (m <- metrics) {
- m.executorRunTime = serviceTime.toInt
- m.jvmGCTime = getTotalGCTime - startGCTime
- }
- val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics)
- localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure))
- }
- }
- }
-
- /**
- * Download any missing dependencies if we receive a new set of files and JARs from the
- * SparkContext. Also adds any new JARs we fetched to the class loader.
- */
- private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
- synchronized {
- // Fetch missing dependencies
- for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
- currentFiles(name) = timestamp
- }
-
- for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
- currentJars(name) = timestamp
- // Add it to our class loader
- val localName = name.split("/").last
- val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
- if (!classLoader.getURLs.contains(url)) {
- logInfo("Adding " + url + " to class loader")
- classLoader.addURL(url)
+ override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
+ if (TaskState.isFinished(state)) {
+ synchronized {
+ taskIdToTaskSetId.get(taskId) match {
+ case Some(taskSetId) =>
+ val taskSetManager = activeTaskSets(taskSetId)
+ taskSetTaskIds(taskSetId) -= taskId
+
+ state match {
+ case TaskState.FINISHED =>
+ taskSetManager.taskEnded(taskId, state, serializedData)
+ case TaskState.FAILED =>
+ taskSetManager.taskFailed(taskId, state, serializedData)
+ case TaskState.KILLED =>
+ taskSetManager.error("Task %d was killed".format(taskId))
+ case _ => {}
+ }
+ case None =>
+ logInfo("Ignoring update from TID " + taskId + " because its task set is gone")
}
}
+ localActor ! LocalStatusUpdate(taskId, state, serializedData)
}
}
- def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) {
- synchronized {
- val taskSetId = taskIdToTaskSetId(taskId)
- val taskSetManager = activeTaskSets(taskSetId)
- taskSetTaskIds(taskSetId) -= taskId
- taskSetManager.statusUpdate(taskId, state, serializedData)
- }
- }
-
- override def stop() {
- threadPool.shutdownNow()
+ override def stop() {
}
override def defaultParallelism() = threads
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
index c2e2399ccb..55f8313e87 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
@@ -132,19 +132,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
return None
}
- def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- SparkEnv.set(env)
- state match {
- case TaskState.FINISHED =>
- taskEnded(tid, state, serializedData)
- case TaskState.FAILED =>
- taskFailed(tid, state, serializedData)
- case _ => {}
- }
- }
-
def taskStarted(task: Task[_], info: TaskInfo) {
- sched.listener.taskStarted(task, info)
+ sched.dagScheduler.taskStarted(task, info)
}
def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
@@ -159,7 +148,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
}
}
result.metrics.resultSize = serializedData.limit()
- sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
+ sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info,
+ result.metrics)
numFinished += 1
decreaseRunningTasks(1)
finished(index) = true
@@ -176,7 +166,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
decreaseRunningTasks(1)
val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](
serializedData, getClass.getClassLoader)
- sched.listener.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
+ sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
if (!finished(index)) {
copiesRunning(index) -= 1
numFailures(index) += 1
@@ -187,7 +177,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(
taskSet.id, index, 4, reason.description)
decreaseRunningTasks(runningTasks)
- sched.listener.taskSetFailed(taskSet, errorMessage)
+ sched.dagScheduler.taskSetFailed(taskSet, errorMessage)
// need to delete failed Taskset from schedule queue
sched.taskSetFinished(this)
}
@@ -195,5 +185,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
}
override def error(message: String) {
+ sched.dagScheduler.taskSetFailed(taskSet, message)
+ sched.taskSetFinished(this)
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index e936b1cfed..55b25f145a 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -26,9 +26,8 @@ import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput}
import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar}
import org.apache.spark.{SerializableWritable, Logging}
-import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock, StorageLevel}
-
import org.apache.spark.broadcast.HttpBroadcast
+import org.apache.spark.storage.{GetBlock,GotBlock, PutBlock, StorageLevel, TestBlockId}
/**
* A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]].
@@ -43,13 +42,14 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging
val kryo = instantiator.newKryo()
val classLoader = Thread.currentThread.getContextClassLoader
+ val blockId = TestBlockId("1")
// Register some commonly used classes
val toRegister: Seq[AnyRef] = Seq(
ByteBuffer.allocate(1),
StorageLevel.MEMORY_ONLY,
- PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
- GotBlock("1", ByteBuffer.allocate(1)),
- GetBlock("1"),
+ PutBlock(blockId, ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
+ GotBlock(blockId, ByteBuffer.allocate(1)),
+ GetBlock(blockId),
1 to 10,
1 until 10,
1L to 10L,
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockException.scala b/core/src/main/scala/org/apache/spark/storage/BlockException.scala
index 290dbce4f5..0d0a2dadc7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockException.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockException.scala
@@ -18,5 +18,5 @@
package org.apache.spark.storage
private[spark]
-case class BlockException(blockId: String, message: String) extends Exception(message)
+case class BlockException(blockId: BlockId, message: String) extends Exception(message)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
index 3aeda3879d..e51c5b30a3 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
@@ -47,7 +47,7 @@ import org.apache.spark.util.Utils
*/
private[storage]
-trait BlockFetcherIterator extends Iterator[(String, Option[Iterator[Any]])]
+trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])]
with Logging with BlockFetchTracker {
def initialize()
}
@@ -57,20 +57,20 @@ private[storage]
object BlockFetcherIterator {
// A request to fetch one or more blocks, complete with their sizes
- class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
+ class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) {
val size = blocks.map(_._2).sum
}
// A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
// the block (since we want all deserializaton to happen in the calling thread); can also
// represent a fetch failure if size == -1.
- class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
+ class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) {
def failed: Boolean = size == -1
}
class BasicBlockFetcherIterator(
private val blockManager: BlockManager,
- val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer)
extends BlockFetcherIterator {
@@ -92,12 +92,12 @@ object BlockFetcherIterator {
// This represents the number of local blocks, also counting zero-sized blocks
private var numLocal = 0
// BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
- protected val localBlocksToFetch = new ArrayBuffer[String]()
+ protected val localBlocksToFetch = new ArrayBuffer[BlockId]()
// This represents the number of remote blocks, also counting zero-sized blocks
private var numRemote = 0
// BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
- protected val remoteBlocksToFetch = new HashSet[String]()
+ protected val remoteBlocksToFetch = new HashSet[BlockId]()
// A queue to hold our results.
protected val results = new LinkedBlockingQueue[FetchResult]
@@ -167,7 +167,7 @@ object BlockFetcherIterator {
logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
val iterator = blockInfos.iterator
var curRequestSize = 0L
- var curBlocks = new ArrayBuffer[(String, Long)]
+ var curBlocks = new ArrayBuffer[(BlockId, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
// Skip empty blocks
@@ -183,7 +183,7 @@ object BlockFetcherIterator {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
curRequestSize = 0
- curBlocks = new ArrayBuffer[(String, Long)]
+ curBlocks = new ArrayBuffer[(BlockId, Long)]
}
}
// Add in the final request
@@ -241,7 +241,7 @@ object BlockFetcherIterator {
override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
- override def next(): (String, Option[Iterator[Any]]) = {
+ override def next(): (BlockId, Option[Iterator[Any]]) = {
resultsGotten += 1
val startFetchWait = System.currentTimeMillis()
val result = results.take()
@@ -267,7 +267,7 @@ object BlockFetcherIterator {
class NettyBlockFetcherIterator(
blockManager: BlockManager,
- blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])],
+ blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer)
extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) {
@@ -303,7 +303,7 @@ object BlockFetcherIterator {
override protected def sendRequest(req: FetchRequest) {
- def putResult(blockId: String, blockSize: Long, blockData: ByteBuf) {
+ def putResult(blockId: BlockId, blockSize: Long, blockData: ByteBuf) {
val fetchResult = new FetchResult(blockId, blockSize,
() => dataDeserialize(blockId, blockData.nioBuffer, serializer))
results.put(fetchResult)
@@ -337,7 +337,7 @@ object BlockFetcherIterator {
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
}
- override def next(): (String, Option[Iterator[Any]]) = {
+ override def next(): (BlockId, Option[Iterator[Any]]) = {
resultsGotten += 1
val result = results.take()
// If all the results has been retrieved, copiers will exit automatically
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
new file mode 100644
index 0000000000..7156d855d8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+/**
+ * Identifies a particular Block of data, usually associated with a single file.
+ * A Block can be uniquely identified by its filename, but each type of Block has a different
+ * set of keys which produce its unique name.
+ *
+ * If your BlockId should be serializable, be sure to add it to the BlockId.fromString() method.
+ */
+private[spark] sealed abstract class BlockId {
+ /** A globally unique identifier for this Block. Can be used for ser/de. */
+ def name: String
+
+ // convenience methods
+ def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
+ def isRDD = isInstanceOf[RDDBlockId]
+ def isShuffle = isInstanceOf[ShuffleBlockId]
+ def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId]
+
+ override def toString = name
+ override def hashCode = name.hashCode
+ override def equals(other: Any): Boolean = other match {
+ case o: BlockId => getClass == o.getClass && name.equals(o.name)
+ case _ => false
+ }
+}
+
+private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
+ def name = "rdd_" + rddId + "_" + splitIndex
+}
+
+private[spark]
+case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
+ def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
+}
+
+private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
+ def name = "broadcast_" + broadcastId
+}
+
+private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
+ def name = broadcastId.name + "_" + hType
+}
+
+private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
+ def name = "taskresult_" + taskId
+}
+
+private[spark] case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId {
+ def name = "input-" + streamId + "-" + uniqueId
+}
+
+// Intended only for testing purposes
+private[spark] case class TestBlockId(id: String) extends BlockId {
+ def name = "test_" + id
+}
+
+private[spark] object BlockId {
+ val RDD = "rdd_([0-9]+)_([0-9]+)".r
+ val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
+ val BROADCAST = "broadcast_([0-9]+)".r
+ val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
+ val TASKRESULT = "taskresult_([0-9]+)".r
+ val STREAM = "input-([0-9]+)-([0-9]+)".r
+ val TEST = "test_(.*)".r
+
+ /** Converts a BlockId "name" String back into a BlockId. */
+ def apply(id: String) = id match {
+ case RDD(rddId, splitIndex) =>
+ RDDBlockId(rddId.toInt, splitIndex.toInt)
+ case SHUFFLE(shuffleId, mapId, reduceId) =>
+ ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
+ case BROADCAST(broadcastId) =>
+ BroadcastBlockId(broadcastId.toLong)
+ case BROADCAST_HELPER(broadcastId, hType) =>
+ BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType)
+ case TASKRESULT(taskId) =>
+ TaskResultBlockId(taskId.toLong)
+ case STREAM(streamId, uniqueId) =>
+ StreamBlockId(streamId.toInt, uniqueId.toLong)
+ case TEST(value) =>
+ TestBlockId(value)
+ case _ =>
+ throw new IllegalStateException("Unrecognized BlockId: " + id)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 2322922f75..e6329cbd47 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -20,14 +20,15 @@ package org.apache.spark.storage
import java.io.{InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
-import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet}
+import scala.collection.mutable.{HashMap, ArrayBuffer}
+import scala.util.Random
import akka.actor.{ActorSystem, Cancellable, Props}
import akka.dispatch.{Await, Future}
import akka.util.Duration
import akka.util.duration._
-import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
+import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
import org.apache.spark.{Logging, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
@@ -37,7 +38,6 @@ import org.apache.spark.util._
import sun.nio.ch.DirectBuffer
-
private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
@@ -102,18 +102,19 @@ private[spark] class BlockManager(
}
val shuffleBlockManager = new ShuffleBlockManager(this)
+ val diskBlockManager = new DiskBlockManager(
+ System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
- private val blockInfo = new TimeStampedHashMap[String, BlockInfo]
+ private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
- private[storage] val diskStore: DiskStore =
- new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
+ private[storage] val diskStore = new DiskStore(this, diskBlockManager)
// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
private val nettyPort: Int = {
val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt
- if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0
+ if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
}
val connectionManager = new ConnectionManager(0)
@@ -249,7 +250,7 @@ private[spark] class BlockManager(
/**
* Get storage level of local block. If no info exists for the block, then returns null.
*/
- def getLevel(blockId: String): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
+ def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull
/**
* Tell the master about the current storage status of a block. This will send a block update
@@ -259,7 +260,7 @@ private[spark] class BlockManager(
* droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid).
* This ensures that update in master will compensate for the increase in memory on slave.
*/
- def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) {
+ def reportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L) {
val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize)
if (needReregister) {
logInfo("Got told to reregister updating block " + blockId)
@@ -270,11 +271,11 @@ private[spark] class BlockManager(
}
/**
- * Actually send a UpdateBlockInfo message. Returns the mater's response,
+ * Actually send a UpdateBlockInfo message. Returns the master's response,
* which will be true if the block was successfully recorded and false if
* the slave needs to re-register.
*/
- private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = {
+ private def tryToReportBlockStatus(blockId: BlockId, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = {
val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized {
info.level match {
case null =>
@@ -299,7 +300,7 @@ private[spark] class BlockManager(
/**
* Get locations of an array of blocks.
*/
- def getLocationBlockIds(blockIds: Array[String]): Array[Seq[BlockManagerId]] = {
+ def getLocationBlockIds(blockIds: Array[BlockId]): Array[Seq[BlockManagerId]] = {
val startTimeMs = System.currentTimeMillis
val locations = master.getLocations(blockIds).toArray
logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
@@ -311,7 +312,7 @@ private[spark] class BlockManager(
* shuffle blocks. It is safe to do so without a lock on block info since disk store
* never deletes (recent) items.
*/
- def getLocalFromDisk(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ def getLocalFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
diskStore.getValues(blockId, serializer).orElse(
sys.error("Block " + blockId + " not found on disk, though it should be"))
}
@@ -319,94 +320,19 @@ private[spark] class BlockManager(
/**
* Get block from local block manager.
*/
- def getLocal(blockId: String): Option[Iterator[Any]] = {
+ def getLocal(blockId: BlockId): Option[Iterator[Any]] = {
logDebug("Getting local block " + blockId)
- val info = blockInfo.get(blockId).orNull
- if (info != null) {
- info.synchronized {
-
- // In the another thread is writing the block, wait for it to become ready.
- if (!info.waitForReady()) {
- // If we get here, the block write failed.
- logWarning("Block " + blockId + " was marked as failure.")
- return None
- }
-
- val level = info.level
- logDebug("Level for block " + blockId + " is " + level)
-
- // Look for the block in memory
- if (level.useMemory) {
- logDebug("Getting block " + blockId + " from memory")
- memoryStore.getValues(blockId) match {
- case Some(iterator) =>
- return Some(iterator)
- case None =>
- logDebug("Block " + blockId + " not found in memory")
- }
- }
-
- // Look for block on disk, potentially loading it back into memory if required
- if (level.useDisk) {
- logDebug("Getting block " + blockId + " from disk")
- if (level.useMemory && level.deserialized) {
- diskStore.getValues(blockId) match {
- case Some(iterator) =>
- // Put the block back in memory before returning it
- // TODO: Consider creating a putValues that also takes in a iterator ?
- val elements = new ArrayBuffer[Any]
- elements ++= iterator
- memoryStore.putValues(blockId, elements, level, true).data match {
- case Left(iterator2) =>
- return Some(iterator2)
- case _ =>
- throw new Exception("Memory store did not return back an iterator")
- }
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- } else if (level.useMemory && !level.deserialized) {
- // Read it as a byte buffer into memory first, then return it
- diskStore.getBytes(blockId) match {
- case Some(bytes) =>
- // Put a copy of the block back in memory before returning it. Note that we can't
- // put the ByteBuffer returned by the disk store as that's a memory-mapped file.
- // The use of rewind assumes this.
- assert (0 == bytes.position())
- val copyForMemory = ByteBuffer.allocate(bytes.limit)
- copyForMemory.put(bytes)
- memoryStore.putBytes(blockId, copyForMemory, level)
- bytes.rewind()
- return Some(dataDeserialize(blockId, bytes))
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- } else {
- diskStore.getValues(blockId) match {
- case Some(iterator) =>
- return Some(iterator)
- case None =>
- throw new Exception("Block " + blockId + " not found on disk, though it should be")
- }
- }
- }
- }
- } else {
- logDebug("Block " + blockId + " not registered locally")
- }
- return None
+ doGetLocal(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]]
}
/**
* Get block from the local block manager as serialized bytes.
*/
- def getLocalBytes(blockId: String): Option[ByteBuffer] = {
- // TODO: This whole thing is very similar to getLocal; we need to refactor it somehow
+ def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = {
logDebug("Getting local block " + blockId + " as bytes")
-
// As an optimization for map output fetches, if the block is for a shuffle, return it
// without acquiring a lock; the disk store never deletes (recent) items so this should work
- if (ShuffleBlockManager.isShuffle(blockId)) {
+ if (blockId.isShuffle) {
return diskStore.getBytes(blockId) match {
case Some(bytes) =>
Some(bytes)
@@ -414,12 +340,15 @@ private[spark] class BlockManager(
throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
}
+ doGetLocal(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]]
+ }
+ private def doGetLocal(blockId: BlockId, asValues: Boolean): Option[Any] = {
val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
- // In the another thread is writing the block, wait for it to become ready.
+ // If another thread is writing the block, wait for it to become ready.
if (!info.waitForReady()) {
// If we get here, the block write failed.
logWarning("Block " + blockId + " was marked as failure.")
@@ -432,62 +361,104 @@ private[spark] class BlockManager(
// Look for the block in memory
if (level.useMemory) {
logDebug("Getting block " + blockId + " from memory")
- memoryStore.getBytes(blockId) match {
- case Some(bytes) =>
- return Some(bytes)
+ val result = if (asValues) {
+ memoryStore.getValues(blockId)
+ } else {
+ memoryStore.getBytes(blockId)
+ }
+ result match {
+ case Some(values) =>
+ return Some(values)
case None =>
logDebug("Block " + blockId + " not found in memory")
}
}
- // Look for block on disk
+ // Look for block on disk, potentially storing it back into memory if required:
if (level.useDisk) {
- // Read it as a byte buffer into memory first, then return it
- diskStore.getBytes(blockId) match {
- case Some(bytes) =>
- assert (0 == bytes.position())
- if (level.useMemory) {
- if (level.deserialized) {
- memoryStore.putBytes(blockId, bytes, level)
- } else {
- // The memory store will hang onto the ByteBuffer, so give it a copy instead of
- // the memory-mapped file buffer we got from the disk store
- val copyForMemory = ByteBuffer.allocate(bytes.limit)
- copyForMemory.put(bytes)
- memoryStore.putBytes(blockId, copyForMemory, level)
- }
- }
- bytes.rewind()
- return Some(bytes)
+ logDebug("Getting block " + blockId + " from disk")
+ val bytes: ByteBuffer = diskStore.getBytes(blockId) match {
+ case Some(bytes) => bytes
case None =>
throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
+ assert (0 == bytes.position())
+
+ if (!level.useMemory) {
+ // If the block shouldn't be stored in memory, we can just return it:
+ if (asValues) {
+ return Some(dataDeserialize(blockId, bytes))
+ } else {
+ return Some(bytes)
+ }
+ } else {
+ // Otherwise, we also have to store something in the memory store:
+ if (!level.deserialized || !asValues) {
+ // We'll store the bytes in memory if the block's storage level includes
+ // "memory serialized", or if it should be cached as objects in memory
+ // but we only requested its serialized bytes:
+ val copyForMemory = ByteBuffer.allocate(bytes.limit)
+ copyForMemory.put(bytes)
+ memoryStore.putBytes(blockId, copyForMemory, level)
+ bytes.rewind()
+ }
+ if (!asValues) {
+ return Some(bytes)
+ } else {
+ val values = dataDeserialize(blockId, bytes)
+ if (level.deserialized) {
+ // Cache the values before returning them:
+ // TODO: Consider creating a putValues that also takes in a iterator?
+ val valuesBuffer = new ArrayBuffer[Any]
+ valuesBuffer ++= values
+ memoryStore.putValues(blockId, valuesBuffer, level, true).data match {
+ case Left(values2) =>
+ return Some(values2)
+ case _ =>
+ throw new Exception("Memory store did not return back an iterator")
+ }
+ } else {
+ return Some(values)
+ }
+ }
+ }
}
}
} else {
logDebug("Block " + blockId + " not registered locally")
}
- return None
+ None
}
/**
* Get block from remote block managers.
*/
- def getRemote(blockId: String): Option[Iterator[Any]] = {
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
+ def getRemote(blockId: BlockId): Option[Iterator[Any]] = {
logDebug("Getting remote block " + blockId)
- // Get locations of block
- val locations = master.getLocations(blockId)
+ doGetRemote(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]]
+ }
- // Get block from remote locations
+ /**
+ * Get block from remote block managers as serialized bytes.
+ */
+ def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = {
+ logDebug("Getting remote block " + blockId + " as bytes")
+ doGetRemote(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]]
+ }
+
+ private def doGetRemote(blockId: BlockId, asValues: Boolean): Option[Any] = {
+ require(blockId != null, "BlockId is null")
+ val locations = Random.shuffle(master.getLocations(blockId))
for (loc <- locations) {
logDebug("Getting remote block " + blockId + " from " + loc)
val data = BlockManagerWorker.syncGetBlock(
GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
if (data != null) {
- return Some(dataDeserialize(blockId, data))
+ if (asValues) {
+ return Some(dataDeserialize(blockId, data))
+ } else {
+ return Some(data)
+ }
}
logDebug("The value of block " + blockId + " is null")
}
@@ -496,34 +467,9 @@ private[spark] class BlockManager(
}
/**
- * Get block from remote block managers as serialized bytes.
- */
- def getRemoteBytes(blockId: String): Option[ByteBuffer] = {
- // TODO: As with getLocalBytes, this is very similar to getRemote and perhaps should be
- // refactored.
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- logDebug("Getting remote block " + blockId + " as bytes")
-
- val locations = master.getLocations(blockId)
- for (loc <- locations) {
- logDebug("Getting remote block " + blockId + " from " + loc)
- val data = BlockManagerWorker.syncGetBlock(
- GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
- if (data != null) {
- return Some(data)
- }
- logDebug("The value of block " + blockId + " is null")
- }
- logDebug("Block " + blockId + " not found")
- return None
- }
-
- /**
* Get a block from the block manager (either local or remote).
*/
- def get(blockId: String): Option[Iterator[Any]] = {
+ def get(blockId: BlockId): Option[Iterator[Any]] = {
val local = getLocal(blockId)
if (local.isDefined) {
logInfo("Found block %s locally".format(blockId))
@@ -544,7 +490,7 @@ private[spark] class BlockManager(
* so that we can control the maxMegabytesInFlight for the fetch.
*/
def getMultiple(
- blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])], serializer: Serializer)
+ blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer)
: BlockFetcherIterator = {
val iter =
@@ -558,7 +504,7 @@ private[spark] class BlockManager(
iter
}
- def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
+ def put(blockId: BlockId, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
: Long = {
val elements = new ArrayBuffer[Any]
elements ++= values
@@ -567,16 +513,20 @@ private[spark] class BlockManager(
/**
* A short circuited method to get a block writer that can write data directly to disk.
+ * The Block will be appended to the File specified by filename.
* This is currently used for writing shuffle files out. Callers should handle error
* cases.
*/
- def getDiskBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
+ def getDiskWriter(blockId: BlockId, filename: String, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
- val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
+ val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
+ val file = diskBlockManager.createBlockFile(blockId, filename, allowAppending = true)
+ val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
writer.registerCloseEventHandler(() => {
+ diskBlockManager.mapBlockToFileSegment(blockId, writer.fileSegment())
val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
blockInfo.put(blockId, myInfo)
- myInfo.markReady(writer.size())
+ myInfo.markReady(writer.fileSegment().length)
})
writer
}
@@ -584,18 +534,25 @@ private[spark] class BlockManager(
/**
* Put a new block of values to the block manager. Returns its (estimated) size in bytes.
*/
- def put(blockId: String, values: ArrayBuffer[Any], level: StorageLevel,
- tellMaster: Boolean = true) : Long = {
+ def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
+ tellMaster: Boolean = true) : Long = {
+ require(values != null, "Values is null")
+ doPut(blockId, Left(values), level, tellMaster)
+ }
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- if (values == null) {
- throw new IllegalArgumentException("Values is null")
- }
- if (level == null || !level.isValid) {
- throw new IllegalArgumentException("Storage level is null or invalid")
- }
+ /**
+ * Put a new block of serialized bytes to the block manager.
+ */
+ def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel,
+ tellMaster: Boolean = true) {
+ require(bytes != null, "Bytes is null")
+ doPut(blockId, Right(bytes), level, tellMaster)
+ }
+
+ private def doPut(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer],
+ level: StorageLevel, tellMaster: Boolean = true): Long = {
+ require(blockId != null, "BlockId is null")
+ require(level != null && level.isValid, "StorageLevel is null or invalid")
// Remember the block's storage level so that we can correctly drop it to disk if it needs
// to be dropped right after it got put into memory. Note, however, that other threads will
@@ -611,7 +568,8 @@ private[spark] class BlockManager(
return oldBlockOpt.get.size
}
- // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
+ // TODO: So the block info exists - but previous attempt to load it (?) failed.
+ // What do we do now ? Retry on it ?
oldBlockOpt.get
} else {
tinfo
@@ -620,10 +578,10 @@ private[spark] class BlockManager(
val startTimeMs = System.currentTimeMillis
- // If we need to replicate the data, we'll want access to the values, but because our
- // put will read the whole iterator, there will be no values left. For the case where
- // the put serializes data, we'll remember the bytes, above; but for the case where it
- // doesn't, such as deserialized storage, let's rely on the put returning an Iterator.
+ // If we're storing values and we need to replicate the data, we'll want access to the values,
+ // but because our put will read the whole iterator, there will be no values left. For the
+ // case where the put serializes data, we'll remember the bytes, above; but for the case where
+ // it doesn't, such as deserialized storage, let's rely on the put returning an Iterator.
var valuesAfterPut: Iterator[Any] = null
// Ditto for the bytes after the put
@@ -632,30 +590,51 @@ private[spark] class BlockManager(
// Size of the block in bytes (to return to caller)
var size = 0L
+ // If we're storing bytes, then initiate the replication before storing them locally.
+ // This is faster as data is already serialized and ready to send.
+ val replicationFuture = if (data.isRight && level.replication > 1) {
+ val bufferView = data.right.get.duplicate() // Doesn't copy the bytes, just creates a wrapper
+ Future {
+ replicate(blockId, bufferView, level)
+ }
+ } else {
+ null
+ }
+
myInfo.synchronized {
logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
var marked = false
try {
- if (level.useMemory) {
- // Save it just to memory first, even if it also has useDisk set to true; we will later
- // drop it to disk if the memory store can't hold it.
- val res = memoryStore.putValues(blockId, values, level, true)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case Left(newIterator) => valuesAfterPut = newIterator
+ data match {
+ case Left(values) => {
+ if (level.useMemory) {
+ // Save it just to memory first, even if it also has useDisk set to true; we will
+ // drop it to disk later if the memory store can't hold it.
+ val res = memoryStore.putValues(blockId, values, level, true)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case Left(newIterator) => valuesAfterPut = newIterator
+ }
+ } else {
+ // Save directly to disk.
+ // Don't get back the bytes unless we replicate them.
+ val askForBytes = level.replication > 1
+ val res = diskStore.putValues(blockId, values, level, askForBytes)
+ size = res.size
+ res.data match {
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case _ =>
+ }
+ }
}
- } else {
- // Save directly to disk.
- // Don't get back the bytes unless we replicate them.
- val askForBytes = level.replication > 1
- val res = diskStore.putValues(blockId, values, level, askForBytes)
- size = res.size
- res.data match {
- case Right(newBytes) => bytesAfterPut = newBytes
- case _ =>
+ case Right(bytes) => {
+ bytes.rewind()
+ // Store it only in memory at first, even if useDisk is also set to true
+ (if (level.useMemory) memoryStore else diskStore).putBytes(blockId, bytes, level)
+ size = bytes.limit
}
}
@@ -680,132 +659,46 @@ private[spark] class BlockManager(
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
- // Replicate block if required
+ // Either we're storing bytes and we asynchronously started replication, or we're storing
+ // values and need to serialize and replicate them now:
if (level.replication > 1) {
- val remoteStartTime = System.currentTimeMillis
- // Serialize the block if not already done
- if (bytesAfterPut == null) {
- if (valuesAfterPut == null) {
- throw new SparkException(
- "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
- }
- bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
- }
- replicate(blockId, bytesAfterPut, level)
- logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime))
- }
- BlockManager.dispose(bytesAfterPut)
-
- return size
- }
-
-
- /**
- * Put a new block of serialized bytes to the block manager.
- */
- def putBytes(
- blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) {
-
- if (blockId == null) {
- throw new IllegalArgumentException("Block Id is null")
- }
- if (bytes == null) {
- throw new IllegalArgumentException("Bytes is null")
- }
- if (level == null || !level.isValid) {
- throw new IllegalArgumentException("Storage level is null or invalid")
- }
-
- // Remember the block's storage level so that we can correctly drop it to disk if it needs
- // to be dropped right after it got put into memory. Note, however, that other threads will
- // not be able to get() this block until we call markReady on its BlockInfo.
- val myInfo = {
- val tinfo = new BlockInfo(level, tellMaster)
- // Do atomically !
- val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
-
- if (oldBlockOpt.isDefined) {
- if (oldBlockOpt.get.waitForReady()) {
- logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
- return
- }
-
- // TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
- oldBlockOpt.get
- } else {
- tinfo
- }
- }
-
- val startTimeMs = System.currentTimeMillis
-
- // Initiate the replication before storing it locally. This is faster as
- // data is already serialized and ready for sending
- val replicationFuture = if (level.replication > 1) {
- val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper
- Future {
- replicate(blockId, bufferView, level)
- }
- } else {
- null
- }
-
- myInfo.synchronized {
- logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
- + " to get into synchronized block")
-
- var marked = false
- try {
- if (level.useMemory) {
- // Store it only in memory at first, even if useDisk is also set to true
- bytes.rewind()
- memoryStore.putBytes(blockId, bytes, level)
- } else {
- bytes.rewind()
- diskStore.putBytes(blockId, bytes, level)
- }
-
- // assert (0 == bytes.position(), "" + bytes)
-
- // Now that the block is in either the memory or disk store, let other threads read it,
- // and tell the master about it.
- marked = true
- myInfo.markReady(bytes.limit)
- if (tellMaster) {
- reportBlockStatus(blockId, myInfo)
- }
- } finally {
- // If we failed at putting the block to memory/disk, notify other possible readers
- // that it has failed, and then remove it from the block info map.
- if (! marked) {
- // Note that the remove must happen before markFailure otherwise another thread
- // could've inserted a new BlockInfo before we remove it.
- blockInfo.remove(blockId)
- myInfo.markFailure()
- logWarning("Putting block " + blockId + " failed")
+ data match {
+ case Right(bytes) => Await.ready(replicationFuture, Duration.Inf)
+ case Left(values) => {
+ val remoteStartTime = System.currentTimeMillis
+ // Serialize the block if not already done
+ if (bytesAfterPut == null) {
+ if (valuesAfterPut == null) {
+ throw new SparkException(
+ "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
+ }
+ bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
+ }
+ replicate(blockId, bytesAfterPut, level)
+ logDebug("Put block " + blockId + " remotely took " +
+ Utils.getUsedTimeMs(remoteStartTime))
}
}
}
- // If replication had started, then wait for it to finish
- if (level.replication > 1) {
- Await.ready(replicationFuture, Duration.Inf)
- }
+ BlockManager.dispose(bytesAfterPut)
if (level.replication > 1) {
- logDebug("PutBytes for block " + blockId + " with replication took " +
+ logDebug("Put for block " + blockId + " with replication took " +
Utils.getUsedTimeMs(startTimeMs))
} else {
- logDebug("PutBytes for block " + blockId + " without replication took " +
+ logDebug("Put for block " + blockId + " without replication took " +
Utils.getUsedTimeMs(startTimeMs))
}
+
+ size
}
/**
* Replicate block to another node.
*/
var cachedPeers: Seq[BlockManagerId] = null
- private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) {
+ private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel) {
val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1)
if (cachedPeers == null) {
cachedPeers = master.getPeers(blockManagerId, level.replication - 1)
@@ -828,14 +721,14 @@ private[spark] class BlockManager(
/**
* Read a block consisting of a single object.
*/
- def getSingle(blockId: String): Option[Any] = {
+ def getSingle(blockId: BlockId): Option[Any] = {
get(blockId).map(_.next())
}
/**
* Write a block consisting of a single object.
*/
- def putSingle(blockId: String, value: Any, level: StorageLevel, tellMaster: Boolean = true) {
+ def putSingle(blockId: BlockId, value: Any, level: StorageLevel, tellMaster: Boolean = true) {
put(blockId, Iterator(value), level, tellMaster)
}
@@ -843,7 +736,7 @@ private[spark] class BlockManager(
* Drop a block from memory, possibly putting it on disk if applicable. Called when the memory
* store reaches its limit and needs to free up space.
*/
- def dropFromMemory(blockId: String, data: Either[ArrayBuffer[Any], ByteBuffer]) {
+ def dropFromMemory(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer]) {
logInfo("Dropping block " + blockId + " from memory")
val info = blockInfo.get(blockId).orNull
if (info != null) {
@@ -892,16 +785,15 @@ private[spark] class BlockManager(
// TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps
// from RDD.id to blocks.
logInfo("Removing RDD " + rddId)
- val rddPrefix = "rdd_" + rddId + "_"
- val blocksToRemove = blockInfo.filter(_._1.startsWith(rddPrefix)).map(_._1)
- blocksToRemove.foreach(blockId => removeBlock(blockId, false))
+ val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
+ blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false))
blocksToRemove.size
}
/**
* Remove a block from both memory and disk.
*/
- def removeBlock(blockId: String, tellMaster: Boolean = true) {
+ def removeBlock(blockId: BlockId, tellMaster: Boolean = true) {
logInfo("Removing block " + blockId)
val info = blockInfo.get(blockId).orNull
if (info != null) info.synchronized {
@@ -924,34 +816,20 @@ private[spark] class BlockManager(
private def dropOldNonBroadcastBlocks(cleanupTime: Long) {
logInfo("Dropping non broadcast blocks older than " + cleanupTime)
- val iterator = blockInfo.internalMap.entrySet().iterator()
- while (iterator.hasNext) {
- val entry = iterator.next()
- val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
- if (time < cleanupTime && ! BlockManager.isBroadcastBlock(id) ) {
- info.synchronized {
- val level = info.level
- if (level.useMemory) {
- memoryStore.remove(id)
- }
- if (level.useDisk) {
- diskStore.remove(id)
- }
- iterator.remove()
- logInfo("Dropped block " + id)
- }
- reportBlockStatus(id, info)
- }
- }
+ dropOldBlocks(cleanupTime, !_.isBroadcast)
}
private def dropOldBroadcastBlocks(cleanupTime: Long) {
logInfo("Dropping broadcast blocks older than " + cleanupTime)
+ dropOldBlocks(cleanupTime, _.isBroadcast)
+ }
+
+ private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) {
val iterator = blockInfo.internalMap.entrySet().iterator()
while (iterator.hasNext) {
val entry = iterator.next()
val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
- if (time < cleanupTime && BlockManager.isBroadcastBlock(id) ) {
+ if (time < cleanupTime && shouldDrop(id)) {
info.synchronized {
val level = info.level
if (level.useMemory) {
@@ -968,39 +846,45 @@ private[spark] class BlockManager(
}
}
- def shouldCompress(blockId: String): Boolean = {
- if (ShuffleBlockManager.isShuffle(blockId)) {
- compressShuffle
- } else if (BlockManager.isBroadcastBlock(blockId)) {
- compressBroadcast
- } else if (blockId.startsWith("rdd_")) {
- compressRdds
- } else {
- false // Won't happen in a real cluster, but it can in tests
- }
+ def shouldCompress(blockId: BlockId): Boolean = blockId match {
+ case ShuffleBlockId(_, _, _) => compressShuffle
+ case BroadcastBlockId(_) => compressBroadcast
+ case RDDBlockId(_, _) => compressRdds
+ case _ => false
}
/**
* Wrap an output stream for compression if block compression is enabled for its block type
*/
- def wrapForCompression(blockId: String, s: OutputStream): OutputStream = {
+ def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s
}
/**
* Wrap an input stream for compression if block compression is enabled for its block type
*/
- def wrapForCompression(blockId: String, s: InputStream): InputStream = {
+ def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = {
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}
+ /** Serializes into a stream. */
+ def dataSerializeStream(
+ blockId: BlockId,
+ outputStream: OutputStream,
+ values: Iterator[Any],
+ serializer: Serializer = defaultSerializer) {
+ val byteStream = new FastBufferedOutputStream(outputStream)
+ val ser = serializer.newInstance()
+ ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ }
+
+ /** Serializes into a byte buffer. */
def dataSerialize(
- blockId: String,
+ blockId: BlockId,
values: Iterator[Any],
serializer: Serializer = defaultSerializer): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
- val ser = serializer.newInstance()
- ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
+ dataSerializeStream(blockId, byteStream, values, serializer)
byteStream.trim()
ByteBuffer.wrap(byteStream.array)
}
@@ -1010,7 +894,7 @@ private[spark] class BlockManager(
* the iterator is reached.
*/
def dataDeserialize(
- blockId: String,
+ blockId: BlockId,
bytes: ByteBuffer,
serializer: Serializer = defaultSerializer): Iterator[Any] = {
bytes.rewind()
@@ -1065,10 +949,10 @@ private[spark] object BlockManager extends Logging {
}
def blockIdsToBlockManagers(
- blockIds: Array[String],
+ blockIds: Array[BlockId],
env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null)
- : Map[String, Seq[BlockManagerId]] =
+ : Map[BlockId, Seq[BlockManagerId]] =
{
// env == null and blockManagerMaster != null is used in tests
assert (env != null || blockManagerMaster != null)
@@ -1078,7 +962,7 @@ private[spark] object BlockManager extends Logging {
blockManagerMaster.getLocations(blockIds)
}
- val blockManagers = new HashMap[String, Seq[BlockManagerId]]
+ val blockManagers = new HashMap[BlockId, Seq[BlockManagerId]]
for (i <- 0 until blockIds.length) {
blockManagers(blockIds(i)) = blockLocations(i)
}
@@ -1086,25 +970,21 @@ private[spark] object BlockManager extends Logging {
}
def blockIdsToExecutorIds(
- blockIds: Array[String],
+ blockIds: Array[BlockId],
env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null)
- : Map[String, Seq[String]] =
+ : Map[BlockId, Seq[String]] =
{
blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId))
}
def blockIdsToHosts(
- blockIds: Array[String],
+ blockIds: Array[BlockId],
env: SparkEnv,
blockManagerMaster: BlockManagerMaster = null)
- : Map[String, Seq[String]] =
+ : Map[BlockId, Seq[String]] =
{
blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host))
}
-
- def isBroadcastBlock(blockId: String): Boolean = null != blockId && blockId.startsWith("broadcast_")
-
- def toBroadcastId(id: Long): String = "broadcast_" + id
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index cf463d6ffc..94038649b3 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -60,7 +60,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
def updateBlockInfo(
blockManagerId: BlockManagerId,
- blockId: String,
+ blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long): Boolean = {
@@ -71,12 +71,12 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
}
/** Get locations of the blockId from the driver */
- def getLocations(blockId: String): Seq[BlockManagerId] = {
+ def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId))
}
/** Get locations of multiple blockIds from the driver */
- def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+ def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
}
@@ -94,7 +94,7 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
* Remove a block from the slaves that have it. This can only be used to remove
* blocks that the driver knows about.
*/
- def removeBlock(blockId: String) {
+ def removeBlock(blockId: BlockId) {
askDriverWithReply(RemoveBlock(blockId))
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index c7b23ab094..f8cf14b503 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -48,7 +48,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId]
// Mapping from block id to the set of block managers that have the block.
- private val blockLocations = new JHashMap[String, mutable.HashSet[BlockManagerId]]
+ private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]]
val akkaTimeout = Duration.create(
System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
@@ -129,10 +129,9 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
// First remove the metadata for the given RDD, and then asynchronously remove the blocks
// from the slaves.
- val prefix = "rdd_" + rddId + "_"
// Find all blocks for the given RDD, remove the block from both blockLocations and
// the blockManagerInfo that is tracking the blocks.
- val blocks = blockLocations.keySet().filter(_.startsWith(prefix))
+ val blocks = blockLocations.keys.flatMap(_.asRDDId).filter(_.rddId == rddId)
blocks.foreach { blockId =>
val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId)
bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId)))
@@ -198,7 +197,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
// Remove a block from the slaves that have it. This can only be used to remove
// blocks that the master knows about.
- private def removeBlockFromWorkers(blockId: String) {
+ private def removeBlockFromWorkers(blockId: BlockId) {
val locations = blockLocations.get(blockId)
if (locations != null) {
locations.foreach { blockManagerId: BlockManagerId =>
@@ -228,9 +227,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
- if (id.executorId == "<driver>" && !isLocal) {
- // Got a register message from the master node; don't register it
- } else if (!blockManagerInfo.contains(id)) {
+ if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
case Some(manager) =>
// A block manager of the same executor already exists.
@@ -247,7 +244,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
private def updateBlockInfo(
blockManagerId: BlockManagerId,
- blockId: String,
+ blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long) {
@@ -292,11 +289,11 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
sender ! true
}
- private def getLocations(blockId: String): Seq[BlockManagerId] = {
+ private def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty
}
- private def getLocationsMultipleBlockIds(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+ private def getLocationsMultipleBlockIds(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
blockIds.map(blockId => getLocations(blockId))
}
@@ -330,7 +327,7 @@ object BlockManagerMasterActor {
private var _remainingMem: Long = maxMem
// Mapping from block id to its status.
- private val _blocks = new JHashMap[String, BlockStatus]
+ private val _blocks = new JHashMap[BlockId, BlockStatus]
logInfo("Registering block manager %s with %s RAM".format(
blockManagerId.hostPort, Utils.bytesToString(maxMem)))
@@ -339,7 +336,7 @@ object BlockManagerMasterActor {
_lastSeenMs = System.currentTimeMillis()
}
- def updateBlockInfo(blockId: String, storageLevel: StorageLevel, memSize: Long,
+ def updateBlockInfo(blockId: BlockId, storageLevel: StorageLevel, memSize: Long,
diskSize: Long) {
updateLastSeenMs()
@@ -383,7 +380,7 @@ object BlockManagerMasterActor {
}
}
- def removeBlock(blockId: String) {
+ def removeBlock(blockId: BlockId) {
if (_blocks.containsKey(blockId)) {
_remainingMem += _blocks.get(blockId).memSize
_blocks.remove(blockId)
@@ -394,7 +391,7 @@ object BlockManagerMasterActor {
def lastSeenMs: Long = _lastSeenMs
- def blocks: JHashMap[String, BlockStatus] = _blocks
+ def blocks: JHashMap[BlockId, BlockStatus] = _blocks
override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 24333a179c..45f51da288 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -30,7 +30,7 @@ private[storage] object BlockManagerMessages {
// Remove a block from the slaves that have it. This can only be used to remove
// blocks that the master knows about.
- case class RemoveBlock(blockId: String) extends ToBlockManagerSlave
+ case class RemoveBlock(blockId: BlockId) extends ToBlockManagerSlave
// Remove all blocks belonging to a specific RDD.
case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave
@@ -51,7 +51,7 @@ private[storage] object BlockManagerMessages {
class UpdateBlockInfo(
var blockManagerId: BlockManagerId,
- var blockId: String,
+ var blockId: BlockId,
var storageLevel: StorageLevel,
var memSize: Long,
var diskSize: Long)
@@ -62,7 +62,7 @@ private[storage] object BlockManagerMessages {
override def writeExternal(out: ObjectOutput) {
blockManagerId.writeExternal(out)
- out.writeUTF(blockId)
+ out.writeUTF(blockId.name)
storageLevel.writeExternal(out)
out.writeLong(memSize)
out.writeLong(diskSize)
@@ -70,7 +70,7 @@ private[storage] object BlockManagerMessages {
override def readExternal(in: ObjectInput) {
blockManagerId = BlockManagerId(in)
- blockId = in.readUTF()
+ blockId = BlockId(in.readUTF())
storageLevel = StorageLevel(in)
memSize = in.readLong()
diskSize = in.readLong()
@@ -79,7 +79,7 @@ private[storage] object BlockManagerMessages {
object UpdateBlockInfo {
def apply(blockManagerId: BlockManagerId,
- blockId: String,
+ blockId: BlockId,
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long): UpdateBlockInfo = {
@@ -87,14 +87,14 @@ private[storage] object BlockManagerMessages {
}
// For pattern-matching
- def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = {
+ def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, BlockId, StorageLevel, Long, Long)] = {
Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize))
}
}
- case class GetLocations(blockId: String) extends ToBlockManagerMaster
+ case class GetLocations(blockId: BlockId) extends ToBlockManagerMaster
- case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster
+ case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster
case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
index 951503019f..3a65e55733 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
@@ -26,6 +26,7 @@ import org.apache.spark.storage.BlockManagerMessages._
* An actor to take commands from the master to execute options. For example,
* this is used to remove blocks from the slave's BlockManager.
*/
+private[storage]
class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor {
override def receive = {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
index 678c38203c..0c66addf9d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala
@@ -77,7 +77,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
}
}
- private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) {
+ private def putBlock(id: BlockId, bytes: ByteBuffer, level: StorageLevel) {
val startTimeMs = System.currentTimeMillis()
logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes)
blockManager.putBytes(id, bytes, level)
@@ -85,7 +85,7 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
+ " with data size: " + bytes.limit)
}
- private def getBlock(id: String): ByteBuffer = {
+ private def getBlock(id: BlockId): ByteBuffer = {
val startTimeMs = System.currentTimeMillis()
logDebug("GetBlock " + id + " started from " + startTimeMs)
val buffer = blockManager.getLocalBytes(id) match {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
index d8fa6a91d1..80dcb5a207 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala
@@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.network._
-private[spark] case class GetBlock(id: String)
-private[spark] case class GotBlock(id: String, data: ByteBuffer)
-private[spark] case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel)
+private[spark] case class GetBlock(id: BlockId)
+private[spark] case class GotBlock(id: BlockId, data: ByteBuffer)
+private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel)
private[spark] class BlockMessage() {
// Un-initialized: typ = 0
@@ -34,7 +34,7 @@ private[spark] class BlockMessage() {
// GotBlock: typ = 2
// PutBlock: typ = 3
private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED
- private var id: String = null
+ private var id: BlockId = null
private var data: ByteBuffer = null
private var level: StorageLevel = null
@@ -74,7 +74,7 @@ private[spark] class BlockMessage() {
for (i <- 1 to idLength) {
idBuilder += buffer.getChar()
}
- id = idBuilder.toString()
+ id = BlockId(idBuilder.toString)
if (typ == BlockMessage.TYPE_PUT_BLOCK) {
@@ -109,28 +109,17 @@ private[spark] class BlockMessage() {
set(buffer)
}
- def getType: Int = {
- return typ
- }
-
- def getId: String = {
- return id
- }
-
- def getData: ByteBuffer = {
- return data
- }
-
- def getLevel: StorageLevel = {
- return level
- }
+ def getType: Int = typ
+ def getId: BlockId = id
+ def getData: ByteBuffer = data
+ def getLevel: StorageLevel = level
def toBufferMessage: BufferMessage = {
val startTime = System.currentTimeMillis
val buffers = new ArrayBuffer[ByteBuffer]()
- var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2)
- buffer.putInt(typ).putInt(id.length())
- id.foreach((x: Char) => buffer.putChar(x))
+ var buffer = ByteBuffer.allocate(4 + 4 + id.name.length * 2)
+ buffer.putInt(typ).putInt(id.name.length)
+ id.name.foreach((x: Char) => buffer.putChar(x))
buffer.flip()
buffers += buffer
@@ -212,7 +201,8 @@ private[spark] object BlockMessage {
def main(args: Array[String]) {
val B = new BlockMessage()
- B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2))
+ val blockId = TestBlockId("ABC")
+ B.set(new PutBlock(blockId, ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2))
val bMsg = B.toBufferMessage
val C = new BlockMessage()
C.set(bMsg)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
index 0aaf846b5b..6ce9127c74 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala
@@ -111,14 +111,15 @@ private[spark] object BlockMessageArray {
}
def main(args: Array[String]) {
- val blockMessages =
+ val blockMessages =
(0 until 10).map { i =>
if (i % 2 == 0) {
val buffer = ByteBuffer.allocate(100)
buffer.clear
- BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY_SER))
+ BlockMessage.fromPutBlock(PutBlock(TestBlockId(i.toString), buffer,
+ StorageLevel.MEMORY_ONLY_SER))
} else {
- BlockMessage.fromGetBlock(GetBlock(i.toString))
+ BlockMessage.fromGetBlock(GetBlock(TestBlockId(i.toString)))
}
}
val blockMessageArray = new BlockMessageArray(blockMessages)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 39f103297f..32d2dd0694 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -17,6 +17,13 @@
package org.apache.spark.storage
+import java.io.{FileOutputStream, File, OutputStream}
+import java.nio.channels.FileChannel
+
+import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
+
+import org.apache.spark.Logging
+import org.apache.spark.serializer.{SerializationStream, Serializer}
/**
* An interface for writing JVM objects to some underlying storage. This interface allows
@@ -25,7 +32,7 @@ package org.apache.spark.storage
*
* This interface does not support concurrent writes.
*/
-abstract class BlockObjectWriter(val blockId: String) {
+abstract class BlockObjectWriter(val blockId: BlockId) {
var closeEventHandler: () => Unit = _
@@ -59,7 +66,129 @@ abstract class BlockObjectWriter(val blockId: String) {
def write(value: Any)
/**
- * Size of the valid writes, in bytes.
+ * Returns the file segment of committed data that this Writer has written.
+ */
+ def fileSegment(): FileSegment
+
+ /**
+ * Cumulative time spent performing blocking writes, in ns.
*/
- def size(): Long
+ def timeWriting(): Long
+}
+
+/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
+class DiskBlockObjectWriter(
+ blockId: BlockId,
+ file: File,
+ serializer: Serializer,
+ bufferSize: Int,
+ compressStream: OutputStream => OutputStream)
+ extends BlockObjectWriter(blockId)
+ with Logging
+{
+
+ /** Intercepts write calls and tracks total time spent writing. Not thread safe. */
+ private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
+ def timeWriting = _timeWriting
+ private var _timeWriting = 0L
+
+ private def callWithTiming(f: => Unit) = {
+ val start = System.nanoTime()
+ f
+ _timeWriting += (System.nanoTime() - start)
+ }
+
+ def write(i: Int): Unit = callWithTiming(out.write(i))
+ override def write(b: Array[Byte]) = callWithTiming(out.write(b))
+ override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
+ }
+
+ private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean
+
+ /** The file channel, used for repositioning / truncating the file. */
+ private var channel: FileChannel = null
+ private var bs: OutputStream = null
+ private var fos: FileOutputStream = null
+ private var ts: TimeTrackingOutputStream = null
+ private var objOut: SerializationStream = null
+ private var initialPosition = 0L
+ private var lastValidPosition = 0L
+ private var initialized = false
+ private var _timeWriting = 0L
+
+ override def open(): BlockObjectWriter = {
+ fos = new FileOutputStream(file, true)
+ ts = new TimeTrackingOutputStream(fos)
+ channel = fos.getChannel()
+ initialPosition = channel.position
+ lastValidPosition = initialPosition
+ bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
+ objOut = serializer.newInstance().serializeStream(bs)
+ initialized = true
+ this
+ }
+
+ override def close() {
+ if (initialized) {
+ if (syncWrites) {
+ // Force outstanding writes to disk and track how long it takes
+ objOut.flush()
+ val start = System.nanoTime()
+ fos.getFD.sync()
+ _timeWriting += System.nanoTime() - start
+ }
+ objOut.close()
+
+ _timeWriting += ts.timeWriting
+
+ channel = null
+ bs = null
+ fos = null
+ ts = null
+ objOut = null
+ }
+ // Invoke the close callback handler.
+ super.close()
+ }
+
+ override def isOpen: Boolean = objOut != null
+
+ override def commit(): Long = {
+ if (initialized) {
+ // NOTE: Flush the serializer first and then the compressed/buffered output stream
+ objOut.flush()
+ bs.flush()
+ val prevPos = lastValidPosition
+ lastValidPosition = channel.position()
+ lastValidPosition - prevPos
+ } else {
+ // lastValidPosition is zero if stream is uninitialized
+ lastValidPosition
+ }
+ }
+
+ override def revertPartialWrites() {
+ if (initialized) {
+ // Discard current writes. We do this by flushing the outstanding writes and
+ // truncate the file to the last valid position.
+ objOut.flush()
+ bs.flush()
+ channel.truncate(lastValidPosition)
+ }
+ }
+
+ override def write(value: Any) {
+ if (!initialized) {
+ open()
+ }
+ objOut.writeObject(value)
+ }
+
+ override def fileSegment(): FileSegment = {
+ val bytesWritten = lastValidPosition - initialPosition
+ new FileSegment(file, initialPosition, bytesWritten)
+ }
+
+ // Only valid if called after close()
+ override def timeWriting() = _timeWriting
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
index fa834371f4..ea42656240 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
@@ -27,7 +27,7 @@ import org.apache.spark.Logging
*/
private[spark]
abstract class BlockStore(val blockManager: BlockManager) extends Logging {
- def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel)
+ def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel)
/**
* Put in a block and, possibly, also return its content as either bytes or another Iterator.
@@ -36,26 +36,26 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging {
* @return a PutResult that contains the size of the data, as well as the values put if
* returnValues is true (if not, the result's data field can be null)
*/
- def putValues(blockId: String, values: ArrayBuffer[Any], level: StorageLevel,
+ def putValues(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
returnValues: Boolean) : PutResult
/**
* Return the size of a block in bytes.
*/
- def getSize(blockId: String): Long
+ def getSize(blockId: BlockId): Long
- def getBytes(blockId: String): Option[ByteBuffer]
+ def getBytes(blockId: BlockId): Option[ByteBuffer]
- def getValues(blockId: String): Option[Iterator[Any]]
+ def getValues(blockId: BlockId): Option[Iterator[Any]]
/**
* Remove a block, if it exists.
* @param blockId the block to remove.
* @return True if the block was found and removed, False otherwise.
*/
- def remove(blockId: String): Boolean
+ def remove(blockId: BlockId): Boolean
- def contains(blockId: String): Boolean
+ def contains(blockId: BlockId): Boolean
def clear() { }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
new file mode 100644
index 0000000000..bcb58ad946
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.File
+import java.text.SimpleDateFormat
+import java.util.{Date, Random}
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.Logging
+import org.apache.spark.executor.ExecutorExitCode
+import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
+
+/**
+ * Creates and maintains the logical mapping between logical blocks and physical on-disk
+ * locations. By default, one block is mapped to one file with a name given by its BlockId.
+ * However, it is also possible to have a block map to only a segment of a file, by calling
+ * mapBlockToFileSegment().
+ *
+ * @param rootDirs The directories to use for storing block files. Data will be hashed among these.
+ */
+private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver with Logging {
+
+ private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
+ private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
+
+ // Create one local directory for each path mentioned in spark.local.dir; then, inside this
+ // directory, create multiple subdirectories that we will hash files into, in order to avoid
+ // having really large inodes at the top level.
+ private val localDirs: Array[File] = createLocalDirs()
+ private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
+ private var shuffleSender : ShuffleSender = null
+
+ // Stores only Blocks which have been specifically mapped to segments of files
+ // (rather than the default, which maps a Block to a whole file).
+ // This keeps our bookkeeping down, since the file system itself tracks the standalone Blocks.
+ private val blockToFileSegmentMap = new TimeStampedHashMap[BlockId, FileSegment]
+
+ val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DISK_BLOCK_MANAGER, this.cleanup)
+
+ addShutdownHook()
+
+ /**
+ * Creates a logical mapping from the given BlockId to a segment of a file.
+ * This will cause any accesses of the logical BlockId to be directed to the specified
+ * physical location.
+ */
+ def mapBlockToFileSegment(blockId: BlockId, fileSegment: FileSegment) {
+ blockToFileSegmentMap.put(blockId, fileSegment)
+ }
+
+ /**
+ * Returns the phyiscal file segment in which the given BlockId is located.
+ * If the BlockId has been mapped to a specific FileSegment, that will be returned.
+ * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
+ */
+ def getBlockLocation(blockId: BlockId): FileSegment = {
+ if (blockToFileSegmentMap.internalMap.containsKey(blockId)) {
+ blockToFileSegmentMap.get(blockId).get
+ } else {
+ val file = getFile(blockId.name)
+ new FileSegment(file, 0, file.length())
+ }
+ }
+
+ /**
+ * Simply returns a File to place the given Block into. This does not physically create the file.
+ * If filename is given, that file will be used. Otherwise, we will use the BlockId to get
+ * a unique filename.
+ */
+ def createBlockFile(blockId: BlockId, filename: String = "", allowAppending: Boolean): File = {
+ val actualFilename = if (filename == "") blockId.name else filename
+ val file = getFile(actualFilename)
+ if (!allowAppending && file.exists()) {
+ throw new IllegalStateException(
+ "Attempted to create file that already exists: " + actualFilename)
+ }
+ file
+ }
+
+ private def getFile(filename: String): File = {
+ // Figure out which local directory it hashes to, and which subdirectory in that
+ val hash = Utils.nonNegativeHash(filename)
+ val dirId = hash % localDirs.length
+ val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
+
+ // Create the subdirectory if it doesn't already exist
+ var subDir = subDirs(dirId)(subDirId)
+ if (subDir == null) {
+ subDir = subDirs(dirId).synchronized {
+ val old = subDirs(dirId)(subDirId)
+ if (old != null) {
+ old
+ } else {
+ val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
+ newDir.mkdir()
+ subDirs(dirId)(subDirId) = newDir
+ newDir
+ }
+ }
+ }
+
+ new File(subDir, filename)
+ }
+
+ private def createLocalDirs(): Array[File] = {
+ logDebug("Creating local directories at root dirs '" + rootDirs + "'")
+ val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
+ rootDirs.split(",").map { rootDir =>
+ var foundLocalDir = false
+ var localDir: File = null
+ var localDirId: String = null
+ var tries = 0
+ val rand = new Random()
+ while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
+ tries += 1
+ try {
+ localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
+ localDir = new File(rootDir, "spark-local-" + localDirId)
+ if (!localDir.exists) {
+ foundLocalDir = localDir.mkdirs()
+ }
+ } catch {
+ case e: Exception =>
+ logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
+ }
+ }
+ if (!foundLocalDir) {
+ logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
+ " attempts to create local dir in " + rootDir)
+ System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
+ }
+ logInfo("Created local directory at " + localDir)
+ localDir
+ }
+ }
+
+ private def cleanup(cleanupTime: Long) {
+ blockToFileSegmentMap.clearOldValues(cleanupTime)
+ }
+
+ private def addShutdownHook() {
+ localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
+ Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
+ override def run() {
+ logDebug("Shutdown hook called")
+ localDirs.foreach { localDir =>
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ } catch {
+ case t: Throwable =>
+ logError("Exception while deleting local spark dir: " + localDir, t)
+ }
+ }
+
+ if (shuffleSender != null) {
+ shuffleSender.stop()
+ }
+ }
+ })
+ }
+
+ private[storage] def startShuffleBlockSender(port: Int): Int = {
+ shuffleSender = new ShuffleSender(port, this)
+ logInfo("Created ShuffleSender binding to port : " + shuffleSender.port)
+ shuffleSender.port
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index 63447baf8c..a3c496f9e0 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -17,153 +17,46 @@
package org.apache.spark.storage
-import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
+import java.io.{FileOutputStream, RandomAccessFile}
import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
-import java.util.{Random, Date}
-import java.text.SimpleDateFormat
import scala.collection.mutable.ArrayBuffer
-import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
-
-import org.apache.spark.executor.ExecutorExitCode
-import org.apache.spark.serializer.{Serializer, SerializationStream}
import org.apache.spark.Logging
-import org.apache.spark.network.netty.ShuffleSender
-import org.apache.spark.network.netty.PathResolver
+import org.apache.spark.serializer.Serializer
import org.apache.spark.util.Utils
/**
* Stores BlockManager blocks on disk.
*/
-private class DiskStore(blockManager: BlockManager, rootDirs: String)
+private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager)
extends BlockStore(blockManager) with Logging {
- class DiskBlockObjectWriter(blockId: String, serializer: Serializer, bufferSize: Int)
- extends BlockObjectWriter(blockId) {
-
- private val f: File = createFile(blockId /*, allowAppendExisting */)
-
- // The file channel, used for repositioning / truncating the file.
- private var channel: FileChannel = null
- private var bs: OutputStream = null
- private var objOut: SerializationStream = null
- private var lastValidPosition = 0L
- private var initialized = false
-
- override def open(): DiskBlockObjectWriter = {
- val fos = new FileOutputStream(f, true)
- channel = fos.getChannel()
- bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize))
- objOut = serializer.newInstance().serializeStream(bs)
- initialized = true
- this
- }
-
- override def close() {
- if (initialized) {
- objOut.close()
- channel = null
- bs = null
- objOut = null
- }
- // Invoke the close callback handler.
- super.close()
- }
-
- override def isOpen: Boolean = objOut != null
-
- // Flush the partial writes, and set valid length to be the length of the entire file.
- // Return the number of bytes written for this commit.
- override def commit(): Long = {
- if (initialized) {
- // NOTE: Flush the serializer first and then the compressed/buffered output stream
- objOut.flush()
- bs.flush()
- val prevPos = lastValidPosition
- lastValidPosition = channel.position()
- lastValidPosition - prevPos
- } else {
- // lastValidPosition is zero if stream is uninitialized
- lastValidPosition
- }
- }
-
- override def revertPartialWrites() {
- if (initialized) {
- // Discard current writes. We do this by flushing the outstanding writes and
- // truncate the file to the last valid position.
- objOut.flush()
- bs.flush()
- channel.truncate(lastValidPosition)
- }
- }
-
- override def write(value: Any) {
- if (!initialized) {
- open()
- }
- objOut.writeObject(value)
- }
-
- override def size(): Long = lastValidPosition
- }
-
- private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
-
- private var shuffleSender : ShuffleSender = null
- // Create one local directory for each path mentioned in spark.local.dir; then, inside this
- // directory, create multiple subdirectories that we will hash files into, in order to avoid
- // having really large inodes at the top level.
- private val localDirs: Array[File] = createLocalDirs()
- private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
-
- addShutdownHook()
-
- def getBlockWriter(blockId: String, serializer: Serializer, bufferSize: Int)
- : BlockObjectWriter = {
- new DiskBlockObjectWriter(blockId, serializer, bufferSize)
+ override def getSize(blockId: BlockId): Long = {
+ diskManager.getBlockLocation(blockId).length
}
- override def getSize(blockId: String): Long = {
- getFile(blockId).length()
- }
-
- override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
// So that we do not modify the input offsets !
// duplicate does not copy buffer, so inexpensive
val bytes = _bytes.duplicate()
logDebug("Attempting to put block " + blockId)
val startTime = System.currentTimeMillis
- val file = createFile(blockId)
- val channel = new RandomAccessFile(file, "rw").getChannel()
+ val file = diskManager.createBlockFile(blockId, allowAppending = false)
+ val channel = new FileOutputStream(file).getChannel()
while (bytes.remaining > 0) {
channel.write(bytes)
}
channel.close()
val finishTime = System.currentTimeMillis
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
- }
-
- private def getFileBytes(file: File): ByteBuffer = {
- val length = file.length()
- val channel = new RandomAccessFile(file, "r").getChannel()
- val buffer = try {
- channel.map(MapMode.READ_ONLY, 0, length)
- } finally {
- channel.close()
- }
-
- buffer
+ file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
}
override def putValues(
- blockId: String,
+ blockId: BlockId,
values: ArrayBuffer[Any],
level: StorageLevel,
returnValues: Boolean)
@@ -171,159 +64,62 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
logDebug("Attempting to write values for block " + blockId)
val startTime = System.currentTimeMillis
- val file = createFile(blockId)
- val fileOut = blockManager.wrapForCompression(blockId,
- new FastBufferedOutputStream(new FileOutputStream(file)))
- val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut)
- objOut.writeAll(values.iterator)
- objOut.close()
- val length = file.length()
+ val file = diskManager.createBlockFile(blockId, allowAppending = false)
+ val outputStream = new FileOutputStream(file)
+ blockManager.dataSerializeStream(blockId, outputStream, values.iterator)
+ val length = file.length
val timeTaken = System.currentTimeMillis - startTime
logDebug("Block %s stored as %s file on disk in %d ms".format(
- blockId, Utils.bytesToString(length), timeTaken))
+ file.getName, Utils.bytesToString(length), timeTaken))
if (returnValues) {
// Return a byte buffer for the contents of the file
- val buffer = getFileBytes(file)
+ val buffer = getBytes(blockId).get
PutResult(length, Right(buffer))
} else {
PutResult(length, null)
}
}
- override def getBytes(blockId: String): Option[ByteBuffer] = {
- val file = getFile(blockId)
- val bytes = getFileBytes(file)
- Some(bytes)
+ override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
+ val segment = diskManager.getBlockLocation(blockId)
+ val channel = new RandomAccessFile(segment.file, "r").getChannel()
+ val buffer = try {
+ channel.map(MapMode.READ_ONLY, segment.offset, segment.length)
+ } finally {
+ channel.close()
+ }
+ Some(buffer)
}
- override def getValues(blockId: String): Option[Iterator[Any]] = {
- getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
+ override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
+ getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer))
}
/**
* A version of getValues that allows a custom serializer. This is used as part of the
* shuffle short-circuit code.
*/
- def getValues(blockId: String, serializer: Serializer): Option[Iterator[Any]] = {
+ def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer))
}
- override def remove(blockId: String): Boolean = {
- val file = getFile(blockId)
- if (file.exists()) {
+ override def remove(blockId: BlockId): Boolean = {
+ val fileSegment = diskManager.getBlockLocation(blockId)
+ val file = fileSegment.file
+ if (file.exists() && file.length() == fileSegment.length) {
file.delete()
} else {
- false
- }
- }
-
- override def contains(blockId: String): Boolean = {
- getFile(blockId).exists()
- }
-
- private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = {
- val file = getFile(blockId)
- if (!allowAppendExisting && file.exists()) {
- // NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
- // was rescheduled on the same machine as the old task.
- logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
- file.delete()
- }
- file
- }
-
- private def getFile(blockId: String): File = {
- logDebug("Getting file for block " + blockId)
-
- // Figure out which local directory it hashes to, and which subdirectory in that
- val hash = Utils.nonNegativeHash(blockId)
- val dirId = hash % localDirs.length
- val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
-
- // Create the subdirectory if it doesn't already exist
- var subDir = subDirs(dirId)(subDirId)
- if (subDir == null) {
- subDir = subDirs(dirId).synchronized {
- val old = subDirs(dirId)(subDirId)
- if (old != null) {
- old
- } else {
- val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
- newDir.mkdir()
- subDirs(dirId)(subDirId) = newDir
- newDir
- }
- }
- }
-
- new File(subDir, blockId)
- }
-
- private def createLocalDirs(): Array[File] = {
- logDebug("Creating local directories at root dirs '" + rootDirs + "'")
- val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
- rootDirs.split(",").map { rootDir =>
- var foundLocalDir = false
- var localDir: File = null
- var localDirId: String = null
- var tries = 0
- val rand = new Random()
- while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
- tries += 1
- try {
- localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
- localDir = new File(rootDir, "spark-local-" + localDirId)
- if (!localDir.exists) {
- foundLocalDir = localDir.mkdirs()
- }
- } catch {
- case e: Exception =>
- logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
- }
+ if (fileSegment.length < file.length()) {
+ logWarning("Could not delete block associated with only a part of a file: " + blockId)
}
- if (!foundLocalDir) {
- logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
- " attempts to create local dir in " + rootDir)
- System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
- }
- logInfo("Created local directory at " + localDir)
- localDir
+ false
}
}
- private def addShutdownHook() {
- localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
- Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
- override def run() {
- logDebug("Shutdown hook called")
- localDirs.foreach { localDir =>
- try {
- if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
- } catch {
- case t: Throwable =>
- logError("Exception while deleting local spark dir: " + localDir, t)
- }
- }
- if (shuffleSender != null) {
- shuffleSender.stop
- }
- }
- })
- }
-
- private[storage] def startShuffleBlockSender(port: Int): Int = {
- val pResolver = new PathResolver {
- override def getAbsolutePath(blockId: String): String = {
- if (!blockId.startsWith("shuffle_")) {
- return null
- }
- DiskStore.this.getFile(blockId).getAbsolutePath()
- }
- }
- shuffleSender = new ShuffleSender(port, pResolver)
- logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port)
- shuffleSender.port
+ override def contains(blockId: BlockId): Boolean = {
+ val file = diskManager.getBlockLocation(blockId).file
+ file.exists()
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
new file mode 100644
index 0000000000..555486830a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.io.File
+
+/**
+ * References a particular segment of a file (potentially the entire file),
+ * based off an offset and a length.
+ */
+private[spark] class FileSegment(val file: File, val offset: Long, val length : Long) {
+ override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length)
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index 77a39c71ed..05f676c6e2 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -32,7 +32,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
case class Entry(value: Any, size: Long, deserialized: Boolean)
- private val entries = new LinkedHashMap[String, Entry](32, 0.75f, true)
+ private val entries = new LinkedHashMap[BlockId, Entry](32, 0.75f, true)
@volatile private var currentMemory = 0L
// Object used to ensure that only one thread is putting blocks and if necessary, dropping
// blocks from the memory store.
@@ -42,13 +42,13 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
def freeMemory: Long = maxMemory - currentMemory
- override def getSize(blockId: String): Long = {
+ override def getSize(blockId: BlockId): Long = {
entries.synchronized {
entries.get(blockId).size
}
}
- override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
+ override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
// Work on a duplicate - since the original input might be used elsewhere.
val bytes = _bytes.duplicate()
bytes.rewind()
@@ -64,7 +64,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
override def putValues(
- blockId: String,
+ blockId: BlockId,
values: ArrayBuffer[Any],
level: StorageLevel,
returnValues: Boolean)
@@ -81,7 +81,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def getBytes(blockId: String): Option[ByteBuffer] = {
+ override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
val entry = entries.synchronized {
entries.get(blockId)
}
@@ -94,7 +94,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def getValues(blockId: String): Option[Iterator[Any]] = {
+ override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
val entry = entries.synchronized {
entries.get(blockId)
}
@@ -108,7 +108,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def remove(blockId: String): Boolean = {
+ override def remove(blockId: BlockId): Boolean = {
entries.synchronized {
val entry = entries.remove(blockId)
if (entry != null) {
@@ -131,14 +131,10 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
/**
- * Return the RDD ID that a given block ID is from, or null if it is not an RDD block.
+ * Return the RDD ID that a given block ID is from, or None if it is not an RDD block.
*/
- private def getRddId(blockId: String): String = {
- if (blockId.startsWith("rdd_")) {
- blockId.split('_')(1)
- } else {
- null
- }
+ private def getRddId(blockId: BlockId): Option[Int] = {
+ blockId.asRDDId.map(_.rddId)
}
/**
@@ -151,7 +147,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* blocks to free memory for one block, another thread may use up the freed space for
* another block.
*/
- private def tryToPut(blockId: String, value: Any, size: Long, deserialized: Boolean): Boolean = {
+ private def tryToPut(blockId: BlockId, value: Any, size: Long, deserialized: Boolean): Boolean = {
// TODO: Its possible to optimize the locking by locking entries only when selecting blocks
// to be dropped. Once the to-be-dropped blocks have been selected, and lock on entries has been
// released, it must be ensured that those to-be-dropped blocks are not double counted for
@@ -195,7 +191,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
* Assumes that a lock is held by the caller to ensure only one thread is dropping blocks.
* Otherwise, the freed space may fill up before the caller puts in their new value.
*/
- private def ensureFreeSpace(blockIdToAdd: String, space: Long): Boolean = {
+ private def ensureFreeSpace(blockIdToAdd: BlockId, space: Long): Boolean = {
logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format(
space, currentMemory, maxMemory))
@@ -207,7 +203,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
if (maxMemory - currentMemory < space) {
val rddToAdd = getRddId(blockIdToAdd)
- val selectedBlocks = new ArrayBuffer[String]()
+ val selectedBlocks = new ArrayBuffer[BlockId]()
var selectedMemory = 0L
// This is synchronized to ensure that the set of entries is not changed
@@ -218,7 +214,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
while (maxMemory - (currentMemory - selectedMemory) < space && iterator.hasNext) {
val pair = iterator.next()
val blockId = pair.getKey
- if (rddToAdd != null && rddToAdd == getRddId(blockId)) {
+ if (rddToAdd != None && rddToAdd == getRddId(blockId)) {
logInfo("Will not store " + blockIdToAdd + " as it would require dropping another " +
"block from the same RDD")
return false
@@ -252,7 +248,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
return true
}
- override def contains(blockId: String): Boolean = {
+ override def contains(blockId: BlockId): Boolean = {
entries.synchronized { entries.containsKey(blockId) }
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index 9da11efb57..229178c095 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -17,12 +17,13 @@
package org.apache.spark.storage
-import org.apache.spark.serializer.Serializer
+import java.util.concurrent.ConcurrentLinkedQueue
+import java.util.concurrent.atomic.AtomicInteger
+import org.apache.spark.serializer.Serializer
private[spark]
-class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter])
-
+class ShuffleWriterGroup(val id: Int, val fileId: Int, val writers: Array[BlockObjectWriter])
private[spark]
trait ShuffleBlocks {
@@ -30,38 +31,61 @@ trait ShuffleBlocks {
def releaseWriters(group: ShuffleWriterGroup)
}
-
+/**
+ * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one writer
+ * per reducer.
+ *
+ * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle
+ * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer
+ * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle files,
+ * it releases them for another task.
+ * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple:
+ * - shuffleId: The unique id given to the entire shuffle stage.
+ * - bucketId: The id of the output partition (i.e., reducer id)
+ * - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a
+ * time owns a particular fileId, and this id is returned to a pool when the task finishes.
+ */
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) {
+ // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
+ // TODO: Remove this once the shuffle file consolidation feature is stable.
+ val consolidateShuffleFiles =
+ System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean
- def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = {
+ var nextFileId = new AtomicInteger(0)
+ val unusedFileIds = new ConcurrentLinkedQueue[java.lang.Integer]()
+
+ def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) = {
new ShuffleBlocks {
// Get a group of writers for a map task.
override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+ val fileId = getUnusedFileId()
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
- val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
- blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
+ val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
+ val filename = physicalFileName(shuffleId, bucketId, fileId)
+ blockManager.getDiskWriter(blockId, filename, serializer, bufferSize)
}
- new ShuffleWriterGroup(mapId, writers)
+ new ShuffleWriterGroup(mapId, fileId, writers)
}
- override def releaseWriters(group: ShuffleWriterGroup) = {
- // Nothing really to release here.
+ override def releaseWriters(group: ShuffleWriterGroup) {
+ recycleFileId(group.fileId)
}
}
}
-}
-
-private[spark]
-object ShuffleBlockManager {
+ private def getUnusedFileId(): Int = {
+ val fileId = unusedFileIds.poll()
+ if (fileId == null) nextFileId.getAndIncrement() else fileId
+ }
- // Returns the block id for a given shuffle block.
- def blockId(shuffleId: Int, bucketId: Int, groupId: Int): String = {
- "shuffle_" + shuffleId + "_" + groupId + "_" + bucketId
+ private def recycleFileId(fileId: Int) {
+ if (!consolidateShuffleFiles) { return } // ensures we always generate new file id
+ unusedFileIds.add(fileId)
}
- // Returns true if the block is a shuffle block.
- def isShuffle(blockId: String): Boolean = blockId.startsWith("shuffle_")
+ private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
+ "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
new file mode 100644
index 0000000000..1b074e5ec7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
@@ -0,0 +1,84 @@
+package org.apache.spark.storage
+
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.{CountDownLatch, Executors}
+
+import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.SparkContext
+import org.apache.spark.util.Utils
+
+/** Utility for micro-benchmarking shuffle write performance.
+ *
+ * Writes simulated shuffle output from several threads and records the observed throughput*/
+object StoragePerfTester {
+ def main(args: Array[String]) = {
+ /** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */
+ val dataSizeMb = Utils.memoryStringToMb(sys.env.getOrElse("OUTPUT_DATA", "1g"))
+
+ /** Number of map tasks. All tasks execute concurrently. */
+ val numMaps = sys.env.get("NUM_MAPS").map(_.toInt).getOrElse(8)
+
+ /** Number of reduce splits for each map task. */
+ val numOutputSplits = sys.env.get("NUM_REDUCERS").map(_.toInt).getOrElse(500)
+
+ val recordLength = 1000 // ~1KB records
+ val totalRecords = dataSizeMb * 1000
+ val recordsPerMap = totalRecords / numMaps
+
+ val writeData = "1" * recordLength
+ val executor = Executors.newFixedThreadPool(numMaps)
+
+ System.setProperty("spark.shuffle.compress", "false")
+ System.setProperty("spark.shuffle.sync", "true")
+
+ // This is only used to instantiate a BlockManager. All thread scheduling is done manually.
+ val sc = new SparkContext("local[4]", "Write Tester")
+ val blockManager = sc.env.blockManager
+
+ def writeOutputBytes(mapId: Int, total: AtomicLong) = {
+ val shuffle = blockManager.shuffleBlockManager.forShuffle(1, numOutputSplits,
+ new KryoSerializer())
+ val buckets = shuffle.acquireWriters(mapId)
+ for (i <- 1 to recordsPerMap) {
+ buckets.writers(i % numOutputSplits).write(writeData)
+ }
+ buckets.writers.map {w =>
+ w.commit()
+ total.addAndGet(w.fileSegment().length)
+ w.close()
+ }
+
+ shuffle.releaseWriters(buckets)
+ }
+
+ val start = System.currentTimeMillis()
+ val latch = new CountDownLatch(numMaps)
+ val totalBytes = new AtomicLong()
+ for (task <- 1 to numMaps) {
+ executor.submit(new Runnable() {
+ override def run() = {
+ try {
+ writeOutputBytes(task, totalBytes)
+ latch.countDown()
+ } catch {
+ case e: Exception =>
+ println("Exception in child thread: " + e + " " + e.getMessage)
+ System.exit(1)
+ }
+ }
+ })
+ }
+ latch.await()
+ val end = System.currentTimeMillis()
+ val time = (end - start) / 1000.0
+ val bytesPerSecond = totalBytes.get() / time
+ val bytesPerFile = (totalBytes.get() / (numOutputSplits * numMaps.toDouble)).toLong
+
+ System.err.println("files_total\t\t%s".format(numMaps * numOutputSplits))
+ System.err.println("bytes_per_file\t\t%s".format(Utils.bytesToString(bytesPerFile)))
+ System.err.println("agg_throughput\t\t%s/s".format(Utils.bytesToString(bytesPerSecond.toLong)))
+
+ executor.shutdown()
+ sc.stop()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
index 2bb7715696..1720007e4e 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
@@ -23,20 +23,24 @@ import org.apache.spark.util.Utils
private[spark]
case class StorageStatus(blockManagerId: BlockManagerId, maxMem: Long,
- blocks: Map[String, BlockStatus]) {
+ blocks: Map[BlockId, BlockStatus]) {
- def memUsed(blockPrefix: String = "") = {
- blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.memSize).
- reduceOption(_+_).getOrElse(0l)
- }
+ def memUsed() = blocks.values.map(_.memSize).reduceOption(_+_).getOrElse(0L)
- def diskUsed(blockPrefix: String = "") = {
- blocks.filterKeys(_.startsWith(blockPrefix)).values.map(_.diskSize).
- reduceOption(_+_).getOrElse(0l)
- }
+ def memUsedByRDD(rddId: Int) =
+ rddBlocks.filterKeys(_.rddId == rddId).values.map(_.memSize).reduceOption(_+_).getOrElse(0L)
+
+ def diskUsed() = blocks.values.map(_.diskSize).reduceOption(_+_).getOrElse(0L)
+
+ def diskUsedByRDD(rddId: Int) =
+ rddBlocks.filterKeys(_.rddId == rddId).values.map(_.diskSize).reduceOption(_+_).getOrElse(0L)
def memRemaining : Long = maxMem - memUsed()
+ def rddBlocks = blocks.flatMap {
+ case (rdd: RDDBlockId, status) => Some(rdd, status)
+ case _ => None
+ }
}
case class RDDInfo(id: Int, name: String, storageLevel: StorageLevel,
@@ -60,7 +64,7 @@ object StorageUtils {
/* Returns RDD-level information, compiled from a list of StorageStatus objects */
def rddInfoFromStorageStatus(storageStatusList: Seq[StorageStatus],
sc: SparkContext) : Array[RDDInfo] = {
- rddInfoFromBlockStatusList(storageStatusList.flatMap(_.blocks).toMap, sc)
+ rddInfoFromBlockStatusList(storageStatusList.flatMap(_.rddBlocks).toMap[RDDBlockId, BlockStatus], sc)
}
/* Returns a map of blocks to their locations, compiled from a list of StorageStatus objects */
@@ -71,26 +75,21 @@ object StorageUtils {
}
/* Given a list of BlockStatus objets, returns information for each RDD */
- def rddInfoFromBlockStatusList(infos: Map[String, BlockStatus],
+ def rddInfoFromBlockStatusList(infos: Map[RDDBlockId, BlockStatus],
sc: SparkContext) : Array[RDDInfo] = {
// Group by rddId, ignore the partition name
- val groupedRddBlocks = infos.filterKeys(_.startsWith("rdd_")).groupBy { case(k, v) =>
- k.substring(0,k.lastIndexOf('_'))
- }.mapValues(_.values.toArray)
+ val groupedRddBlocks = infos.groupBy { case(k, v) => k.rddId }.mapValues(_.values.toArray)
// For each RDD, generate an RDDInfo object
- val rddInfos = groupedRddBlocks.map { case (rddKey, rddBlocks) =>
+ val rddInfos = groupedRddBlocks.map { case (rddId, rddBlocks) =>
// Add up memory and disk sizes
val memSize = rddBlocks.map(_.memSize).reduce(_ + _)
val diskSize = rddBlocks.map(_.diskSize).reduce(_ + _)
- // Find the id of the RDD, e.g. rdd_1 => 1
- val rddId = rddKey.split("_").last.toInt
-
// Get the friendly name and storage level for the RDD, if available
sc.persistentRdds.get(rddId).map { r =>
- val rddName = Option(r.name).getOrElse(rddKey)
+ val rddName = Option(r.name).getOrElse(rddId.toString)
val rddStorageLevel = r.getStorageLevel
RDDInfo(rddId, rddName, rddStorageLevel, rddBlocks.length, r.partitions.size, memSize, diskSize)
}
@@ -101,16 +100,14 @@ object StorageUtils {
rddInfos
}
- /* Removes all BlockStatus object that are not part of a block prefix */
- def filterStorageStatusByPrefix(storageStatusList: Array[StorageStatus],
- prefix: String) : Array[StorageStatus] = {
+ /* Filters storage status by a given RDD id. */
+ def filterStorageStatusByRDD(storageStatusList: Array[StorageStatus], rddId: Int)
+ : Array[StorageStatus] = {
storageStatusList.map { status =>
- val newBlocks = status.blocks.filterKeys(_.startsWith(prefix))
+ val newBlocks = status.rddBlocks.filterKeys(_.rddId == rddId).toMap[BlockId, BlockStatus]
//val newRemainingMem = status.maxMem - newBlocks.values.map(_.memSize).reduce(_ + _)
StorageStatus(status.blockManagerId, status.maxMem, newBlocks)
}
-
}
-
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
index f2ae8dd97d..860e680576 100644
--- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
@@ -36,11 +36,11 @@ private[spark] object ThreadingTest {
val numBlocksPerProducer = 20000
private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread {
- val queue = new ArrayBlockingQueue[(String, Seq[Int])](100)
+ val queue = new ArrayBlockingQueue[(BlockId, Seq[Int])](100)
override def run() {
for (i <- 1 to numBlocksPerProducer) {
- val blockId = "b-" + id + "-" + i
+ val blockId = TestBlockId("b-" + id + "-" + i)
val blockSize = Random.nextInt(1000)
val block = (1 to blockSize).map(_ => Random.nextInt())
val level = randomLevel()
@@ -64,7 +64,7 @@ private[spark] object ThreadingTest {
private[spark] class ConsumerThread(
manager: BlockManager,
- queue: ArrayBlockingQueue[(String, Seq[Int])]
+ queue: ArrayBlockingQueue[(BlockId, Seq[Int])]
) extends Thread {
var numBlockConsumed = 0
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 163a3746ea..b7c81d091c 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -86,7 +86,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
Seq("Task ID", "Status", "Locality Level", "Executor", "Launch Time", "Duration") ++
Seq("GC Time") ++
{if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++
- {if (hasShuffleWrite) Seq("Shuffle Write") else Nil} ++
+ {if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++
Seq("Errors")
val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite), tasks)
@@ -169,6 +169,8 @@ private[spark] class StagePage(parent: JobProgressUI) {
Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")}</td>
}}
{if (shuffleWrite) {
+ <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
+ parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}</td>
<td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")}</td>
}}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
index 43c1257677..b83cd54f3c 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
@@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.Node
-import org.apache.spark.storage.{StorageStatus, StorageUtils}
+import org.apache.spark.storage.{BlockId, StorageStatus, StorageUtils}
import org.apache.spark.storage.BlockManagerMasterActor.BlockStatus
import org.apache.spark.ui.UIUtils._
import org.apache.spark.ui.Page._
@@ -33,21 +33,20 @@ private[spark] class RDDPage(parent: BlockManagerUI) {
val sc = parent.sc
def render(request: HttpServletRequest): Seq[Node] = {
- val id = request.getParameter("id")
- val prefix = "rdd_" + id.toString
+ val id = request.getParameter("id").toInt
val storageStatusList = sc.getExecutorStorageStatus
- val filteredStorageStatusList = StorageUtils.
- filterStorageStatusByPrefix(storageStatusList, prefix)
+ val filteredStorageStatusList = StorageUtils.filterStorageStatusByRDD(storageStatusList, id)
val rddInfo = StorageUtils.rddInfoFromStorageStatus(filteredStorageStatusList, sc).head
val workerHeaders = Seq("Host", "Memory Usage", "Disk Usage")
- val workers = filteredStorageStatusList.map((prefix, _))
+ val workers = filteredStorageStatusList.map((id, _))
val workerTable = listingTable(workerHeaders, workerRow, workers)
val blockHeaders = Seq("Block Name", "Storage Level", "Size in Memory", "Size on Disk",
"Executors")
- val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.sortWith(_._1 < _._1)
+ val blockStatuses = filteredStorageStatusList.flatMap(_.blocks).toArray.
+ sortWith(_._1.name < _._1.name)
val blockLocations = StorageUtils.blockLocationsFromStorageStatus(filteredStorageStatusList)
val blocks = blockStatuses.map {
case(id, status) => (id, status, blockLocations.get(id).getOrElse(Seq("UNKNOWN")))
@@ -99,7 +98,7 @@ private[spark] class RDDPage(parent: BlockManagerUI) {
headerSparkPage(content, parent.sc, "RDD Storage Info for " + rddInfo.name, Storage)
}
- def blockRow(row: (String, BlockStatus, Seq[String])): Seq[Node] = {
+ def blockRow(row: (BlockId, BlockStatus, Seq[String])): Seq[Node] = {
val (id, block, locations) = row
<tr>
<td>{id}</td>
@@ -118,15 +117,15 @@ private[spark] class RDDPage(parent: BlockManagerUI) {
</tr>
}
- def workerRow(worker: (String, StorageStatus)): Seq[Node] = {
- val (prefix, status) = worker
+ def workerRow(worker: (Int, StorageStatus)): Seq[Node] = {
+ val (rddId, status) = worker
<tr>
<td>{status.blockManagerId.host + ":" + status.blockManagerId.port}</td>
<td>
- {Utils.bytesToString(status.memUsed(prefix))}
+ {Utils.bytesToString(status.memUsedByRDD(rddId))}
({Utils.bytesToString(status.memRemaining)} Remaining)
</td>
- <td>{Utils.bytesToString(status.diskUsed(prefix))}</td>
+ <td>{Utils.bytesToString(status.diskUsedByRDD(rddId))}</td>
</tr>
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index 0ce1394c77..3f963727d9 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -56,9 +56,10 @@ class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, clea
}
object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext", "HttpBroadcast", "DagScheduler", "ResultTask",
- "ShuffleMapTask", "BlockManager", "BroadcastVars") {
+ "ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") {
- val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, SHUFFLE_MAP_TASK, BLOCK_MANAGER, BROADCAST_VARS = Value
+ val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
+ SHUFFLE_MAP_TASK, BLOCK_MANAGER, DISK_BLOCK_MANAGER, BROADCAST_VARS = Value
type MetadataCleanerType = Value
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index f384875cc9..a3b3968c5e 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -447,14 +447,17 @@ private[spark] object Utils extends Logging {
hostPortParseResults.get(hostPort)
}
- private[spark] val daemonThreadFactory: ThreadFactory =
- new ThreadFactoryBuilder().setDaemon(true).build()
+ private val daemonThreadFactoryBuilder: ThreadFactoryBuilder =
+ new ThreadFactoryBuilder().setDaemon(true)
/**
- * Wrapper over newCachedThreadPool.
+ * Wrapper over newCachedThreadPool. Thread names are formatted as prefix-ID, where ID is a
+ * unique, sequentially assigned integer.
*/
- def newDaemonCachedThreadPool(): ThreadPoolExecutor =
- Executors.newCachedThreadPool(daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
+ def newDaemonCachedThreadPool(prefix: String): ThreadPoolExecutor = {
+ val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build()
+ Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor]
+ }
/**
* Return the string to tell how long has passed in seconds. The passing parameter should be in
@@ -465,10 +468,13 @@ private[spark] object Utils extends Logging {
}
/**
- * Wrapper over newFixedThreadPool.
+ * Wrapper over newFixedThreadPool. Thread names are formatted as prefix-ID, where ID is a
+ * unique, sequentially assigned integer.
*/
- def newDaemonFixedThreadPool(nThreads: Int): ThreadPoolExecutor =
- Executors.newFixedThreadPool(nThreads, daemonThreadFactory).asInstanceOf[ThreadPoolExecutor]
+ def newDaemonFixedThreadPool(nThreads: Int, prefix: String): ThreadPoolExecutor = {
+ val threadFactory = daemonThreadFactoryBuilder.setNameFormat(prefix + "-%d").build()
+ Executors.newFixedThreadPool(nThreads, threadFactory).asInstanceOf[ThreadPoolExecutor]
+ }
private def listFilesSafely(file: File): Seq[File] = {
val files = file.listFiles()
diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
index b3a53d928b..e022accee6 100644
--- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala
@@ -20,8 +20,42 @@ package org.apache.spark
import org.scalatest.FunSuite
class BroadcastSuite extends FunSuite with LocalSparkContext {
-
- test("basic broadcast") {
+
+ override def afterEach() {
+ super.afterEach()
+ System.clearProperty("spark.broadcast.factory")
+ }
+
+ test("Using HttpBroadcast locally") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+ sc = new SparkContext("local", "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === Set((1, 10), (2, 10)))
+ }
+
+ test("Accessing HttpBroadcast variables from multiple threads") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+ sc = new SparkContext("local[10]", "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
+ }
+
+ test("Accessing HttpBroadcast variables in a local cluster") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
+ val numSlaves = 4
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ }
+
+ test("Using TorrentBroadcast locally") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
sc = new SparkContext("local", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
@@ -29,11 +63,23 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
assert(results.collect.toSet === Set((1, 10), (2, 10)))
}
- test("broadcast variables accessed in multiple threads") {
+ test("Accessing TorrentBroadcast variables from multiple threads") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
sc = new SparkContext("local[10]", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
}
+
+ test("Accessing TorrentBroadcast variables in a local cluster") {
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
+ val numSlaves = 4
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
+ val list = List(1, 2, 3, 4)
+ val listBroadcast = sc.broadcast(list)
+ val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
+ assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 3a7171c488..ea936e815b 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.mock.EasyMockSugar
import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.{BlockManager, StorageLevel}
+import org.apache.spark.storage.{BlockManager, RDDBlockId, StorageLevel}
// TODO: Test the CacheManager's thread-safety aspects
class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar {
@@ -52,13 +52,14 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
test("get uncached rdd") {
expecting {
- blockManager.get("rdd_0_0").andReturn(None)
- blockManager.put("rdd_0_0", ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY, true).
- andReturn(0)
+ blockManager.get(RDDBlockId(0, 0)).andReturn(None)
+ blockManager.put(RDDBlockId(0, 0), ArrayBuffer[Any](1, 2, 3, 4), StorageLevel.MEMORY_ONLY,
+ true).andReturn(0)
}
whenExecuting(blockManager) {
- val context = new TaskContext(0, 0, 0, runningLocally = false, null)
+ val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
+ taskMetrics = null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
@@ -66,11 +67,12 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
test("get cached rdd") {
expecting {
- blockManager.get("rdd_0_0").andReturn(Some(ArrayBuffer(5, 6, 7).iterator))
+ blockManager.get(RDDBlockId(0, 0)).andReturn(Some(ArrayBuffer(5, 6, 7).iterator))
}
whenExecuting(blockManager) {
- val context = new TaskContext(0, 0, 0, runningLocally = false, null)
+ val context = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
+ taskMetrics = null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
@@ -79,11 +81,12 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
test("get uncached local rdd") {
expecting {
// Local computation should not persist the resulting value, so don't expect a put().
- blockManager.get("rdd_0_0").andReturn(None)
+ blockManager.get(RDDBlockId(0, 0)).andReturn(None)
}
whenExecuting(blockManager) {
- val context = new TaskContext(0, 0, 0, runningLocally = true, null)
+ val context = new TaskContext(0, 0, 0, runningLocally = true, interrupted = false,
+ taskMetrics = null)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index d9103aebb7..f26c44d3e7 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -21,7 +21,7 @@ import org.scalatest.FunSuite
import java.io.File
import org.apache.spark.rdd._
import org.apache.spark.SparkContext._
-import storage.StorageLevel
+import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
import org.apache.spark.util.Utils
class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
@@ -62,8 +62,8 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testCheckpointing(_.sample(false, 0.5, 0))
testCheckpointing(_.glom())
testCheckpointing(_.mapPartitions(_.map(_.toString)))
- testCheckpointing(r => new MapPartitionsWithIndexRDD(r,
- (i: Int, iter: Iterator[Int]) => iter.map(_.toString), false ))
+ testCheckpointing(r => new MapPartitionsWithContextRDD(r,
+ (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false ))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
testCheckpointing(_.pipe(Seq("cat")))
@@ -83,7 +83,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
test("BlockRDD") {
- val blockId = "id"
+ val blockId = TestBlockId("id")
val blockManager = SparkEnv.get.blockManager
blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
val blockRDD = new BlockRDD[String](sc, Array(blockId))
@@ -191,7 +191,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
test("CheckpointRDD with zero partitions") {
- val rdd = new BlockRDD[Int](sc, Array[String]())
+ val rdd = new BlockRDD[Int](sc, Array[BlockId]())
assert(rdd.partitions.size === 0)
assert(rdd.isCheckpointed === false)
rdd.checkpoint()
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index cd2bf9a8ff..480bac84f3 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -18,24 +18,14 @@
package org.apache.spark
import network.ConnectionManagerId
-import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Timeouts._
+import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
-import org.scalatest.prop.Checkers
import org.scalatest.time.{Span, Millis}
-import org.scalacheck.Arbitrary._
-import org.scalacheck.Gen
-import org.scalacheck.Prop._
-import org.eclipse.jetty.server.{Server, Request, Handler}
-
-import com.google.common.io.Files
-
-import scala.collection.mutable.ArrayBuffer
import SparkContext._
-import storage.{GetBlock, BlockManagerWorker, StorageLevel}
-import ui.JettyUtils
+import org.apache.spark.storage.{BlockManagerWorker, GetBlock, RDDBlockId, StorageLevel}
class NotSerializableClass
@@ -193,7 +183,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
// Get all the locations of the first partition and try to fetch the partitions
// from those locations.
- val blockIds = data.partitions.indices.map(index => "rdd_%d_%d".format(data.id, index)).toArray
+ val blockIds = data.partitions.indices.map(index => RDDBlockId(data.id, index)).toArray
val blockId = blockIds(0)
val blockManager = SparkEnv.get.blockManager
blockManager.master.getLocations(blockId).foreach(id => {
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 591c1d498d..7b0bb89ab2 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -495,7 +495,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContext(0, 0, 0, false, null);
+ TaskContext context = new TaskContext(0, 0, 0, false, false, null);
Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
}
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
new file mode 100644
index 0000000000..d8a0e983b2
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.util.concurrent.Semaphore
+
+import scala.concurrent.Await
+import scala.concurrent.duration.Duration
+import scala.concurrent.future
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.scalatest.matchers.ShouldMatchers
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.scheduler.{SparkListenerTaskStart, SparkListener}
+
+
+/**
+ * Test suite for cancelling running jobs. We run the cancellation tasks for single job action
+ * (e.g. count) as well as multi-job action (e.g. take). We test the local and cluster schedulers
+ * in both FIFO and fair scheduling modes.
+ */
+class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
+ with LocalSparkContext {
+
+ override def afterEach() {
+ super.afterEach()
+ resetSparkContext()
+ System.clearProperty("spark.scheduler.mode")
+ }
+
+ test("local mode, FIFO scheduler") {
+ System.setProperty("spark.scheduler.mode", "FIFO")
+ sc = new SparkContext("local[2]", "test")
+ testCount()
+ testTake()
+ // Make sure we can still launch tasks.
+ assert(sc.parallelize(1 to 10, 2).count === 10)
+ }
+
+ test("local mode, fair scheduler") {
+ System.setProperty("spark.scheduler.mode", "FAIR")
+ val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
+ System.setProperty("spark.scheduler.allocation.file", xmlPath)
+ sc = new SparkContext("local[2]", "test")
+ testCount()
+ testTake()
+ // Make sure we can still launch tasks.
+ assert(sc.parallelize(1 to 10, 2).count === 10)
+ }
+
+ test("cluster mode, FIFO scheduler") {
+ System.setProperty("spark.scheduler.mode", "FIFO")
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ testCount()
+ testTake()
+ // Make sure we can still launch tasks.
+ assert(sc.parallelize(1 to 10, 2).count === 10)
+ }
+
+ test("cluster mode, fair scheduler") {
+ System.setProperty("spark.scheduler.mode", "FAIR")
+ val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
+ System.setProperty("spark.scheduler.allocation.file", xmlPath)
+ sc = new SparkContext("local-cluster[2,1,512]", "test")
+ testCount()
+ testTake()
+ // Make sure we can still launch tasks.
+ assert(sc.parallelize(1 to 10, 2).count === 10)
+ }
+
+ test("job group") {
+ sc = new SparkContext("local[2]", "test")
+
+ // Add a listener to release the semaphore once any tasks are launched.
+ val sem = new Semaphore(0)
+ sc.dagScheduler.addSparkListener(new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ sem.release()
+ }
+ })
+
+ // jobA is the one to be cancelled.
+ val jobA = future {
+ sc.setJobGroup("jobA", "this is a job to be cancelled")
+ sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
+ }
+
+ sc.clearJobGroup()
+ val jobB = sc.parallelize(1 to 100, 2).countAsync()
+
+ // Block until both tasks of job A have started and cancel job A.
+ sem.acquire(2)
+ sc.cancelJobGroup("jobA")
+ val e = intercept[SparkException] { Await.result(jobA, Duration.Inf) }
+ assert(e.getMessage contains "cancel")
+
+ // Once A is cancelled, job B should finish fairly quickly.
+ assert(jobB.get() === 100)
+ }
+
+ test("two jobs sharing the same stage") {
+ // sem1: make sure cancel is issued after some tasks are launched
+ // sem2: make sure the first stage is not finished until cancel is issued
+ val sem1 = new Semaphore(0)
+ val sem2 = new Semaphore(0)
+
+ sc = new SparkContext("local[2]", "test")
+ sc.dagScheduler.addSparkListener(new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ sem1.release()
+ }
+ })
+
+ // Create two actions that would share the some stages.
+ val rdd = sc.parallelize(1 to 10, 2).map { i =>
+ sem2.acquire()
+ (i, i)
+ }.reduceByKey(_+_)
+ val f1 = rdd.collectAsync()
+ val f2 = rdd.countAsync()
+
+ // Kill one of the action.
+ future {
+ sem1.acquire()
+ f1.cancel()
+ sem2.release(10)
+ }
+
+ // Expect both to fail now.
+ // TODO: update this test when we change Spark so cancelling f1 wouldn't affect f2.
+ intercept[SparkException] { f1.get() }
+ intercept[SparkException] { f2.get() }
+ }
+
+ def testCount() {
+ // Cancel before launching any tasks
+ {
+ val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync()
+ future { f.cancel() }
+ val e = intercept[SparkException] { f.get() }
+ assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
+ }
+
+ // Cancel after some tasks have been launched
+ {
+ // Add a listener to release the semaphore once any tasks are launched.
+ val sem = new Semaphore(0)
+ sc.dagScheduler.addSparkListener(new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ sem.release()
+ }
+ })
+
+ val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync()
+ future {
+ // Wait until some tasks were launched before we cancel the job.
+ sem.acquire()
+ f.cancel()
+ }
+ val e = intercept[SparkException] { f.get() }
+ assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
+ }
+ }
+
+ def testTake() {
+ // Cancel before launching any tasks
+ {
+ val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000)
+ future { f.cancel() }
+ val e = intercept[SparkException] { f.get() }
+ assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
+ }
+
+ // Cancel after some tasks have been launched
+ {
+ // Add a listener to release the semaphore once any tasks are launched.
+ val sem = new Semaphore(0)
+ sc.dagScheduler.addSparkListener(new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ sem.release()
+ }
+ })
+ val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000)
+ future {
+ sem.acquire()
+ f.cancel()
+ }
+ val e = intercept[SparkException] { f.get() }
+ assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
new file mode 100644
index 0000000000..da032b17d9
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import java.util.concurrent.Semaphore
+
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.concurrent.Timeouts
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.{SparkContext, SparkException, LocalSparkContext}
+
+
+class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll with Timeouts {
+
+ @transient private var sc: SparkContext = _
+
+ override def beforeAll() {
+ sc = new SparkContext("local[2]", "test")
+ }
+
+ override def afterAll() {
+ LocalSparkContext.stop(sc)
+ sc = null
+ }
+
+ lazy val zeroPartRdd = new EmptyRDD[Int](sc)
+
+ test("countAsync") {
+ assert(zeroPartRdd.countAsync().get() === 0)
+ assert(sc.parallelize(1 to 10000, 5).countAsync().get() === 10000)
+ }
+
+ test("collectAsync") {
+ assert(zeroPartRdd.collectAsync().get() === Seq.empty)
+
+ val collected = sc.parallelize(1 to 1000, 3).collectAsync().get()
+ assert(collected === (1 to 1000))
+ }
+
+ test("foreachAsync") {
+ zeroPartRdd.foreachAsync(i => Unit).get()
+
+ val accum = sc.accumulator(0)
+ sc.parallelize(1 to 1000, 3).foreachAsync { i =>
+ accum += 1
+ }.get()
+ assert(accum.value === 1000)
+ }
+
+ test("foreachPartitionAsync") {
+ zeroPartRdd.foreachPartitionAsync(iter => Unit).get()
+
+ val accum = sc.accumulator(0)
+ sc.parallelize(1 to 1000, 9).foreachPartitionAsync { iter =>
+ accum += 1
+ }.get()
+ assert(accum.value === 9)
+ }
+
+ test("takeAsync") {
+ def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) {
+ val expected = input.take(num)
+ val saw = rdd.takeAsync(num).get()
+ assert(saw == expected, "incorrect result for rdd with %d partitions (expected %s, saw %s)"
+ .format(rdd.partitions.size, expected, saw))
+ }
+ val input = Range(1, 1000)
+
+ var rdd = sc.parallelize(input, 1)
+ for (num <- Seq(0, 1, 999, 1000)) {
+ testTake(rdd, input, num)
+ }
+
+ rdd = sc.parallelize(input, 2)
+ for (num <- Seq(0, 1, 3, 500, 501, 999, 1000)) {
+ testTake(rdd, input, num)
+ }
+
+ rdd = sc.parallelize(input, 100)
+ for (num <- Seq(0, 1, 500, 501, 999, 1000)) {
+ testTake(rdd, input, num)
+ }
+
+ rdd = sc.parallelize(input, 1000)
+ for (num <- Seq(0, 1, 3, 999, 1000)) {
+ testTake(rdd, input, num)
+ }
+ }
+
+ /**
+ * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
+ * of a successful job execution.
+ */
+ test("async success handling") {
+ val f = sc.parallelize(1 to 10, 2).countAsync()
+
+ // Use a semaphore to make sure onSuccess and onComplete's success path will be called.
+ // If not, the test will hang.
+ val sem = new Semaphore(0)
+
+ f.onComplete {
+ case scala.util.Success(res) =>
+ sem.release()
+ case scala.util.Failure(e) =>
+ info("Should not have reached this code path (onComplete matching Failure)")
+ throw new Exception("Task should succeed")
+ }
+ f.onSuccess { case a: Any =>
+ sem.release()
+ }
+ f.onFailure { case t =>
+ info("Should not have reached this code path (onFailure)")
+ throw new Exception("Task should succeed")
+ }
+ assert(f.get() === 10)
+
+ failAfter(10 seconds) {
+ sem.acquire(2)
+ }
+ }
+
+ /**
+ * Make sure onComplete, onSuccess, and onFailure are invoked correctly in the case
+ * of a failed job execution.
+ */
+ test("async failure handling") {
+ val f = sc.parallelize(1 to 10, 2).map { i =>
+ throw new Exception("intentional"); i
+ }.countAsync()
+
+ // Use a semaphore to make sure onFailure and onComplete's failure path will be called.
+ // If not, the test will hang.
+ val sem = new Semaphore(0)
+
+ f.onComplete {
+ case scala.util.Success(res) =>
+ info("Should not have reached this code path (onComplete matching Success)")
+ throw new Exception("Task should fail")
+ case scala.util.Failure(e) =>
+ sem.release()
+ }
+ f.onSuccess { case a: Any =>
+ info("Should not have reached this code path (onSuccess)")
+ throw new Exception("Task should fail")
+ }
+ f.onFailure { case t =>
+ sem.release()
+ }
+ intercept[SparkException] {
+ f.get()
+ }
+
+ failAfter(10 seconds) {
+ sem.acquire(2)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 31f97fc139..57d3382ed0 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -106,7 +106,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
}
}
visit(sums)
- assert(deps.size === 2) // ShuffledRDD, ParallelCollection
+ assert(deps.size === 2) // ShuffledRDD, ParallelCollection.
}
test("join") {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 2f933246b0..2a2f828be6 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -24,15 +24,14 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.LocalSparkContext
import org.apache.spark.MapOutputTracker
-import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.spark.Partition
import org.apache.spark.TaskContext
import org.apache.spark.{Dependency, ShuffleDependency, OneToOneDependency}
import org.apache.spark.{FetchFailed, Success, TaskEndReason}
-import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster}
-
+import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
+import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
/**
* Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
@@ -60,7 +59,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
taskSets += taskSet
}
- override def setListener(listener: TaskSchedulerListener) = {}
+ override def cancelTasks(stageId: Int) {}
+ override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
override def defaultParallelism() = 2
}
@@ -75,15 +75,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
// stub out BlockManagerMaster.getLocations to use our cacheLocations
val blockManagerMaster = new BlockManagerMaster(null) {
- override def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
- blockIds.map { name =>
- val pieces = name.split("_")
- if (pieces(0) == "rdd") {
- val key = pieces(1).toInt -> pieces(2).toInt
- cacheLocations.getOrElse(key, Seq())
- } else {
- Seq()
- }
+ override def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
+ blockIds.map {
+ _.asRDDId.map(id => (id.rddId -> id.splitIndex)).flatMap(key => cacheLocations.get(key)).
+ getOrElse(Seq())
}.toSeq
}
override def removeExecutor(execId: String) {
@@ -186,7 +181,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
func: (TaskContext, Iterator[_]) => _ = jobComputeFunc,
allowLocal: Boolean = false,
listener: JobListener = listener) {
- runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener))
+ val jobId = scheduler.nextJobId.getAndIncrement()
+ runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, null, listener))
}
/** Sends TaskSetFailed to the scheduler. */
@@ -220,7 +216,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
override def getPreferredLocations(split: Partition) = Nil
override def toString = "DAGSchedulerSuite Local RDD"
}
- runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener))
+ val jobId = scheduler.nextJobId.getAndIncrement()
+ runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, listener))
assert(results === Map(0 -> 42))
}
@@ -247,7 +244,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
test("trivial job failure") {
submit(makeRdd(1, Nil), Array(0))
failed(taskSets(0), "some failure")
- assert(failure.getMessage === "Job failed: some failure")
+ assert(failure.getMessage === "Job aborted: some failure")
}
test("run trivial shuffle") {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
index 80d0c5a5e9..b97f2b19b5 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
@@ -28,6 +28,30 @@ import org.apache.spark.executor.TaskMetrics
import java.nio.ByteBuffer
import org.apache.spark.util.{Utils, FakeClock}
+class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler(taskScheduler) {
+ override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
+ taskScheduler.startedTasks += taskInfo.index
+ }
+
+ override def taskEnded(
+ task: Task[_],
+ reason: TaskEndReason,
+ result: Any,
+ accumUpdates: mutable.Map[Long, Any],
+ taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics) {
+ taskScheduler.endedTasks(taskInfo.index) = reason
+ }
+
+ override def executorGained(execId: String, host: String) {}
+
+ override def executorLost(execId: String) {}
+
+ override def taskSetFailed(taskSet: TaskSet, reason: String) {
+ taskScheduler.taskSetsFailed += taskSet.id
+ }
+}
+
/**
* A mock ClusterScheduler implementation that just remembers information about tasks started and
* feedback received from the TaskSetManagers. Note that it's important to initialize this with
@@ -44,30 +68,7 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /*
val executors = new mutable.HashMap[String, String] ++ liveExecutors
- listener = new TaskSchedulerListener {
- def taskStarted(task: Task[_], taskInfo: TaskInfo) {
- startedTasks += taskInfo.index
- }
-
- def taskEnded(
- task: Task[_],
- reason: TaskEndReason,
- result: Any,
- accumUpdates: mutable.Map[Long, Any],
- taskInfo: TaskInfo,
- taskMetrics: TaskMetrics)
- {
- endedTasks(taskInfo.index) = reason
- }
-
- def executorGained(execId: String, host: String) {}
-
- def executorLost(execId: String) {}
-
- def taskSetFailed(taskSet: TaskSet, reason: String) {
- taskSetsFailed += taskSet.id
- }
- }
+ dagScheduler = new FakeDAGScheduler(this)
def removeExecutor(execId: String): Unit = executors -= execId
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
index 2f12aaed18..0f01515179 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
@@ -17,10 +17,11 @@
package org.apache.spark.scheduler.cluster
+import org.apache.spark.TaskContext
import org.apache.spark.scheduler.{TaskLocation, Task}
-class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId) {
- override def run(attemptId: Long): Int = 0
+class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
+ override def runTask(context: TaskContext): Int = 0
override def preferredLocations: Seq[TaskLocation] = prefLocs
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
index 119ba30090..ee150a3107 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
@@ -23,6 +23,7 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv}
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
+import org.apache.spark.storage.TaskResultBlockId
/**
* Removes the TaskResult from the BlockManager before delegating to a normal TaskResultGetter.
@@ -85,7 +86,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
assert(result === 1.to(akkaFrameSize).toArray)
- val RESULT_BLOCK_ID = "taskresult_0"
+ val RESULT_BLOCK_ID = TaskResultBlockId(0)
assert(sc.env.blockManager.master.getLocations(RESULT_BLOCK_ID).size === 0,
"Expect result to be removed from the block manager.")
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
index af76c843e8..1e676c1719 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
@@ -17,17 +17,15 @@
package org.apache.spark.scheduler.local
-import org.scalatest.FunSuite
-import org.scalatest.BeforeAndAfter
-
-import org.apache.spark._
-import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster._
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.{ConcurrentMap, HashMap}
import java.util.concurrent.Semaphore
import java.util.concurrent.CountDownLatch
-import java.util.Properties
+
+import scala.collection.mutable.HashMap
+
+import org.scalatest.{BeforeAndAfterEach, FunSuite}
+
+import org.apache.spark._
+
class Lock() {
var finished = false
@@ -63,7 +61,12 @@ object TaskThreadInfo {
* 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue,
* thus it will be scheduled later when cluster has free cpu cores.
*/
-class LocalSchedulerSuite extends FunSuite with LocalSparkContext {
+class LocalSchedulerSuite extends FunSuite with LocalSparkContext with BeforeAndAfterEach {
+
+ override def afterEach() {
+ super.afterEach()
+ System.clearProperty("spark.scheduler.mode")
+ }
def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
@@ -148,12 +151,13 @@ class LocalSchedulerSuite extends FunSuite with LocalSparkContext {
}
test("Local fair scheduler end-to-end test") {
- sc = new SparkContext("local[8]", "LocalSchedulerSuite")
- val sem = new Semaphore(0)
System.setProperty("spark.scheduler.mode", "FAIR")
val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
System.setProperty("spark.scheduler.allocation.file", xmlPath)
+ sc = new SparkContext("local[8]", "LocalSchedulerSuite")
+ val sem = new Semaphore(0)
+
createThread(10,"1",sc,sem)
TaskThreadInfo.threadToStarted(10).await()
createThread(20,"2",sc,sem)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
new file mode 100644
index 0000000000..cb76275e39
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.scalatest.FunSuite
+
+class BlockIdSuite extends FunSuite {
+ def assertSame(id1: BlockId, id2: BlockId) {
+ assert(id1.name === id2.name)
+ assert(id1.hashCode === id2.hashCode)
+ assert(id1 === id2)
+ }
+
+ def assertDifferent(id1: BlockId, id2: BlockId) {
+ assert(id1.name != id2.name)
+ assert(id1.hashCode != id2.hashCode)
+ assert(id1 != id2)
+ }
+
+ test("test-bad-deserialization") {
+ try {
+ // Try to deserialize an invalid block id.
+ BlockId("myblock")
+ fail()
+ } catch {
+ case e: IllegalStateException => // OK
+ case _ => fail()
+ }
+ }
+
+ test("rdd") {
+ val id = RDDBlockId(1, 2)
+ assertSame(id, RDDBlockId(1, 2))
+ assertDifferent(id, RDDBlockId(1, 1))
+ assert(id.name === "rdd_1_2")
+ assert(id.asRDDId.get.rddId === 1)
+ assert(id.asRDDId.get.splitIndex === 2)
+ assert(id.isRDD)
+ assertSame(id, BlockId(id.toString))
+ }
+
+ test("shuffle") {
+ val id = ShuffleBlockId(1, 2, 3)
+ assertSame(id, ShuffleBlockId(1, 2, 3))
+ assertDifferent(id, ShuffleBlockId(3, 2, 3))
+ assert(id.name === "shuffle_1_2_3")
+ assert(id.asRDDId === None)
+ assert(id.shuffleId === 1)
+ assert(id.mapId === 2)
+ assert(id.reduceId === 3)
+ assert(id.isShuffle)
+ assertSame(id, BlockId(id.toString))
+ }
+
+ test("broadcast") {
+ val id = BroadcastBlockId(42)
+ assertSame(id, BroadcastBlockId(42))
+ assertDifferent(id, BroadcastBlockId(123))
+ assert(id.name === "broadcast_42")
+ assert(id.asRDDId === None)
+ assert(id.broadcastId === 42)
+ assert(id.isBroadcast)
+ assertSame(id, BlockId(id.toString))
+ }
+
+ test("taskresult") {
+ val id = TaskResultBlockId(60)
+ assertSame(id, TaskResultBlockId(60))
+ assertDifferent(id, TaskResultBlockId(61))
+ assert(id.name === "taskresult_60")
+ assert(id.asRDDId === None)
+ assert(id.taskId === 60)
+ assert(!id.isRDD)
+ assertSame(id, BlockId(id.toString))
+ }
+
+ test("stream") {
+ val id = StreamBlockId(1, 100)
+ assertSame(id, StreamBlockId(1, 100))
+ assertDifferent(id, StreamBlockId(2, 101))
+ assert(id.name === "input-1-100")
+ assert(id.asRDDId === None)
+ assert(id.streamId === 1)
+ assert(id.uniqueId === 100)
+ assert(!id.isBroadcast)
+ assertSame(id, BlockId(id.toString))
+ }
+
+ test("test") {
+ val id = TestBlockId("abc")
+ assertSame(id, TestBlockId("abc"))
+ assertDifferent(id, TestBlockId("ab"))
+ assert(id.name === "test_abc")
+ assert(id.asRDDId === None)
+ assert(id.id === "abc")
+ assert(!id.isShuffle)
+ assertSame(id, BlockId(id.toString))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 038a9acb85..484a654108 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -32,7 +32,6 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.util.{SizeEstimator, Utils, AkkaUtils, ByteBufferInputStream}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
-
class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester {
var store: BlockManager = null
var store2: BlockManager = null
@@ -46,6 +45,10 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
System.setProperty("spark.kryoserializer.buffer.mb", "1")
val serializer = new KryoSerializer
+ // Implicitly convert strings to BlockIds for test clarity.
+ implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value)
+ def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId)
+
before {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0)
this.actorSystem = actorSystem
@@ -229,31 +232,31 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
// Putting a1, a2 and a3 in memory.
- store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY)
store.putSingle("nonrddblock", a3, StorageLevel.MEMORY_ONLY)
master.removeRdd(0, blocking = false)
eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
- store.getSingle("rdd_0_0") should be (None)
- master.getLocations("rdd_0_0") should have size 0
+ store.getSingle(rdd(0, 0)) should be (None)
+ master.getLocations(rdd(0, 0)) should have size 0
}
eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
- store.getSingle("rdd_0_1") should be (None)
- master.getLocations("rdd_0_1") should have size 0
+ store.getSingle(rdd(0, 1)) should be (None)
+ master.getLocations(rdd(0, 1)) should have size 0
}
eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
store.getSingle("nonrddblock") should not be (None)
master.getLocations("nonrddblock") should have size (1)
}
- store.putSingle("rdd_0_0", a1, StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_1", a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 0), a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 1), a2, StorageLevel.MEMORY_ONLY)
master.removeRdd(0, blocking = true)
- store.getSingle("rdd_0_0") should be (None)
- master.getLocations("rdd_0_0") should have size 0
- store.getSingle("rdd_0_1") should be (None)
- master.getLocations("rdd_0_1") should have size 0
+ store.getSingle(rdd(0, 0)) should be (None)
+ master.getLocations(rdd(0, 0)) should have size 0
+ store.getSingle(rdd(0, 1)) should be (None)
+ master.getLocations(rdd(0, 1)) should have size 0
}
test("reregistration on heart beat") {
@@ -372,41 +375,41 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
val a1 = new Array[Byte](400)
val a2 = new Array[Byte](400)
val a3 = new Array[Byte](400)
- store.putSingle("rdd_0_1", a1, StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_2", a2, StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_3", a3, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 1), a1, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 2), a2, StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 3), a3, StorageLevel.MEMORY_ONLY)
// Even though we accessed rdd_0_3 last, it should not have replaced partitions 1 and 2
// from the same RDD
- assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
- assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store")
- assert(store.getSingle("rdd_0_1") != None, "rdd_0_1 was not in store")
+ assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store")
+ assert(store.getSingle(rdd(0, 2)) != None, "rdd_0_2 was not in store")
+ assert(store.getSingle(rdd(0, 1)) != None, "rdd_0_1 was not in store")
// Check that rdd_0_3 doesn't replace them even after further accesses
- assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
- assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
- assert(store.getSingle("rdd_0_3") === None, "rdd_0_3 was in store")
+ assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store")
+ assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store")
+ assert(store.getSingle(rdd(0, 3)) === None, "rdd_0_3 was in store")
}
test("in-memory LRU for partitions of multiple RDDs") {
store = new BlockManager("<driver>", actorSystem, master, serializer, 1200)
- store.putSingle("rdd_0_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_2", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_1_1", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
// At this point rdd_1_1 should've replaced rdd_0_1
- assert(store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was not in store")
- assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store")
- assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store")
+ assert(store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was not in store")
+ assert(!store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was in store")
+ assert(store.memoryStore.contains(rdd(0, 2)), "rdd_0_2 was not in store")
// Do a get() on rdd_0_2 so that it is the most recently used item
- assert(store.getSingle("rdd_0_2") != None, "rdd_0_2 was not in store")
+ assert(store.getSingle(rdd(0, 2)) != None, "rdd_0_2 was not in store")
// Put in more partitions from RDD 0; they should replace rdd_1_1
- store.putSingle("rdd_0_3", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
- store.putSingle("rdd_0_4", new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 3), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
+ store.putSingle(rdd(0, 4), new Array[Byte](400), StorageLevel.MEMORY_ONLY)
// Now rdd_1_1 should be dropped to add rdd_0_3, but then rdd_0_2 should *not* be dropped
// when we try to add rdd_0_4.
- assert(!store.memoryStore.contains("rdd_1_1"), "rdd_1_1 was in store")
- assert(!store.memoryStore.contains("rdd_0_1"), "rdd_0_1 was in store")
- assert(!store.memoryStore.contains("rdd_0_4"), "rdd_0_4 was in store")
- assert(store.memoryStore.contains("rdd_0_2"), "rdd_0_2 was not in store")
- assert(store.memoryStore.contains("rdd_0_3"), "rdd_0_3 was not in store")
+ assert(!store.memoryStore.contains(rdd(1, 1)), "rdd_1_1 was in store")
+ assert(!store.memoryStore.contains(rdd(0, 1)), "rdd_0_1 was in store")
+ assert(!store.memoryStore.contains(rdd(0, 4)), "rdd_0_4 was in store")
+ assert(store.memoryStore.contains(rdd(0, 2)), "rdd_0_2 was not in store")
+ assert(store.memoryStore.contains(rdd(0, 3)), "rdd_0_3 was not in store")
}
test("on-disk storage") {
@@ -590,43 +593,46 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
try {
System.setProperty("spark.shuffle.compress", "true")
store = new BlockManager("exec1", actorSystem, master, serializer, 2000)
- store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("shuffle_0_0_0") <= 100, "shuffle_0_0_0 was not compressed")
+ store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100,
+ "shuffle_0_0_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.shuffle.compress", "false")
store = new BlockManager("exec2", actorSystem, master, serializer, 2000)
- store.putSingle("shuffle_0_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("shuffle_0_0_0") >= 1000, "shuffle_0_0_0 was compressed")
+ store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000,
+ "shuffle_0_0_0 was compressed")
store.stop()
store = null
System.setProperty("spark.broadcast.compress", "true")
store = new BlockManager("exec3", actorSystem, master, serializer, 2000)
- store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("broadcast_0") <= 100, "broadcast_0 was not compressed")
+ store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100,
+ "broadcast_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.broadcast.compress", "false")
store = new BlockManager("exec4", actorSystem, master, serializer, 2000)
- store.putSingle("broadcast_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("broadcast_0") >= 1000, "broadcast_0 was compressed")
+ store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed")
store.stop()
store = null
System.setProperty("spark.rdd.compress", "true")
store = new BlockManager("exec5", actorSystem, master, serializer, 2000)
- store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("rdd_0_0") <= 100, "rdd_0_0 was not compressed")
+ store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed")
store.stop()
store = null
System.setProperty("spark.rdd.compress", "false")
store = new BlockManager("exec6", actorSystem, master, serializer, 2000)
- store.putSingle("rdd_0_0", new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
- assert(store.memoryStore.getSize("rdd_0_0") >= 1000, "rdd_0_0 was compressed")
+ store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER)
+ assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed")
store.stop()
store = null
diff --git a/docs/configuration.md b/docs/configuration.md
index 7940d41a27..97183bafdb 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -149,7 +149,7 @@ Apart from these, the following properties are also available, and may be useful
<td>spark.io.compression.codec</td>
<td>org.apache.spark.io.<br />LZFCompressionCodec</td>
<td>
- The compression codec class to use for various compressions. By default, Spark provides two
+ The codec used to compress internal data such as RDD partitions and shuffle outputs. By default, Spark provides two
codecs: <code>org.apache.spark.io.LZFCompressionCodec</code> and <code>org.apache.spark.io.SnappyCompressionCodec</code>.
</td>
</tr>
@@ -319,6 +319,14 @@ Apart from these, the following properties are also available, and may be useful
Should be greater than or equal to 1. Number of allowed retries = this value - 1.
</td>
</tr>
+<tr>
+ <td>spark.broadcast.blockSize</td>
+ <td>4096</td>
+ <td>
+ Size of each piece of a block in kilobytes for <code>TorrentBroadcastFactory</code>.
+ Too large a value decreases parallelism during broadcast (makes it slower); however, if it is too small, <code>BlockManager</code> might take a performance hit.
+ </td>
+</tr>
</table>
diff --git a/examples/pom.xml b/examples/pom.xml
index b97e6af288..aee371fbc7 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -153,6 +153,14 @@
<groupId>org.apache.cassandra.deps</groupId>
<artifactId>avro</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>org.sonatype.sisu.inject</groupId>
+ <artifactId>*</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.xerial.snappy</groupId>
+ <artifactId>*</artifactId>
+ </exclusion>
</exclusions>
</dependency>
</dependencies>
diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
index 868ff81f67..529709c2f9 100644
--- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
@@ -22,12 +22,19 @@ import org.apache.spark.SparkContext
object BroadcastTest {
def main(args: Array[String]) {
if (args.length == 0) {
- System.err.println("Usage: BroadcastTest <master> [<slices>] [numElem]")
+ System.err.println("Usage: BroadcastTest <master> [slices] [numElem] [broadcastAlgo] [blockSize]")
System.exit(1)
}
- val sc = new SparkContext(args(0), "Broadcast Test",
+ val bcName = if (args.length > 3) args(3) else "Http"
+ val blockSize = if (args.length > 4) args(4) else "4096"
+
+ System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast." + bcName + "BroadcastFactory")
+ System.setProperty("spark.broadcast.blockSize", blockSize)
+
+ val sc = new SparkContext(args(0), "Broadcast Test 2",
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
+
val slices = if (args.length > 1) args(1).toInt else 2
val num = if (args.length > 2) args(2).toInt else 1000000
@@ -36,13 +43,15 @@ object BroadcastTest {
arr1(i) = i
}
- for (i <- 0 until 2) {
+ for (i <- 0 until 3) {
println("Iteration " + i)
println("===========")
+ val startTime = System.nanoTime
val barr1 = sc.broadcast(arr1)
sc.parallelize(1 to 10, slices).foreach {
i => println(barr1.value.size)
}
+ println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6))
}
System.exit(0)
diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala
index 884d6d6f34..de70c50473 100644
--- a/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala
+++ b/examples/src/main/scala/org/apache/spark/streaming/examples/clickstream/PageViewGenerator.scala
@@ -17,17 +17,19 @@
package org.apache.spark.streaming.examples.clickstream
-import java.net.{InetAddress,ServerSocket,Socket,SocketException}
-import java.io.{InputStreamReader, BufferedReader, PrintWriter}
+import java.net.ServerSocket
+import java.io.PrintWriter
import util.Random
/** Represents a page view on a website with associated dimension data.*/
-class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int) {
+class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int)
+ extends Serializable {
override def toString() : String = {
"%s\t%s\t%s\t%s\n".format(url, status, zipCode, userID)
}
}
-object PageView {
+
+object PageView extends Serializable {
def fromString(in : String) : PageView = {
val parts = in.split("\t")
new PageView(parts(0), parts(1).toInt, parts(2).toInt, parts(3).toInt)
@@ -39,6 +41,9 @@ object PageView {
* This should be used in tandem with PageViewStream.scala. Example:
* $ ./run-example spark.streaming.examples.clickstream.PageViewGenerator 44444 10
* $ ./run-example spark.streaming.examples.clickstream.PageViewStream errorRatePerZipCode localhost 44444
+ *
+ * When running this, you may want to set the root logging level to ERROR in
+ * conf/log4j.properties to reduce the verbosity of the output.
* */
object PageViewGenerator {
val pages = Map("http://foo.com/" -> .7,
diff --git a/pom.xml b/pom.xml
index 18df1bf826..54f100c37f 100644
--- a/pom.xml
+++ b/pom.xml
@@ -21,7 +21,7 @@
<parent>
<groupId>org.apache</groupId>
<artifactId>apache</artifactId>
- <version>11</version>
+ <version>13</version>
</parent>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent</artifactId>
@@ -61,6 +61,29 @@
<maven>3.0.0</maven>
</prerequisites>
+ <mailingLists>
+ <mailingList>
+ <name>Dev Mailing List</name>
+ <post>dev@spark.incubator.apache.org</post>
+ <subscribe>dev-subscribe@spark.incubator.apache.org</subscribe>
+ <unsubscribe>dev-unsubscribe@spark.incubator.apache.org</unsubscribe>
+ </mailingList>
+
+ <mailingList>
+ <name>User Mailing List</name>
+ <post>user@spark.incubator.apache.org</post>
+ <subscribe>user-subscribe@spark.incubator.apache.org</subscribe>
+ <unsubscribe>user-unsubscribe@spark.incubator.apache.org</unsubscribe>
+ </mailingList>
+
+ <mailingList>
+ <name>Commits Mailing List</name>
+ <post>commits@spark.incubator.apache.org</post>
+ <subscribe>commits-subscribe@spark.incubator.apache.org</subscribe>
+ <unsubscribe>commits-unsubscribe@spark.incubator.apache.org</unsubscribe>
+ </mailingList>
+ </mailingLists>
+
<modules>
<module>core</module>
<module>bagel</module>
@@ -227,16 +250,34 @@
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-actor</artifactId>
<version>${akka.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-remote</artifactId>
<version>${akka.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-slf4j</artifactId>
<version>${akka.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>it.unimi.dsi</groupId>
@@ -371,19 +412,11 @@
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-core-asl</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-mapper-asl</artifactId>
+ <artifactId>*</artifactId>
</exclusion>
<exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-jaxrs</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-xc</artifactId>
+ <groupId>org.sonatype.sisu.inject</groupId>
+ <artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
@@ -407,19 +440,11 @@
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-core-asl</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-mapper-asl</artifactId>
+ <artifactId>*</artifactId>
</exclusion>
<exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-jaxrs</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-xc</artifactId>
+ <groupId>org.sonatype.sisu.inject</groupId>
+ <artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
@@ -438,19 +463,11 @@
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-core-asl</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-mapper-asl</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-jaxrs</artifactId>
+ <artifactId>*</artifactId>
</exclusion>
<exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-xc</artifactId>
+ <groupId>org.sonatype.sisu.inject</groupId>
+ <artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
@@ -469,19 +486,11 @@
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-core-asl</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-mapper-asl</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-jaxrs</artifactId>
+ <artifactId>*</artifactId>
</exclusion>
<exclusion>
- <groupId>org.codehaus.jackson</groupId>
- <artifactId>jackson-xc</artifactId>
+ <groupId>org.sonatype.sisu.inject</groupId>
+ <artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index b14970942b..17f480e3f0 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -60,6 +60,8 @@ object SparkBuild extends Build {
lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings)
.dependsOn(core, bagel, mllib, repl, streaming) dependsOn(maybeYarn: _*)
+ lazy val assembleDeps = TaskKey[Unit]("assemble-deps", "Build assembly of dependencies and packages Spark projects")
+
// A configuration to set an alternative publishLocalConfiguration
lazy val MavenCompile = config("m2r") extend(Compile)
lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
@@ -74,8 +76,11 @@ object SparkBuild extends Build {
// Conditionally include the yarn sub-project
lazy val maybeYarn = if(isYarnEnabled) Seq[ClasspathDependency](yarn) else Seq[ClasspathDependency]()
lazy val maybeYarnRef = if(isYarnEnabled) Seq[ProjectReference](yarn) else Seq[ProjectReference]()
- lazy val allProjects = Seq[ProjectReference](
- core, repl, examples, bagel, streaming, mllib, tools, assemblyProj) ++ maybeYarnRef
+
+ // Everything except assembly, tools and examples belong to packageProjects
+ lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib) ++ maybeYarnRef
+
+ lazy val allProjects = packageProjects ++ Seq[ProjectReference](examples, tools, assemblyProj)
def sharedSettings = Defaults.defaultSettings ++ Seq(
organization := "org.apache.spark",
@@ -306,7 +311,9 @@ object SparkBuild extends Build {
def assemblyProjSettings = sharedSettings ++ Seq(
name := "spark-assembly",
- jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" }
+ assembleDeps in Compile <<= (packageProjects.map(packageBin in Compile in _) ++ Seq(packageDependency in Compile)).dependOn,
+ jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" },
+ jarName in packageDependency <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + "-deps.jar" }
) ++ assemblySettings ++ extraAssemblySettings
def extraAssemblySettings() = Seq(
@@ -314,6 +321,7 @@ object SparkBuild extends Build {
mergeStrategy in assembly := {
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard
+ case "log4j.properties" => MergeStrategy.discard
case "META-INF/services/org.apache.hadoop.fs.FileSystem" => MergeStrategy.concat
case "reference.conf" => MergeStrategy.concat
case _ => MergeStrategy.first
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index d367f91967..da3d96689a 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -42,6 +42,13 @@
>>> a.value
13
+>>> b = sc.accumulator(0)
+>>> def g(x):
+... b.add(x)
+>>> rdd.foreach(g)
+>>> b.value
+6
+
>>> from pyspark.accumulators import AccumulatorParam
>>> class VectorAccumulatorParam(AccumulatorParam):
... def zero(self, value):
@@ -139,9 +146,13 @@ class Accumulator(object):
raise Exception("Accumulator.value cannot be accessed inside tasks")
self._value = value
+ def add(self, term):
+ """Adds a term to this accumulator's value"""
+ self._value = self.accum_param.addInPlace(self._value, term)
+
def __iadd__(self, term):
"""The += operator; adds a term to this accumulator's value"""
- self._value = self.accum_param.addInPlace(self._value, term)
+ self.add(term)
return self
def __str__(self):
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
index 36f54a22cf..48a8fa9328 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -845,7 +845,14 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
val jars = Option(System.getenv("ADD_JARS")).map(_.split(','))
.getOrElse(new Array[String](0))
.map(new java.io.File(_).getAbsolutePath)
- sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars)
+ try {
+ sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars)
+ } catch {
+ case e: Exception =>
+ e.printStackTrace()
+ echo("Failed to create SparkContext, exiting...")
+ sys.exit(1)
+ }
sparkContext
}
diff --git a/streaming/pom.xml b/streaming/pom.xml
index 14c043175d..339fcd2a39 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -84,12 +84,22 @@
<groupId>org.jboss.netty</groupId>
<artifactId>netty</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>org.xerial.snappy</groupId>
+ <artifactId>*</artifactId>
+ </exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.twitter4j</groupId>
<artifactId>twitter4j-stream</artifactId>
<version>3.0.3</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
@@ -99,6 +109,12 @@
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-zeromq</artifactId>
<version>2.0.3</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jboss.netty</groupId>
+ <artifactId>netty</artifactId>
+ </exclusion>
+ </exclusions>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index 2d8f072624..bb9febad38 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.Logging
import org.apache.spark.io.CompressionCodec
+import org.apache.spark.util.MetadataCleaner
private[streaming]
@@ -40,6 +41,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
val checkpointDir = ssc.checkpointDir
val checkpointDuration = ssc.checkpointDuration
val pendingTimes = ssc.scheduler.jobManager.getPendingTimes()
+ val delaySeconds = MetadataCleaner.getDelaySeconds
def validate() {
assert(master != null, "Checkpoint.master is null")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala
index aae79a4e6f..b97fb7e6e3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/NetworkInputTracker.scala
@@ -30,10 +30,11 @@ import akka.actor._
import akka.pattern.ask
import akka.util.duration._
import akka.dispatch._
+import org.apache.spark.storage.BlockId
private[streaming] sealed trait NetworkInputTrackerMessage
private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage
-private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[String], metadata: Any) extends NetworkInputTrackerMessage
+private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[BlockId], metadata: Any) extends NetworkInputTrackerMessage
private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage
/**
@@ -48,7 +49,7 @@ class NetworkInputTracker(
val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*)
val receiverExecutor = new ReceiverExecutor()
val receiverInfo = new HashMap[Int, ActorRef]
- val receivedBlockIds = new HashMap[Int, Queue[String]]
+ val receivedBlockIds = new HashMap[Int, Queue[BlockId]]
val timeout = 5000.milliseconds
var currentTime: Time = null
@@ -67,9 +68,9 @@ class NetworkInputTracker(
}
/** Return all the blocks received from a receiver. */
- def getBlockIds(receiverId: Int, time: Time): Array[String] = synchronized {
+ def getBlockIds(receiverId: Int, time: Time): Array[BlockId] = synchronized {
val queue = receivedBlockIds.synchronized {
- receivedBlockIds.getOrElse(receiverId, new Queue[String]())
+ receivedBlockIds.getOrElse(receiverId, new Queue[BlockId]())
}
val result = queue.synchronized {
queue.dequeueAll(x => true)
@@ -92,7 +93,7 @@ class NetworkInputTracker(
case AddBlocks(streamId, blockIds, metadata) => {
val tmp = receivedBlockIds.synchronized {
if (!receivedBlockIds.contains(streamId)) {
- receivedBlockIds += ((streamId, new Queue[String]))
+ receivedBlockIds += ((streamId, new Queue[BlockId]))
}
receivedBlockIds(streamId)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index dc60046805..ee265ab4e9 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -100,6 +100,10 @@ class StreamingContext private (
"both SparkContext and checkpoint as null")
}
+ if(cp_ != null && cp_.delaySeconds >= 0 && MetadataCleaner.getDelaySeconds < 0) {
+ MetadataCleaner.setDelaySeconds(cp_.delaySeconds)
+ }
+
if (MetadataCleaner.getDelaySeconds < 0) {
throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; "
+ "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
index 31f9891560..8d3ac0fc65 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
@@ -31,7 +31,7 @@ import org.apache.spark.streaming.util.{RecurringTimer, SystemClock}
import org.apache.spark.streaming._
import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.rdd.{RDD, BlockRDD}
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId}
/**
* Abstract class for defining any InputDStream that has to start a receiver on worker
@@ -69,7 +69,7 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming
val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime)
Some(new BlockRDD[T](ssc.sc, blockIds))
} else {
- Some(new BlockRDD[T](ssc.sc, Array[String]()))
+ Some(new BlockRDD[T](ssc.sc, Array[BlockId]()))
}
}
}
@@ -77,7 +77,7 @@ abstract class NetworkInputDStream[T: ClassManifest](@transient ssc_ : Streaming
private[streaming] sealed trait NetworkReceiverMessage
private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage
-private[streaming] case class ReportBlock(blockId: String, metadata: Any) extends NetworkReceiverMessage
+private[streaming] case class ReportBlock(blockId: BlockId, metadata: Any) extends NetworkReceiverMessage
private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage
/**
@@ -158,7 +158,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
/**
* Pushes a block (as an ArrayBuffer filled with data) into the block manager.
*/
- def pushBlock(blockId: String, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) {
+ def pushBlock(blockId: BlockId, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) {
env.blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], level)
actor ! ReportBlock(blockId, metadata)
}
@@ -166,7 +166,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
/**
* Pushes a block (as bytes) into the block manager.
*/
- def pushBlock(blockId: String, bytes: ByteBuffer, metadata: Any, level: StorageLevel) {
+ def pushBlock(blockId: BlockId, bytes: ByteBuffer, metadata: Any, level: StorageLevel) {
env.blockManager.putBytes(blockId, bytes, level)
actor ! ReportBlock(blockId, metadata)
}
@@ -209,7 +209,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
class BlockGenerator(storageLevel: StorageLevel)
extends Serializable with Logging {
- case class Block(id: String, buffer: ArrayBuffer[T], metadata: Any = null)
+ case class Block(id: BlockId, buffer: ArrayBuffer[T], metadata: Any = null)
val clock = new SystemClock()
val blockInterval = System.getProperty("spark.streaming.blockInterval", "200").toLong
@@ -241,7 +241,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
val newBlockBuffer = currentBuffer
currentBuffer = new ArrayBuffer[T]
if (newBlockBuffer.size > 0) {
- val blockId = "input-" + NetworkReceiver.this.streamId + "-" + (time - blockInterval)
+ val blockId = StreamBlockId(NetworkReceiver.this.streamId, time - blockInterval)
val newBlock = new Block(blockId, newBlockBuffer)
blocksForPushing.add(newBlock)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala
index c91f12ecd7..10ed4ef78d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala
@@ -18,7 +18,7 @@
package org.apache.spark.streaming.dstream
import org.apache.spark.Logging
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming.StreamingContext
import java.net.InetSocketAddress
@@ -71,7 +71,7 @@ class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel)
var nextBlockNumber = 0
while (true) {
val buffer = queue.take()
- val blockId = "input-" + streamId + "-" + nextBlockNumber
+ val blockId = StreamBlockId(streamId, nextBlockNumber)
nextBlockNumber += 1
pushBlock(blockId, buffer, null, storageLevel)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
index 4b5d8c467e..ef0f85a717 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
@@ -21,7 +21,7 @@ import akka.actor.{ Actor, PoisonPill, Props, SupervisorStrategy }
import akka.actor.{ actorRef2Scala, ActorRef }
import akka.actor.{ PossiblyHarmful, OneForOneStrategy }
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming.dstream.NetworkReceiver
import java.util.concurrent.atomic.AtomicInteger
@@ -159,7 +159,7 @@ private[streaming] class ActorReceiver[T: ClassManifest](
protected def pushBlock(iter: Iterator[T]) {
val buffer = new ArrayBuffer[T]
buffer ++= iter
- pushBlock("block-" + streamId + "-" + System.nanoTime(), buffer, null, storageLevel)
+ pushBlock(StreamBlockId(streamId, System.nanoTime()), buffer, null, storageLevel)
}
protected def onStart() = {
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
index 6d6ef149cc..25da9aa917 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala
@@ -22,7 +22,7 @@ import org.apache.spark.util.Utils
import org.apache.spark.scheduler.SplitInfo
import scala.collection
import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId, ContainerId, Priority, Resource, ResourceRequest, ContainerStatus, Container}
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend}
+import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend}
import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse}
import org.apache.hadoop.yarn.util.{RackResolver, Records}
import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap}
@@ -211,7 +211,7 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM
val workerId = workerIdCounter.incrementAndGet().toString
val driverUrl = "akka://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
- StandaloneSchedulerBackend.ACTOR_NAME)
+ CoarseGrainedSchedulerBackend.ACTOR_NAME)
logInfo("launching container on " + containerId + " host " + workerHostname)
// just to be safe, simply remove it from pendingReleaseContainers. Should not be there, but ..